Repository: crazyguitar/pysheeet Branch: master Commit: eca8a13c450c Files: 178 Total size: 1.3 MB Directory structure: gitextract_168c70l7/ ├── .clang-format ├── .claude-plugin/ │ ├── marketplace.json │ └── plugin.json ├── .coveragerc ├── .gitattributes ├── .github/ │ ├── FUNDING.yml │ ├── dependabot.yml │ └── workflows/ │ └── pythonpackage.yml ├── .gitignore ├── CITATION.cff ├── LICENSE ├── Makefile ├── Procfile ├── README.rst ├── app.py ├── app_test.py ├── docs/ │ ├── 404.rst │ ├── Makefile │ ├── _extra/ │ │ └── robots.txt │ ├── _static/ │ │ ├── .gitignore │ │ ├── carbonad.css │ │ └── style.css │ ├── _templates/ │ │ ├── carbonad.html │ │ ├── cheatsheets.html │ │ ├── layout.html │ │ ├── link.html │ │ └── sidebarintro.html │ ├── conf.py │ ├── index.rst │ └── notes/ │ ├── appendix/ │ │ ├── disaggregated-prefill-decode.rst │ │ ├── index.rst │ │ ├── megatron-efa-monitoring.rst │ │ ├── nccl-gin.rst │ │ ├── python-gdb.rst │ │ └── python-walrus.rst │ ├── asyncio/ │ │ ├── index.rst │ │ ├── python-asyncio-advanced.rst │ │ ├── python-asyncio-basic.rst │ │ ├── python-asyncio-guide.rst │ │ └── python-asyncio-server.rst │ ├── basic/ │ │ ├── index.rst │ │ ├── python-basic.rst │ │ ├── python-dict.rst │ │ ├── python-func.rst │ │ ├── python-future.rst │ │ ├── python-generator.rst │ │ ├── python-heap.rst │ │ ├── python-list.rst │ │ ├── python-object.rst │ │ ├── python-rexp.rst │ │ ├── python-set.rst │ │ ├── python-typing.rst │ │ └── python-unicode.rst │ ├── concurrency/ │ │ ├── index.rst │ │ ├── python-futures.rst │ │ ├── python-multiprocessing.rst │ │ └── python-threading.rst │ ├── database/ │ │ ├── index.rst │ │ ├── python-sqlalchemy-orm.rst │ │ ├── python-sqlalchemy-query.rst │ │ └── python-sqlalchemy.rst │ ├── extension/ │ │ ├── cpp-from-python.rst │ │ ├── index.rst │ │ ├── python-capi.rst │ │ ├── python-cext-modern.rst │ │ └── python-ctypes.rst │ ├── hpc/ │ │ ├── index.rst │ │ └── slurm.rst │ ├── llm/ │ │ ├── index.rst │ │ ├── llm-bench.rst │ │ ├── llm-serving.rst │ │ ├── megatron.rst │ │ └── pytorch.rst │ ├── network/ │ │ ├── index.rst │ │ ├── python-socket-async.rst │ │ ├── python-socket-server.rst │ │ ├── python-socket-sniffer.rst │ │ ├── python-socket-ssl.rst │ │ ├── python-socket.rst │ │ └── python-ssh.rst │ ├── os/ │ │ ├── index.rst │ │ ├── python-date.rst │ │ ├── python-io.rst │ │ └── python-os.rst │ ├── python-new-py3.rst │ └── security/ │ ├── index.rst │ ├── python-crypto.rst │ ├── python-tls.rst │ └── python-vulnerability.rst ├── requirements.txt ├── runtime.txt ├── skills/ │ └── py/ │ ├── SKILL.md │ └── references/ │ ├── guidelines.md │ └── structure.md └── src/ ├── basic/ │ ├── asyncio_.py │ ├── basic.py │ ├── cext_.py │ ├── concurrency_.py │ ├── crypto_.py │ ├── datetime_.py │ ├── dict.py │ ├── fileio_.py │ ├── func.py │ ├── future_.py │ ├── generator.py │ ├── heap.py │ ├── list.py │ ├── object.py │ ├── os_.py │ ├── rexp.py │ ├── set.py │ ├── socket_.py │ ├── sqlalchemy_core.py │ ├── sqlalchemy_orm.py │ ├── sqlalchemy_query.py │ ├── typing_.py │ └── unicode_.py ├── cext/ │ ├── CMakeLists.txt │ ├── README.md │ ├── capi/ │ │ ├── args.c │ │ ├── errors.c │ │ ├── gil.c │ │ ├── setup.py │ │ ├── simple.c │ │ ├── test_capi.py │ │ └── types_demo.c │ ├── conftest.py │ ├── example.cpp │ ├── fib.c │ ├── gil_example.cpp │ ├── numpy_example.cpp │ ├── setup.py │ ├── test_cext.py │ └── vector.cpp ├── cpp_from_python/ │ ├── CMakeLists.txt │ └── cpp_from_py.cpp ├── gin/ │ ├── Dockerfile │ ├── Makefile │ ├── run.enroot │ └── run.sbatch ├── llm/ │ ├── sglang/ │ │ ├── Dockerfile │ │ ├── Makefile │ │ ├── README.rst │ │ ├── bench.sh │ │ ├── run.sbatch │ │ └── test.sh │ ├── tensorrt-llm/ │ │ ├── Dockerfile │ │ ├── Makefile │ │ ├── README.rst │ │ ├── bench.sh │ │ ├── run.sbatch │ │ └── test.sh │ └── vllm/ │ ├── Dockerfile │ ├── Makefile │ ├── README.rst │ ├── bench.sh │ ├── offline_bench.py │ ├── offline_bench.sh │ ├── run.sbatch │ ├── run.sh │ ├── sweep.sbatch │ ├── sweep.sh │ └── test.sh ├── megatron/ │ ├── Dockerfile │ ├── Makefile │ ├── README.md │ ├── enroot.sh │ ├── entrypoint.py │ ├── recipes/ │ │ └── deepseek_v2_lite_pretrain.py │ ├── srun.sh │ └── viztracer_plugin.py ├── new_py3/ │ └── py3.py ├── nixl/ │ ├── Dockerfile │ ├── Makefile │ ├── bench.sh │ ├── nixl.sbatch │ └── vllm.sbatch └── security/ └── vulnerability_.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .clang-format ================================================ BasedOnStyle: Google AlignAfterOpenBracket: BlockIndent AllowAllParametersOfDeclarationOnNextLine: false BinPackParameters: false ColumnLimit: 150 ================================================ FILE: .claude-plugin/marketplace.json ================================================ { "name": "pysheeet", "owner": { "name": "crazyguitar" }, "plugins": [ { "name": "pysheeet", "source": { "source": "github", "repo": "crazyguitar/pysheeet" }, "description": "Comprehensive Python programming reference covering syntax, concurrency, networking, databases, ML/LLM development, and HPC", "version": "1.0.0", "author": { "name": "crazyguitar" } } ] } ================================================ FILE: .claude-plugin/plugin.json ================================================ { "name": "pysheeet", "description": "Comprehensive Python programming reference covering syntax, concurrency, networking, databases, ML/LLM development, and HPC", "version": "1.0.0", "author": { "name": "crazyguitar" }, "homepage": "https://www.pythonsheets.com", "repository": "https://github.com/crazyguitar/pysheeet", "license": "MIT" } ================================================ FILE: .coveragerc ================================================ [report] omit = */python?.?/* */site-packages/* app_test.py exclude_lines = if __name__ == .__main__.: if .DYNO. in os.environ: ================================================ FILE: .gitattributes ================================================ *.png filter=lfs diff=lfs merge=lfs -text docs/_static/appendix/*.png !filter !diff !merge docs/_static/appendix/nixl/*.png !filter !diff !merge ================================================ FILE: .github/FUNDING.yml ================================================ github: crazyguitar ================================================ FILE: .github/dependabot.yml ================================================ version: 2 updates: - package-ecosystem: pip directory: "/" schedule: interval: daily time: "21:00" open-pull-requests-limit: 10 ignore: - dependency-name: sphinx versions: - 3.5.3 ================================================ FILE: .github/workflows/pythonpackage.yml ================================================ name: Build on: [push, pull_request] jobs: build: runs-on: ubuntu-latest strategy: max-parallel: 4 matrix: python: ["3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} - name: Install clang-format and cmake run: sudo apt-get install -y clang-format cmake - name: Install run: make deps - name: Test run: make clean && make test ================================================ FILE: .gitignore ================================================ _build/ # Created by https://www.gitignore.io/api/vim,python ### Python ### # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover .hypothesis/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # pyenv .python-version # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ ### Vim ### # swap [._]*.s[a-v][a-z] [._]*.sw[a-p] [._]s[a-v][a-z] [._]sw[a-p] # session Session.vim # temporary .netrwhist *~ # auto-generated tag files tags # End of https://www.gitignore.io/api/vim,python ================================================ FILE: CITATION.cff ================================================ cff-version: 1.2.0 message: "If you use this software, please cite it as below." title: "Python Cheatsheet" authors: - family-names: "Tsai" given-names: "Chang-Ning" orcid: "https://orcid.org/0009-0000-5297-5940" github: "crazyguitar" abstract: "A comprehensive Python cheat sheet covering Python 2 and 3 syntax, tips, and code snippets." version: "master" repository-code: "https://github.com/crazyguitar/pysheeet" license: "MIT" date-released: "2016-02-29" url: "https://www.pythonsheets.com" keywords: - python - cheatsheet - python-2 - python-3 ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2015-2026 Chang Ning Tsai Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Makefile ================================================ REQUIREMENT = requirements.txt VER = $(word 2, $(shell python --version 2>&1)) SRC = app.py app_test.py PY36 = $(shell expr $(VER) \>= 3.6) CEXT_DIR = src/cext CEXT_BUILD = $(CEXT_DIR)/build CAPI_DIR = src/cext/capi CPP_FROM_PY_DIR = src/cpp_from_python CPP_FROM_PY_BUILD = $(CPP_FROM_PY_DIR)/build PY_ARCH = $(shell python -c "import platform; print(platform.machine())") .PHONY: build deps test format cext build: html %: cd docs && make $@ clean: cd docs && make clean rm -rf $(CEXT_BUILD) rm -rf $(CPP_FROM_PY_BUILD) rm -rf $(CAPI_DIR)/build $(CAPI_DIR)/*.so $(CAPI_DIR)/*.egg-info cext: @echo "Building C/C++ extensions for $(PY_ARCH)..." mkdir -p $(CEXT_BUILD) && \ cd $(CEXT_BUILD) && \ cmake -DCMAKE_OSX_ARCHITECTURES=$(PY_ARCH) .. && \ make cd $(CAPI_DIR) && python setup.py build_ext --inplace cpp_from_python: @echo "Building C++ from Python examples..." mkdir -p $(CPP_FROM_PY_BUILD) && \ cd $(CPP_FROM_PY_BUILD) && \ cmake .. && \ make test: clean build cext cpp_from_python pycodestyle $(SRC) pydocstyle $(SRC) bandit app.py coverage run app_test.py && coverage report --fail-under=100 -m $(SRC) python -m pytest src/basic/*.py src/new_py3/*.py -v python -m pytest $(CEXT_DIR)/test_cext.py -v python -m pytest $(CAPI_DIR)/test_capi.py -v cd $(CPP_FROM_PY_BUILD) && make test ifeq ($(PY36), 1) black --quiet --diff --check --line-length 79 $(SRC) endif deps: pip install -r requirements.txt pip install pybind11 ifeq ($(PY36), 1) pip install black==22.3.0 endif format: black --line-length 79 $(SRC) src/ find src/cext -type f \( -name "*.cpp" -o -name "*.c" -o -name "*.h" \) | xargs -I{} clang-format -style=file -i {} ================================================ FILE: Procfile ================================================ web: make clean && make && gunicorn app:app --log-file - ================================================ FILE: README.rst ================================================ .. raw:: html


pysheeet

Build Status Coverage License MIT DOI

Introduction ============= This project was started to bring together useful Python code snippets that make coding faster, easier, and more enjoyable. You can explore all the cheat sheets at `Pysheeet `_. Contributions are always welcome—feel free to fork the repo and submit a pull request to help it grow! Plugin ====== **pysheeet** is available as a Claude Code plugin. Once installed, Claude automatically uses the cheat sheets to answer Python questions — just ask naturally and the skill triggers based on context. Installation ------------ **As a Claude Code plugin (recommended):** .. code-block:: bash # Step 1: Add the marketplace claude plugin marketplace add crazyguitar/pysheeet # Step 2: Install the plugin claude plugin install pysheeet@pysheeet **Local testing (single session only):** .. code-block:: bash claude --plugin-dir /path/to/pysheeet **Manual installation (requires cloning the repo):** .. code-block:: bash git clone https://github.com/crazyguitar/pysheeet.git mkdir -p ~/.claude/skills cp -r pysheeet/skills/py ~/.claude/skills/py What's New In Python 3 ====================== This part only provides a quick glance at some important features in Python 3. If you're interested in all of the most important features, please read the official document, `What’s New in Python `_. - `New in Python3 `_ Cheat Sheet =========== Core Python fundamentals including data types, functions, classes, and commonly used patterns for everyday programming tasks. - `From Scratch `_ - `Future `_ - `Typing `_ - `Class `_ - `Function `_ - `Unicode `_ - `List `_ - `Set `_ - `Dictionary `_ - `Heap `_ - `Generator `_ - `Regular expression `_ System ====== Date/time handling, file I/O, and operating system interfaces. - `Datetime `_ - Timestamps, formatting, parsing, timezones, timedelta - `Files and I/O `_ - Reading, writing, pathlib, shutil, tempfile - `Operating System `_ - Processes, environment, system calls Concurrency =========== Threading, multiprocessing, and concurrent.futures for parallel execution. Covers synchronization primitives, process pools, and bypassing the GIL. - `Threading `_ - Threads, locks, semaphores, events, conditions - `Multiprocessing `_ - Processes, pools, shared memory, IPC - `concurrent.futures `_ - Executors, futures, callbacks Asyncio ======= Asynchronous programming with Python's ``asyncio`` module. Covers coroutines, event loops, tasks, networking, and advanced patterns. - `A Hitchhiker's Guide to Asynchronous Programming `_ - Design philosophy and evolution - `Asyncio Basics `_ - Coroutines, tasks, gather, timeouts - `Asyncio Networking `_ - TCP/UDP servers, HTTP, SSL/TLS - `Asyncio Advanced `_ - Synchronization, queues, subprocesses C/C++ Extensions ================ Native extensions for performance-critical code. Covers modern pybind11 (used by PyTorch, TensorFlow), ctypes, cffi, Cython, and the traditional Python C API. Also includes a guide for Python developers learning modern C++ syntax. - `ctypes `_ - Load shared libraries without compilation - `Python C API `_ - Traditional C extension reference - `Modern C/C++ Extensions `_ - pybind11, Cython - `Learn C++ from Python `_ - Modern C++ for Python developers Security ======== Modern cryptographic practices and common security vulnerabilities. Covers encryption, TLS/SSL, and why legacy patterns are dangerous. - `Modern Cryptography `_ - AES-GCM, RSA-OAEP, Ed25519, Argon2 - `TLS/SSL and Certificates `_ - HTTPS servers, certificate generation - `Common Vulnerabilities `_ - Padding oracle, injection, timing attacks Network ======= Low-level network programming with Python sockets. Covers TCP/UDP communication, server implementations, asynchronous I/O, SSL/TLS encryption, and packet analysis. - `Socket Basics `_ - `Socket Servers `_ - `Async Socket I/O `_ - `SSL/TLS Sockets `_ - `Packet Sniffing `_ - `SSH and Tunnels `_ Database ======== Database access with SQLAlchemy, Python's most popular ORM. Covers connection management, raw SQL, object-relational mapping, and common query patterns. - `SQLAlchemy Basics `_ - `SQLAlchemy ORM `_ - `SQLAlchemy Query Recipes `_ LLM === Large Language Models (LLM) training, inference, and optimization. Covers PyTorch for model development, distributed training across GPUs, and vLLM/SGLang for high-performance LLM inference and serving. - `PyTorch `_ - Tensors, autograd, neural networks, training loops - `Megatron `_ - NVIDIA Megatron training/fine-tuning framework with enroot/pyxis - `LLM Serving `_ - vLLM and SGLang for production inference with TP/PP/DP/EP - `LLM Benchmark `_ - Benchmark suite for measuring serving performance HPC === High-Performance Computing tools for cluster management and job scheduling. Covers Slurm workload manager for distributed computing and GPU clusters. - `Slurm `_ Blog ==== Supplementary topics covering Python internals, debugging techniques, and language features that don't fit elsewhere. - `Is Disaggregated Prefill/Decode a Silver Bullet for LLM Serving? `_ - `Monitoring EFA with NCCL GIN and Nsys `_ - `GPU-Initiated Networking for NCCL on AWS `_ - `PEP 572 and the walrus operator `_ - `Python Interpreter in GNU Debugger `_ PDF Version ============ `pdf`_ .. _pdf: https://media.readthedocs.org/pdf/pysheeet/latest/pysheeet.pdf How to run the server ======================= .. code-block:: bash $ virtualenv venv $ . venv/bin/activate $ pip install -r requirements.txt $ make $ python app.py # URL: localhost:5000 ================================================ FILE: app.py ================================================ # -*- coding: utf-8 -*- """This is a simple cheatsheet webapp.""" import os from flask import Flask, abort, send_from_directory, render_template from flask_sslify import SSLify from flask_seasurf import SeaSurf from flask_talisman import Talisman from werkzeug.exceptions import NotFound from werkzeug.utils import safe_join DIR = os.path.dirname(os.path.realpath(__file__)) ROOT = os.path.join(DIR, "docs", "_build", "html") def find_key(token): """Find the key from the environment variable.""" if token == os.environ.get("ACME_TOKEN"): return os.environ.get("ACME_KEY") for k, v in os.environ.items(): if v == token and k.startswith("ACME_TOKEN_"): n = k.replace("ACME_TOKEN_", "") return os.environ.get("ACME_KEY_{}".format(n)) csp = { "default-src": "'none'", "style-src": ["'self'", "'unsafe-inline'"], "script-src": [ "'self'", "*.cloudflare.com", "*.cloudflareinsights.com", "*.googletagmanager.com", "*.google-analytics.com", "*.carbonads.com", "*.carbonads.net", "cdn.carbonads.com", "srv.carbonads.net", "'unsafe-inline'", "'unsafe-eval'", ], "connect-src": [ "'self'", "*.google-analytics.com", "*.analytics.google.com", "analytics.google.com", "*.googletagmanager.com", "*.carbonads.com", "*.carbonads.net", "*.doubleclick.net", ], "font-src": "'self'", "form-action": "'self'", "base-uri": "'self'", "img-src": "*", "frame-src": ["ghbtns.com", "*.carbonads.com", "*.carbonads.net"], "frame-ancestors": "'none'", "object-src": "'none'", } feature_policy = {"geolocation": "'none'"} app = Flask(__name__, template_folder=ROOT) app.config["SECRET_KEY"] = os.urandom(16) app.config["SESSION_COOKIE_NAME"] = "__Secure-session" app.config["SESSION_COOKIE_SAMESITE"] = "Strict" app.config["CSRF_COOKIE_NAME"] = "__Secure-csrf-token" app.config["CSRF_COOKIE_HTTPONLY"] = True app.config["CSRF_COOKIE_SECURE"] = True csrf = SeaSurf(app) talisman = Talisman( app, force_https=False, content_security_policy=csp, feature_policy=feature_policy, ) if "DYNO" in os.environ: sslify = SSLify(app, permanent=True, skips=[".well-known"]) @app.errorhandler(404) def page_not_found(e): """Redirect to 404.html.""" return render_template("404.html"), 404 @app.route("/") def static_proxy(path): """Find static files safely.""" try: return send_from_directory(ROOT, path) except NotFound: # Handle file not found or directory errors return render_template("404.html"), 404 @app.route("/") def index_redirection(): """Redirecting index file.""" return send_from_directory(ROOT, "index.html") @csrf.exempt @app.route("/.well-known/acme-challenge/") def acme(token): """Find the acme-key from environment variable.""" key = find_key(token) if key is None: abort(404) return key if __name__ == "__main__": # Only run the app in debug mode during development app.run(debug=os.environ.get("FLASK_ENV") == "development") ================================================ FILE: app_test.py ================================================ """Test app.py.""" import multiprocessing import platform import unittest import requests import os from pathlib import Path from werkzeug.exceptions import NotFound from flask_testing import LiveServerTestCase from app import acme, find_key, static_proxy, index_redirection, page_not_found from app import ROOT from app import app if platform.system() == "Darwin": multiprocessing.set_start_method("fork") class PysheeetTest(LiveServerTestCase): """Test app.""" def create_app(self): """Create a app for test.""" # remove env ACME_TOKEN* for k, v in os.environ.items(): if not k.startswith("ACME_TOKEN"): continue del os.environ[k] self.token = "token" self.key = "key" os.environ["ACME_TOKEN"] = self.token os.environ["ACME_KEY"] = self.key os.environ["FLASK_ENV"] = "development" os.environ["FLASK_DEBUG"] = "1" app.config["TESTING"] = True app.config["LIVESERVER_PORT"] = 0 return app def check_security_headers(self, resp): """Check security headers.""" headers = resp.headers self.assertTrue("Content-Security-Policy" in headers) self.assertTrue("X-Content-Type-Options" in headers) self.assertTrue("Content-Security-Policy" in headers) self.assertTrue("Feature-Policy" in headers) self.assertEqual(headers["Feature-Policy"], "geolocation 'none'") self.assertEqual(headers["X-Frame-Options"], "SAMEORIGIN") def check_csrf_cookies(self, resp): """Check cookies for csrf.""" cookies = resp.cookies self.assertTrue(cookies.get("__Secure-session")) self.assertTrue(cookies.get("__Secure-csrf-token")) def test_index_redirection_req(self): """Test that send a request for the index page.""" url = self.get_server_url() resp = requests.get(url) self.check_security_headers(resp) self.check_csrf_cookies(resp) self.assertEqual(resp.status_code, 200) def test_static_proxy_req(self): """Test that send a request for notes.""" url = self.get_server_url() notes = Path(ROOT) / "notes" for html in notes.rglob("*.html"): page = html.relative_to(ROOT) u = f"{url}/{page}" resp = requests.get(u) self.check_security_headers(resp) self.check_csrf_cookies(resp) self.assertEqual(resp.status_code, 200) def test_acme_req(self): """Test that send a request for a acme key.""" url = self.get_server_url() u = url + "/.well-known/acme-challenge/token" resp = requests.get(u) self.check_security_headers(resp) self.assertEqual(resp.status_code, 200) u = url + "/.well-known/acme-challenge/foo" resp = requests.get(u) self.check_security_headers(resp) self.assertEqual(resp.status_code, 404) def test_find_key(self): """Test that find a acme key from the environment.""" token = self.token key = self.key self.assertEqual(find_key(token), key) del os.environ["ACME_TOKEN"] del os.environ["ACME_KEY"] os.environ["ACME_TOKEN_ENV"] = token os.environ["ACME_KEY_ENV"] = key self.assertEqual(find_key(token), key) del os.environ["ACME_TOKEN_ENV"] del os.environ["ACME_KEY_ENV"] def test_acme(self): """Test that send a request for a acme key.""" token = self.token key = self.key self.assertEqual(acme(token), key) token = token + "_env" key = key + "_env" os.environ["ACME_TOKEN_ENV"] = token os.environ["ACME_KEY_ENV"] = key self.assertEqual(find_key(token), key) del os.environ["ACME_TOKEN_ENV"] del os.environ["ACME_KEY_ENV"] self.assertRaises(NotFound, acme, token) def test_index_redirection(self): """Test index page redirection.""" resp = index_redirection() self.assertEqual(resp.status_code, 200) resp.close() def test_static_proxy(self): """Test that request static pages.""" notes = Path(ROOT) / "notes" for html in notes.rglob("*.html"): u = html.relative_to(ROOT) resp = static_proxy(u) self.assertEqual(resp.status_code, 200) resp.close() u = "notes/../conf.py" _, code = static_proxy(u) self.assertEqual(code, 404) def test_page_not_found(self): """Test page not found.""" html, status_code = page_not_found(None) self.assertEqual(status_code, 404) if __name__ == "__main__": unittest.main() ================================================ FILE: docs/404.rst ================================================ :orphan: 404 Page Not Found ================== What you were looking for is just not there. `Click here to go back to homepage. `_ ================================================ FILE: docs/Makefile ================================================ # Makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build PAPER = BUILDDIR = _build # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) endif # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . # the i18n builder cannot share the environment and doctrees with the others I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . .PHONY: help help: @echo "Please use \`make ' where is one of" @echo " html to make standalone HTML files" @echo " dirhtml to make HTML files named index.html in directories" @echo " singlehtml to make a single large HTML file" @echo " pickle to make pickle files" @echo " json to make JSON files" @echo " htmlhelp to make HTML files and a HTML help project" @echo " qthelp to make HTML files and a qthelp project" @echo " applehelp to make an Apple Help Book" @echo " devhelp to make HTML files and a Devhelp project" @echo " epub to make an epub" @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" @echo " latexpdf to make LaTeX files and run them through pdflatex" @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" @echo " text to make text files" @echo " man to make manual pages" @echo " texinfo to make Texinfo files" @echo " info to make Texinfo files and run them through makeinfo" @echo " gettext to make PO message catalogs" @echo " changes to make an overview of all changed/added/deprecated items" @echo " xml to make Docutils-native XML files" @echo " pseudoxml to make pseudoxml-XML files for display purposes" @echo " linkcheck to check all external links for integrity" @echo " doctest to run all doctests embedded in the documentation (if enabled)" @echo " coverage to run coverage check of the documentation (if enabled)" .PHONY: clean clean: rm -rf $(BUILDDIR)/* .PHONY: html html: $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." .PHONY: dirhtml dirhtml: $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." .PHONY: singlehtml singlehtml: $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml @echo @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." .PHONY: pickle pickle: $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle @echo @echo "Build finished; now you can process the pickle files." .PHONY: json json: $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json @echo @echo "Build finished; now you can process the JSON files." .PHONY: htmlhelp htmlhelp: $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp @echo @echo "Build finished; now you can run HTML Help Workshop with the" \ ".hhp project file in $(BUILDDIR)/htmlhelp." .PHONY: qthelp qthelp: $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp @echo @echo "Build finished; now you can run "qcollectiongenerator" with the" \ ".qhcp project file in $(BUILDDIR)/qthelp, like this:" @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/pysheeet.qhcp" @echo "To view the help file:" @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/pysheeet.qhc" .PHONY: applehelp applehelp: $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp @echo @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." @echo "N.B. You won't be able to view it unless you put it in" \ "~/Library/Documentation/Help or install it in your application" \ "bundle." .PHONY: devhelp devhelp: $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp @echo @echo "Build finished." @echo "To view the help file:" @echo "# mkdir -p $$HOME/.local/share/devhelp/pysheeet" @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/pysheeet" @echo "# devhelp" .PHONY: epub epub: $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub @echo @echo "Build finished. The epub file is in $(BUILDDIR)/epub." .PHONY: latex latex: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." @echo "Run \`make' in that directory to run these through (pdf)latex" \ "(use \`make latexpdf' here to do that automatically)." .PHONY: latexpdf latexpdf: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through pdflatex..." $(MAKE) -C $(BUILDDIR)/latex all-pdf @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." .PHONY: latexpdfja latexpdfja: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through platex and dvipdfmx..." $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." .PHONY: text text: $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text @echo @echo "Build finished. The text files are in $(BUILDDIR)/text." .PHONY: man man: $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man @echo @echo "Build finished. The manual pages are in $(BUILDDIR)/man." .PHONY: texinfo texinfo: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." @echo "Run \`make' in that directory to run these through makeinfo" \ "(use \`make info' here to do that automatically)." .PHONY: info info: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo "Running Texinfo files through makeinfo..." make -C $(BUILDDIR)/texinfo info @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." .PHONY: gettext gettext: $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale @echo @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." .PHONY: changes changes: $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes @echo @echo "The overview file is in $(BUILDDIR)/changes." .PHONY: linkcheck linkcheck: $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck @echo @echo "Link check complete; look for any errors in the above output " \ "or in $(BUILDDIR)/linkcheck/output.txt." .PHONY: doctest doctest: $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest @echo "Testing of doctests in the sources finished, look at the " \ "results in $(BUILDDIR)/doctest/output.txt." .PHONY: coverage coverage: $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage @echo "Testing of coverage in the sources finished, look at the " \ "results in $(BUILDDIR)/coverage/python.txt." .PHONY: xml xml: $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml @echo @echo "Build finished. The XML files are in $(BUILDDIR)/xml." .PHONY: pseudoxml pseudoxml: $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml @echo @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." ================================================ FILE: docs/_extra/robots.txt ================================================ User-agent: * Allow: / Sitemap: https://www.pythonsheets.com/sitemap.xml ================================================ FILE: docs/_static/.gitignore ================================================ # Ignore everything in this directory * # Except this file !.gitignore !guido.png !logo.svg !style.css !carbonad.css !favicon.ico ================================================ FILE: docs/_static/carbonad.css ================================================ #carbonads { display: block; overflow: hidden; padding: 1em; padding-bottom: 0.3em; line-height: 1.5; margin-top: 10px; margin-bottom: 10px; } #carbonads a { text-decoration: none !important; border-bottom: none; } #carbonads a:hover { color: inherit; } #carbonads span { display: block; overflow: hidden; } .carbon-img { display: block; margin: 0 auto 8px; } .carbon-text { display: block; text-align: left; margin-bottom: .1em; color: #666; } .carbon-poweredby { display: block; text-align: left; font-size: .9em; color: #888 !important; } @media only screen and (min-width: 320px) and (max-width: 875px) { #carbonads { float: none; max-width: 330px; border: 0; display: block; overflow: hidden; margin-top: 20px; margin-bottom: 20px; border-radius: 4px; text-align: center; box-shadow: 0 0 0 1px hsla(0, 0%, 0%, .1); font-size: var(--font-size); background-color: #eee; line-height: 1.5; } #carbonads span { position: relative; } #carbonads > span { max-width: none; } .carbon-img { float: left; margin: 0; } .carbon-img img { max-width: 130px !important; } .carbon-text { float: left; margin-bottom: 0; padding: 8px 20px; text-align: left; color: #333 !important; max-width: calc(100% - 130px - 3em); } .carbon-poweredby { left: 130px; bottom: 0; display: block; color: #555 !important; text-align: right; width: 100%; } } ================================================ FILE: docs/_static/style.css ================================================ nav#table-of-contents { display: none; } div.highlight > pre { font-size: 14px; border-radius: 3px; background: #f6f8fa !important; border: 1px solid #000000 !important; } :root { --cu-boulder-gold: #CFB87C; } .bd-container { max-width: 99%; } .bd-container .bd-container__inner { max-width: 99%; } .bd-main .bd-content .bd-article-container { max-width: 100em; } .code-block-caption { color: black; } .bd-sidebar-primary li.has-children>details>summary .toctree-toggle { justify-content: left; } html[data-theme=light] { --pst-font-size-base: none; --pst-color-secondary: #176de8; --pst-color-primary: #176de8; } .graph#doc-flowchart .node text { font-weight: bold; } .bd-content .sd-tab-set .sd-tab-content { padding: 1.5rem; } a { text-decoration: none; } a:hover { text-decoration: underline; } button.theme-switch-button { display: none !important; } blockquote { background-color: transparent; border: none; } ================================================ FILE: docs/_templates/carbonad.html ================================================ ================================================ FILE: docs/_templates/cheatsheets.html ================================================

Cheat Sheets

================================================ FILE: docs/_templates/layout.html ================================================ {% extends "!layout.html" %} {%- block extrahead %} {%- if pagename == 'index' %} {%- elif pagename == '404' -%} {%- else %} {%- endif %} {%- if tracking_id %} {% endif -%} {%- if pagename == '404' -%} {%- endif -%} {% endblock %} ================================================ FILE: docs/_templates/link.html ================================================

Useful Links

================================================ FILE: docs/_templates/sidebarintro.html ================================================

This project tries to provide many snippets of Python code that make life easier.

================================================ FILE: docs/conf.py ================================================ # -*- coding: utf-8 -*- # # python-cheatsheet documentation build configuration file, created by # sphinx-quickstart on Sun Feb 28 09:26:04 2016. # # This file is execfile()d with the current directory set to its # containing dir. # # Note that not all possible configuration values are present in this # autogenerated file. # # All configuration values have a default; values that are commented out # serve to show the default. from datetime import datetime import os # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # sys.path.insert(0, os.path.abspath('.')) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. #needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.todo', 'sphinx.ext.coverage', 'sphinx.ext.viewcode', 'myst_parser', 'sphinx_copybutton', 'sphinx.ext.graphviz', 'sphinx_design', 'sphinx.ext.extlinks' ] myst_enable_extensions = [ "colon_fence", "attrs_inline", "attrs_block", "tasklist", "substitution", ] myst_enable_checkboxes = True myst_heading_anchors = 6 copybutton_prompt_text = r'^\$ ' copybutton_prompt_is_regexp = True # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # source_suffix = ['.rst', '.md'] source_suffix = '.rst' # The encoding of source files. #source_encoding = 'utf-8-sig' # The master toctree document. master_doc = 'index' # General information about the project. year = datetime.now().year project = u'pysheeet' copyright = u'2016-{}, crazyguitar'.format(year) author = u'crazyguitar' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. version = u'0.1.0' # The full version, including alpha/beta/rc tags. release = u'0.1.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. language = 'en' # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: #today = '' # Else, today_fmt is used as the format for a strftime call. #today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. exclude_patterns = [] # The reST default role (used for this markup: `text`) to use for all # documents. #default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. #add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). #add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. #show_authors = False # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' # A list of ignored prefixes for module index sorting. #modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. #keep_warnings = False # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True # -- Options for HTML output ---------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. html_theme = 'sphinx_book_theme' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. html_theme_options = { "repository_url": "https://github.com/crazyguitar/pysheeet", "use_repository_button": True, } # Custom sidebar templates # Add any paths that contain custom themes here, relative to this directory. #html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". html_title = "Python Cheat Sheet" # A shorter title for the navigation bar. Default is the same as html_title. #html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. html_logo = "_static/logo.svg" # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. html_favicon = '_static/favicon.ico' # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] html_css_files = ['style.css'] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. html_extra_path = ['_extra'] html_context = { "tracking_id": os.environ.get("TRACKING_ID"), } has_carbonad = os.environ.get("CARBONAD_SERVE") and os.environ.get("CARBONAD_PLACEMENT") if has_carbonad: html_context["carbonad_serve"] = os.environ.get("CARBONAD_SERVE") html_context["carbonad_placement"] = os.environ.get("CARBONAD_PLACEMENT") # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. #html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. #html_use_smartypants = True # Additional templates that should be rendered to pages, maps page names to # template names. #html_additional_pages = {} # If false, no module index is generated. #html_domain_indices = True # If false, no index is generated. #html_use_index = True # If true, the index is split into individual pages for each letter. #html_split_index = False # If true, links to the reST sources are added to the pages. #html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. #html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. #html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. #html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). #html_file_suffix = None # Language to be used for generating the HTML full-text search index. # Sphinx supports the following languages: # 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' # 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' #html_search_language = 'en' # A dictionary with options for the search language support, empty by default. # Now only 'ja' uses this config value #html_search_options = {'type': 'default'} # The name of a javascript file (relative to the configuration directory) that # implements a search results scorer. If empty, the default will be used. #html_search_scorer = 'scorer.js' # Output file base name for HTML help builder. htmlhelp_basename = 'python-cheatsheetdoc' # -- Options for LaTeX output --------------------------------------------- latex_elements = { # The paper size ('letterpaper' or 'a4paper'). #'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). #'pointsize': '10pt', # Additional stuff for the LaTeX preamble. #'preamble': '', # Latex figure (float) alignment #'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ (master_doc, 'python-cheatsheet.tex', u'python-cheatsheet Documentation', u'crazyguitar', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of # the title page. #latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. #latex_use_parts = False # If true, show page references after internal links. #latex_show_pagerefs = False # If true, show URL addresses after external links. #latex_show_urls = False # Documents to append as an appendix to all manuals. #latex_appendices = [] # If false, no module index is generated. #latex_domain_indices = True # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ (master_doc, 'python-cheatsheet', u'python-cheatsheet Documentation', [author], 1) ] # If true, show URL addresses after external links. #man_show_urls = False # -- Options for Texinfo output ------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ (master_doc, 'python-cheatsheet', u'python-cheatsheet Documentation', author, 'python-cheatsheet', 'One line description of project.', 'Miscellaneous'), ] html_sidebars = { "**": [ "navbar-logo.html", "search-button-field.html", "sbt-sidebar-nav.html", "carbonad.html", ] } # Documents to append as an appendix to all manuals. #texinfo_appendices = [] # If false, no module index is generated. #texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. #texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. #texinfo_no_detailmenu = False def add_html_link(app, pagename, templatename, context, doctree): """Append html page.""" if pagename in ['404', 'search', 'genindex']: return app.sitemaps.append({ 'pagename': pagename + ".html", 'priority': '1.0' if pagename == 'index' else '0.8', 'changefreq': 'weekly' if pagename == 'index' else 'monthly' }) def create_sitemap(app, exception): """Generate a sitemap.xml""" from xml.etree.ElementTree import ElementTree, Element, SubElement from datetime import datetime r = Element("urlset") r.set("xmlns", "http://www.sitemaps.org/schemas/sitemap/0.9") r.set("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance") r.set("xsi:schemaLocation", "http://www.sitemaps.org/schemas/sitemap/0.9" + " http://www.sitemaps.org/schemas/sitemap/0.9/sitemap.xsd") for link_info in app.sitemaps: url = SubElement(r, "url") now = datetime.now() SubElement(url, "loc").text = app.pysheeet + link_info['pagename'] SubElement(url, "lastmod").text = now.date().isoformat() SubElement(url, "changefreq").text = link_info['changefreq'] SubElement(url, "priority").text = link_info['priority'] f = app.outdir + "/sitemap.xml" t = ElementTree(r) t.write(f, xml_declaration=True, encoding='utf-8', method="xml") def setup(app): """Customize setup.""" site = os.environ.get("PYSHEEET") if not site: return if site[-1] != '/': site += '/' # create a sitemap app.pysheeet = site app.sitemaps = [] app.connect('html-page-context', add_html_link) app.connect('build-finished', create_sitemap) ================================================ FILE: docs/index.rst ================================================ .. python-cheatsheet documentation master file, created by sphinx-quickstart on Sun Feb 28 09:26:04 2016. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. .. meta:: :description lang=en: Comprehensive Python cheat sheet with practical code snippets, examples, and tutorials for Python developers. Learn Python basics, advanced topics, databases, networking, and more. :keywords: Python, Python Cheat Sheet, Python Tutorial, Python Examples, Python Code Snippets, Programming, Development, Python Reference, Python Guide Python Cheat Sheet - Complete Guide with Code Examples ====================================================== Welcome to **pysheeet** - your ultimate Python cheat sheet! This comprehensive resource contains practical Python code snippets, examples, and tutorials to make coding easier and more efficient for Python developers of all levels. From basic Python syntax to advanced topics like databases, networking, and multitasking, this cheat sheet serves as your complete Python reference guide. Ideal for beginners learning Python fundamentals and experienced developers seeking quick code examples. Whether you're learning Python for web development, data science, automation, or general programming, you'll find practical examples that save you time and improve your coding efficiency. Contributions are always welcome—feel free to share ideas for new snippets, improvements, or clearer explanations! If you'd like to contribute, `fork pysheeet on GitHub`_. If there is any question or suggestion, please create an issue on `GitHub Issues`_. .. _fork pysheeet on GitHub: https://github.com/crazyguitar/pysheeet .. _GitHub Issues: https://github.com/crazyguitar/pysheeet/issues Plugin ------ **pysheeet** is available as a `Claude Code `_ plugin. Once installed, Claude automatically uses the cheat sheets to answer Python questions. .. code-block:: bash # Step 1: Add the marketplace claude plugin marketplace add crazyguitar/pysheeet # Step 2: Install the plugin claude plugin install pysheeet@pysheeet For local testing and manual installation, see the main `README `_. What's New In Python 3 ---------------------- The official document, `What's New In Python`_, displays all of the most important changes. However, if you're too busy to read the whole changes, this part provides a brief glance of new features in Python 3. .. _What's New In Python: https://docs.python.org/3/whatsnew/index.html .. toctree:: :maxdepth: 1 notes/python-new-py3 Python Cheat Sheet ------------------ This section focuses on commonly used Python code snippets. The cheat sheet covers not only core Python features but also essential data structures, algorithms, and frequently used modules to help programmers efficiently tackle everyday tasks. .. toctree:: :maxdepth: 1 notes/basic/index notes/os/index notes/concurrency/index notes/asyncio/index notes/network/index notes/database/index notes/security/index notes/extension/index notes/llm/index notes/hpc/index notes/appendix/index ================================================ FILE: docs/notes/appendix/disaggregated-prefill-decode.rst ================================================ .. meta:: :description lang=en: Evaluating disaggregated prefill/decode for LLM serving with vLLM, NIXL, and EFA on AWS :keywords: LLM, vLLM, NIXL, disaggregated prefill decode, KV cache, EFA, inference serving Is Disaggregated Prefill/Decode a Silver Bullet for LLM Serving? ================================================================ :Date: 2026-03-10 Abstract -------- Disaggregated prefill/decode has gained traction as a promising architecture for LLM serving, separating the compute-intensive prefill phase from the memory-bound decode phase onto dedicated node groups. Proponents argue that this separation enables independent scaling and eliminates interference between the two phases. But is it truly a silver bullet? This article puts the claim to the test by evaluating disaggregated prefill/decode using vLLM with NIXL over the AWS Elastic Fabric Adapter (EFA) on a 4-node cluster. We compare data parallelism and simple load-balanced routing as baselines against disaggregated configurations. Our results show that while disaggregation dramatically reduces inter-token latency (ITL), it comes at a significant cost to throughput and time-to-first-token (TTFT), revealing that the architecture is far from a universal solution. Introduction ------------ In standard LLM serving, each node handles both prefill and decode for incoming requests. The prefill phase is compute-bound and processes the entire input prompt in parallel, while the decode phase is memory-bandwidth-bound and generates tokens autoregressively. When both phases share the same GPU pool, long prefill requests can block decode iterations, increasing inter-token latency for concurrent requests. Disaggregated prefill/decode addresses this interference by assigning prefill and decode to separate node groups. After a prefill node completes prompt processing, the KV cache is transferred to a decode node via a high-bandwidth interconnect. NIXL [1]_ (NVIDIA Inference Xfer Library) provides the KV cache transfer mechanism, and on AWS, this transfer occurs over EFA using the ``LIBFABRIC`` backend. The appeal is intuitive: by isolating decode nodes from prefill interference, token generation should proceed at a steady, low-latency pace. However, this separation introduces new costs — KV cache transfer overhead, prefill node saturation at long input lengths, and reduced effective cluster capacity for each phase. The question is whether these trade-offs are worthwhile compared to simpler alternatives like data parallelism or stateless load-balanced routing. This experiment uses vLLM [2]_ with the ``NixlConnector`` to orchestrate disaggregated serving, and ``vllm-router`` [3]_ as a reverse proxy to load-balance requests across node groups. The experiment code is available under `src/nixl `_ in the companion repository. Container Image --------------- The experiment uses a custom Docker image that bundles all required components. The ``Dockerfile`` builds on ``nvidia/cuda:12.8.1-devel-ubuntu24.04`` and installs the following stack: - **GDRCopy** v2.5.1 for GPU-direct memory registration - **EFA installer** v1.47.0 for AWS Elastic Fabric Adapter support - **UCX** v1.20.0 built with verbs, rdmacm, and EFA transport - **NIXL** v0.10.1 with ``LIBFABRIC`` backend for KV cache transfer - **nixlbench** for standalone NIXL bandwidth/latency microbenchmarks - **PyTorch** 2.9.1, **flash-attn** 2.8.1, and **DeepGEMM** v2.1.1.post3 - **vLLM** 0.15.1 with ``NixlConnector`` support - **vllm-router** for load-balancing across disaggregated node groups The image is built and saved as a portable tarball via the ``Makefile``: .. code-block:: bash make docker && make save This produces ``nixl-latest.tar.gz``, which is distributed to all Slurm nodes at launch time via ``pigz`` decompression and ``docker load``. Serving Script -------------- The ``vllm.sbatch`` script orchestrates multi-node vLLM serving on Slurm. It accepts two key flags that control the serving topology: - ``--route R``: splits the allocated nodes into ``R`` identical groups, each running an independent vLLM instance. A ``vllm-router`` process on the head node round-robins requests across groups. - ``--prefill P``: within each group, assigns ``P`` nodes as prefill-only (``kv_producer``) and the remaining nodes as decode-only (``kv_consumer``). KV cache transfer between prefill and decode nodes uses ``NixlConnector`` with the ``LIBFABRIC`` backend over EFA. When ``--prefill 0`` (default), all nodes in a group run standard data-parallel serving. The script computes ``DP = nodes_per_group * (8 / TP)`` and launches vLLM with ``--data-parallel-size`` accordingly. For disaggregated mode, each prefill and decode node runs as an independent vLLM process with explicit KV transfer configuration: .. code-block:: bash # Prefill node vllm serve ... \ --kv-transfer-config.kv_connector NixlConnector \ --kv-transfer-config.kv_role kv_producer \ --kv-transfer-config.kv_connector_extra_config.backends+ LIBFABRIC # Decode node vllm serve ... \ --kv-transfer-config.kv_connector NixlConnector \ --kv-transfer-config.kv_role kv_consumer \ --kv-transfer-config.kv_connector_extra_config.backends+ LIBFABRIC The router uses ``round_robin`` policy for pure-DP groups and ``consistent_hash`` with ``--vllm-pd-disaggregation`` for PD groups, directing initial requests to prefill endpoints and subsequent decode traffic to decode endpoints: .. code-block:: bash # Router for pure-DP groups (round-robin across group endpoints) vllm-router \ --policy round_robin \ --worker-urls http://:8000 http://:8001 \ --host 0.0.0.0 --port 8010 # Router for PD disaggregation (consistent hash with prefill/decode split) vllm-router \ --policy consistent_hash \ --vllm-pd-disaggregation \ --prefill http://:8000 \ --decode http://:8001 --decode http://:8002 \ --host 0.0.0.0 --port 8010 Each container is launched with ``--privileged``, ``--net=host``, and explicit ``/dev/infiniband/uverbs*`` and ``/dev/gdrdrv`` device mounts to enable GPU-direct RDMA over EFA. Benchmark Script ---------------- The ``bench.sh`` script wraps ``vllm bench serve`` and handles Docker image loading transparently. If the ``vllm`` CLI is not available on the host, the script re-executes itself inside the container. It points the benchmark client at the router endpoint (or the direct vLLM endpoint for single-group configurations): .. code-block:: bash bash bench.sh -H -p -- \ --model /fsx/models/deepseek-ai/DeepSeek-V2-Lite \ --dataset-name random \ --random-input-len 512 --random-output-len 256 \ --num-prompts 1024 Experimental Setup ------------------ All experiments run on 4 nodes with 8 GPUs each (TP=8) using DeepSeek-V2-Lite as the model. The benchmark uses random input/output data with 1024 prompts via ``vllm bench serve``. The configurations are: - **Baseline (data parallelism)**: 4 nodes, TP=8, DP=4. All nodes serve both prefill and decode. This is the standard data-parallel serving setup. - **Route 2**: 2 groups of 2 nodes each, TP=8, DP=2 per group. A router round-robins requests across groups. Each group independently handles both prefill and decode. - **Route 4**: 4 groups of 1 node each, TP=8, no data parallelism. A router distributes requests across all 4 independent nodes. - **PD 1P3D**: Disaggregated prefill/decode with 1 prefill node and 3 decode nodes. KV cache is transferred from the prefill node to decode nodes via NIXL. - **PD 2P2D**: Disaggregated prefill/decode with 2 prefill nodes and 2 decode nodes. .. code-block:: bash # Exp 1: Baseline — 4 nodes, TP=8, pure DP salloc -N 4 bash vllm.sbatch \ --model /fsx/models/deepseek-ai/DeepSeek-V2-Lite \ --gpu-memory-utilization 0.9 # Exp 2: 2 groups × 2 nodes, DP=2 per group, router round-robins salloc -N 4 bash vllm.sbatch --route 2 \ --model /fsx/models/deepseek-ai/DeepSeek-V2-Lite \ --gpu-memory-utilization 0.9 # Exp 3: 4 groups × 1 node, no DP, router round-robins salloc -N 4 bash vllm.sbatch --route 4 \ --model /fsx/models/deepseek-ai/DeepSeek-V2-Lite \ --gpu-memory-utilization 0.9 # Exp 4: 1 prefill + 3 decode salloc -N 4 bash vllm.sbatch --prefill 1 \ --model /fsx/models/deepseek-ai/DeepSeek-V2-Lite \ --gpu-memory-utilization 0.9 # Exp 5: 2 prefill + 2 decode salloc -N 4 bash vllm.sbatch --prefill 2 \ --model /fsx/models/deepseek-ai/DeepSeek-V2-Lite \ --gpu-memory-utilization 0.9 Results ------- We evaluate each configuration along four metrics: output token throughput, request throughput, time to first token (TTFT), and inter-token latency (ITL). Each plot contains two panels — the left panel sweeps input length with a fixed output length of 256 tokens (prefill-dominated regime), while the right panel sweeps output length with a fixed input length of 512 tokens (decode-dominated regime). This allows us to observe how each configuration behaves when the workload shifts from prefill-heavy to decode-heavy. Microbenchmark: KV Cache Transfer Bandwidth ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Before examining end-to-end serving results, we use ``nixlbench`` to measure the raw NIXL transfer bandwidth over EFA between two nodes. This establishes an upper bound on KV cache transfer speed and helps contextualize the TTFT overhead observed in disaggregated configurations. The benchmark runs in Multi-GPU (MG) mode with all 8 GPUs per node performing VRAM-to-VRAM transfers over the ``LIBFABRIC`` backend: .. code-block:: bash salloc -N 2 bash nixl.sbatch --backend LIBFABRIC \ --initiator_seg_type VRAM --target_seg_type VRAM \ --mode MG --num_initiator_dev 8 --num_target_dev 8 Block Size (B) Batch Size B/W (GB/Sec) Avg Lat. (us) P99 Tx (us) --------------------------------------------------------------------------------- 4096 1 0.670064 6.1 47.0 8192 1 1.315392 6.2 45.0 16384 1 2.511416 6.5 47.0 32768 1 4.820423 6.8 50.0 65536 1 8.733224 7.5 56.0 131072 1 12.341950 10.6 52.0 262144 1 23.272188 11.3 59.0 524288 1 43.365764 12.1 62.0 1048576 1 74.816773 14.0 77.0 2097152 1 121.086563 17.3 105.0 4194304 1 180.631395 23.2 146.0 8388608 1 239.037623 35.1 247.0 16777216 1 289.500030 58.0 432.0 33554432 1 327.436372 102.5 796.0 67108864 1 349.608429 192.0 1724.0 **Mapping to DeepSeek-V2-Lite KV cache transfer.** DeepSeek-V2-Lite uses Multi-head Latent Attention (MLA), which compresses the KV cache into a latent vector per token per layer. The per-token-per-layer KV cache size is ``(kv_lora_rank + qk_rope_head_dim) × dtype_size = (512 + 64) × 2 = 1,152 bytes``. For 512 input tokens across 27 layers, the total KV cache is approximately **15.2 MB**. With TP=8, each GPU transfers about **1.9 MB**, which falls in the ~121 GB/s bandwidth range per the table above. Without tensor parallelism, the full 15.2 MB transfer achieves approximately ~289 GB/s. Output Token Throughput ~~~~~~~~~~~~~~~~~~~~~~~ .. image:: https://raw.githubusercontent.com/crazyguitar/pysheeet/master/docs/_static/appendix/nixl/throughput.png :alt: Output token throughput comparison The left panel varies input length with a fixed output length of 256 tokens (prefill-dominated), while the right panel varies output length with a fixed input length of 512 tokens (decode-dominated). For prefill-dominated workloads, Route 4 achieves the highest throughput since each node operates independently without the overhead of data parallelism coordination. The disaggregated configurations (PD 1P3D and PD 2P2D) show competitive throughput at shorter input lengths but degrade at longer inputs where the prefill nodes become the bottleneck. For decode-dominated workloads, Route 4 again leads, followed by PD 1P3D. PD 2P2D shows the lowest throughput in this regime, as its two decode nodes cannot match the decode capacity of other configurations. Request Throughput ~~~~~~~~~~~~~~~~~~ .. image:: https://raw.githubusercontent.com/crazyguitar/pysheeet/master/docs/_static/appendix/nixl/req_throughput.png :alt: Request throughput comparison Request throughput follows a similar pattern. Route 4 consistently achieves the highest request throughput across all configurations. The disaggregated PD 1P3D configuration maintains reasonable request throughput for short inputs but drops significantly at longer input lengths (4096 tokens), where the single prefill node becomes saturated. Time to First Token (TTFT) ~~~~~~~~~~~~~~~~~~~~~~~~~~ .. image:: https://raw.githubusercontent.com/crazyguitar/pysheeet/master/docs/_static/appendix/nixl/ttft.png :alt: TTFT comparison TTFT is critical for user-perceived latency. The baseline DP and Route 2 configurations show moderate TTFT that scales with input length. Route 4 achieves the lowest TTFT across all input lengths due to the absence of cross-node coordination. The disaggregated configurations exhibit higher TTFT, particularly at longer input lengths. PD 1P3D shows TTFT exceeding 37 seconds at 4096 input tokens, as all prefill work funnels through a single node. PD 2P2D improves on this but still lags behind the non-disaggregated configurations. The additional latency from KV cache transfer over NIXL contributes to the elevated TTFT. For decode-dominated workloads (right panel), the differences are smaller. At short output lengths (256–512 tokens), PD 1P3D shows 1–2 seconds higher TTFT than the baseline, as the KV cache transfer overhead is proportionally more significant. At longer output lengths (1024+ tokens), the disaggregated configurations converge with or improve upon the baseline, as the baseline suffers from increased prefill/decode contention under heavier concurrent decode load. Inter-Token Latency (ITL) ~~~~~~~~~~~~~~~~~~~~~~~~~ .. image:: https://raw.githubusercontent.com/crazyguitar/pysheeet/master/docs/_static/appendix/nixl/itl.png :alt: ITL comparison ITL measures the latency between consecutive generated tokens during the decode phase. This is where disaggregated serving shows its primary advantage. In the prefill-dominated regime (left panel), PD 1P3D achieves the lowest ITL across all input lengths, with mean ITL as low as 10 ms at 4096 input tokens. By isolating decode nodes from prefill interference, the decode phase runs uninterrupted. PD 2P2D also shows reduced ITL compared to the baseline, though the benefit is less pronounced due to having fewer decode nodes. The baseline DP and Route configurations show higher ITL, particularly at longer input lengths where prefill and decode contend for the same GPU resources. In the decode-dominated regime (right panel), Route 4 achieves the lowest ITL (~25–29 ms) since each node serves independently without cross-node coordination. Among the disaggregated configurations, PD 1P3D outperforms PD 2P2D due to its greater decode capacity (3 decode nodes vs. 2), maintaining ITL around 26–35 ms. PD 2P2D, with only 2 decode nodes, shows ITL comparable to the baseline (~45–50 ms). As output length increases, ITL gradually rises across all configurations, reflecting the growing decode load. Discussion ---------- So, is disaggregated prefill/decode a silver bullet? The answer is clearly no — at least not under the conditions tested here. All benchmarks use randomly generated prompts, meaning every request produces a unique KV cache with zero prefix cache hit rate. This represents a worst-case scenario for disaggregated serving, where every prefill must be computed from scratch and the full KV cache must be transferred over the network. In production workloads with shared system prompts or repeated prefixes, prefix caching on prefill nodes could substantially reduce redundant computation and transfer volume, potentially shifting the balance in favor of disaggregation. Even so, the results reveal a set of sharp trade-offs that make disaggregation a specialized tool rather than a universal improvement: - **ITL wins, but throughput depends on scaling**: Disaggregated configurations deliver dramatically lower inter-token latency — PD 1P3D achieves as low as 10 ms ITL at long input lengths, up to 14× better than the baseline in prefill-dominated regimes and 1.4–2.4× better in decode-dominated regimes. The throughput and TTFT degradation observed here is partly an artifact of a fixed 4-node cluster: dedicating nodes to one role starves the other. In practice, prefill and decode pools can be scaled independently — adding more prefill nodes to eliminate the prefill bottleneck, or more decode nodes to increase token throughput. The challenge is finding the right ratio between prefill and decode capacity for a given workload, as over-provisioning either side increases cost without proportional benefit. - **Prefill bottleneck is a hard constraint**: With a fixed cluster size, dedicating nodes to prefill reduces decode capacity and vice versa. PD 1P3D suffers severe prefill saturation at long input lengths (TTFT > 37s at 4096 tokens), while PD 2P2D has fewer decode nodes, limiting decode throughput. Frameworks such as `NVIDIA Dynamo `_ aim to address this by dynamically scaling prefill and decode pools based on real-time demand, though this adds operational complexity. - **Simple routing beats disaggregation on throughput**: Route 4 (pure routing, no DP, no disaggregation) consistently achieves the highest throughput across all configurations by eliminating cross-node synchronization entirely. It also achieves the lowest TTFT in prefill-dominated workloads, though PD 1P3D edges it out on TTFT in decode-dominated regimes where the fixed 512-token input is short enough to avoid prefill saturation. This is a surprisingly strong baseline — for workloads where ITL is not the primary concern, stateless load-balanced independent nodes outperform both data parallelism and disaggregated configurations. - **KV cache transfer is not free**: The NIXL transfer over EFA adds measurable latency to TTFT in disaggregated configurations. This overhead is amortized for longer decode sequences but is noticeable for short output lengths, making disaggregation less attractive for short-response workloads. In summary, disaggregated prefill/decode aims to optimize both TTFT and ITL by isolating the two phases, but achieving these goals is not guaranteed. KV cache transfer over the network introduces additional overhead that can negate the TTFT benefit, particularly at long input lengths where the transfer volume is large. While ITL improvements are consistently observed due to the elimination of prefill interference on decode nodes, the overall serving performance depends heavily on the prefill-to-decode ratio, workload characteristics, and network bandwidth. Teams considering this architecture should carefully profile their input/output length distributions, latency SLAs, and throughput requirements before committing to the added complexity. References ---------- .. [1] NVIDIA, "NIXL: NVIDIA Inference Xfer Library," GitHub, 2025. https://github.com/ai-dynamo/nixl .. [2] vLLM Project, "vLLM: Easy, fast, and cheap LLM serving," GitHub, 2024. https://github.com/vllm-project/vllm .. [3] vLLM Project, "vllm-router: Production-ready router for vLLM," GitHub, 2025. https://github.com/vllm-project/vllm-router ================================================ FILE: docs/notes/appendix/index.rst ================================================ .. meta:: :description lang=en: Python appendix covering advanced topics including the walrus operator (PEP 572) and Python debugging with GDB :keywords: Python, Python3, walrus operator, PEP 572, GDB, debugging, advanced Python Blog ---- This section explores advanced programming topics to help users build a deeper understanding of complex concepts and practical techniques. Programmers working in other languages, such as C/C++, often use Python as a versatile debugging tool. With debuggers like GDB, they may write Python scripts to parse memory regions, improve output readability, or automate troubleshooting tasks. More advanced topics and examples can be found in the following link. .. toctree:: :maxdepth: 1 disaggregated-prefill-decode megatron-efa-monitoring nccl-gin python-walrus python-gdb ================================================ FILE: docs/notes/appendix/megatron-efa-monitoring.rst ================================================ .. meta:: :description lang=en: Monitoring EFA network performance with NCCL GIN and Nsys during distributed LLM training on AWS :keywords: EFA, NCCL, GIN, Nsys, Megatron-LM, distributed training, network monitoring, AWS Monitoring EFA with NCCL GIN and Nsys ====================================== :Date: 2026-02-28 Abstract -------- Distributed training at scale requires deep visibility into network behavior to identify bottlenecks and optimize communication patterns. When training large language models with Megatron-LM on AWS infrastructure using the Elastic Fabric Adapter (EFA), understanding network performance becomes critical for achieving optimal throughput. This article demonstrates how to enable NCCL GPU-Initiated Networking (GIN) in Megatron-LM using Megatron Bridge and leverage Nsys with EFA metrics to monitor network behavior during distributed training workloads. The techniques presented here are based on best practices from AWS re:Invent 2024 [1]_. Introduction ------------ `Megatron-LM `_ is a widely adopted framework for training large transformer models using model parallelism, pipeline parallelism, and data parallelism. When deployed on AWS instances with EFA, the network fabric provides high-bandwidth, low-latency communication essential for scaling to hundreds or thousands of GPUs. However, achieving peak performance requires careful tuning and monitoring of the communication layer. NCCL GPU-Initiated Networking allows GPUs to initiate network operations directly without CPU involvement, reducing latency and enabling kernel fusion. Nsys (NVIDIA Nsight Systems) provides comprehensive profiling of GPU kernels, CUDA API calls, and network operations. When combined with EFA metrics collection (``--enable efa_metrics``), Nsys captures detailed network adapter statistics including bandwidth utilization, packet counts, and error rates, correlated with GPU execution timelines. This enables practitioners to diagnose performance issues and validate that the network is operating at expected capacity. `Megatron Bridge `_ simplifies the configuration and deployment of Megatron-LM training jobs by providing a high-level recipe-based interface. This eliminates the need to manually construct complex command-line arguments and makes it easier to enable advanced features like NCCL GIN and DeepEP for MoE models. Therefore, the tutorial in this article will use Megatron Bridge. Prerequisites ------------- This guide assumes the following environment: - AWS HyperPod or EC2 instances with EFA support (e.g., P5, P5e, P5en) - NCCL >= v2.29.3-1 with Device API support - aws-ofi-nccl plugin with GIN support - Megatron-LM with Megatron-Bridge We have demonstrated how to use vLLM with NCCL GIN and DeepEP in a previous article. If you are interested in building NCCL and aws-ofi-nccl from source, refer to the `NCCL GIN article `_ in this repository. Building the Megatron Container -------------------------------- The Megatron training environment is packaged as a Docker container and converted to an Enroot squash file for deployment on Slurm clusters. The container includes NCCL with Device API support, aws-ofi-nccl with GIN support, and Megatron-LM with Megatron Bridge. To build the container and create the Enroot image: .. code-block:: bash cd src/megatron make build This will create a ``megatron-lm+latest.sqsh`` file that can be used with the Slurm launcher scripts. For details on the container build process, refer to the `Dockerfile `_ and `enroot.sh `_ scripts in the repository. Enabling NCCL GIN in Megatron Bridge ------------------------------------- Megatron Bridge recipes provide a declarative way to configure training jobs. To enable NCCL GIN for MoE models using DeepEP, the following environment variables are set automatically by the ``srun.sh`` launcher script: .. code-block:: bash export DEEP_EP_BACKEND=nccl export NCCL_GIN_TYPE=2 # proxy-based GIN export LD_LIBRARY_PATH=/opt/amazon/ofi-nccl/lib:$LD_LIBRARY_PATH ``NCCL_GIN_TYPE=2`` selects the proxy-based implementation, where a CPU thread mediates GPU-initiated transfers. This mode is currently supported on EFA, while GPU Direct Async Kernel-Initiated (DAKI) networking (``NCCL_GIN_TYPE=3``) is not yet available on AWS at the time of writing (February 2026). The ``srun.sh`` script also configures additional EFA-specific settings for optimal performance: .. code-block:: bash export FI_PROVIDER=efa export FI_EFA_USE_DEVICE_RDMA=1 export FI_EFA_FORK_SAFE=1 export NCCL_NET_PLUGIN=/opt/amazon/ofi-nccl/lib/libnccl-net-ofi.so export NCCL_TUNER_PLUGIN=/opt/amazon/ofi-nccl/lib/libnccl-tuner-ofi.so export NCCL_BUFFSIZE=8388608 export NCCL_P2P_NET_CHUNKSIZE=524288 Launching Megatron Training with DeepEP and NCCL GIN ----------------------------------------------------- The following example demonstrates how to launch a DeepSeek-V2-Lite pretraining job with DeepEP enabled for MoE token dispatching. The recipe configures the model to use expert parallelism across 64 ranks with NCCL GIN for low-latency all-to-all communication. .. code-block:: bash cd src/megatron # Allocate 2 nodes on Slurm salloc -N 2 # Launch DeepSeek-V2-Lite with DeepEP and NCCL GIN ./srun.sh recipes/deepseek_v2_lite_pretrain.py \ hf_path=/fsx/models/deepseek-ai/DeepSeek-V2-Lite \ moe_token_dispatcher_type=deepep \ model.tensor_model_parallel_size=1 \ model.expert_model_parallel_size=64 \ model.sequence_parallel=false The ``moe_token_dispatcher_type=deepep`` argument enables DeepEP as the MoE dispatcher backend. Under the hood, the recipe configures the following settings: .. code-block:: python cfg.model.moe_token_dispatcher_type = "flex" cfg.model.moe_flex_dispatcher_backend = "deepep" cfg.model.moe_enable_deepep = True cfg.model.moe_shared_expert_overlap = False When the training job starts, verify that NCCL initializes with GIN enabled by checking the logs for Device API initialization messages: .. code-block:: text [NCCL] Device API initialized [NCCL] GIN proxy mode enabled (type=2) [NCCL Backend] LOW LATENCY MODE: Rank 0 connecting to all ranks [NCCL Backend] Initialized global rank 0/64 Monitoring EFA with Nsys and EFA Metrics ----------------------------------------- Nsys (NVIDIA Nsight Systems) provides comprehensive profiling of GPU kernels, CUDA API calls, and network operations. The ``--enable efa_metrics`` flag instructs Nsys to collect EFA adapter statistics in real-time from the EFA device counters (e.g., rdmap113s0, rdmap114s0) at 10Hz sampling rate, including: - **TX/RX Bandwidth**: Transmit and receive throughput - **TX/RX Packets**: Packet counts sent and received - **Error Counters**: Link errors and dropped packets Additionally, aws-ofi-nccl uses NVTX annotations to mark NCCL operations in the timeline, allowing correlation between NCCL collective calls and EFA network activity. These metrics are embedded in the Nsys timeline and correlated with GPU kernel execution and NCCL operations, making it easy to identify communication bottlenecks and validate network saturation. To profile a Megatron training run with Nsys and capture EFA metrics: .. code-block:: bash cd src/megatron salloc -N 8 ./srun.sh --nsys recipes/deepseek_v2_lite_pretrain.py \ hf_path=/fsx/models/deepseek-ai/DeepSeek-V2-Lite \ moe_token_dispatcher_type=deepep \ model.tensor_model_parallel_size=1 \ model.expert_model_parallel_size=64 \ model.sequence_parallel=false \ profiling.use_nsys_profiler=true \ profiling.profile_step_start=10 \ profiling.profile_step_end=15 \ profiling.profile_ranks=[0] The ``--nsys`` flag enables Nsys profiling with the following configuration: .. code-block:: bash nsys profile \ -t cuda,nvtx \ -s none \ --cpuctxsw=none \ --capture-range=cudaProfilerApi \ --capture-range-end=stop \ --enable efa_metrics \ -o nsys-megatron/profile--rank.nsys-rep \ --force-overwrite=true The ``--enable efa_metrics`` flag is the key parameter that enables EFA adapter monitoring. Nsys will automatically detect all EFA devices (typically ``rdmap182s0``, ``rdmap183s0``, etc.) and collect statistics at regular intervals throughout the profiling session. After profiling completes, the ``.nsys-rep`` files can be downloaded and opened in Nsight Systems GUI for analysis. The EFA metrics appear as additional rows in the timeline view, showing bandwidth and packet rate correlated with GPU kernel execution and NCCL collective operations. .. image:: https://raw.githubusercontent.com/crazyguitar/pysheeet/master/docs/_static/appendix/deepep-nsys.png Profiling with Viztracer ------------------------- For Python-level profiling of the training loop, Megatron Bridge supports Viztracer, a low-overhead tracing tool that captures function calls and timing information. This is useful for identifying CPU bottlenecks in data loading, preprocessing, or scheduler logic that may indirectly impact network performance. .. code-block:: bash salloc -N 2 ./srun.sh recipes/deepseek_v2_lite_pretrain.py \ hf_path=/fsx/models/deepseek-ai/DeepSeek-V2-Lite \ train.train_iters=100 \ profiling.use_viztracer=true \ profiling.profile_step_start=10 \ profiling.profile_step_end=15 \ profiling.profile_ranks=[0] The resulting ``.json`` trace files can be visualized in the Viztracer web UI or Chrome's ``chrome://tracing`` interface. By enabling ``log_torch``, Viztracer can capture additional PyTorch-level details such as NCCL stream and CUDA stream operations, providing visibility into the execution flow of collective communications and GPU kernels. However, to observe detailed EFA adapter statistics (bandwidth, packet counts, error counters), Nsys with ``--enable efa_metrics`` remains the required tool. Conclusion ---------- Nsys profiling with ``--enable efa_metrics`` now provides the capability to monitor both EFA adapter behavior and NCCL operations simultaneously during distributed training. This visibility is essential for diagnosing whether long NCCL operation times are caused by actual EFA transmission delays or other issues such as CPU bottlenecks, memory contention, or suboptimal NCCL configuration. By examining the correlated timeline of GPU kernels, NCCL collectives, and EFA bandwidth utilization, practitioners can pinpoint the root cause of performance bottlenecks and validate that the network fabric is operating at expected capacity. In this article, we demonstrated this monitoring approach using Megatron-LM with NCCL GIN and DeepEP as an example. The recipe-based approach of Megatron Bridge simplifies the deployment of complex training configurations, making it easier to adopt advanced features like DeepEP and NCCL GIN for large-scale MoE model training while maintaining full observability into network performance. For complete examples and scripts, refer to the `megatron directory `_ in this repository. References ---------- .. [1] `AWS re:Invent 2024 - CMP335: Drilling down into performance for distributed training `_ ================================================ FILE: docs/notes/appendix/nccl-gin.rst ================================================ .. meta:: :description lang=en: Enabling GPU-Initiated Networking for NCCL with DeepEP on AWS using EFA :keywords: NCCL, GIN, GPU-Initiated Networking, DeepEP, EFA, AWS, MoE, HyperPod GPU-Initiated Networking for NCCL on AWS ======================================== :Date: 2026-02-22 Abstract -------- GPU-Initiated Networking (GIN) has attracted significant attention as a key enabler for kernel fusion in large language model (LLM) training and inference. Mixture-of-Experts (MoE) architectures, such as DeepSeek-V3 and Qwen3-30B, require efficient token dispatching and combining across MoE layers. Conventionally, inter-GPU communication is initiated by the CPU through collective libraries such as NCCL or Gloo, necessitating explicit GPU synchronization barriers and additional ``cudaLaunchKernel`` calls that introduce non-trivial overhead. GPU-Initiated Networking eliminates this CPU-mediated round-trip by allowing data exchange to occur directly within CUDA kernels, thereby enabling kernel fusion and efficient CUDA Graph capture for accelerating end-to-end LLM layer computation. This article demonstrates how to enable NCCL GIN with DeepEP on AWS HyperPod Slurm using the AWS Elastic Fabric Adapter (EFA). Introduction ------------ Prior to 2026, adopting DeepEP as a Mixture-of-Experts dispatch and combine backend on AWS presented a significant challenge. The DeepEP kernel was originally built on top of InfiniBand with a customized NVSHMEM implementation, a transport layer unavailable on AWS infrastructure. This incompatibility effectively prevented users from leveraging DeepEP on instances equipped with the Elastic Fabric Adapter (EFA). Recent collaborative efforts by NVIDIA and Amazon Annapurna Labs have addressed this gap by introducing GPU-Initiated Networking support in NCCL and the EFA provider, enabling DeepEP to operate over EFA without relying on InfiniBand (see `DeepEP PR #521 `_ and `aws-ofi-nccl PR #1069 `_). The following experiment builds upon these contributions to illustrate how to deploy DeepEP with NCCL GIN on AWS using EFA. Build DeepEP ------------ Before deploying DeepEP on AWS HyperPod Slurm, several components must be built from source. First, NCCL >= v2.29.3-1 is required, as this is the minimum version that exposes the Device API needed for GPU-Initiated Networking. The build targets ``sm_90`` (NVIDIA H100) and ``sm_100`` (NVIDIA B200) compute capabilities to ensure compatibility with current-generation GPU instances. .. code-block:: bash NCCL_VERSION=v2.29.3-1 git clone -b ${NCCL_VERSION} https://github.com/NVIDIA/nccl.git /opt/nccl \ && cd /opt/nccl \ && make -j $(nproc) src.build CUDA_HOME=/usr/local/cuda \ NVCC_GENCODE="-gencode=arch=compute_90,code=sm_90 -gencode=arch=compute_100,code=sm_100" Optionally, the NCCL Device API examples can be built to verify that GPU-initiated communication functions correctly in the target environment. In addition, the latest release of nccl-tests (v2.17.9) ships with a GIN-enabled microbenchmark for the ``alltoall`` collective, which is useful for validating inter-GPU bandwidth and latency before running full-scale MoE workloads (see `nccl-tests alltoall.cu `_). .. code-block:: bash ## Build NCCL Device API examples cd /opt/nccl/examples/06_device_api \ && make -j $(nproc) NCCL_HOME=/opt/nccl/build CUDA_HOME=/usr/local/cuda MPI=1 MPI_HOME=/opt/amazon/openmpi NCCL_TESTS_VERSION=v2.17.9 git clone -b ${NCCL_TESTS_VERSION} https://github.com/NVIDIA/nccl-tests.git /opt/nccl-tests \ && cd /opt/nccl-tests \ && make -j $(nproc) \ MPI=1 \ MPI_HOME=/opt/amazon/openmpi/ \ CUDA_HOME=/usr/local/cuda \ NCCL_HOME=/opt/nccl/build \ NVCC_GENCODE="-gencode=arch=compute_90,code=sm_90 -gencode=arch=compute_100,code=sm_100" To test DeepEP on HyperPod Slurm, both DeepEP and aws-ofi-nccl must be pinned to specific commits that include the NCCL GIN transport path. The DeepEP fork by Aamir Shafi introduces an NCCL-based communication backend as an alternative to the original NVSHMEM/InfiniBand path, while the aws-ofi-nccl plugin provides the libfabric-to-NCCL translation layer required for EFA. Note that the NCCL GIN implementation has since been merged into the aws-ofi-nccl main branch; the commit hash is pinned here for reproducibility. .. code-block:: bash ## Install DeepEP with NCCL GIN backend (PR #521) unset NVSHMEM_DIR NVSHMEM_HOME \ && export ENABLE_NCCL=1 \ && export NCCL_DIR=/opt/nccl/build \ && export LD_LIBRARY_PATH=/opt/nccl/build/lib:$LD_LIBRARY_PATH \ && export LD_PRELOAD=/opt/nccl/build/lib/libnccl.so.2 \ && git clone -b nccl https://github.com/aamirshafi/DeepEP.git /opt/DeepEP \ && cd /opt/DeepEP \ && git checkout 6d29f34 \ && python3 setup.py build_ext --inplace \ && pip install --break-system-packages --no-build-isolation . AWS_OFI_NCCL_VERSION=5f4202f11db1585d878196db4430aeda0e834a0c git clone https://github.com/aws/aws-ofi-nccl.git /tmp/aws-ofi-nccl \ && cd /tmp/aws-ofi-nccl \ && git checkout ${AWS_OFI_NCCL_VERSION} \ && ./autogen.sh \ && ./configure --prefix=/opt/amazon/ofi-nccl \ --with-libfabric=/opt/amazon/efa \ --with-cuda=/usr/local/cuda \ && make -j$(nproc) \ && make install \ && rm -rf /tmp/aws-ofi-nccl For a complete build with all necessary dependencies, refer to the `Dockerfile `_ provided in this repository. Test NCCL GIN ------------- With the Docker image (or Enroot squash file) prepared in the previous section, NCCL GIN functionality can be validated on a Slurm cluster. The following examples demonstrate how to launch the NCCL Device API samples and nccl-tests benchmarks. The corresponding Slurm wrapper scripts are available under the `gin `_ directory in this repository. .. code-block:: bash make docker && make save # build a docker image and import an Enroot squash file # 01_allreduce_lsa (single node only) salloc -N 1 ./run.enroot /opt/nccl/examples/06_device_api/01_allreduce_lsa/allreduce_lsa # 01_allreduce_lsa (multi-node) — requires MNNVL (e.g. P6e-GB200), does NOT work over RDMA/EFA salloc -N 2 ./run.enroot /opt/nccl/examples/06_device_api/01_allreduce_lsa/allreduce_lsa # 02_alltoall_gin (multi-node) salloc -N 2 ./run.enroot /opt/nccl/examples/06_device_api/02_alltoall_gin/alltoall_gin # 03_alltoall_hybrid (multi-node) salloc -N 2 ./run.enroot /opt/nccl/examples/06_device_api/03_alltoall_hybrid/alltoall_hybrid The nccl-tests ``alltoall`` benchmark exposes two critical flags for selecting the GIN transport mode and memory registration strategy: The ``-D`` flag selects the device-side implementation for the ``alltoall`` collective: .. code-block:: text -D 0 — Host API (default) -D 1 — NVL simple (LSA/NVLink only) -D 2 — NVL optimized (LSA/NVLink only) -D 3 — GIN only (network) -D 4 — Hybrid (LSA intra-node + GIN inter-node) The ``-R`` flag controls memory registration. Symmetric memory allocation (``NCCL_MEM_SHARED``) is required for any device-side implementation (``-D > 0``), as it maps GPU memory across all ranks to enable direct remote read and write over the network: .. code-block:: text -R 0 — no registration (default) -R 1 — register memory with ncclMemAlloc -R 2 — register memory with symmetric memory allocation (NCCL_MEM_SHARED) The following examples launch the nccl-tests ``alltoall_perf`` benchmark in GIN-only mode (``-D 3``) and hybrid mode (``-D 4``), sweeping message sizes from 32 MB to 2048 MB. The ``--blocking 0`` flag enables non-blocking collectives, which is representative of how MoE layers overlap communication with computation in production workloads: .. code-block:: bash # alltoall_perf with GIN (-D 3) salloc -N 2 ./run.enroot /opt/nccl-tests/build/alltoall_perf \ -D 3 -R 2 -b 32M -e 2048M -f 2 -n 1000 -w 10 --blocking 0 # alltoall_perf with Hybrid LSA+GIN (-D 4) salloc -N 2 ./run.enroot /opt/nccl-tests/build/alltoall_perf \ -D 4 -R 2 -b 32M -e 2048M -f 2 -n 1000 -w 10 --blocking 0 Serving MoE Models with vLLM and DeepEP over NCCL GIN ----------------------------------------------------- With NCCL GIN and EFA validated on AWS HyperPod Slurm, this section demonstrates an end-to-end inference deployment using vLLM with DeepEP as the MoE all-to-all communication backend. DeepEP's low-latency dispatch and combine kernels, now operating over NCCL GIN rather than NVSHMEM, enable efficient expert-parallel inference for large MoE models such as DeepSeek-V3. The Slurm launch script ``run.sbatch`` is the same one used to launch a vLLM server in the `vllm example directory `_. However, to direct the DeepEP backend to use NCCL GIN, the following environment variables must be set at launch time: .. code-block:: bash DEEP_EP_BACKEND=nccl NCCL_GIN_TYPE=2 # proxy-based GIN ``NCCL_GIN_TYPE=2`` selects the proxy-based GIN path, in which a CPU-side proxy thread mediates network transfers on behalf of the GPU. ``NCCL_GIN_TYPE=3`` would enable GPU Direct Async Kernel-Initiated (DAKI) networking, which bypasses the proxy entirely; however, DAKI is not yet supported on AWS with EFA at the time of writing. For additional details on serving configurations and benchmarking, refer to `llm-serving.rst `_ or the `vLLM README `_. The following example launches a multi-node vLLM inference server for DeepSeek-V3-0324 with expert parallelism enabled and the DeepEP low-latency all-to-all backend: .. code-block:: bash IMAGE="${PWD}/src/gin/nccl+latest.tar.gz" MODEL="/fsx/models/deepseek-ai/DeepSeek-V3-0324" salloc -N 4 bash run.sbatch "${MODEL}" \ --image "${IMAGE}" \ --all2all-backend deepep_low_latency \ --tensor-parallel-size 8 \ --enable-expert-parallel \ --gpu-memory-utilization 0.8 \ --enforce-eager Upon successful launch, the vLLM server logs confirm that DeepEP is active as the all-to-all backend and that NCCL GIN is being used for inter-GPU communication. The key indicators are the ``DeepEPLLAll2AllManager`` manager selection and the ``[NCCL Backend]`` initialization messages showing communicator setup, symmetric memory allocation, and window registration across all ranks: .. code-block:: bash ... INFO 02-22 19:06:49 [serve.py:100] Defaulting api_server_count to data_parallel_size (4). INFO 02-22 19:06:49 [utils.py:325] INFO 02-22 19:06:49 [utils.py:325] █ █ █▄ ▄█ INFO 02-22 19:06:49 [utils.py:325] ▄▄ ▄█ █ █ █ ▀▄▀ █ version 0.15.1 INFO 02-22 19:06:49 [utils.py:325] █▄█▀ █ █ █ █ model /fsx/models/deepseek-ai/DeepSeek-V3-0324 INFO 02-22 19:06:49 [utils.py:325] ▀▀ ▀▀▀▀▀ ▀▀▀▀▀ ▀ ▀ INFO 02-22 19:06:49 [utils.py:325] ... INFO 02-22 19:07:51 [cuda_communicator.py:124] Using DeepEPLLAll2AllManager all2all manager. ... [NCCL Backend] LOW LATENCY MODE: Rank 0 connecting to all 32 ranks [NCCL Backend] NCCL version: 2.29.3 (loaded library) [NCCL Backend] Initializing 2 communicator(s) (qps_per_rank=8) for rank 0/32 [NCCL Backend] Rank 0 successfully initialized 2 communicator(s) [NCCL Backend] Rank 0 created 2 device communication(s) with 32 barrier sessions each [NCCL Backend] Initialized global rank 0/32 (comm rank 0/32) [NCCL Backend - Memory Alloc] Rank 0: Allocated ptr=0xf882000000, size=3816818816 [NCCL Backend - Memory Register] Rank 0: Copying 2 NCCL windows to GPU [NCCL Backend - Memory Register] Rank 0: Successfully copied windows to GPU [NCCL Backend - Memory Register] Rank 0: Registered windows for ptr=0xf882000000, size=3816818816 Once the server is ready, inference requests can be issued via the OpenAI-compatible completions API: .. code-block:: bash curl -sf -X POST http://:8000/v1/completions \ -H 'Content-Type: application/json' \ -d '{ "model": "/fsx/models/deepseek-ai/DeepSeek-V3-0324", "prompt": "Hello", "max_tokens": 10 }' # output {"id":"cmpl-b6e9530a07561f11","object":"text_completion" ... } Conclusion ---------- This article has demonstrated how to deploy vLLM with DeepEP and NCCL GIN on AWS HyperPod Slurm using the Elastic Fabric Adapter. As this integration is still under active development, certain limitations remain at the time of writing. For instance, although DeepEP's low-latency mode supports CUDA Graph capture, enabling it by removing ``--enforce-eager`` currently results in a startup failure in vLLM. Additionally, performance over EFA may not yet match that of InfiniBand-based deployments, as further optimizations are ongoing. This article is intended as an early reference for evaluating DeepEP with NCCL GIN on AWS. For production workloads, it is advisable to wait for official stable releases from NVIDIA and Amazon Annapurna Labs. ================================================ FILE: docs/notes/appendix/python-gdb.rst ================================================ .. meta:: :description lang=en: Python interpreter in GNU Debugger (GDB) :keywords: Python, Python3, GDB ================================== Python Interpreter in GNU Debugger ================================== :Date: 2025-08-30 Abstract -------- The GNU Debugger (GDB) is the most powerful debugging tool for developers to troubleshoot errors in their code. However, it is hard for beginners to learn, and that is why many programmers prefer to insert ``print`` to examine runtime status. Fortunately, `GDB Text User Interface (TUI)`_ provides a way for developers to review their source code and debug simultaneously. More excitingly, In GDB 7, **Python Interpreter** was built into GDB. This feature offers more straightforward ways to customize GDB printers and commands through the Python library. By discussing examples, this article tries to explore advanced debugging techniques via Python to develop tool kits for GDB. Introduction ------------ Troubleshooting software bugs is a big challenge for developers. While GDB provides many “debug commands” to inspect programs’ runtime status, its non-intuitive usages impede programmers to use it to solve problems. Indeed, mastering GDB is a long-term process. However, a quick start is not complicated; you must unlearn what you have learned like Yoda. To better understand how to use Python in GDB, this article will focus on discussing Python interpreter in GDB. Define Commands --------------- GDB supports customizing commands by using ``define``. It is useful to run a batch of commands to troubleshoot at the same time. For example, a developer can display the current frame information by defining a ``sf`` command. .. code-block:: bash # define in .gdbinit define sf where # find out where the program is info args # show arguments info locals # show local variables end However, writing a user-defined command may be inconvenient due to limited APIs. Fortunately, by interacting with Python interpreter in GDB, developers can utilize Python libraries to establish their debugging tool kits readily. The following sections show how to use Python to simplify debugging processes. Dump Memory ----------- Inspecting a process’s memory information is an effective way to troubleshoot memory issues. Developers can acquire memory contents by ``info proc mappings`` and ``dump memory``. To simplify these steps, defining a customized command is useful. However, the implementation is not straightforward by using pure GDB syntax. Even though GDB supports conditions, processing output is not intuitive. To solve this problem, using Python API in GDB would be helpful because Python contains many useful operations for handling strings. .. code-block:: python # mem.py import gdb import time import re class DumpMemory(gdb.Command): """Dump memory info into a file.""" def __init__(self): super().__init__("dm", gdb.COMMAND_USER) def get_addr(self, p, tty): """Get memory addresses.""" cmd = "info proc mappings" out = gdb.execute(cmd, tty, True) addrs = [] for l in out.split("\n"): if re.match(f".*{p}*", l): s, e, *_ = l.split() addrs.append((s, e)) return addrs def dump(self, addrs): """Dump memory result.""" if not addrs: return for s, e in addrs: f = int(time.time() * 1000) gdb.execute(f"dump memory {f}.bin {s} {e}") def invoke(self, args, tty): try: # cat /proc/self/maps addrs = self.get_addr(args, tty) # dump memory self.dump(addrs) except Exception as e: print("Usage: dm [pattern]") DumpMemory() Running the ``dm`` command will invoke ``DumpMemory.invoke``. By sourcing or implementing Python scripts in *.gdbinit*, developers can utilize user-defined commands to trace bugs when a program is running. For example, the following steps show how to invoke ``DumpMemory`` in GDB. .. code-block:: bash (gdb) start ... (gdb) source mem.py # source commands (gdb) dm stack # dump stack to ${timestamp}.bin (gdb) shell ls # ls current dir 1577283091687.bin a.cpp a.out mem.py Dump JSON --------- Parsing JSON is helpful when a developer is inspecting a JSON string in a running program. GDB can parse a ``std::string`` via ``gdb.parse_and_eval`` and return it as a ``gdb.Value``. By processing ``gdb.Value``, developers can pass a JSON string into Python ``json`` API and print it in a pretty format. .. code-block:: python # dj.py import gdb import re import json class DumpJson(gdb.Command): """Dump std::string as a styled JSON.""" def __init__(self): super().__init__("dj", gdb.COMMAND_USER) def get_json(self, args): """Parse std::string to JSON string.""" ret = gdb.parse_and_eval(args) typ = str(ret.type) if re.match("^std::.*::string", typ): return json.loads(str(ret)) return None def invoke(self, args, tty): try: # string to json string s = self.get_json(args) # json string to object o = json.loads(s) print(json.dumps(o, indent=2)) except Exception as e: print(f"Parse json error! {args}") DumpJson() The command ``dj`` displays a more readable JSON format in GDB. This command helps improve visual recognization when a JSON string large. Also, by using this command, it can detect or monitor whether a ``std::string`` is JSON or not. .. code-block:: bash (gdb) start (gdb) list 1 #include 2 3 int main(int argc, char *argv[]) 4 { 5 std::string json = R"({"foo": "FOO","bar": "BAR"})"; 6 return 0; 7 } ... (gdb) ptype json type = std::string (gdb) p json $1 = "{\"foo\": \"FOO\",\"bar\": \"BAR\"}" (gdb) source dj.py (gdb) dj json { "foo": "FOO", "bar": "BAR" } Highlight Syntax ---------------- Syntax highlighting is useful for developers to trace source code or to troubleshoot issues. By using `Pygments`_, applying color to the source is easy without defining ANSI escape code manually. The following example shows how to apply color to the ``list`` command output. .. code-block:: python import gdb from pygments import highlight from pygments.lexers import CLexer from pygments.formatters import TerminalFormatter class PrettyList(gdb.Command): """Print source code with color.""" def __init__(self): super().__init__("pl", gdb.COMMAND_USER) self.lex = CLexer() self.fmt = TerminalFormatter() def invoke(self, args, tty): try: out = gdb.execute(f"l {args}", tty, True) print(highlight(out, self.lex, self.fmt)) except Exception as e: print(e) PrettyList() Tracepoints ----------- Although a developer can insert ``printf``, ``std::cout``, or ``syslog`` to inspect functions, printing messages is not an effective way to debug when a project is enormous. Developers may waste their time in building source code and may acquire little information. Even worse, the output may become too much to detect problems. In fact, inspecting functions or variables do not require to embed *print functions* in code. By writing a Python script with GDB API, developers can customize watchpoints to trace issues dynamically at runtime. For example, by implementing a ``gdb.Breakpoint`` and a ``gdb.Command``, it is useful for developers to acquire essential information, such as parameters, call stacks, or memory usage. .. code-block:: python # tp.py import gdb tp = {} class Tracepoint(gdb.Breakpoint): def __init__(self, *args): super().__init__(*args) self.silent = True self.count = 0 def stop(self): self.count += 1 frame = gdb.newest_frame() block = frame.block() sym_and_line = frame.find_sal() framename = frame.name() filename = sym_and_line.symtab.filename line = sym_and_line.line # show tracepoint info print(f"{framename} @ {filename}:{line}") # show args and vars for s in block: if not s.is_argument and not s.is_variable: continue typ = s.type val = s.value(frame) size = typ.sizeof name = s.name print(f"\t{name}({typ}: {val}) [{size}]") # do not stop at tracepoint return False class SetTracepoint(gdb.Command): def __init__(self): super().__init__("tp", gdb.COMMAND_USER) def invoke(self, args, tty): try: global tp tp[args] = Tracepoint(args) except Exception as e: print(e) def finish(event): for t, p in tp.items(): c = p.count print(f"Tracepoint '{t}' Count: {c}") gdb.events.exited.connect(finish) SetTracepoint() Instead of inserting ``std::cout`` at the beginning of functions, using a tracepoint at a function's entry point provides useful information to inspect arguments, variables, and stacks. For instance, by setting a tracepoint at ``fib``, it is helpful to examine memory usage, stack, and the number of calls. .. code-block:: cpp int fib(int n) { if (n < 2) { return 1; } return fib(n-1) + fib(n-2); } int main(int argc, char *argv[]) { fib(3); return 0; } The following output shows the result of an inspection of the function ``fib``. In this case, tracepoints display all information a developer needs, including arguments' value, recursive flow, and variables' size. By using tracepoints, developers can acquire more useful information comparing with ``std::cout``. .. code-block:: bash (gdb) source tp.py (gdb) tp main Breakpoint 1 at 0x647: file a.cpp, line 12. (gdb) tp fib Breakpoint 2 at 0x606: file a.cpp, line 3. (gdb) r Starting program: /root/a.out main @ a.cpp:12 argc(int: 1) [4] argv(char **: 0x7fffffffe788) [8] fib @ a.cpp:3 n(int: 3) [4] fib @ a.cpp:3 n(int: 2) [4] fib @ a.cpp:3 n(int: 1) [4] fib @ a.cpp:3 n(int: 0) [4] fib @ a.cpp:3 n(int: 1) [4] [Inferior 1 (process 5409) exited normally] Tracepoint 'main' Count: 1 Tracepoint 'fib' Count: 5 Profiling --------- Without inserting timestamps, profiling is still feasible through tracepoints. By using a ``gdb.FinishBreakpoint`` after a ``gdb.Breakpoint``, GDB sets a temporary breakpoint at the return address of a frame for developers to get the current timestamp and to calculate the time difference. Note that profiling via GDB is not precise. Other tools, such as `Linux perf`_ or `Valgrind`_, provide more useful and accurate information to trace performance issues. .. code-block:: python import gdb import time class EndPoint(gdb.FinishBreakpoint): def __init__(self, breakpoint, *a, **kw): super().__init__(*a, **kw) self.silent = True self.breakpoint = breakpoint def stop(self): # normal finish end = time.time() start, out = self.breakpoint.stack.pop() diff = end - start print(out.strip()) print(f"\tCost: {diff}") return False class StartPoint(gdb.Breakpoint): def __init__(self, *a, **kw): super().__init__(*a, **kw) self.silent = True self.stack = [] def stop(self): start = time.time() # start, end, diff frame = gdb.newest_frame() sym_and_line = frame.find_sal() func = frame.function().name filename = sym_and_line.symtab.filename line = sym_and_line.line block = frame.block() args = [] for s in block: if not s.is_argument: continue name = s.name typ = s.type val = s.value(frame) args.append(f"{name}: {val} [{typ}]") # format out = "" out += f"{func} @ {filename}:{line}\n" for a in args: out += f"\t{a}\n" # append current status to a breakpoint stack self.stack.append((start, out)) EndPoint(self, internal=True) return False class Profile(gdb.Command): def __init__(self): super().__init__("prof", gdb.COMMAND_USER) def invoke(self, args, tty): try: StartPoint(args) except Exception as e: print(e) Profile() The following output shows the profiling result by setting a tracepoint at the function ``fib``. It is convenient to inspect the function's performance and stack at the same time. .. code-block:: bash (gdb) source prof.py (gdb) prof fib Breakpoint 1 at 0x606: file a.cpp, line 3. (gdb) r Starting program: /root/a.out fib(int) @ a.cpp:3 n: 1 [int] Cost: 0.0007786750793457031 fib(int) @ a.cpp:3 n: 0 [int] Cost: 0.002572298049926758 fib(int) @ a.cpp:3 n: 2 [int] Cost: 0.008517265319824219 fib(int) @ a.cpp:3 n: 1 [int] Cost: 0.0014069080352783203 fib(int) @ a.cpp:3 n: 3 [int] Cost: 0.01870584487915039 Pretty Print ------------ Although ``set print pretty on`` in GDB offers a better format to inspect variables, developers may require to parse variables' value for readability. Take the system call ``stat`` as an example. While it provides useful information to examine file attributes, the output values, such as the permission, may not be readable for debugging. By implementing a user-defined pretty print, developers can parse ``struct stat`` and output information in a readable format. .. code-block:: python import gdb import pwd import grp import stat import time from datetime import datetime class StatPrint: def __init__(self, val): self.val = val def get_filetype(self, st_mode): if stat.S_ISDIR(st_mode): return "directory" if stat.S_ISCHR(st_mode): return "character device" if stat.S_ISBLK(st_mode): return "block device" if stat.S_ISREG: return "regular file" if stat.S_ISFIFO(st_mode): return "FIFO" if stat.S_ISLNK(st_mode): return "symbolic link" if stat.S_ISSOCK(st_mode): return "socket" return "unknown" def get_access(self, st_mode): out = "-" info = ("r", "w", "x") perm = [ (stat.S_IRUSR, stat.S_IWUSR, stat.S_IXUSR), (stat.S_IRGRP, stat.S_IRWXG, stat.S_IXGRP), (stat.S_IROTH, stat.S_IWOTH, stat.S_IXOTH), ] for pm in perm: for c, p in zip(pm, info): out += p if st_mode & c else "-" return out def get_time(self, st_time): tv_sec = int(st_time["tv_sec"]) return datetime.fromtimestamp(tv_sec).isoformat() def to_string(self): st = self.val st_ino = int(st["st_ino"]) st_mode = int(st["st_mode"]) st_uid = int(st["st_uid"]) st_gid = int(st["st_gid"]) st_size = int(st["st_size"]) st_blksize = int(st["st_blksize"]) st_blocks = int(st["st_blocks"]) st_atim = st["st_atim"] st_mtim = st["st_mtim"] st_ctim = st["st_ctim"] out = "{\n" out += f"Size: {st_size}\n" out += f"Blocks: {st_blocks}\n" out += f"IO Block: {st_blksize}\n" out += f"Inode: {st_ino}\n" out += f"Access: {self.get_access(st_mode)}\n" out += f"File Type: {self.get_filetype(st_mode)}\n" out += f"Uid: ({st_uid}/{pwd.getpwuid(st_uid).pw_name})\n" out += f"Gid: ({st_gid}/{grp.getgrgid(st_gid).gr_name})\n" out += f"Access: {self.get_time(st_atim)}\n" out += f"Modify: {self.get_time(st_mtim)}\n" out += f"Change: {self.get_time(st_ctim)}\n" out += "}" return out p = gdb.printing.RegexpCollectionPrettyPrinter("sp") p.add_printer("stat", "^stat$", StatPrint) o = gdb.current_objfile() gdb.printing.register_pretty_printer(o, p) By sourcing the previous Python script, the ``PrettyPrinter`` can recognize ``struct stat`` and output a readable format for developers to inspect file attributes. Without inserting functions to parse and print ``struct stat``, it is a more convenient way to acquire a better output from Python API. .. code-block:: bash (gdb) list 15 10 struct stat st; 11 12 if ((rc = stat("./a.cpp", &st)) < 0) { 13 perror("stat failed."); 14 goto end; 15 } 16 17 rc = 0; 18 end: 19 return rc; (gdb) source st.py (gdb) b 17 Breakpoint 1 at 0x762: file a.cpp, line 17. (gdb) r Starting program: /root/a.out Breakpoint 1, main (argc=1, argv=0x7fffffffe788) at a.cpp:17 17 rc = 0; (gdb) p st $1 = { Size: 298 Blocks: 8 IO Block: 4096 Inode: 1322071 Access: -rw-rw-r-- File Type: regular file Uid: (0/root) Gid: (0/root) Access: 2019-12-28T15:53:17 Modify: 2019-12-28T15:53:01 Change: 2019-12-28T15:53:01 } Note that developers can disable a user-defined pretty-print via the command ``disable``. For example, the previous Python script registers a pretty printer under the global pretty-printers. By calling ``disable pretty-print``, the printer ``sp`` will be disabled. .. code-block:: bash (gdb) disable pretty-print global sp 1 printer disabled 1 of 2 printers enabled (gdb) i pretty-print global pretty-printers: builtin mpx_bound128 sp [disabled] stat Additionally, developers can exclude a printer in the current GDB debugging session if it is no longer required. The following snippet shows how to delete the ``sp`` printer through ``gdb.pretty_printers.remove``. .. code-block:: bash (gdb) python >import gdb >for p in gdb.pretty_printers: > if p.name == "sp": > gdb.pretty_printers.remove(p) >end (gdb) i pretty-print global pretty-printers: builtin mpx_bound128 Conclusion ---------- Integrating Python interpreter into GDB offers many flexible ways to troubleshoot issues. While many integrated development environments (IDEs) may embed GDB to debug visually, GDB allows developers to implement their commands and parse variables’ output at runtime. By using debugging scripts, developers can monitor and record necessary information without modifying their code. Honestly, inserting or enabling debugging code blocks may change a program’s behaviors, and developers should get rid of this bad habit. Also, when a problem is reproduced, GDB can attach that process and examine its status without stopping it. Obviously, debugging via GDB is inevitable if a challenging issue emerges. Thanks to integrating Python into GDB, developing a script to troubleshoot becomes more accessible that leads to developers establishing their debugging methods diversely. Reference --------- 1. `Extending GDB using Python`_ 2. `gcc/gcc/gdbhooks.py`_ 3. `gdbinit/Gdbinit`_ 4. `cyrus-and/gdb-dashboard`_ 5. `hugsy/gef`_ 6. `sharkdp/stack-inspector`_ 7. `gdb Debugging Full Example (Tutorial)`_ .. _Pygments: https://pygments.org/ .. _Extending GDB using Python: https://sourceware.org/gdb/onlinedocs/gdb/Python.html .. _gcc/gcc/gdbhooks.py: https://github.com/gcc-mirror/gcc/blob/master/gcc/gdbhooks.py .. _hugsy/gef: https://github.com/hugsy/gef .. _cyrus-and/gdb-dashboard: https://github.com/cyrus-and/gdb-dashboard .. _gdbinit/Gdbinit: https://github.com/gdbinit/Gdbinit .. _sharkdp/stack-inspector: https://github.com/sharkdp/stack-inspector .. _GDB Text User Interface (TUI): https://sourceware.org/gdb/onlinedocs/gdb/TUI.html .. _Linux perf: https://github.com/torvalds/linux/tree/master/tools/perf .. _Valgrind: https://valgrind.org/ .. _gdb Debugging Full Example (Tutorial): http://www.brendangregg.com/blog/2016-08-09/gdb-example-ncurses.html ================================================ FILE: docs/notes/appendix/python-walrus.rst ================================================ .. meta:: :description lang=en: Design philosophy of pep 572, the walrus operator :keywords: Python3, PEP 572, walrus operator PEP 572 and The Walrus Operator =============================== :Date: 2025-08-30 Abstract -------- `PEP 572`_ is one of the most contentious proposals in Python3 history because assigning a value within an expression seems unnecessary. Also, it is ambiguous for developers to distinguish the difference between **the walrus operator** (``:=``) and the equal operator (``=``). Even though sophisticated developers can use "``:=``" smoothly, they may concern the readability of their code. To better understand the usage of "``:=``," this article discusses its design philosophy and what kind of problems it tries to solve. Introduction ------------ For C/C++ developer, assigning a function return to a variable is common due to error code style handling. Managing function errors includes two steps; one is to check the return value; another is to check ``errno``. For example, .. code-block:: cpp #include #include #include #include int main(int argc, char *argv[]) { int rc = -1; // assign access return to rc and check its value if ((rc = access("hello_walrus", R_OK)) == -1) { fprintf(stderr, "%s", strerror(errno)); goto end; } rc = 0; end: return rc; } In this case, ``access`` will assign its return value to the variable ``rc`` first. Then, the program will compare the ``rc`` value with ``-1`` to check whether the execution of ``access`` is successful or not. However, Python did not allow assigning values to variables within an expression before 3.8. To fix this problem, therefore, PEP 572 introduced the walrus operator for developers. The following Python snippet is equal to the previous C example. .. code-block:: python >>> import os >>> from ctypes import * >>> libc = CDLL("libc.dylib", use_errno=True) >>> access = libc.access >>> path = create_string_buffer(b"hello_walrus") >>> if (rc := access(path, os.R_OK)) == -1: ... errno = get_errno() ... print(os.strerror(errno), file=sys.stderr) ... No such file or directory Why ``:=`` ? ------------ Developers may confuse the difference between "``:=``" and "``=``." In fact, they serve the same purpose, assigning somethings to variables. Why Python introduced "``:=``" instead of using "``=``"? What is the benefit of using "``:=``"? One reason is to reinforce the visual recognition due to a common mistake made by C/C++ developers. For instance, .. code-block:: cpp int rc = access("hello_walrus", R_OK); // rc is unintentionally assigned to -1 if (rc = -1) { fprintf(stderr, "%s", strerror(errno)); goto end; } Rather than comparison, the variable, ``rc``, is mistakenly assigned to -1. To prevent this error, some people advocate using `Yoda conditions`_ within an expression. .. code-block:: cpp int rc = access("hello_walrus", R_OK); // -1 = rc will raise a compile error if (-1 == rc) { fprintf(stderr, "%s", strerror(errno)); goto end; } However, Yoda style is not readable enough like Yoda speaks non-standardized English. Also, unlike C/C++ can detect assigning error during the compile-time via compiler options (e.g., -Wparentheses), it is difficult for Python interpreter to distinguish such mistakes throughout the runtime. Thus, the final result of PEP 572 was to use a new syntax as a solution to implement *assignment expressions*. The walrus operator was not the first solution for PEP 572. The original proposal used ``EXPR as NAME`` to assign values to variables. Unfortunately, there are some rejected reasons in this solution and other solutions as well. After intense debates, the final decision was ``:=``. Scopes ------ Unlike other expressions, which a variable is bound to a scope, an assignment expression belongs to the current scope. The purpose of this design is to allow a compact way to write code. .. code-block:: python3 >>> if not (env := os.environ.get("HOME")): ... raise KeyError("env HOME does not find!") ... >>> print(env) /root In PEP 572, another benefit is to conveniently capture a "witness" for an ``any()`` or an ``all()`` expression. Although capturing function inputs can assist an interactive debugger, the advantage is not so obvious, and examples lack readability. Therefore, this benefit does not discuss here. Note that other languages (e.g., C/C++ or Go) may bind an assignment to a scope. Take Golang as an example. .. code-block:: go package main import ( "fmt" "os" ) func main() { if env := os.Getenv("HOME"); env == "" { panic(fmt.Sprintf("Home does not find")) } fmt.Print(env) // <--- compile error: undefined: env } Pitfalls -------- Although an assigning expression allows writing compact code, there are many pitfalls when a developer uses it in a list comprehension. A common ``SyntaxError`` is to rebind iteration variables. .. code-block:: python3 >>> [i := i+1 for i in range(5)] # invalid However, updating an iteration variable will reduce readability and introduce bugs. Even if Python 3.8 did not implement the walrus operator, a programmer should avoid reusing iteration variables within a scope. Another pitfall is Python prohibits using assignment expressions within a comprehension under a class scope. .. code-block:: python3 >>> class Example: ... [(j := i) for i in range(5)] # invalid ... This limitation was from `bpo-3692`_. The interpreter's behavior is unpredictable when a class declaration contains a list comprehension. To avoid this corner case, assigning expression is invalid under a class. .. code-block:: python3 >>> class Foo: ... a = [1, 2, 3] ... b = [4, 5, 6] ... c = [i for i in zip(a, b)] # b is defined ... >>> class Bar: ... a = [1,2,3] ... b = [4,5,6] ... c = [x * y for x in a for y in b] # b is undefined ... Traceback (most recent call last): File "", line 1, in File "", line 4, in Bar File "", line 4, in NameError: name 'b' is not defined Conclusion ---------- The reason why the walrus operator (``:=``) is so controversial is that code readability may decrease. In fact, in the discussion `mail thread `_, the author of PEP 572, Christoph Groth, had considered using "``=``" to implement inline assignment like C/C++. Without judging "``:=``" is ugly, many developers argue that distinguishing the functionality between "``:=``" and "``=``" is difficult because they serve the same purpose, but behaviors are not consistent. Also, writing compact code is not persuasive enough because smaller is not always better. However, in some cases, the walrus operator can enhance readability (if you understand how to use ``:=``). For example, .. code-block:: python3 buf = b"" while True: data = read(1024) if not data: break buf += data By using ``:=``, the previous example can be simplified. .. code-block:: python3 buf = b"" while (data := read(1024)): buf += data `Python document`_ and GitHub `issue-8122`_ provides many great examples about improving code readability by "``:=``". However, using the walrus operator should be careful. Some cases, such as ``foo(x := 3, cat='vector')``, may introduce new bugs if developers are not aware of scopes. Although PEP 572 may be risky for developers to write buggy code, an in-depth understanding of design philosophy and useful examples will help us use it to write readable code at the right time. References ---------- 1. `PEP 572 - Assignment Expressions`_ 2. `What’s New In Python 3.8`_ 3. `PEP 572 and decision-making in Python`_ 4. `The PEP 572 endgame`_ 5. `Use assignment expression in stdlib (combined PR)`_ 6. `Improper scope in list comprehension, when used in class declaration`_ .. _PEP 572: https://www.python.org/dev/peps/pep-0572/ .. _PEP 572 - Assignment Expressions: https://www.python.org/dev/peps/pep-0572/ .. _What’s New In Python 3.8: https://docs.python.org/3/whatsnew/3.8.html .. _PEP 572 and decision-making in Python: https://lwn.net/Articles/757713/ .. _The PEP 572 endgame: https://lwn.net/Articles/759558/ .. _Use assignment expression in stdlib (combined PR): https://github.com/python/cpython/pull/8122/files .. _improper scope in list comprehension, when used in class declaration: https://bugs.python.org/issue3692 .. _Yoda conditions: https://en.wikipedia.org/wiki/Yoda_conditions .. _bpo-3692: https://bugs.python.org/issue3692 .. _Python document: https://docs.python.org/3/whatsnew/3.8.html#assignment-expressions .. _issue-8122: https://github.com/python/cpython/pull/8122/files ================================================ FILE: docs/notes/asyncio/index.rst ================================================ .. meta:: :description lang=en: Python asyncio tutorial covering coroutines, event loops, tasks, async/await syntax, networking, and asynchronous programming patterns :keywords: Python, Python3, asyncio, async, await, coroutine, event loop, asynchronous, concurrent, networking, TCP, UDP Asyncio ======= Python's ``asyncio`` module provides infrastructure for writing single-threaded concurrent code using coroutines, multiplexing I/O access over sockets and other resources, running network clients and servers, and other related primitives. Unlike threading, asyncio uses cooperative multitasking, where tasks voluntarily yield control to allow other tasks to run. This makes it ideal for I/O-bound applications like web servers, database clients, and network services where waiting for external resources is the primary bottleneck. This section covers asyncio from basic concepts to advanced patterns, including the event loop, coroutines, tasks, synchronization primitives, and real-world examples like TCP/UDP servers, HTTP clients, and connection pools. .. toctree:: :maxdepth: 1 python-asyncio-guide python-asyncio-basic python-asyncio-server python-asyncio-advanced ================================================ FILE: docs/notes/asyncio/python-asyncio-advanced.rst ================================================ .. meta:: :description lang=en: Python asyncio advanced - synchronization, queues, subprocesses, debugging, patterns :keywords: Python, Python3, Asyncio, Synchronization, Queue, Semaphore, Lock, Subprocess, Debugging ================= Asyncio Advanced ================= :Source: `src/basic/asyncio_.py `_ .. contents:: Table of Contents :backlinks: none Introduction ------------ Beyond basic coroutines and networking, asyncio provides synchronization primitives, queues, subprocess management, and debugging tools. This section covers advanced patterns for building robust async applications, including producer-consumer patterns, rate limiting, graceful shutdown, and integration with synchronous code. Locks ----- ``asyncio.Lock`` prevents multiple coroutines from accessing a shared resource simultaneously. Unlike threading locks, async locks must be used with ``await`` and only work within the same event loop. .. code-block:: python import asyncio class SharedCounter: def __init__(self): self.value = 0 self._lock = asyncio.Lock() async def increment(self): async with self._lock: current = self.value await asyncio.sleep(0.01) # Simulate work self.value = current + 1 async def worker(counter, name, count): for _ in range(count): await counter.increment() print(f"{name} done") async def main(): counter = SharedCounter() await asyncio.gather( worker(counter, "A", 100), worker(counter, "B", 100), worker(counter, "C", 100), ) print(f"Final value: {counter.value}") # Should be 300 asyncio.run(main()) Semaphores for Rate Limiting ---------------------------- ``asyncio.Semaphore`` limits the number of concurrent operations. This is essential for rate limiting API calls, limiting database connections, or controlling resource usage. .. code-block:: python import asyncio async def fetch(url, semaphore): async with semaphore: print(f"Fetching {url}") await asyncio.sleep(1) # Simulate network request return f"Response from {url}" async def main(): # Limit to 3 concurrent requests semaphore = asyncio.Semaphore(3) urls = [f"https://api.example.com/{i}" for i in range(10)] tasks = [fetch(url, semaphore) for url in urls] results = await asyncio.gather(*tasks) for r in results: print(r) asyncio.run(main()) Events for Signaling -------------------- ``asyncio.Event`` allows coroutines to wait for a signal from another coroutine. This is useful for coordinating startup, shutdown, or state changes between multiple tasks. .. code-block:: python import asyncio async def waiter(event, name): print(f"{name} waiting for event") await event.wait() print(f"{name} got the event!") async def setter(event): print("Setting event in 2 seconds...") await asyncio.sleep(2) event.set() print("Event set!") async def main(): event = asyncio.Event() await asyncio.gather( waiter(event, "Task 1"), waiter(event, "Task 2"), waiter(event, "Task 3"), setter(event), ) asyncio.run(main()) Conditions for Complex Synchronization -------------------------------------- ``asyncio.Condition`` combines a lock with the ability to wait for a condition. This is useful for producer-consumer patterns where consumers need to wait for specific conditions. .. code-block:: python import asyncio class Buffer: def __init__(self, size): self.buffer = [] self.size = size self.condition = asyncio.Condition() async def put(self, item): async with self.condition: while len(self.buffer) >= self.size: await self.condition.wait() self.buffer.append(item) self.condition.notify() async def get(self): async with self.condition: while not self.buffer: await self.condition.wait() item = self.buffer.pop(0) self.condition.notify() return item async def producer(buffer, name): for i in range(5): await buffer.put(f"{name}-{i}") print(f"Produced: {name}-{i}") await asyncio.sleep(0.1) async def consumer(buffer, name): for _ in range(5): item = await buffer.get() print(f"{name} consumed: {item}") await asyncio.sleep(0.2) async def main(): buffer = Buffer(size=2) await asyncio.gather( producer(buffer, "P1"), consumer(buffer, "C1"), consumer(buffer, "C2"), ) asyncio.run(main()) Queues for Producer-Consumer ---------------------------- ``asyncio.Queue`` is the preferred way to implement producer-consumer patterns. It handles synchronization internally and provides blocking get/put operations with optional timeouts. .. code-block:: python import asyncio async def producer(queue, name): for i in range(5): item = f"{name}-item-{i}" await queue.put(item) print(f"Produced: {item}") await asyncio.sleep(0.5) async def consumer(queue, name): while True: try: item = await asyncio.wait_for(queue.get(), timeout=2.0) print(f"{name} consumed: {item}") queue.task_done() await asyncio.sleep(0.1) except asyncio.TimeoutError: print(f"{name} timed out, exiting") break async def main(): queue = asyncio.Queue(maxsize=3) producers = [ asyncio.create_task(producer(queue, "P1")), asyncio.create_task(producer(queue, "P2")), ] consumers = [ asyncio.create_task(consumer(queue, "C1")), asyncio.create_task(consumer(queue, "C2")), ] await asyncio.gather(*producers) await queue.join() # Wait for all items to be processed for c in consumers: c.cancel() asyncio.run(main()) Priority Queue -------------- ``asyncio.PriorityQueue`` processes items by priority. Lower priority values are processed first. Items must be comparable or wrapped in tuples with priority as the first element. .. code-block:: python import asyncio async def producer(queue): items = [ (3, "low priority"), (1, "high priority"), (2, "medium priority"), ] for priority, item in items: await queue.put((priority, item)) print(f"Added: {item} (priority {priority})") async def consumer(queue): while not queue.empty(): priority, item = await queue.get() print(f"Processing: {item} (priority {priority})") await asyncio.sleep(0.5) queue.task_done() async def main(): queue = asyncio.PriorityQueue() await producer(queue) await consumer(queue) asyncio.run(main()) Running Subprocesses -------------------- Asyncio can run and communicate with subprocesses asynchronously. This is useful for running shell commands, external tools, or parallel processes without blocking the event loop. .. code-block:: python import asyncio async def run_command(cmd): proc = await asyncio.create_subprocess_shell( cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await proc.communicate() return { 'cmd': cmd, 'returncode': proc.returncode, 'stdout': stdout.decode().strip(), 'stderr': stderr.decode().strip() } async def main(): commands = [ "echo 'Hello World'", "python --version", "date", ] results = await asyncio.gather(*[run_command(c) for c in commands]) for r in results: print(f"Command: {r['cmd']}") print(f"Output: {r['stdout']}") print() asyncio.run(main()) Subprocess with Streaming Output -------------------------------- For long-running processes, you can stream output line by line instead of waiting for the process to complete. This is useful for monitoring logs or progress. .. code-block:: python import asyncio async def stream_subprocess(cmd): proc = await asyncio.create_subprocess_shell( cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT ) while True: line = await proc.stdout.readline() if not line: break print(f"[{cmd[:20]}] {line.decode().strip()}") await proc.wait() return proc.returncode async def main(): # Run multiple commands and stream their output await asyncio.gather( stream_subprocess("for i in 1 2 3; do echo $i; sleep 1; done"), stream_subprocess("for i in a b c; do echo $i; sleep 0.5; done"), ) asyncio.run(main()) Graceful Shutdown ----------------- Proper shutdown handling ensures all tasks complete cleanly and resources are released. Use signal handlers to catch SIGINT/SIGTERM and cancel tasks gracefully. .. code-block:: python import asyncio import signal async def worker(name): try: while True: print(f"{name} working...") await asyncio.sleep(1) except asyncio.CancelledError: print(f"{name} cancelled, cleaning up...") await asyncio.sleep(0.5) # Cleanup time print(f"{name} cleanup done") raise async def main(): loop = asyncio.get_event_loop() tasks = [ asyncio.create_task(worker("Worker-1")), asyncio.create_task(worker("Worker-2")), ] def shutdown(): print("\nShutdown requested...") for task in tasks: task.cancel() loop.add_signal_handler(signal.SIGINT, shutdown) loop.add_signal_handler(signal.SIGTERM, shutdown) try: await asyncio.gather(*tasks) except asyncio.CancelledError: print("All tasks cancelled") asyncio.run(main()) Running Async Code in Threads ----------------------------- When you need to run async code from synchronous code (e.g., in a callback or from another thread), use ``asyncio.run_coroutine_threadsafe()``. .. code-block:: python import asyncio import threading import time async def async_task(value): await asyncio.sleep(1) return value * 2 def thread_function(loop): # Run async code from a different thread future = asyncio.run_coroutine_threadsafe( async_task(21), loop ) result = future.result(timeout=5) print(f"Thread got result: {result}") async def main(): loop = asyncio.get_event_loop() # Start a thread that will call async code thread = threading.Thread(target=thread_function, args=(loop,)) thread.start() # Keep the event loop running await asyncio.sleep(2) thread.join() asyncio.run(main()) Debugging Asyncio ----------------- Enable debug mode to catch common mistakes like blocking calls, unawaited coroutines, and slow callbacks. Debug mode adds overhead so use it only during development. .. code-block:: python import asyncio import logging # Enable debug logging logging.basicConfig(level=logging.DEBUG) async def slow_callback(): import time time.sleep(0.2) # This will trigger a warning in debug mode async def main(): await slow_callback() # Method 1: Environment variable # PYTHONASYNCIODEBUG=1 python script.py # Method 2: asyncio.run with debug=True asyncio.run(main(), debug=True) Custom Event Loop ----------------- You can customize the event loop behavior by subclassing or patching. This is useful for debugging, profiling, or adding custom functionality. .. code-block:: python import asyncio class DebugEventLoop(asyncio.SelectorEventLoop): def _run_once(self): # Track number of scheduled callbacks num_ready = len(self._ready) num_scheduled = len(self._scheduled) if num_ready or num_scheduled: print(f"Ready: {num_ready}, Scheduled: {num_scheduled}") super()._run_once() async def task(n): await asyncio.sleep(n) print(f"Task {n} done") # Use custom event loop loop = DebugEventLoop() asyncio.set_event_loop(loop) try: loop.run_until_complete(asyncio.gather( task(0.1), task(0.2), task(0.3), )) finally: loop.close() Timeout Patterns ---------------- Different timeout patterns for various use cases: per-operation timeout, overall timeout, and timeout with fallback. .. code-block:: python import asyncio async def fetch(url, delay): await asyncio.sleep(delay) return f"Response from {url}" async def fetch_with_timeout(url, delay, timeout): """Per-operation timeout.""" try: return await asyncio.wait_for(fetch(url, delay), timeout) except asyncio.TimeoutError: return f"Timeout for {url}" async def fetch_all_with_timeout(urls, timeout): """Overall timeout for all operations.""" async def fetch_all(): return await asyncio.gather(*[fetch(u, i) for i, u in enumerate(urls)]) try: return await asyncio.wait_for(fetch_all(), timeout) except asyncio.TimeoutError: return ["Overall timeout"] async def fetch_with_fallback(url, delay, timeout, fallback): """Timeout with fallback value.""" try: return await asyncio.wait_for(fetch(url, delay), timeout) except asyncio.TimeoutError: return fallback async def main(): # Per-operation timeout result = await fetch_with_timeout("slow.com", 5, 1) print(result) # Timeout with fallback result = await fetch_with_fallback("slow.com", 5, 1, "cached response") print(result) asyncio.run(main()) Retry Pattern ------------- Implement retry logic for transient failures with exponential backoff. This is essential for robust network clients. .. code-block:: python import asyncio import random class RetryError(Exception): pass async def unreliable_operation(): """Simulates an operation that fails randomly.""" if random.random() < 0.7: raise ConnectionError("Network error") return "Success!" async def retry(coro_func, max_retries=3, base_delay=1.0): """Retry with exponential backoff.""" last_exception = None for attempt in range(max_retries): try: return await coro_func() except Exception as e: last_exception = e if attempt < max_retries - 1: delay = base_delay * (2 ** attempt) jitter = random.uniform(0, 0.1 * delay) print(f"Attempt {attempt + 1} failed, retrying in {delay:.2f}s") await asyncio.sleep(delay + jitter) raise RetryError(f"Failed after {max_retries} attempts") from last_exception async def main(): try: result = await retry(unreliable_operation, max_retries=5) print(f"Result: {result}") except RetryError as e: print(f"All retries failed: {e}") asyncio.run(main()) Async Context Variable ---------------------- Context variables (Python 3.7+) provide task-local storage, similar to thread-local storage but for async tasks. Useful for request IDs, user context, or database connections. .. code-block:: python import asyncio import contextvars # Create context variable request_id = contextvars.ContextVar('request_id', default=None) async def process_request(rid): request_id.set(rid) await step1() await step2() async def step1(): rid = request_id.get() print(f"[{rid}] Step 1") await asyncio.sleep(0.1) async def step2(): rid = request_id.get() print(f"[{rid}] Step 2") await asyncio.sleep(0.1) async def main(): await asyncio.gather( process_request("req-001"), process_request("req-002"), process_request("req-003"), ) asyncio.run(main()) TaskGroup (Python 3.11+) ------------------------ ``TaskGroup`` provides structured concurrency, ensuring all tasks complete or are cancelled together. Exceptions in any task cancel all other tasks in the group. .. code-block:: python import asyncio async def task(name, delay, should_fail=False): await asyncio.sleep(delay) if should_fail: raise ValueError(f"{name} failed!") return f"{name} done" async def main(): try: async with asyncio.TaskGroup() as tg: tg.create_task(task("A", 1)) tg.create_task(task("B", 2)) tg.create_task(task("C", 0.5, should_fail=True)) except* ValueError as eg: for exc in eg.exceptions: print(f"Caught: {exc}") # Python 3.11+ asyncio.run(main()) ================================================ FILE: docs/notes/asyncio/python-asyncio-basic.rst ================================================ .. meta:: :description lang=en: Python asyncio basics - coroutines, tasks, event loop, async/await syntax :keywords: Python, Python3, Asyncio, Coroutines, Event Loop, async await, Asynchronous Programming ================ Asyncio Basics ================ :Source: `src/basic/asyncio_.py `_ .. contents:: Table of Contents :backlinks: none Introduction ------------ The ``asyncio`` module, introduced in Python 3.4 and significantly improved in Python 3.5+ with ``async/await`` syntax, provides a foundation for writing asynchronous code. Unlike threads which use preemptive multitasking (the OS decides when to switch), asyncio uses cooperative multitasking where coroutines explicitly yield control using ``await``. This eliminates race conditions common in threaded code and makes reasoning about program flow much easier. Key concepts: - **Coroutine**: A function defined with ``async def`` that can be paused and resumed - **Event Loop**: The central scheduler that runs coroutines and handles I/O events - **Task**: A wrapper around a coroutine that schedules it for execution - **Future**: A placeholder for a result that will be available later Running Coroutines with asyncio.run ----------------------------------- The simplest way to run async code is ``asyncio.run()``, introduced in Python 3.7. It creates an event loop, runs the coroutine until completion, and cleans up automatically. This is the recommended entry point for asyncio programs. .. code-block:: python import asyncio async def hello(): print("Hello") await asyncio.sleep(1) print("World") # Python 3.7+ asyncio.run(hello()) For file I/O or other blocking operations, use ``run_in_executor`` to avoid blocking the event loop: .. code-block:: python import asyncio from concurrent.futures import ThreadPoolExecutor async def read_file(path): loop = asyncio.get_event_loop() with ThreadPoolExecutor() as pool: with open(path) as f: return await loop.run_in_executor(pool, f.read) content = asyncio.run(read_file('/etc/hosts')) Creating and Managing Tasks --------------------------- Tasks allow multiple coroutines to run concurrently. When you create a task, it's scheduled to run on the event loop immediately. Use ``asyncio.create_task()`` (Python 3.7+) or ``loop.create_task()`` to create tasks. .. code-block:: python import asyncio async def fetch(name, delay): await asyncio.sleep(delay) return f"{name} done" async def main(): # Create tasks - they start running immediately task1 = asyncio.create_task(fetch("A", 2)) task2 = asyncio.create_task(fetch("B", 1)) # Wait for both to complete result1 = await task1 result2 = await task2 print(result1, result2) asyncio.run(main()) Gathering Multiple Coroutines ----------------------------- ``asyncio.gather()`` runs multiple coroutines concurrently and collects their results in order. This is the most common way to run multiple async operations in parallel and wait for all of them to complete. .. code-block:: python import asyncio async def fetch(url, delay): await asyncio.sleep(delay) return f"Response from {url}" async def main(): urls = ["site1.com", "site2.com", "site3.com"] coros = [fetch(url, i * 0.5) for i, url in enumerate(urls)] # Run all concurrently, results in same order as input results = await asyncio.gather(*coros) for r in results: print(r) asyncio.run(main()) Waiting with Timeout -------------------- Use ``asyncio.wait_for()`` to set a timeout on async operations. This is essential for network operations where you don't want to wait indefinitely for a response that may never come. .. code-block:: python import asyncio async def slow_operation(): await asyncio.sleep(10) return "done" async def main(): try: result = await asyncio.wait_for(slow_operation(), timeout=2.0) except asyncio.TimeoutError: print("Operation timed out!") asyncio.run(main()) Waiting for First Completed --------------------------- ``asyncio.wait()`` provides more control than ``gather()``. You can wait for the first task to complete, first exception, or all tasks. This is useful when you want to process results as they become available. .. code-block:: python import asyncio async def fetch(name, delay): await asyncio.sleep(delay) return f"{name}: {delay}s" async def main(): tasks = [ asyncio.create_task(fetch("fast", 1)), asyncio.create_task(fetch("slow", 3)), ] # Wait for first to complete done, pending = await asyncio.wait( tasks, return_when=asyncio.FIRST_COMPLETED ) for task in done: print(f"Completed: {task.result()}") print(f"Still pending: {len(pending)}") # Cancel pending tasks for task in pending: task.cancel() asyncio.run(main()) Asynchronous Iteration ---------------------- Async iterators allow you to iterate over data that arrives asynchronously, such as streaming responses or database cursors. Implement ``__aiter__`` and ``__anext__`` methods to create custom async iterators. .. code-block:: python import asyncio class AsyncRange: """Async iterator that yields numbers with delays.""" def __init__(self, start, stop): self.current = start self.stop = stop def __aiter__(self): return self async def __anext__(self): if self.current >= self.stop: raise StopAsyncIteration await asyncio.sleep(0.5) value = self.current self.current += 1 return value async def main(): async for num in AsyncRange(0, 5): print(num) asyncio.run(main()) Asynchronous Context Managers ----------------------------- Async context managers are essential for managing resources that require async setup or cleanup, such as database connections, file handles, or network sessions. Use ``async with`` to ensure proper resource management. .. code-block:: python import asyncio class AsyncConnection: """Simulated async database connection.""" async def __aenter__(self): print("Connecting...") await asyncio.sleep(1) print("Connected") return self async def __aexit__(self, exc_type, exc_val, exc_tb): print("Disconnecting...") await asyncio.sleep(0.5) print("Disconnected") async def query(self, sql): await asyncio.sleep(0.1) return f"Result of: {sql}" async def main(): async with AsyncConnection() as conn: result = await conn.query("SELECT * FROM users") print(result) asyncio.run(main()) Using @asynccontextmanager -------------------------- The ``@asynccontextmanager`` decorator (Python 3.7+) provides a simpler way to create async context managers using generator syntax, similar to the synchronous ``@contextmanager`` decorator. .. code-block:: python import asyncio from contextlib import asynccontextmanager @asynccontextmanager async def managed_resource(name): print(f"Acquiring {name}") await asyncio.sleep(0.5) try: yield name finally: print(f"Releasing {name}") await asyncio.sleep(0.2) async def main(): async with managed_resource("database") as resource: print(f"Using {resource}") asyncio.run(main()) Running Blocking Code in Executor --------------------------------- When you need to call blocking code (file I/O, CPU-intensive operations, or libraries without async support), use ``run_in_executor()`` to run it in a thread pool without blocking the event loop. .. code-block:: python import asyncio import time from concurrent.futures import ThreadPoolExecutor def blocking_io(): """Simulates blocking I/O operation.""" time.sleep(2) return "IO complete" def cpu_bound(): """Simulates CPU-intensive operation.""" return sum(i * i for i in range(10**6)) async def main(): loop = asyncio.get_event_loop() # Run in default executor (ThreadPoolExecutor) result1 = await loop.run_in_executor(None, blocking_io) print(result1) # Run in custom executor with ThreadPoolExecutor(max_workers=4) as pool: result2 = await loop.run_in_executor(pool, cpu_bound) print(result2) asyncio.run(main()) Async Generators ---------------- Async generators (Python 3.6+) combine generators with async/await, allowing you to yield values asynchronously. They're useful for streaming data or implementing async iterators more concisely. .. code-block:: python import asyncio async def async_range(start, stop): """Async generator that yields numbers with delays.""" for i in range(start, stop): await asyncio.sleep(0.5) yield i async def main(): async for num in async_range(0, 5): print(num) # Async comprehension results = [x async for x in async_range(0, 3)] print(results) asyncio.run(main()) Exception Handling in Tasks --------------------------- Exceptions in tasks are stored and re-raised when you await the task or call ``result()``. Unhandled exceptions in tasks that are never awaited will be logged but may be silently ignored, so always await your tasks. .. code-block:: python import asyncio async def failing_task(): await asyncio.sleep(1) raise ValueError("Something went wrong") async def main(): task = asyncio.create_task(failing_task()) try: await task except ValueError as e: print(f"Caught exception: {e}") # Using gather with return_exceptions tasks = [ asyncio.create_task(asyncio.sleep(1)), asyncio.create_task(failing_task()), ] results = await asyncio.gather(*tasks, return_exceptions=True) for r in results: if isinstance(r, Exception): print(f"Task failed: {r}") else: print(f"Task succeeded: {r}") asyncio.run(main()) Cancelling Tasks ---------------- Tasks can be cancelled using ``task.cancel()``. The cancelled task will raise ``asyncio.CancelledError`` at the next await point. Handle this exception to perform cleanup when a task is cancelled. .. code-block:: python import asyncio async def long_running(): try: while True: print("Working...") await asyncio.sleep(1) except asyncio.CancelledError: print("Task was cancelled, cleaning up...") raise # Re-raise to mark task as cancelled async def main(): task = asyncio.create_task(long_running()) await asyncio.sleep(3) task.cancel() try: await task except asyncio.CancelledError: print("Task cancellation confirmed") asyncio.run(main()) ================================================ FILE: docs/notes/asyncio/python-asyncio-guide.rst ================================================ .. meta:: :description lang=en: A comprehensive guide to understanding asynchronous programming in Python, from blocking I/O to event loops, callbacks, generators, and async/await syntax :keywords: Python, Python3, asyncio, coroutine, event loop, async await, asynchronous programming, C10k problem, non-blocking I/O, selectors, generators, callback ================================================ A Hitchhiker's Guide to Asynchronous Programming ================================================ .. contents:: Table of Contents :backlinks: none Abstract -------- The `C10k problem`_ remains a fundamental challenge for programmers seeking to handle massive concurrent connections efficiently. Traditionally, developers address extensive I/O operations using **threads**, **epoll**, or **kqueue** to prevent software from blocking on expensive operations. However, developing readable and bug-free concurrent code is challenging due to complexities around data sharing and task dependencies. Even powerful tools like `Valgrind`_ that help detect deadlocks and race conditions cannot eliminate the time-consuming debugging process as software scales. To address these challenges, many programming languages—including Python, JavaScript, and C++—have developed better libraries, frameworks, and syntaxes to help programmers manage concurrent tasks properly. Rather than focusing on how to use modern parallel APIs, this article concentrates on the **design philosophy** behind asynchronous programming patterns, tracing the evolution from blocking I/O to the elegant ``async/await`` syntax. Using threads is the most natural approach for dispatching tasks without blocking the main thread. However, threads introduce performance overhead from context switching and require careful locking of critical sections for atomic operations. While event loops can enhance performance in I/O-bound scenarios, writing readable event-driven code is challenging due to callback complexity (commonly known as "callback hell"). Fortunately, Python introduced the ``async/await`` syntax to help developers write understandable code with high performance. The following figure illustrates how ``async/await`` enables handling socket connections with the simplicity of threads but the efficiency of event loops. .. image:: https://raw.githubusercontent.com/crazyguitar/pysheeet/master/docs/_static/appendix/event-loop-vs-thread.png Introduction ------------ Handling I/O operations such as network connections is among the most expensive tasks in any program. Consider a simple TCP blocking echo server (shown below). If a client connects without sending any data, it blocks all other connections. Even when clients send data promptly, the server cannot handle concurrent requests because it wastes significant time waiting for I/O responses from hardware like network interfaces. Thus, socket programming with concurrency becomes essential for managing high request volumes. .. code-block:: python import socket s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(("127.0.0.1", 5566)) s.listen(10) while True: conn, addr = s.accept() msg = conn.recv(1024) conn.send(msg) One solution to prevent blocking is dispatching tasks to separate threads. The following example demonstrates handling connections simultaneously using threads. However, creating numerous threads consumes computing resources without proportional throughput gains. Worse, applications may waste time waiting for locks when processing tasks in critical sections. While threads solve blocking issues, factors like CPU utilization and memory overhead remain critical for solving the C10k problem. Without creating unlimited threads, the **event loop** provides an alternative solution for managing connections efficiently. .. code-block:: python import threading import socket s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(("127.0.0.1", 5566)) s.listen(10240) def handler(conn): while True: msg = conn.recv(65535) conn.send(msg) while True: conn, addr = s.accept() t = threading.Thread(target=handler, args=(conn,)) t.start() A simple event-driven socket server comprises three main components: an **I/O multiplexing module** (e.g., `select`_), a **scheduler** (the loop), and **callback functions** (event handlers). The following server uses Python's high-level I/O multiplexing module, `selectors`_, within a loop to check whether I/O operations are ready. When data becomes available for reading or writing, the loop retrieves I/O events and executes the appropriate callback functions—``accept``, ``read``, or ``write``—to complete tasks. .. code-block:: python import socket from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE from functools import partial s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(("127.0.0.1", 5566)) s.listen(10240) s.setblocking(False) sel = DefaultSelector() def accept(s, mask): conn, addr = s.accept() conn.setblocking(False) sel.register(conn, EVENT_READ, read) def read(conn, mask): msg = conn.recv(65535) if not msg: sel.unregister(conn) return conn.close() sel.modify(conn, EVENT_WRITE, partial(write, msg=msg)) def write(conn, mask, msg=None): if msg: conn.send(msg) sel.modify(conn, EVENT_READ, read) sel.register(s, EVENT_READ, accept) while True: events = sel.select() for e, m in events: cb = e.data cb(e.fileobj, m) Although managing connections via threads may be inefficient, event-loop-based programs are harder to read and maintain. To enhance code readability, many programming languages—including Python—introduce abstract concepts such as **coroutines**, **futures**, and **async/await** to handle I/O multiplexing elegantly. The following sections explore these concepts and the problems they solve. Callback Functions ------------------ Callback functions control data flow at runtime when events occur. However, preserving state across callbacks is challenging. For example, implementing a handshake protocol over TCP requires storing previous state somewhere accessible to subsequent callbacks. .. code-block:: python import socket from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE from functools import partial s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(("127.0.0.1", 5566)) s.listen(10240) s.setblocking(False) sel = DefaultSelector() is_hello = {} def accept(s, mask): conn, addr = s.accept() conn.setblocking(False) is_hello[conn] = False sel.register(conn, EVENT_READ, read) def read(conn, mask): msg = conn.recv(65535) if not msg: sel.unregister(conn) return conn.close() # Check whether handshake is successful if is_hello[conn]: sel.modify(conn, EVENT_WRITE, partial(write, msg=msg)) return # Perform handshake if msg.decode("utf-8").strip() != "hello": sel.unregister(conn) return conn.close() is_hello[conn] = True def write(conn, mask, msg=None): if msg: conn.send(msg) sel.modify(conn, EVENT_READ, read) sel.register(s, EVENT_READ, accept) while True: events = sel.select() for e, m in events: cb = e.data cb(e.fileobj, m) Although the ``is_hello`` dictionary stores state to track handshake status, the code becomes difficult to understand. The underlying logic is actually simple—equivalent to this blocking version: .. code-block:: python def accept(s): conn, addr = s.accept() success = handshake(conn) if not success: conn.close() def handshake(conn): data = conn.recv(65535) if not data: return False if data.decode('utf-8').strip() != "hello": return False conn.send(b"hello") return True To achieve similar structure in non-blocking code, a function (or task) must snapshot its current state—including arguments, local variables, and execution position—when waiting for I/O operations. The scheduler must then be able to **re-enter** the function and execute remaining code after I/O completes. Unlike languages like C++, Python achieves this naturally because **generators** preserve all state and can be re-entered by calling ``next()``. By utilizing generators, handling I/O operations in a non-blocking manner with readable, linear code—called *inline callbacks*—becomes possible within an event loop. Event Loop ---------- An event loop is a user-space scheduler that manages tasks within a program instead of relying on operating system thread scheduling. The following snippet demonstrates a simple event loop handling socket connections asynchronously. The implementation appends tasks to a FIFO job queue and registers with a *selector* when I/O operations are not ready. A *generator* preserves task state, allowing execution to resume without callback functions when I/O results become available. Understanding how this event loop works reveals that a Python generator is indeed a form of **coroutine**. .. code-block:: python # loop.py from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE class Loop: def __init__(self): self.sel = DefaultSelector() self.queue = [] def create_task(self, task): self.queue.append(task) def polling(self): for e, m in self.sel.select(0): self.queue.append((e.data, None)) self.sel.unregister(e.fileobj) def is_registered(self, fileobj): try: self.sel.get_key(fileobj) except KeyError: return False return True def register(self, t, data): if not data: return False event_type, fileobj = data if event_type in (EVENT_READ, EVENT_WRITE): if self.is_registered(fileobj): self.sel.modify(fileobj, event_type, t) else: self.sel.register(fileobj, event_type, t) return True return False def accept(self, s): while True: try: conn, addr = s.accept() except BlockingIOError: yield (EVENT_READ, s) else: break return conn, addr def recv(self, conn, size): while True: try: msg = conn.recv(size) except BlockingIOError: yield (EVENT_READ, conn) else: break return msg def send(self, conn, msg): while True: try: size = conn.send(msg) except BlockingIOError: yield (EVENT_WRITE, conn) else: break return size def once(self): self.polling() unfinished = [] for t, data in self.queue: try: data = t.send(data) except StopIteration: continue if self.register(t, data): unfinished.append((t, None)) self.queue = unfinished def run(self): while self.queue or self.sel.get_map(): self.once() By assigning jobs to an event loop, the programming pattern resembles using threads but with a user-level scheduler. `PEP 380`_ introduced generator delegation via ``yield from``, allowing a generator to wait for other generators to complete. The following snippet is far more intuitive and readable than callback-based I/O handling: .. code-block:: python # server.py # $ python3 server.py & # $ nc localhost 5566 import socket from loop import Loop s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(("127.0.0.1", 5566)) s.listen(10240) s.setblocking(False) loop = Loop() def handler(conn): while True: msg = yield from loop.recv(conn, 1024) if not msg: conn.close() break yield from loop.send(conn, msg) def main(): while True: conn, addr = yield from loop.accept(s) conn.setblocking(False) loop.create_task((handler(conn), None)) loop.create_task((main(), None)) loop.run() Using an event loop with ``yield from`` manages connections without blocking the main thread—this was how ``asyncio`` worked before Python 3.5. However, ``yield from`` is ambiguous: why does adding ``@asyncio.coroutine`` transform a generator into a coroutine? Instead of overloading generator syntax for asynchronous operations, `PEP 492`_ proposed that coroutines should become a **standalone concept** in Python. This led to the introduction of ``async/await`` syntax, dramatically improving readability for asynchronous programming. What is a Coroutine? -------------------- Python documentation defines coroutines as "a generalized form of subroutines." This definition, while technically accurate, can be confusing. Based on our discussion, an event loop schedules generators to perform specific tasks—similar to how an OS dispatches jobs to threads. In this context, generators serve as "routine workers." A **coroutine** is simply a task scheduled by an event loop within a program, rather than by the operating system. The following snippet illustrates what ``@coroutine`` does. This decorator transforms a function into a generator function and wraps it with ``types.coroutine`` for backward compatibility: .. code-block:: python import asyncio import inspect import types from functools import wraps from asyncio.futures import Future def coroutine(func): """Simple prototype of coroutine decorator""" if inspect.isgeneratorfunction(func): return types.coroutine(func) @wraps(func) def coro(*a, **k): res = func(*a, **k) if isinstance(res, Future) or inspect.isgenerator(res): res = yield from res return res return types.coroutine(coro) @coroutine def foo(): yield from asyncio.sleep(1) print("Hello Foo") loop = asyncio.get_event_loop() loop.run_until_complete(loop.create_task(foo())) loop.close() With Python 3.5+, the ``async def`` syntax creates native coroutines directly, and ``await`` replaces ``yield from`` for suspending execution. This makes the intent explicit: ``async def`` declares a coroutine, and ``await`` marks suspension points where the event loop can switch to other tasks. Conclusion ---------- Asynchronous programming via event loops has become more straightforward and readable thanks to modern syntax and library support. Most programming languages, including Python, implement libraries that manage task scheduling through integration with new syntaxes. While ``async/await`` may seem enigmatic initially, it provides a way for programmers to develop logical, linear code structure—similar to using threads—while gaining the performance benefits of event-driven I/O. Without callback functions passing state between handlers, programmers no longer need to worry about preserving local variables and arguments across asynchronous boundaries. This allows developers to focus on application logic rather than spending time troubleshooting concurrency issues. The evolution from callbacks to generators to ``async/await`` represents a significant advancement in making concurrent programming accessible and maintainable. References ---------- 1. `asyncio — Asynchronous I/O`_ 2. `PEP 342 - Coroutines via Enhanced Generators`_ 3. `PEP 380 - Syntax for Delegating to a Subgenerator`_ 4. `PEP 492 - Coroutines with async and await syntax`_ .. _C10k problem: https://en.wikipedia.org/wiki/C10k_problem .. _Valgrind: https://valgrind.org/ .. _select: https://docs.python.org/3/library/select.html .. _selectors: https://docs.python.org/3/library/selectors.html .. _asyncio — Asynchronous I/O: https://docs.python.org/3/library/asyncio.html .. _PEP 492: https://www.python.org/dev/peps/pep-0492/ .. _PEP 380: https://www.python.org/dev/peps/pep-0380/ .. _PEP 342 - Coroutines via Enhanced Generators: https://www.python.org/dev/peps/pep-0342/ .. _PEP 492 - Coroutines with async and await syntax: https://www.python.org/dev/peps/pep-0492/ .. _PEP 380 - Syntax for Delegating to a Subgenerator: https://www.python.org/dev/peps/pep-0380/ ================================================ FILE: docs/notes/asyncio/python-asyncio-server.rst ================================================ .. meta:: :description lang=en: Python asyncio networking - TCP/UDP servers, HTTP clients, SSL/TLS, protocols :keywords: Python, Python3, Asyncio, TCP Server, UDP Server, HTTP Client, SSL TLS, Network Programming =================== Asyncio Networking =================== :Source: `src/basic/asyncio_.py `_ .. contents:: Table of Contents :backlinks: none Introduction ------------ Asyncio excels at network programming because network I/O is inherently asynchronous - you send a request and wait for a response. Instead of blocking a thread while waiting, asyncio allows other tasks to run. This section covers building TCP/UDP servers and clients, HTTP requests, SSL/TLS encryption, and the Transport/Protocol API for low-level control. TCP Echo Server with Streams ---------------------------- The streams API (``asyncio.start_server``, ``open_connection``) provides a high-level interface for TCP networking. It handles buffering, encoding, and connection management automatically, making it the recommended approach for most applications. .. code-block:: python import asyncio async def handle_client(reader, writer): addr = writer.get_extra_info('peername') print(f"Connected: {addr}") while True: data = await reader.read(1024) if not data: break message = data.decode() print(f"Received: {message!r} from {addr}") writer.write(data) await writer.drain() print(f"Disconnected: {addr}") writer.close() await writer.wait_closed() async def main(): server = await asyncio.start_server( handle_client, 'localhost', 8888 ) addr = server.sockets[0].getsockname() print(f"Serving on {addr}") async with server: await server.serve_forever() asyncio.run(main()) TCP Client with Streams ----------------------- The client side uses ``asyncio.open_connection()`` to establish a connection. The returned reader and writer objects provide async methods for sending and receiving data. .. code-block:: python import asyncio async def tcp_client(message): reader, writer = await asyncio.open_connection( 'localhost', 8888 ) print(f"Sending: {message!r}") writer.write(message.encode()) await writer.drain() data = await reader.read(1024) print(f"Received: {data.decode()!r}") writer.close() await writer.wait_closed() asyncio.run(tcp_client("Hello, Server!")) Low-Level TCP with Sockets -------------------------- For more control, you can use raw sockets with the event loop's socket methods. This approach is useful when you need fine-grained control over socket options or when integrating with existing socket-based code. .. code-block:: python import asyncio import socket async def handle_client(loop, conn): while True: data = await loop.sock_recv(conn, 1024) if not data: break await loop.sock_sendall(conn, data) conn.close() async def server(): loop = asyncio.get_event_loop() sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setblocking(False) sock.bind(('localhost', 8888)) sock.listen(100) print("Server listening on localhost:8888") while True: conn, addr = await loop.sock_accept(sock) print(f"Connected: {addr}") asyncio.create_task(handle_client(loop, conn)) asyncio.run(server()) UDP Echo Server --------------- UDP is connectionless, so the API is different from TCP. Use ``create_datagram_endpoint()`` with a protocol class to handle UDP packets. Each packet is independent and may arrive out of order or not at all. .. code-block:: python import asyncio class EchoUDPProtocol(asyncio.DatagramProtocol): def connection_made(self, transport): self.transport = transport def datagram_received(self, data, addr): message = data.decode() print(f"Received {message!r} from {addr}") self.transport.sendto(data, addr) async def main(): loop = asyncio.get_event_loop() transport, protocol = await loop.create_datagram_endpoint( EchoUDPProtocol, local_addr=('localhost', 9999) ) print("UDP server listening on localhost:9999") try: await asyncio.sleep(3600) # Run for 1 hour finally: transport.close() asyncio.run(main()) HTTP Client with SSL -------------------- Making HTTPS requests requires SSL context configuration. This example shows how to fetch web pages using low-level streams with proper SSL verification. .. code-block:: python import asyncio import ssl async def fetch_https(host, path="/"): # Create SSL context with certificate verification ctx = ssl.create_default_context() reader, writer = await asyncio.open_connection( host, 443, ssl=ctx ) # Send HTTP request request = f"GET {path} HTTP/1.1\r\nHost: {host}\r\nConnection: close\r\n\r\n" writer.write(request.encode()) await writer.drain() # Read response response = await reader.read() writer.close() await writer.wait_closed() return response.decode() async def main(): urls = [ ("www.python.org", "/"), ("github.com", "/"), ] tasks = [fetch_https(host, path) for host, path in urls] responses = await asyncio.gather(*tasks) for (host, _), resp in zip(urls, responses): status = resp.split('\r\n')[0] print(f"{host}: {status}") asyncio.run(main()) HTTPS Server with SSL --------------------- Creating an HTTPS server requires SSL certificates. This example shows a simple HTTPS server that serves static content with TLS encryption. .. code-block:: python import asyncio import ssl async def handle_request(reader, writer): request = await reader.read(1024) response = b"HTTP/1.1 200 OK\r\n" response += b"Content-Type: text/html\r\n\r\n" response += b"

Hello HTTPS!

" writer.write(response) await writer.drain() writer.close() await writer.wait_closed() async def main(): # Create SSL context ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ctx.load_cert_chain('cert.pem', 'key.pem') server = await asyncio.start_server( handle_request, 'localhost', 8443, ssl=ctx ) print("HTTPS server on https://localhost:8443") async with server: await server.serve_forever() # Generate self-signed cert: # openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes asyncio.run(main()) Transport and Protocol API -------------------------- The Transport/Protocol API provides low-level control over network connections. Transports handle the actual I/O while Protocols handle the data processing. This separation allows for flexible and reusable network code. .. code-block:: python import asyncio class EchoProtocol(asyncio.Protocol): def connection_made(self, transport): self.transport = transport peername = transport.get_extra_info('peername') print(f"Connection from {peername}") def data_received(self, data): print(f"Received: {data.decode()!r}") self.transport.write(data) def connection_lost(self, exc): print("Connection closed") async def main(): loop = asyncio.get_event_loop() server = await loop.create_server( EchoProtocol, 'localhost', 8888 ) async with server: await server.serve_forever() asyncio.run(main()) DNS Resolution -------------- Asyncio provides async DNS resolution through ``getaddrinfo()``. This is useful when you need to resolve hostnames without blocking the event loop. .. code-block:: python import asyncio import socket async def resolve_host(host, port=80): loop = asyncio.get_event_loop() infos = await loop.getaddrinfo( host, port, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM ) for family, type_, proto, canonname, sockaddr in infos: ip, port = sockaddr[:2] family_name = "IPv4" if family == socket.AF_INET else "IPv6" print(f"{host} -> {ip} ({family_name})") async def main(): hosts = ["python.org", "github.com", "google.com"] await asyncio.gather(*[resolve_host(h) for h in hosts]) asyncio.run(main()) Simple HTTP Server ------------------ A minimal HTTP server implementation showing how to parse requests and send responses. For production use, consider frameworks like aiohttp or FastAPI. .. code-block:: python import asyncio async def handle_http(reader, writer): request = await reader.read(1024) request_line = request.decode().split('\r\n')[0] method, path, _ = request_line.split(' ') print(f"{method} {path}") # Simple routing if path == '/': body = b"

Home

" status = "200 OK" elif path == '/about': body = b"

About

" status = "200 OK" else: body = b"

404 Not Found

" status = "404 Not Found" response = f"HTTP/1.1 {status}\r\n" response += f"Content-Length: {len(body)}\r\n" response += "Content-Type: text/html\r\n\r\n" writer.write(response.encode() + body) await writer.drain() writer.close() await writer.wait_closed() async def main(): server = await asyncio.start_server( handle_http, 'localhost', 8080 ) print("HTTP server on http://localhost:8080") async with server: await server.serve_forever() asyncio.run(main()) Using sendfile for Efficient File Transfer ------------------------------------------ The ``sendfile()`` method (Python 3.7+) efficiently transfers file contents to a transport using the OS's sendfile syscall, avoiding copying data through Python. .. code-block:: python import asyncio async def handle_request(reader, writer): await reader.read(1024) # Read request with open('index.html', 'rb') as f: # Get file size f.seek(0, 2) size = f.tell() f.seek(0) # Send headers headers = f"HTTP/1.1 200 OK\r\n" headers += f"Content-Length: {size}\r\n" headers += "Content-Type: text/html\r\n\r\n" writer.write(headers.encode()) # Send file efficiently loop = asyncio.get_event_loop() await loop.sendfile(writer.transport, f) writer.close() await writer.wait_closed() async def main(): server = await asyncio.start_server( handle_request, 'localhost', 8080 ) async with server: await server.serve_forever() asyncio.run(main()) Connection Pool --------------- Connection pools reuse connections to avoid the overhead of establishing new connections for each request. This is essential for high-performance clients that make many requests to the same server. .. code-block:: python import asyncio from collections import deque class ConnectionPool: def __init__(self, host, port, size=5): self.host = host self.port = port self.size = size self._pool = deque() self._lock = asyncio.Lock() async def get(self): async with self._lock: if self._pool: return self._pool.popleft() # Create new connection reader, writer = await asyncio.open_connection( self.host, self.port ) return reader, writer async def put(self, reader, writer): async with self._lock: if len(self._pool) < self.size: self._pool.append((reader, writer)) else: writer.close() await writer.wait_closed() async def close(self): async with self._lock: while self._pool: reader, writer = self._pool.popleft() writer.close() await writer.wait_closed() async def fetch(pool, message): reader, writer = await pool.get() try: writer.write(message.encode()) await writer.drain() data = await reader.read(1024) return data.decode() finally: await pool.put(reader, writer) async def main(): pool = ConnectionPool('localhost', 8888, size=3) try: tasks = [fetch(pool, f"msg{i}") for i in range(10)] results = await asyncio.gather(*tasks) for r in results: print(r) finally: await pool.close() asyncio.run(main()) ================================================ FILE: docs/notes/basic/index.rst ================================================ .. meta:: :description lang=en: Python basics cheat sheet covering syntax, data types, functions, classes, generators, typing, and essential Python programming concepts :keywords: Python, Python3, basics, syntax, data types, functions, classes, generators, typing, list, dict, set, comprehension Quick Start =========== This cheat sheet is designed to help developers learn Python syntax from the ground up. It covers the fundamentals while also introducing common patterns and idioms that experienced Python developers use, which may feel unfamiliar to beginners. For instance, constructs like ``for ... else ...`` are rarely seen in other programming languages. Additionally, we’ll explore interesting topics such as ``__future__``, typing, and Unicode—concepts you may have heard of but never fully understood. By working through this cheat sheet, you’ll gain a solid foundation in Python and learn to write code that feels truly Pythonic. .. toctree:: :maxdepth: 1 python-basic python-future python-func python-object python-typing python-list python-set python-dict python-heap python-generator python-unicode python-rexp ================================================ FILE: docs/notes/basic/python-basic.rst ================================================ .. meta:: :description lang=en: Python basics tutorial covering fundamental syntax, data types, control flow, string formatting, and essential Python programming concepts :keywords: Python, Python3, basics, syntax, data types, control flow, string formatting, variables, operators, conditionals ============ From Scratch ============ .. contents:: Table of Contents :backlinks: none The main goal of this cheat sheet is to collect some common and basic semantics or snippets. The cheat sheet includes some syntax, which we have already known but still ambiguous in our mind, or some snippets, which we google them again and again. In addition, because **the end Of life date for Python 2** is coming. Most of the snippets are mainly based on **Python 3**'s syntax. Hello world! ------------ When we start to learn a new language, we usually learn from printing **Hello world!**. In Python, we can use another way to print the message by importing ``__hello__`` module. The source code can be found on `frozen.c `_. .. code-block:: python >>> print("Hello world!") Hello world! >>> import __hello__ Hello world! >>> import __phello__ Hello world! >>> import __phello__.spam Hello world! Python Version -------------- It is important for a programmer to know current Python version because not every syntax will work in the current version. In this case, we can get the Python version by ``python -V`` or using the module, ``sys``. .. code-block:: python >>> import sys >>> print(sys.version) 3.7.1 (default, Nov 6 2018, 18:46:03) [Clang 10.0.0 (clang-1000.11.45.5)] We can also use ``platform.python_version`` to get Python version. .. code-block:: python >>> import platform >>> platform.python_version() '3.7.1' Sometimes, checking the current Python version is important because we may want to enable some features in some specific versions. ``sys.version_info`` provides more detail information about the interpreter. We can use it to compare with the version we want. .. code-block:: python >>> import sys >>> sys.version_info >= (3, 6) True >>> sys.version_info >= (3, 7) False Ellipsis -------- `Ellipsis `_ is a built-in constant. After Python 3.0, we case use ``...`` as ``Ellipsis``. It may be the most enigmatic constant in Python. Based on the official document, we can use it to extend slicing syntax. Nevertheless, there are some other conventions in type hinting, stub files, or function expressions. .. code-block:: python >>> ... Ellipsis >>> ... == Ellipsis True >>> type(...) The following snippet shows that we can use the ellipsis to represent a function or a class which has not implemented yet. .. code-block:: python >>> class Foo: ... ... >>> def foo(): ... ... if ... elif ... else -------------------- The **if statements** are used to control the code flow. Instead of using ``switch`` or ``case`` statements control the logic of the code, Python uses ``if ... elif ... else`` sequence. Although someone proposes we can use ``dict`` to achieve ``switch`` statements, this solution may introduce unnecessary overhead such as creating disposable dictionaries and undermine a readable code. Thus, the solution is not recommended. .. code-block:: python >>> import random >>> num = random.randint(0, 10) >>> if num < 3: ... print("less than 3") ... elif num < 5: ... print("less than 5") ... else: ... print(num) ... less than 3 for Loop -------- In Python, we can access iterable object's items directly through the **for statement**. If we need to get indexes and items of an iterable object such as list or tuple at the same time, using ``enumerate`` is better than ``range(len(iterable))``. Further information can be found on `Looping Techniques `_. .. code-block:: python >>> for val in ["foo", "bar"]: ... print(val) ... foo bar >>> for idx, val in enumerate(["foo", "bar", "baz"]): ... print(idx, val) ... (0, 'foo') (1, 'bar') (2, 'baz') for ... else ... ---------------- It may be a little weird when we see the ``else`` belongs to a ``for`` loop at the first time. The ``else`` clause can assist us to avoid using flag variables in loops. A loop’s ``else`` clause runs when no break occurs. .. code-block:: python >>> for _ in range(5): ... pass ... else: ... print("no break") ... no break The following snippet shows the difference between using a flag variable and the ``else`` clause to control the loop. We can see that the ``else`` does not run when the ``break`` occurs in the loop. .. code-block:: python >>> is_break = False >>> for x in range(5): ... if x % 2 == 0: ... is_break = True ... break ... >>> if is_break: ... print("break") ... break >>> for x in range(5): ... if x % 2 == 0: ... print("break") ... break ... else: ... print("no break") ... break Using ``range`` --------------- The problem of ``range`` in Python 2 is that ``range`` may take up a lot of memory if we need to iterate a loop many times. Consequently, using ``xrange`` is recommended in Python 2. .. code-block:: python >>> import platform >>> import sys >>> platform.python_version() '2.7.15' >>> sys.getsizeof(range(100000000)) 800000072 >>> sys.getsizeof(xrange(100000000)) 40 In Python 3, the built-in function ``range`` returns an iterable **range object** instead of a list. The behavior of ``range`` is the same as the ``xrange`` in Python 2. Therefore, using ``range`` do not take up huge memory anymore if we want to run a code block many times within a loop. Further information can be found on PEP `3100 `_. .. code-block:: python >>> import platform >>> import sys >>> platform.python_version() '3.7.1' >>> sys.getsizeof(range(100000000)) 48 while ... else ... ------------------ The ``else`` clause belongs to a while loop serves the same purpose as the ``else`` clause in a for loop. We can observe that the ``else`` does not run when the ``break`` occurs in the while loop. .. code-block:: python >>> n = 0 >>> while n < 5: ... if n == 3: ... break ... n += 1 ... else: ... print("no break") ... The ``do while`` Statement -------------------------- There are many programming languages such as C/C++, Ruby, or Javascript, provide the ``do while`` statement. In Python, there is no ``do while`` statement. However, we can place the condition and the ``break`` at the end of a ``while`` loop to achieve the same thing. .. code-block:: python >>> n = 0 >>> while True: ... n += 1 ... if n == 5: ... break ... >>> n 5 try ... except ... else ... --------------------------- Most of the time, we handle errors in ``except`` clause and clean up resources in ``finally`` clause. Interestingly, the ``try`` statement also provides an ``else`` clause for us to avoid catching an exception which was raised by the code that should not be protected by ``try ... except``. The ``else`` clause runs when no exception occurs between ``try`` and ``except``. .. code-block:: python >>> try: ... print("No exception") ... except: ... pass ... else: ... print("Success") ... No exception Success String ------ Unlike other programming languages, Python does not support string’s item assignment directly. Therefore, if it is necessary to manipulate string’s items, e.g., swap items, we have to convert a string to a list and do a join operation after a series item assignments finish. .. code-block:: python >>> a = "Hello Python" >>> l = list(a) >>> l[0], l[6] = 'h', 'p' >>> ''.join(l) 'hello python' List ---- Lists are versatile containers. Python provides a lot of ways such as **negative index**, **slicing statement**, or **list comprehension** to manipulate lists. The following snippet shows some common operations of lists. .. code-block:: python >>> a = [1, 2, 3, 4, 5] >>> a[-1] # negative index 5 >>> a[1:] # slicing [2, 3, 4, 5] >>> a[1:-1] [2, 3, 4] >>> a[1:-1:2] [2, 4] >>> a[::-1] # reverse [5, 4, 3, 2, 1] >>> a[0] = 0 # set an item >>> a [0, 2, 3, 4, 5] >>> a.append(6) # append an item >>> a [0, 2, 3, 4, 5, 6] >>> del a[-1] # del an item >>> a [0, 2, 3, 4, 5] >>> b = [x for x in range(3)] # list comprehension >>> b [0, 1, 2] >>> a + b # add two lists [0, 2, 3, 4, 5, 0, 1, 2] Dict ---- Dictionaries are key-value pairs containers. Like lists, Python supports many ways such as **dict comprehensions** to manipulate dictionaries. After Python 3.6, dictionaries preserve the insertion order of keys. The Following snippet shows some common operations of dictionaries. .. code-block:: python >>> d = {'timmy': 'red', 'barry': 'green', 'guido': 'blue'} >>> d {'timmy': 'red', 'barry': 'green', 'guido': 'blue'} >>> d['timmy'] = "yellow" # set data >>> d {'timmy': 'yellow', 'barry': 'green', 'guido': 'blue'} >>> del d['guido'] # del data >>> d >>> 'guido' in d # contain data False {'timmy': 'yellow', 'barry': 'green'} >>> {k: v for k ,v in d.items()} # dict comprehension {'timmy': 'yellow', 'barry': 'green'} >>> d.keys() # list all keys dict_keys(['timmy', 'barry']) >>> d.values() # list all values dict_values(['yellow', 'green']) Function -------- Defining a function in Python is flexible. We can define a function with **function documents**, **default values**, **arbitrary arguments**, **keyword arguments**, **keyword-only arguments**, and so on. The Following snippet shows some common expressions to define functions. .. code-block:: python def foo_with_doc(): """Documentation String.""" def foo_with_arg(arg): ... def foo_with_args(*arg): ... def foo_with_kwarg(a, b="foo"): ... def foo_with_args_kwargs(*args, **kwargs): ... def foo_with_kwonly(a, b, *, k): ... # python3 def foo_with_annotations(a: int) -> int: ... # python3 Function Annotations -------------------- Instead of writing string documents in functions to hint the type of parameters and return values, we can denote types by **function annotations**. Function annotations which the details can be found on PEP `3017 `_ and PEP `484 `_ were introduced in Python 3.0. They are an **optional** feature in **Python 3**. Using function annotations will lose compatibility in **Python 2**. We can solve this issue by stub files. In addition, we can do static type checking through `mypy `_. .. code-block:: python >>> def fib(n: int) -> int: ... a, b = 0, 1 ... for _ in range(n): ... b, a = a + b, b ... return a ... >>> fib(10) 55 Generators ---------- Python uses the ``yield`` statement to define a **generator function**. In other words, when we call a generator function, the generator function will return a **generator** instead of return values for creating an **iterator**. .. code-block:: python >>> def fib(n): ... a, b = 0, 1 ... for _ in range(n): ... yield a ... b, a = a + b, b ... >>> g = fib(10) >>> g >>> for f in fib(5): ... print(f) ... 0 1 1 2 3 Generator Delegation -------------------- Python 3.3 introduced ``yield from`` expression. It allows a generator to delegate parts of operations to another generator. In other words, we can **yield** a sequence **from** other **generators** in the current **generator function**. Further information can be found on PEP `380 `_. .. code-block:: python >>> def fib(n): ... a, b = 0, 1 ... for _ in range(n): ... yield a ... b, a = a + b, b ... >>> def fibonacci(n): ... yield from fib(n) ... >>> [f for f in fibonacci(5)] [0, 1, 1, 2, 3] Class ----- Python supports many common features such as **class documents**, **multiple inheritance**, **class variables**, **instance variables**, **static method**, **class method**, and so on. Furthermore, Python provides some special methods for programmers to implement **iterators**, **context manager**, etc. The following snippet displays common definition of a class. .. code-block:: python class A: ... class B: ... class Foo(A, B): """A class document.""" foo = "class variable" def __init__(self, v): self.attr = v self.__private = "private var" @staticmethod def bar_static_method(): ... @classmethod def bar_class_method(cls): ... def bar(self): """A method document.""" def bar_with_arg(self, arg): ... def bar_with_args(self, *args): ... def bar_with_kwarg(self, kwarg="bar"): ... def bar_with_args_kwargs(self, *args, **kwargs): ... def bar_with_kwonly(self, *, k): ... def bar_with_annotations(self, a: int): ... ``async`` / ``await`` --------------------- ``async`` and ``await`` syntax was introduced from Python 3.5. They were designed to be used with an event loop. Some other features such as the **asynchronous generator** were implemented in later versions. A **coroutine function** (``async def``) are used to create a **coroutine** for an event loop. Python provides a built-in module, **asyncio**, to write a concurrent code through ``async``/``await`` syntax. The following snippet shows a simple example of using **asyncio**. The code must be run on Python 3.7 or above. .. code-block:: python import asyncio async def http_ok(r, w): head = b"HTTP/1.1 200 OK\r\n" head += b"Content-Type: text/html\r\n" head += b"\r\n" body = b"" body += b"

Hello world!

" body += b"" _ = await r.read(1024) w.write(head + body) await w.drain() w.close() async def main(): server = await asyncio.start_server( http_ok, "127.0.0.1", 8888 ) async with server: await server.serve_forever() asyncio.run(main()) Avoid ``exec`` and ``eval`` --------------------------- The following snippet shows how to use the built-in function ``exec``. Yet, using ``exec`` and ``eval`` are not recommended because of some security issues and unreadable code for a human. Further reading can be found on `Be careful with exec and eval in Python `_ and `Eval really is dangerous `_ .. code-block:: python >>> py = ''' ... def fib(n): ... a, b = 0, 1 ... for _ in range(n): ... b, a = b + a, b ... return a ... print(fib(10)) ... ''' >>> exec(py, globals(), locals()) 55 ================================================ FILE: docs/notes/basic/python-dict.rst ================================================ .. meta:: :description lang=en: Python dictionary cheat sheet covering creation, manipulation, merging, comprehensions, defaultdict, OrderedDict, and LRU cache with code examples :keywords: Python, Python3, Python dictionary, Python dict cheat sheet, dict, hashmap, key-value pairs, defaultdict, OrderedDict, dictionary comprehension, LRU cache, dict methods ========== Dictionary ========== .. contents:: Table of Contents :backlinks: none Dictionaries are one of Python's most powerful and frequently used data structures. They store key-value pairs and provide O(1) average time complexity for lookups, insertions, and deletions. Since Python 3.7, dictionaries maintain insertion order as a language feature. This cheat sheet covers essential dictionary operations, from basic manipulation to advanced patterns like emulating dictionary behavior with special methods and implementing an LRU (Least Recently Used) cache. The source code is available on `GitHub `_. References ---------- - `Mapping Types — dict `_ - `collections — Container datatypes `_ - `PEP 584 -- Add Union Operators To dict `_ Get All Keys with ``dict.keys()`` --------------------------------- The ``keys()`` method returns a view object containing all dictionary keys. In Python 3, this is a dynamic view that reflects changes to the dictionary. .. code-block:: python >>> a = {"1":1, "2":2, "3":3} >>> b = {"2":2, "3":3, "4":4} >>> a.keys() ['1', '3', '2'] Get Key-Value Pairs with ``dict.items()`` ----------------------------------------- The ``items()`` method returns key-value pairs as tuples, which is useful for iterating over both keys and values simultaneously. .. code-block:: python >>> a = {"1":1, "2":2, "3":3} >>> a.items() Find Common Keys Between Dictionaries ------------------------------------- Finding keys that exist in multiple dictionaries is a common operation. Using set intersection is the most efficient approach. .. code-block:: python >>> a = {"1":1, "2":2, "3":3} >>> b = {"2":2, "3":3, "4":4} >>> [_ for _ in a.keys() if _ in b.keys()] ['3', '2'] >>> # better way >>> c = set(a).intersection(set(b)) >>> list(c) ['3', '2'] >>> # or >>> [_ for _ in a if _ in b] ['3', '2'] [('1', 1), ('3', 3), ('2', 2)] Set Default Values with ``setdefault()`` and ``defaultdict`` ------------------------------------------------------------ When working with dictionaries, you often need to set default values for missing keys. Python provides ``setdefault()`` and ``collections.defaultdict`` for this. .. code-block:: python >>> # intuitive but not recommend >>> d = {} >>> key = "foo" >>> if key not in d: ... d[key] = [] ... # using d.setdefault(key[, default]) >>> d = {} >>> key = "foo" >>> d.setdefault(key, []) [] >>> d[key] = 'bar' >>> d {'foo': 'bar'} # using collections.defaultdict >>> from collections import defaultdict >>> d = defaultdict(list) >>> d["key"] [] >>> d["foo"] [] >>> d["foo"].append("bar") >>> d defaultdict(, {'key': [], 'foo': ['bar']}) ``dict.setdefault(key[, default])`` returns its default value if *key* is not in the dictionary. However, if the key exists in the dictionary, the function will return its value. .. code-block:: python >>> d = {} >>> d.setdefault("key", []) [] >>> d["key"] = "bar" >>> d.setdefault("key", []) 'bar' Update Dictionary with ``dict.update()`` ---------------------------------------- The ``update()`` method merges another dictionary into the current one. Keys from the second dictionary overwrite existing keys in the first. .. code-block:: python >>> a = {"1":1, "2":2, "3":3} >>> b = {"2":2, "3":3, "4":4} >>> a.update(b) >>> a {'1': 1, '3': 3, '2': 2, '4': 4} Merge Two Dictionaries in Python -------------------------------- There are several ways to merge dictionaries depending on your Python version. Python 3.9+ also supports the ``|`` operator for dictionary merging. Python 3.4 or lower .. code-block:: python >>> a = {"x": 55, "y": 66} >>> b = {"a": "foo", "b": "bar"} >>> c = a.copy() >>> c.update(b) >>> c {'y': 66, 'x': 55, 'b': 'bar', 'a': 'foo'} Python 3.5 or above .. code-block:: python >>> a = {"x": 55, "y": 66} >>> b = {"a": "foo", "b": "bar"} >>> c = {**a, **b} >>> c {'x': 55, 'y': 66, 'a': 'foo', 'b': 'bar'} Emulate a Dictionary with Special Methods ----------------------------------------- You can create dictionary-like objects by implementing special methods: ``__getitem__``, ``__setitem__``, ``__delitem__``, ``__contains__``, and ``__iter__``. .. code-block:: python >>> class EmuDict(object): ... def __init__(self, dict_): ... self._dict = dict_ ... def __repr__(self): ... return "EmuDict: " + repr(self._dict) ... def __getitem__(self, key): ... return self._dict[key] ... def __setitem__(self, key, val): ... self._dict[key] = val ... def __delitem__(self, key): ... del self._dict[key] ... def __contains__(self, key): ... return key in self._dict ... def __iter__(self): ... return iter(self._dict.keys()) ... >>> _ = {"1":1, "2":2, "3":3} >>> emud = EmuDict(_) >>> emud # __repr__ EmuDict: {'1': 1, '2': 2, '3': 3} >>> emud['1'] # __getitem__ 1 >>> emud['5'] = 5 # __setitem__ >>> emud EmuDict: {'1': 1, '2': 2, '3': 3, '5': 5} >>> del emud['2'] # __delitem__ >>> emud EmuDict: {'1': 1, '3': 3, '5': 5} >>> for _ in emud: ... print(emud[_], end=' ') # __iter__ ... else: ... print() ... 1 3 5 >>> '1' in emud # __contains__ True Implement LRU Cache with OrderedDict ------------------------------------ An LRU (Least Recently Used) cache evicts the least recently accessed items when full. ``OrderedDict.move_to_end()`` makes implementation straightforward. .. code-block:: python from collections import OrderedDict class LRU(object): def __init__(self, maxsize=128): self._maxsize = maxsize self._cache = OrderedDict() def get(self, k): if k not in self._cache: return None self._cache.move_to_end(k) return self._cache[k] def put(self, k, v): if k in self._cache: self._cache.move_to_end(k) self._cache[k] = v if len(self._cache) > self._maxsize: self._cache.popitem(last=False) def __str__(self): return str(self._cache) def __repr__(self): return self.__str__() Note that dictionaries preserve insertion order from Python 3.7. Moreover, updating a key does not affect the order. Therefore, a dictionary can also simulate an LRU cache, which is similar to using an OrderedDict. .. code-block:: python class LRU(object): def __init__(self, maxsize=128): self._maxsize = maxsize self._cache = {} def get(self, k): if k not in self._cache: return None self.move_to_end(k) return self._cache[k] def put(self, k, v): if k in self._cache: self.move_to_end(k) self._cache[k] = v if len(self._cache) > self._maxsize: self.pop() def pop(self): it = iter(self._cache.keys()) del self._cache[next(it)] def move_to_end(self, k): if k not in self._cache: return v = self._cache[k] del self._cache[k] self._cache[k] = v def __str__(self): return str(self._cache) def __repr__(self): return self.__str__() ================================================ FILE: docs/notes/basic/python-func.rst ================================================ .. meta:: :description lang=en: Python function cheat sheet covering function definitions, arguments, decorators, lambda, closures, and functools with code examples :keywords: Python, Python3, Python function, Python function cheat sheet, decorator, lambda, closure, *args, **kwargs, functools, lru_cache, partial ======== Function ======== .. contents:: Table of Contents :backlinks: none A function can help programmers to wrap their logic into a task for avoiding duplicate code. In Python, the definition of a function is so versatile that we can use many features such as decorator, annotation, docstrings, default arguments and so on to define a function. In this cheat sheet, it collects many ways to define a function and demystifies some enigmatic syntax in functions. Document Functions ------------------ Documentation provides programmers hints about how a function is supposed to be used. A docstring gives an expedient way to write a readable document of functions. The docstring should be placed as the first statement in the function body, enclosed in triple quotes. It can be accessed via the ``__doc__`` attribute or the built-in ``help()`` function. PEP `257 `_ defines conventions for docstrings, and tools like ``pydocstyle`` can help enforce these conventions in your codebase. .. code-block:: python >>> def example(): ... """This is an example function.""" ... print("Example function") ... >>> example.__doc__ 'This is an example function.' >>> help(example) Default Arguments ----------------- Defining a function where the arguments are optional and have a default value is quite simple in Python. We can just assign values in the definition and make sure the default arguments appear in the end. When calling the function, you can omit arguments that have defaults, pass them positionally, or use keyword syntax to specify them explicitly. This flexibility makes functions more versatile and easier to use in different contexts. .. code-block:: python >>> def add(a, b=0): ... return a + b ... >>> add(1) 1 >>> add(1, 2) 3 >>> add(1, b=2) 3 .. warning:: Avoid using mutable objects (like lists or dictionaries) as default arguments. Default argument values are evaluated only once when the function is defined, not each time the function is called. This means mutable defaults are shared across all calls, which can lead to unexpected behavior where modifications persist between function calls. .. code-block:: python >>> def bad(items=[]): # DON'T do this ... items.append(1) ... return items ... >>> bad() [1] >>> bad() # unexpected! [1, 1] >>> def good(items=None): # DO this instead ... if items is None: ... items = [] ... items.append(1) ... return items Variable Arguments ``*args`` and ``**kwargs`` --------------------------------------------- Python provides a flexible way to handle functions that need to accept a variable number of arguments. Use ``*args`` to collect any number of positional arguments into a tuple, and ``**kwargs`` to collect any number of keyword arguments into a dictionary. These are commonly used when writing wrapper functions, decorators, or functions that need to pass arguments through to other functions. The names ``args`` and ``kwargs`` are conventions; you can use any valid identifier after the ``*`` or ``**``. .. code-block:: python >>> def example(a, b=None, *args, **kwargs): ... print(a, b) ... print(args) ... print(kwargs) ... >>> example(1, "var", 2, 3, word="hello") 1 var (2, 3) {'word': 'hello'} Unpack Arguments ---------------- When calling a function, you can use ``*`` to unpack a sequence (like a list or tuple) into separate positional arguments, and ``**`` to unpack a dictionary into keyword arguments. This is the inverse of ``*args`` and ``**kwargs`` in function definitions. Unpacking is particularly useful when you have data in a collection that you want to pass to a function that expects separate arguments. .. code-block:: python >>> def foo(a, b, c='BAZ'): ... print(a, b, c) ... >>> foo(*("FOO", "BAR"), **{"c": "baz"}) FOO BAR baz >>> args = [1, 2, 3] >>> print(*args) 1 2 3 Keyword-Only Arguments ---------------------- Arguments that appear after ``*`` or ``*args`` in a function definition are keyword-only, meaning they must be passed by name and cannot be passed positionally. This feature, introduced in Python 3.0, helps prevent errors when functions have many parameters, as it forces callers to be explicit about which argument they're providing. Keyword-only arguments can have default values, making them optional. **New in Python 3.0** .. code-block:: python >>> def f(a, b, *, kw): ... print(a, b, kw) ... >>> f(1, 2, kw=3) 1 2 3 >>> f(1, 2, 3) Traceback (most recent call last): TypeError: f() takes 2 positional arguments but 3 were given >>> # keyword-only with default >>> def g(a, *, kw=10): ... return a + kw ... >>> g(5) 15 Positional-Only Arguments ------------------------- Arguments that appear before ``/`` in a function definition are positional-only, meaning they cannot be passed by keyword name. This feature, introduced in Python 3.8, is useful when parameter names are not meaningful to callers or when you want to reserve the flexibility to change parameter names without breaking existing code. Many built-in functions like ``len()`` and ``pow()`` use positional-only parameters. You can combine positional-only (``/``) and keyword-only (``*``) in the same function. **New in Python 3.8** .. code-block:: python >>> def f(a, b, /, c): ... print(a, b, c) ... >>> f(1, 2, 3) 1 2 3 >>> f(1, 2, c=3) 1 2 3 >>> f(a=1, b=2, c=3) Traceback (most recent call last): TypeError: f() got some positional-only arguments passed as keyword arguments >>> # combining positional-only and keyword-only >>> def g(a, /, b, *, c): ... return a + b + c ... >>> g(1, 2, c=3) 6 Annotations ----------- Function annotations provide a way to attach metadata to function parameters and return values. While Python doesn't enforce these annotations at runtime, they serve as documentation and are used by static type checkers like ``mypy`` to catch type errors before code runs. Annotations are stored in the function's ``__annotations__`` attribute as a dictionary. The ``typing`` module (Python 3.5+) provides additional types like ``List``, ``Dict``, ``Optional``, and ``Union`` for more expressive type hints. **New in Python 3.0** .. code-block:: python >>> def fib(n: int) -> int: ... a, b = 0, 1 ... for _ in range(n): ... b, a = a + b, b ... return a ... >>> fib(10) 55 >>> fib.__annotations__ {'n': , 'return': } Lambda ------ Lambda expressions create small anonymous functions inline. They are syntactically restricted to a single expression, which is implicitly returned. Lambdas are useful for short, throwaway functions, especially as arguments to higher-order functions like ``sorted()``, ``map()``, ``filter()``, and ``reduce()``. While lambdas can make code more concise, complex logic should be written as regular named functions for better readability and debugging. .. code-block:: python >>> square = lambda x: x ** 2 >>> square(5) 25 >>> # lambda with multiple arguments >>> add = lambda a, b: a + b >>> add(2, 3) 5 >>> # lambda with conditional >>> max_val = lambda a, b: a if a > b else b >>> max_val(3, 5) 5 >>> # common use: sorting key >>> pairs = [(1, 'b'), (2, 'a'), (3, 'c')] >>> sorted(pairs, key=lambda x: x[1]) [(2, 'a'), (1, 'b'), (3, 'c')] Callable -------- In Python, any object that implements the ``__call__`` method is callable, meaning it can be invoked like a function using parentheses. This includes functions, methods, lambdas, classes (calling a class creates an instance), and instances of classes that define ``__call__``. The built-in ``callable()`` function returns ``True`` if an object appears callable, which is useful for checking before attempting to call an object to avoid ``TypeError`` exceptions. .. code-block:: python >>> callable(print) True >>> callable(42) False >>> class Adder: ... def __init__(self, n): ... self.n = n ... def __call__(self, x): ... return self.n + x ... >>> add_five = Adder(5) >>> callable(add_five) True >>> add_five(10) 15 Get Function Name ----------------- Functions in Python are first-class objects with various attributes that provide metadata about them. The ``__name__`` attribute contains the function's name as defined, ``__doc__`` contains the docstring, ``__module__`` indicates which module the function was defined in, and ``__annotations__`` holds type hints. These attributes are useful for debugging, logging, and introspection. .. code-block:: python >>> def example_function(): ... """Example docstring.""" ... pass ... >>> example_function.__name__ 'example_function' >>> example_function.__doc__ 'Example docstring.' >>> example_function.__module__ '__main__' Closure ------- A closure is a function that captures and remembers values from its enclosing lexical scope even after that scope has finished executing. This happens when a nested function references variables from its outer function. Closures are powerful for creating function factories (functions that return customized functions), implementing decorators, and maintaining state without using global variables or classes. Use the ``nonlocal`` keyword to modify captured variables from the enclosing scope. .. code-block:: python >>> def make_multiplier(n): ... def multiplier(x): ... return x * n ... return multiplier ... >>> double = make_multiplier(2) >>> triple = make_multiplier(3) >>> double(5) 10 >>> triple(5) 15 >>> # closure with mutable state >>> def make_counter(): ... count = 0 ... def counter(): ... nonlocal count ... count += 1 ... return count ... return counter ... >>> counter = make_counter() >>> counter() 1 >>> counter() 2 Generator --------- Generator functions use the ``yield`` statement to produce a sequence of values lazily, one at a time, instead of computing all values upfront and storing them in memory. When called, a generator function returns a generator iterator that can be iterated over with ``for`` loops or ``next()``. Generators are memory-efficient for large sequences and can represent infinite sequences. Generator expressions provide a concise syntax similar to list comprehensions but with lazy evaluation. .. code-block:: python >>> def fib(n): ... a, b = 0, 1 ... for _ in range(n): ... yield a ... b, a = a + b, b ... >>> list(fib(10)) [0, 1, 1, 2, 3, 5, 8, 13, 21, 34] >>> # generator expression >>> squares = (x**2 for x in range(5)) >>> list(squares) [0, 1, 4, 9, 16] Decorator --------- Decorators are a powerful pattern for modifying or extending the behavior of functions without changing their source code. A decorator is a function that takes a function as input and returns a new function (usually a wrapper) that adds some functionality before or after calling the original. The ``@decorator`` syntax is syntactic sugar for ``func = decorator(func)``. Always use ``@wraps`` from ``functools`` in your wrapper function to preserve the original function's metadata like ``__name__``, ``__doc__``, and ``__annotations__``. **New in Python 2.4** - PEP `318 `_ .. code-block:: python >>> from functools import wraps >>> def log_calls(func): ... @wraps(func) ... def wrapper(*args, **kwargs): ... print(f"Calling {func.__name__}") ... return func(*args, **kwargs) ... return wrapper ... >>> @log_calls ... def greet(name): ... return f"Hello, {name}!" ... >>> greet("Alice") Calling greet 'Hello, Alice!' >>> # equivalent to: >>> # greet = log_calls(greet) .. note:: Always use ``@wraps(func)`` in decorators to preserve the original function's ``__name__``, ``__doc__``, and other attributes. Without it, the decorated function will have the wrapper's attributes, which makes debugging harder. Decorator with Arguments ------------------------ To create a decorator that accepts arguments, you need an extra layer of nesting. The outermost function takes the decorator's arguments and returns the actual decorator. The middle function takes the function being decorated and returns the wrapper. The innermost function is the wrapper that executes when the decorated function is called. This pattern is commonly used for decorators like ``@repeat(3)`` or ``@route('/path')``. .. code-block:: python >>> from functools import wraps >>> def repeat(times): ... def decorator(func): ... @wraps(func) ... def wrapper(*args, **kwargs): ... for _ in range(times): ... result = func(*args, **kwargs) ... return result ... return wrapper ... return decorator ... >>> @repeat(3) ... def say_hello(): ... print("Hello!") ... >>> say_hello() Hello! Hello! Hello! >>> # equivalent to: >>> # say_hello = repeat(3)(say_hello) Class Decorator --------------- Decorators can also be implemented as classes instead of functions. A class-based decorator implements ``__init__`` to receive the decorated function and ``__call__`` to act as the wrapper. This approach is useful when the decorator needs to maintain state across multiple calls to the decorated function, such as counting calls, caching results, or tracking timing information. .. code-block:: python >>> class CountCalls: ... def __init__(self, func): ... self.func = func ... self.count = 0 ... def __call__(self, *args, **kwargs): ... self.count += 1 ... return self.func(*args, **kwargs) ... >>> @CountCalls ... def example(): ... return "result" ... >>> example() 'result' >>> example() 'result' >>> example.count 2 Cache with ``lru_cache`` ------------------------ The ``lru_cache`` decorator from ``functools`` automatically caches function results based on the arguments passed. When the function is called with the same arguments again, the cached result is returned instead of recomputing it. This is especially useful for expensive computations or recursive functions like Fibonacci. The ``maxsize`` parameter limits cache size (use ``None`` for unlimited). Use ``cache_info()`` to see hit/miss statistics and ``cache_clear()`` to reset the cache. **New in Python 3.2** .. code-block:: python >>> from functools import lru_cache >>> @lru_cache(maxsize=None) ... def fib(n): ... if n < 2: ... return n ... return fib(n - 1) + fib(n - 2) ... >>> fib(100) 354224848179261915075 >>> fib.cache_info() CacheInfo(hits=98, misses=101, maxsize=None, currsize=101) >>> fib.cache_clear() # clear the cache **New in Python 3.9** - ``@cache`` is a simpler alias for ``@lru_cache(maxsize=None)`` .. code-block:: python >>> from functools import cache >>> @cache ... def factorial(n): ... return n * factorial(n-1) if n else 1 Partial Functions ----------------- The ``functools.partial`` function creates a new callable with some arguments of the original function pre-filled. This is useful for adapting functions to interfaces that expect fewer arguments, creating specialized versions of general functions, or preparing callback functions. The resulting partial object can be called with the remaining arguments. You can pre-fill both positional and keyword arguments. .. code-block:: python >>> from functools import partial >>> def power(base, exponent): ... return base ** exponent ... >>> square = partial(power, exponent=2) >>> cube = partial(power, exponent=3) >>> square(5) 25 >>> cube(5) 125 >>> # useful for callbacks >>> from functools import partial >>> def greet(greeting, name): ... return f"{greeting}, {name}!" ... >>> say_hello = partial(greet, "Hello") >>> say_hello("Alice") 'Hello, Alice!' ``singledispatch`` - Function Overloading ----------------------------------------- The ``singledispatch`` decorator from ``functools`` enables function overloading based on the type of the first argument. You define a base function and then register specialized implementations for different types using the ``@func.register`` decorator. When the function is called, Python automatically dispatches to the appropriate implementation based on the argument's type. This is useful for writing generic functions that behave differently for different types. **New in Python 3.4** .. code-block:: python >>> from functools import singledispatch >>> @singledispatch ... def process(arg): ... return f"Default: {arg}" ... >>> @process.register(int) ... def _(arg): ... return f"Integer: {arg * 2}" ... >>> @process.register(list) ... def _(arg): ... return f"List with {len(arg)} items" ... >>> process("hello") 'Default: hello' >>> process(5) 'Integer: 10' >>> process([1, 2, 3]) 'List with 3 items' ``reduce`` - Cumulative Operations ---------------------------------- The ``reduce`` function from ``functools`` applies a two-argument function cumulatively to the items of a sequence, from left to right, reducing the sequence to a single value. For example, ``reduce(f, [a, b, c, d])`` computes ``f(f(f(a, b), c), d)``. An optional third argument provides an initial value. While ``reduce`` can be powerful, list comprehensions or explicit loops are often more readable for simple cases. .. code-block:: python >>> from functools import reduce >>> # sum of list >>> reduce(lambda x, y: x + y, [1, 2, 3, 4, 5]) 15 >>> # product of list >>> reduce(lambda x, y: x * y, [1, 2, 3, 4, 5]) 120 >>> # with initial value >>> reduce(lambda x, y: x + y, [1, 2, 3], 10) 16 Higher-Order Functions ---------------------- Higher-order functions are functions that take other functions as arguments or return functions as results. Python provides several built-in higher-order functions that are commonly used for functional programming patterns. ``map()`` applies a function to every item in an iterable, ``filter()`` keeps items where the function returns ``True``, and ``sorted()``/``min()``/``max()`` accept a ``key`` function to customize comparison. These functions return iterators (except ``sorted``), so wrap them in ``list()`` if you need a list. .. code-block:: python >>> # map - apply function to each item >>> list(map(lambda x: x**2, [1, 2, 3, 4])) [1, 4, 9, 16] >>> # filter - keep items where function returns True >>> list(filter(lambda x: x > 2, [1, 2, 3, 4])) [3, 4] >>> # sorted with key function >>> sorted(['banana', 'apple', 'cherry'], key=len) ['apple', 'banana', 'cherry'] >>> # min/max with key function >>> max(['apple', 'banana', 'cherry'], key=len) 'banana' ================================================ FILE: docs/notes/basic/python-future.rst ================================================ .. meta:: :description lang=en: Python __future__ module guide covering future statements, backward compatibility, and feature backporting from newer Python versions :keywords: Python, __future__, future statements, backward compatibility, print_function, annotations, division ====== Future ====== .. contents:: Table of Contents :backlinks: none `Future statements `_ tell the interpreter to compile some semantics as the semantics which will be available in the future Python version. In other words, Python uses ``from __future__ import feature`` to backport features from other higher Python versions to the current interpreter. In Python 3, many features such as ``print_function`` are already enabled, but we still leave these future statements for backward compatibility. Future statements are **NOT** import statements. Future statements change how Python interprets the code. They **MUST** be at the top of the file. Otherwise, Python interpreter will raise ``SyntaxError``. If you're interested in future statements and want to acquire more explanation, further information can be found on `PEP 236 - Back to the __future__ `_ List All New Features --------------------- `__future__ `_ is a Python module. We can use it to check what kind of future features can import to current Python interpreter. The fun is ``import __future__`` is **NOT** a future statement, it is a import statement. .. code-block:: python >>> from pprint import pprint >>> import __future__ >>> pprint(__future__.all_feature_names) ['nested_scopes', 'generators', 'division', 'absolute_import', 'with_statement', 'print_function', 'unicode_literals', 'barry_as_FLUFL', 'generator_stop', 'annotations'] Future statements not only change the behavior of the Python interpreter but also import ``__future__._Feature`` into the current program. .. code-block:: python >>> from __future__ import print_function >>> print_function _Feature((2, 6, 0, 'alpha', 2), (3, 0, 0, 'alpha', 0), 65536) Print Function -------------- Replacing **print statement** to **print function** is one of the most notorious decision in Python history. However, this change brings some flexibilities to extend the ability of ``print``. Further information can be found on PEP `3105 `_. .. code-block:: python >>> print "Hello World" # print is a statement Hello World >>> from __future__ import print_function >>> print "Hello World" File "", line 1 print "Hello World" ^ SyntaxError: invalid syntax >>> print("Hello World") # print become a function Hello World Unicode ------- As **print function**, making text become Unicode is another infamous decision. Nevertheless, many modern programming languages’ text is Unicode. This change compels us to decode texts early in order to prevent runtime error after we run programs for a while. Further information can be found on PEP `3112 `_. .. code-block:: python >>> type("Guido") # string type is str in python2 >>> from __future__ import unicode_literals >>> type("Guido") # string type become unicode Division -------- Sometimes, it is counterintuitive when the division result is int or long. In this case, Python 3 enables the **true division** by default. However, in Python 2, we have to backport ``division`` to the current interpreter. Further information can be found on PEP `238 `_. .. code-block:: python >>> 1 / 2 0 >>> from __future__ import division >>> 1 / 2 # return a float (classic division) 0.5 >>> 1 // 2 # return a int (floor division) 0 Annotations ----------- Before Python 3.7, we cannot assign annotations in a class or a function if it is not available in the current scope. A common situation is the definition of a container class. .. code-block:: python class Tree(object): def insert(self, tree: Tree): ... Example .. code-block:: bash $ python3 foo.py Traceback (most recent call last): File "foo.py", line 1, in class Tree(object): File "foo.py", line 3, in Tree def insert(self, tree: Tree): ... NameError: name 'Tree' is not defined In this case, the definition of the class is not available yet. Python interpreter cannot parse the annotation during their definition time. To solve this issue, Python uses string literals to replace the class. .. code-block:: python class Tree(object): def insert(self, tree: 'Tree'): ... After version 3.7, Python introduces the future statement, ``annotations``, to perform postponed evaluation. It will become the default feature in Python 4. For further information please refer to PEP `563 `_. .. code-block:: python from __future__ import annotations class Tree(object): def insert(self, tree: Tree): ... BDFL Retirement --------------- **New in Python 3.1** PEP `401 `_ is just an Easter egg. This feature brings the current interpreter back to the past. It enables the diamond operator ``<>`` in Python 3. .. code-block:: python >>> 1 != 2 True >>> from __future__ import barry_as_FLUFL >>> 1 != 2 File "", line 1 1 != 2 ^ SyntaxError: with Barry as BDFL, use '<>' instead of '!=' >>> 1 <> 2 True Braces ------ ``braces`` is an Easter egg. The source code can be found on `future.c `_. .. code-block:: python >>> from __future__ import braces File "", line 1 SyntaxError: not a chance ================================================ FILE: docs/notes/basic/python-generator.rst ================================================ .. meta:: :description lang=en: Python generator cheat sheet covering generator functions, generator expressions, yield, yield from, send, async generators, and coroutines with code examples :keywords: Python, Python3, Python generator, Python generator cheat sheet, yield, yield from, generator expression, async generator, iterator, coroutine, contextmanager ========= Generator ========= .. contents:: Table of Contents :backlinks: none Generators are a powerful feature in Python for creating iterators. They allow you to iterate over data without storing the entire sequence in memory, making them ideal for processing large datasets or infinite sequences. This cheat sheet covers generator functions, generator expressions, ``yield``, ``yield from``, sending values to generators, and async generators. Generator Function vs Generator Expression ------------------------------------------ A generator function is defined like a normal function but uses ``yield`` to produce a sequence of values. When called, it returns a generator object that can be iterated over. A generator expression is a compact syntax similar to list comprehensions but produces values lazily on demand. .. code-block:: python # generator function >>> def gen_func(): ... yield 5566 ... >>> g = gen_func() >>> g >>> next(g) 5566 # generator expression >>> g = (x for x in range(3)) >>> next(g) 0 >>> next(g) 1 Yield Values from Generator --------------------------- The ``yield`` statement produces a value and suspends the generator's execution. When ``next()`` is called again, execution resumes from where it left off. This example generates prime numbers by checking divisibility for each candidate. .. code-block:: python >>> def prime(n): ... p = 2 ... while n > 0: ... for x in range(2, p): ... if p % x == 0: ... break ... else: ... yield p ... n -= 1 ... p += 1 ... >>> list(prime(5)) [2, 3, 5, 7, 11] Unpack Generators ----------------- Python 3.5+ (PEP 448) allows unpacking generators directly into lists, sets, function arguments, and variables using the ``*`` operator. This provides a convenient way to consume generator values without explicit iteration. .. code-block:: python # PEP 448 - unpacking inside a list >>> g1 = (x for x in range(3)) >>> g2 = (x**2 for x in range(2)) >>> [1, *g1, 2, *g2] [1, 0, 1, 2, 2, 0, 1] # unpacking inside a set >>> g = (x for x in [5, 5, 6, 6]) >>> {*g} {5, 6} # unpacking to variables >>> g = (x for x in range(3)) >>> a, b, c = g >>> a, b, c (0, 1, 2) # extended unpacking >>> g = (x for x in range(6)) >>> a, b, *c, d = g >>> a, b, d (0, 1, 5) >>> c [2, 3, 4] # unpacking inside a function >>> print(*(x for x in range(3))) 0 1 2 Iterable Class via Generator ---------------------------- You can make a class iterable by implementing ``__iter__`` as a generator method. This approach is cleaner than implementing a separate iterator class. The ``__reversed__`` method can also be implemented as a generator to support the built-in ``reversed()`` function. .. code-block:: python >>> class Count: ... def __init__(self, n): ... self._n = n ... def __iter__(self): ... n = self._n ... while n > 0: ... yield n ... n -= 1 ... def __reversed__(self): ... n = 1 ... while n <= self._n: ... yield n ... n += 1 ... >>> list(Count(5)) [5, 4, 3, 2, 1] >>> list(reversed(Count(5))) [1, 2, 3, 4, 5] Send Values to Generator ------------------------ Generators can receive values through the ``send()`` method. The sent value becomes the result of the ``yield`` expression inside the generator. Before sending values, you must start the generator by calling ``next()`` or ``send(None)`` to advance it to the first ``yield``. .. code-block:: python >>> def spam(): ... msg = yield ... print("Message:", msg) ... >>> g = spam() >>> next(g) # start generator >>> try: ... g.send("Hello World!") ... except StopIteration: ... pass Message: Hello World! yield from Expression --------------------- The ``yield from`` expression delegates iteration to another generator or iterable. It automatically handles forwarding ``send()``, ``throw()``, and ``close()`` calls to the subgenerator, making it ideal for creating generator pipelines and recursive generators. .. code-block:: python >>> def subgen(): ... try: ... yield 9527 ... except ValueError: ... print("got ValueError") ... >>> def delegating_gen(): ... yield from subgen() ... >>> g = delegating_gen() >>> next(g) 9527 >>> try: ... g.throw(ValueError) ... except StopIteration: ... pass got ValueError You can chain multiple ``yield from`` expressions together. The ``inspect.getgeneratorstate()`` function helps track the generator's lifecycle through its states: GEN_CREATED, GEN_RUNNING, GEN_SUSPENDED, and GEN_CLOSED. .. code-block:: python # yield from + yield from >>> import inspect >>> def subgen(): ... yield from range(3) ... >>> def delegating_gen(): ... yield from subgen() ... >>> g = delegating_gen() >>> inspect.getgeneratorstate(g) 'GEN_CREATED' >>> next(g) 0 >>> inspect.getgeneratorstate(g) 'GEN_SUSPENDED' >>> g.close() >>> inspect.getgeneratorstate(g) 'GEN_CLOSED' yield from with Return ---------------------- Generators can return a value using the ``return`` statement. The returned value is accessible through the ``value`` attribute of the ``StopIteration`` exception. When using ``yield from``, the return value of the subgenerator becomes the value of the ``yield from`` expression. .. code-block:: python >>> def average(): ... total = .0 ... count = 0 ... while True: ... val = yield ... if not val: ... break ... total += val ... count += 1 ... return total / count ... >>> g = average() >>> next(g) >>> g.send(3) >>> g.send(5) >>> try: ... g.send(None) ... except StopIteration as e: ... print(e.value) 4.0 .. code-block:: python >>> def subgen(): ... yield 9527 ... >>> def delegating_gen(): ... yield from subgen() ... return 5566 ... >>> g = delegating_gen() >>> next(g) 9527 >>> try: ... next(g) ... except StopIteration as e: ... print(e.value) 5566 Generate Sequences ------------------ The ``yield from`` expression provides a concise way to yield all values from an iterable. This is particularly useful for chaining multiple sequences together or flattening nested structures. .. code-block:: python >>> def chain(): ... yield from 'ab' ... yield from range(3) ... >>> list(chain()) ['a', 'b', 0, 1, 2] What ``RES = yield from EXP`` Does ---------------------------------- This snippet shows the simplified equivalent of what ``yield from`` does internally, as described in PEP 380. It handles iteration, value passing via ``send()``, and captures the return value from the subgenerator. .. code-block:: python # Simplified version (ref: PEP 380) >>> def subgen(): ... for x in range(3): ... yield x ... >>> def delegating_gen(): ... _i = iter(subgen()) ... try: ... _y = next(_i) ... except StopIteration as _e: ... RES = _e.value ... else: ... while True: ... _s = yield _y ... try: ... _y = _i.send(_s) ... except StopIteration as _e: ... RES = _e.value ... break ... >>> list(delegating_gen()) [0, 1, 2] Check Generator Type -------------------- Use ``types.GeneratorType`` to check if an object is a generator. This is useful for writing functions that need to handle generators differently from other iterables. .. code-block:: python >>> from types import GeneratorType >>> def gen_func(): ... yield 5566 ... >>> isinstance(gen_func(), GeneratorType) True Check Generator State --------------------- The ``inspect.getgeneratorstate()`` function returns the current state of a generator. This is helpful for debugging and understanding the generator lifecycle. The four possible states are: GEN_CREATED (not started), GEN_RUNNING (currently executing), GEN_SUSPENDED (paused at yield), and GEN_CLOSED (completed or closed). .. code-block:: python >>> import inspect >>> def gen_func(): ... yield 9527 ... >>> g = gen_func() >>> inspect.getgeneratorstate(g) 'GEN_CREATED' >>> next(g) 9527 >>> inspect.getgeneratorstate(g) 'GEN_SUSPENDED' >>> g.close() >>> inspect.getgeneratorstate(g) 'GEN_CLOSED' Context Manager via Generator ----------------------------- The ``@contextlib.contextmanager`` decorator transforms a generator function into a context manager. Code before ``yield`` runs on entering the ``with`` block, and code after ``yield`` (typically in ``finally``) runs on exit. The yielded value is bound to the variable after ``as``. .. code-block:: python >>> import contextlib >>> @contextlib.contextmanager ... def mylist(): ... try: ... l = [1, 2, 3, 4, 5] ... yield l ... finally: ... print("exit scope") ... >>> with mylist() as l: ... print(l) [1, 2, 3, 4, 5] exit scope What ``@contextmanager`` Does ----------------------------- This snippet shows a simplified implementation of how ``@contextmanager`` works internally. It wraps a generator in a class that implements the context manager protocol (``__enter__`` and ``__exit__``), handling both normal exit and exception propagation. .. code-block:: python class GeneratorCM: def __init__(self, gen): self._gen = gen def __enter__(self): return next(self._gen) def __exit__(self, *exc_info): try: if exc_info[0] is None: next(self._gen) else: self._gen.throw(*exc_info) except StopIteration: return True raise def contextmanager(func): def run(*a, **k): return GeneratorCM(func(*a, **k)) return run Profile Code Block ------------------ A practical example of using generator-based context managers to measure execution time of code blocks. The ``yield`` statement marks the boundary between setup (recording start time) and teardown (calculating elapsed time). .. code-block:: python >>> import time >>> from contextlib import contextmanager >>> @contextmanager ... def profile(msg): ... try: ... s = time.time() ... yield ... finally: ... print(f'{msg} cost: {time.time() - s:.2f}s') ... >>> with profile('block'): ... time.sleep(0.1) block cost: 0.10s ``yield from`` and ``__iter__`` ------------------------------- When using ``yield from`` with a class instance, Python calls the object's ``__iter__`` method to get an iterator. This allows custom classes to work seamlessly with ``yield from`` delegation, enabling elegant composition of iterables. .. code-block:: python >>> class FakeGen: ... def __iter__(self): ... n = 0 ... while n < 3: ... yield n ... n += 1 ... def __reversed__(self): ... n = 2 ... while n >= 0: ... yield n ... n -= 1 ... >>> def spam(): ... yield from FakeGen() ... >>> list(spam()) [0, 1, 2] >>> list(reversed(FakeGen())) [2, 1, 0] Closure Using Generator ----------------------- Generators provide an elegant way to implement closures that maintain state between calls. Each call to ``next()`` resumes execution and can access and modify the enclosed variables. This is often cleaner than using ``nonlocal`` or class-based approaches. .. code-block:: python # generator version >>> def closure_gen(): ... x = 5566 ... while True: ... x += 1 ... yield x ... >>> g = closure_gen() >>> next(g) 5567 >>> next(g) 5568 Simple Scheduler ---------------- This example demonstrates how generators can be used to implement cooperative multitasking. Each generator represents a task that yields control back to the scheduler. The scheduler uses a deque to round-robin between tasks, advancing each one step at a time. .. code-block:: python >>> from collections import deque >>> def fib(n): ... if n <= 2: return 1 ... return fib(n-1) + fib(n-2) ... >>> def g_fib(n): ... for x in range(1, n + 1): ... yield fib(x) ... >>> q = deque([g_fib(3), g_fib(5)]) >>> def run(): ... while q: ... try: ... t = q.popleft() ... print(next(t)) ... q.append(t) ... except StopIteration: ... print("Task done") ... >>> run() 1 1 1 1 2 2 Task done 3 5 Task done Simple Round-Robin with Blocking -------------------------------- A more advanced scheduler that handles I/O blocking using ``select()``. Tasks yield tuples indicating what operation they're waiting for ('recv' or 'send') and which socket. The scheduler moves blocked tasks to wait queues and only runs them when their I/O is ready. This is the foundation of async I/O frameworks. .. code-block:: python from collections import deque from select import select import socket tasks = deque() w_read = {} w_send = {} def run(): while any([tasks, w_read, w_send]): while not tasks: can_r, can_s, _ = select(w_read, w_send, []) for _r in can_r: tasks.append(w_read.pop(_r)) for _w in can_s: tasks.append(w_send.pop(_w)) try: task = tasks.popleft() why, what = next(task) if why == 'recv': w_read[what] = task elif why == 'send': w_send[what] = task except StopIteration: pass def server(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(('localhost', 5566)) sock.listen(5) while True: yield 'recv', sock conn, addr = sock.accept() tasks.append(client_handler(conn)) def client_handler(conn): while True: yield 'recv', conn msg = conn.recv(1024) if not msg: break yield 'send', conn conn.send(msg) conn.close() tasks.append(server()) run() Async Generator (Python 3.6+) ----------------------------- Async generators combine ``async def`` with ``yield`` to create asynchronous iterators. They can use ``await`` to pause for async operations between yields. Use ``async for`` to iterate over async generators. This is essential for streaming data from async sources like network connections or databases. .. code-block:: python >>> import asyncio >>> async def slow_gen(n, t): ... for x in range(n): ... await asyncio.sleep(t) ... yield x ... >>> async def task(n): ... async for x in slow_gen(n, 0.1): ... print(x) ... >>> asyncio.run(task(3)) 0 1 2 Async Generator with try..finally --------------------------------- Async generators support ``try..finally`` blocks for cleanup, just like regular generators. The ``finally`` block executes when the generator is closed or garbage collected, ensuring resources are properly released even if an exception occurs during iteration. .. code-block:: python >>> import asyncio >>> async def agen(t): ... try: ... await asyncio.sleep(t) ... yield 1 / 0 ... finally: ... print("finally") ... >>> async def main(): ... try: ... g = agen(0.1) ... await g.__anext__() ... except Exception as e: ... print(repr(e)) ... >>> asyncio.run(main()) finally ZeroDivisionError('division by zero') Send and Throw to Async Generator --------------------------------- Async generators support ``asend()`` to send values and ``athrow()`` to throw exceptions, similar to regular generators. These methods are coroutines that must be awaited. This enables two-way communication with async generators for building complex async data pipelines. .. code-block:: python >>> import asyncio >>> async def agen(n): ... try: ... for x in range(n): ... await asyncio.sleep(0.1) ... val = yield x ... print(f'got: {val}') ... except RuntimeError as e: ... yield repr(e) ... >>> async def main(): ... g = agen(5) ... ret = await g.asend(None) + await g.asend('foo') ... print(ret) ... ret = await g.athrow(RuntimeError('error')) ... print(ret) ... >>> asyncio.run(main()) got: foo 1 RuntimeError('error') Async Comprehension (Python 3.6+) --------------------------------- PEP 530 introduced async comprehensions, allowing ``async for`` in list, set, and dict comprehensions. This provides a concise way to collect values from async generators. You can also use ``if`` clauses to filter values and conditional expressions for transformations. .. code-block:: python >>> import asyncio >>> async def agen(n): ... for x in range(n): ... await asyncio.sleep(0.01) ... yield x ... >>> async def main(): ... ret = [x async for x in agen(5)] ... print(ret) ... ret = [x async for x in agen(5) if x < 3] ... print(ret) ... ret = {f'{x}': x async for x in agen(3)} ... print(ret) ... >>> asyncio.run(main()) [0, 1, 2, 3, 4] [0, 1, 2] {'0': 0, '1': 1, '2': 2} Simple Async Round-Robin ------------------------ This example shows cooperative multitasking with async generators. Multiple async generators are scheduled in a deque, and the scheduler awaits each one in turn using ``__anext__()``. This pattern is useful for interleaving multiple async data streams fairly. .. code-block:: python >>> import asyncio >>> from collections import deque >>> async def agen(n): ... for x in range(n): ... await asyncio.sleep(0.1) ... yield x ... >>> async def main(): ... q = deque([agen(3), agen(5)]) ... while q: ... try: ... g = q.popleft() ... print(await g.__anext__()) ... q.append(g) ... except StopAsyncIteration: ... pass ... >>> asyncio.run(main()) 0 0 1 1 2 2 3 4 Async Generator vs Async Iterator Performance ---------------------------------------------- Async generators have better performance than manually implemented async iterators because they are optimized at the C level in CPython. This benchmark shows that async generators can be significantly faster for iteration-heavy workloads. .. code-block:: python >>> import time >>> import asyncio >>> class AsyncIter: ... def __init__(self, n): ... self._n = n ... def __aiter__(self): ... return self ... async def __anext__(self): ... ret = self._n ... if self._n == 0: ... raise StopAsyncIteration ... self._n -= 1 ... return ret ... >>> async def agen(n): ... for i in range(n): ... yield i ... >>> async def task_agen(n): ... s = time.time() ... async for _ in agen(n): pass ... cost = time.time() - s ... print(f"agen cost time: {cost}") ... >>> async def task_aiter(n): ... s = time.time() ... async for _ in AsyncIter(n): pass ... cost = time.time() - s ... print(f"aiter cost time: {cost}") ... >>> n = 10 ** 7 >>> asyncio.run(task_agen(n)) agen cost time: 1.2698817253112793 >>> asyncio.run(task_aiter(n)) aiter cost time: 4.168368101119995 ``yield from == await`` Expression ---------------------------------- Before Python 3.5 introduced ``async``/``await`` syntax, coroutines were implemented using generators with ``@asyncio.coroutine`` decorator and ``yield from``. The ``await`` keyword is essentially equivalent to ``yield from`` for coroutines. This example shows both the old and new syntax for an echo server. .. code-block:: python import asyncio import socket loop = asyncio.get_event_loop() host = 'localhost' port = 5566 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setblocking(False) sock.bind((host, port)) sock.listen(10) # old syntax (Python 3.4) @asyncio.coroutine def echo_server(): while True: conn, addr = yield from loop.sock_accept(sock) loop.create_task(handler(conn)) @asyncio.coroutine def handler(conn): while True: msg = yield from loop.sock_recv(conn, 1024) if not msg: break yield from loop.sock_sendall(conn, msg) conn.close() # new syntax (Python 3.5+) async def echo_server(): while True: conn, addr = await loop.sock_accept(sock) loop.create_task(handler(conn)) async def handler(conn): while True: msg = await loop.sock_recv(conn, 1024) if not msg: break await loop.sock_sendall(conn, msg) conn.close() loop.create_task(echo_server()) loop.run_forever() Simple Compiler Using Generators -------------------------------- This advanced example from David Beazley demonstrates using generators to implement a simple expression compiler. It includes a tokenizer, parser, and evaluator using the visitor pattern with generators for stack-based evaluation. .. code-block:: python import re import types from collections import namedtuple tokens = [ r'(?P\d+)', r'(?P\+)', r'(?P-)', r'(?P\*)', r'(?P/)', r'(?P\s+)'] Token = namedtuple('Token', ['type', 'value']) lex = re.compile('|'.join(tokens)) def tokenize(text): scan = lex.scanner(text) gen = (Token(m.lastgroup, m.group()) for m in iter(scan.match, None) if m.lastgroup != 'WS') return gen class Node: _fields = [] def __init__(self, *args): for attr, value in zip(self._fields, args): setattr(self, attr, value) class Number(Node): _fields = ['value'] class BinOp(Node): _fields = ['op', 'left', 'right'] def parse(toks): lookahead, current = next(toks, None), None def accept(*toktypes): nonlocal lookahead, current if lookahead and lookahead.type in toktypes: current, lookahead = lookahead, next(toks, None) return True def expr(): left = term() while accept('PLUS', 'MINUS'): left = BinOp(current.value, left) left.right = term() return left def term(): left = factor() while accept('TIMES', 'DIVIDE'): left = BinOp(current.value, left) left.right = factor() return left def factor(): if accept('NUMBER'): return Number(int(current.value)) else: raise SyntaxError() return expr() class NodeVisitor: def visit(self, node): stack = [self.genvisit(node)] ret = None while stack: try: node = stack[-1].send(ret) stack.append(self.genvisit(node)) ret = None except StopIteration as e: stack.pop() ret = e.value return ret def genvisit(self, node): ret = getattr(self, 'visit_' + type(node).__name__)(node) if isinstance(ret, types.GeneratorType): ret = yield from ret return ret class Evaluator(NodeVisitor): def visit_Number(self, node): return node.value def visit_BinOp(self, node): leftval = yield node.left rightval = yield node.right if node.op == '+': return leftval + rightval elif node.op == '-': return leftval - rightval elif node.op == '*': return leftval * rightval elif node.op == '/': return leftval / rightval def evaluate(exp): toks = tokenize(exp) tree = parse(toks) return Evaluator().visit(tree) print(evaluate('2 * 3 + 5 / 2')) # 8.5 print(evaluate('+'.join([str(x) for x in range(10000)]))) # 49995000 ================================================ FILE: docs/notes/basic/python-heap.rst ================================================ .. meta:: :description lang=en: Python heap and priority queue cheat sheet covering heapq module operations, heap sort algorithm, priority queue implementation with custom comparators, and practical examples :keywords: Python, Python Cheat Sheet, heap, heapq, priority queue, heap sort, min heap, max heap, Python heapq, nlargest, nsmallest ==== Heap ==== .. contents:: Table of Contents :backlinks: none The heapq module provides an implementation of the heap queue algorithm, also known as the priority queue algorithm. Heaps are binary trees where every parent node has a value less than or equal to any of its children (min-heap). This cheat sheet covers heap operations including heap sort, priority queues, merging sorted iterables, and finding the n largest or smallest elements efficiently. The source code is available on `GitHub `_. References ---------- - `heapq — Heap queue algorithm `_ - `queue.PriorityQueue `_ Basic Heap Operations --------------------- The ``heapq`` module provides functions to create and manipulate heaps. Use ``heapify`` to convert a list into a heap in-place in O(n) time. Use ``heappush`` and ``heappop`` to add and remove elements while maintaining the heap property. .. code-block:: python >>> import heapq >>> # Convert list to heap in-place >>> h = [5, 1, 3, 2, 6] >>> heapq.heapify(h) >>> h[0] # smallest element at root 1 >>> # Push and pop >>> heapq.heappush(h, 0) >>> heapq.heappop(h) 0 >>> # Push and pop in one operation >>> heapq.heappushpop(h, 4) # push 4, then pop smallest 1 >>> # Pop and push in one operation >>> heapq.heapreplace(h, 0) # pop smallest, then push 0 2 Implement Heap Sort with ``heapq`` ---------------------------------- Heap sort works by pushing all elements onto a heap and then popping them off one by one. Since the heap maintains the min-heap property, elements come out in sorted order. The time complexity is O(n log n). .. code-block:: python >>> import heapq >>> a = [5, 1, 3, 2, 6] >>> h = [] >>> for x in a: ... heapq.heappush(h, x) ... >>> x = [heapq.heappop(h) for _ in range(len(a))] >>> x [1, 2, 3, 5, 6] A more efficient approach uses ``heapify`` to convert the list in-place: .. code-block:: python >>> import heapq >>> def heap_sort(items): ... h = items.copy() ... heapq.heapify(h) ... return [heapq.heappop(h) for _ in range(len(h))] ... >>> heap_sort([5, 1, 3, 2, 6]) [1, 2, 3, 5, 6] Implement Max Heap ------------------ Python's ``heapq`` only provides a min-heap. To implement a max-heap, negate the values when pushing and negate again when popping. .. code-block:: python >>> import heapq >>> # Max heap using negation >>> h = [] >>> for x in [5, 1, 3, 2, 6]: ... heapq.heappush(h, -x) ... >>> [-heapq.heappop(h) for _ in range(len(h))] [6, 5, 3, 2, 1] For custom objects, implement ``__lt__`` with reversed comparison: .. code-block:: python import heapq class MaxHeapItem: def __init__(self, val): self.val = val def __lt__(self, other): return self.val > other.val # reversed for max heap h = [] for x in [5, 1, 3]: heapq.heappush(h, MaxHeapItem(x)) print(heapq.heappop(h).val) # 5 (largest) Implement Priority Queue with ``heapq`` --------------------------------------- A priority queue processes elements based on their priority rather than insertion order. Use tuples ``(priority, value)`` where lower numbers indicate higher priority. .. code-block:: python >>> import heapq >>> pq = [] >>> heapq.heappush(pq, (2, "medium")) >>> heapq.heappush(pq, (1, "high")) >>> heapq.heappush(pq, (3, "low")) >>> [heapq.heappop(pq) for _ in range(len(pq))] [(1, 'high'), (2, 'medium'), (3, 'low')] For custom objects, implement the ``__lt__`` method to define comparison behavior: .. code-block:: python import heapq class Task: def __init__(self, priority, name): self.priority = priority self.name = name def __lt__(self, other): return self.priority < other.priority def __repr__(self): return f"Task({self.priority}, {self.name!r})" h = [] heapq.heappush(h, Task(3, "low")) heapq.heappush(h, Task(1, "high")) heapq.heappush(h, Task(2, "medium")) while h: print(heapq.heappop(h)) # Task(1, 'high') # Task(2, 'medium') # Task(3, 'low') Find K Largest or Smallest Elements ----------------------------------- The ``nlargest`` and ``nsmallest`` functions efficiently find the k largest or smallest elements. They are more efficient than sorting when k is small relative to the list size. .. code-block:: python >>> import heapq >>> nums = [5, 1, 8, 3, 9, 2, 7] >>> heapq.nsmallest(3, nums) [1, 2, 3] >>> heapq.nlargest(3, nums) [9, 8, 7] Use the ``key`` parameter to extract comparison keys from complex objects: .. code-block:: python >>> import heapq >>> data = [ ... {'name': 'Alice', 'score': 85}, ... {'name': 'Bob', 'score': 92}, ... {'name': 'Charlie', 'score': 78}, ... ] >>> heapq.nlargest(2, data, key=lambda x: x['score']) [{'name': 'Bob', 'score': 92}, {'name': 'Alice', 'score': 85}] Merge Sorted Iterables ---------------------- The ``merge`` function merges multiple sorted inputs into a single sorted output. It returns an iterator, making it memory-efficient for large datasets. .. code-block:: python >>> import heapq >>> a = [1, 3, 5, 7] >>> b = [2, 4, 6, 8] >>> c = [0, 9, 10] >>> list(heapq.merge(a, b, c)) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] Use ``key`` and ``reverse`` parameters for custom merging: .. code-block:: python >>> import heapq >>> # Merge in descending order >>> a = [5, 3, 1] >>> b = [6, 4, 2] >>> list(heapq.merge(a, b, reverse=True)) [6, 5, 4, 3, 2, 1] Maintain a Fixed-Size Heap -------------------------- To maintain a heap of fixed size k (e.g., tracking top k elements), use ``heappushpop`` or check the size after each push. .. code-block:: python >>> import heapq >>> def top_k(items, k): ... """Keep track of k largest elements using min-heap.""" ... h = [] ... for x in items: ... if len(h) < k: ... heapq.heappush(h, x) ... elif x > h[0]: ... heapq.heapreplace(h, x) ... return sorted(h, reverse=True) ... >>> top_k([5, 1, 8, 3, 9, 2, 7, 4, 6], 3) [9, 8, 7] Heap with Index Tracking ------------------------ When you need to update priorities in a heap, use a dictionary to track element positions or mark entries as invalid. .. code-block:: python import heapq class IndexedHeap: def __init__(self): self.heap = [] self.entry_finder = {} self.REMOVED = '' def push(self, item, priority): if item in self.entry_finder: self.remove(item) entry = [priority, item] self.entry_finder[item] = entry heapq.heappush(self.heap, entry) def remove(self, item): entry = self.entry_finder.pop(item) entry[-1] = self.REMOVED def pop(self): while self.heap: priority, item = heapq.heappop(self.heap) if item is not self.REMOVED: del self.entry_finder[item] return item raise KeyError('pop from empty heap') # Usage h = IndexedHeap() h.push('task1', 3) h.push('task2', 1) h.push('task1', 0) # update priority print(h.pop()) # task1 (now has priority 0) ================================================ FILE: docs/notes/basic/python-list.rst ================================================ .. meta:: :description lang=en: Python list cheat sheet covering list operations, comprehensions, slicing, sorting, filtering, and common list manipulation patterns with code examples :keywords: Python, Python3, Python list, Python list cheat sheet, list comprehension, slicing, sorting, filtering, append, extend, iteration ==== List ==== .. contents:: Table of Contents :backlinks: none The list is a common data structure which we use to store objects. Most of the time, programmers concern about getting, setting, searching, filtering, and sorting. Furthermore, sometimes, we waltz ourself into common pitfalls of the memory management. Thus, the main goal of this cheat sheet is to collect some common operations and pitfalls. Python List Basics and Common Operations ---------------------------------------- There are so many ways that we can manipulate lists in Python. Before we start to learn those versatile manipulations, the following snippet shows the most common operations of lists. .. code-block:: python >>> a = [1, 2, 3, 4, 5] >>> # contains >>> 2 in a True >>> # positive index >>> a[0] 1 >>> # negative index >>> a[-1] 5 >>> # slicing list[start:end:step] >>> a[1:] [2, 3, 4, 5] >>> a[1:-1] [2, 3, 4] >>> a[1:-1:2] [2, 4] >>> # reverse >>> a[::-1] [5, 4, 3, 2, 1] >>> a[:0:-1] [5, 4, 3, 2] >>> # set an item >>> a[0] = 0 >>> a [0, 2, 3, 4, 5] >>> # append items to list >>> a.append(6) >>> a [0, 2, 3, 4, 5, 6] >>> a.extend([7, 8, 9]) >>> a [0, 2, 3, 4, 5, 6, 7, 8, 9] >>> # delete an item >>> del a[-1] >>> a [0, 2, 3, 4, 5, 6, 7, 8] >>> # list comprehension >>> b = [x for x in range(3)] >>> b [0, 1, 2] >>> # add two lists >>> a + b [0, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2] Initialize Lists with Multiplication Operator --------------------------------------------- Generally speaking, we can create a list through ``*`` operator if the item in the list expression is an immutable object. .. code-block:: python >>> a = [None] * 3 >>> a [None, None, None] >>> a[0] = "foo" >>> a ['foo', None, None] However, if the item in the list expression is a mutable object, the ``*`` operator will copy the reference of the item N times. In order to avoid this pitfall, we should use a list comprehension to initialize a list. .. code-block:: python >>> a = [[]] * 3 >>> b = [[] for _ in range(3)] >>> a[0].append("Hello") >>> a [['Hello'], ['Hello'], ['Hello']] >>> b[0].append("Python") >>> b [['Python'], [], []] Copy Lists: Shallow vs Deep Copy -------------------------------- Assigning a list to a variable is a common pitfall. This assignment does not copy the list to the variable. The variable only refers to the list and increase the reference count of the list. .. code-block:: python import sys >>> a = [1, 2, 3] >>> sys.getrefcount(a) 2 >>> b = a >>> sys.getrefcount(a) 3 >>> b[2] = 123456 # a[2] = 123456 >>> b [1, 2, 123456] >>> a [1, 2, 123456] There are two types of copy. The first one is called *shallow copy* (non-recursive copy) and the second one is called *deep copy* (recursive copy). Most of the time, it is sufficient for us to copy a list by shallow copy. However, if a list is nested, we have to use a deep copy. .. code-block:: python >>> # shallow copy >>> a = [1, 2] >>> b = list(a) >>> b[0] = 123 >>> a [1, 2] >>> b [123, 2] >>> a = [[1], [2]] >>> b = list(a) >>> b[0][0] = 123 >>> a [[123], [2]] >>> b [[123], [2]] >>> # deep copy >>> import copy >>> a = [[1], [2]] >>> b = copy.deepcopy(a) >>> b[0][0] = 123 >>> a [[1], [2]] >>> b [[123], [2]] Slice Lists with slice Objects ------------------------------ Sometimes, our data may concatenate as a large segment such as packets. In this case, we will represent the range of data by using ``slice`` objects as explaining variables instead of using *slicing expressions*. .. code-block:: python >>> icmp = ( ... b"080062988e2100005bff49c20005767c" ... b"08090a0b0c0d0e0f1011121314151617" ... b"18191a1b1c1d1e1f2021222324252627" ... b"28292a2b2c2d2e2f3031323334353637" ... ) >>> head = slice(0, 32) >>> data = slice(32, len(icmp)) >>> icmp[head] b'080062988e2100005bff49c20005767c' Create Lists with List Comprehensions ------------------------------------- `List comprehensions `_ which was proposed in PEP `202 `_ provides a graceful way to create a new list based on another list, sequence, or some object which is iterable. In addition, we can use this expression to substitute ``map`` and ``filter`` sometimes. .. code-block:: python >>> [x for x in range(10)] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] >>> [(lambda x: x**2)(i) for i in range(10)] [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] >>> [x for x in range(10) if x > 5] [6, 7, 8, 9] >>> [x if x > 5 else 0 for x in range(10)] [0, 0, 0, 0, 0, 0, 6, 7, 8, 9] >>> [x + 1 if x < 5 else x + 2 if x > 5 else x + 5 for x in range(10)] [1, 2, 3, 4, 5, 10, 8, 9, 10, 11] >>> [(x, y) for x in range(3) for y in range(2)] [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)] Unpack Lists into Variables --------------------------- Sometimes, we want to unpack our list to variables in order to make our code become more readable. In this case, we assign N elements to N variables as following example. .. code-block:: python >>> arr = [1, 2, 3] >>> a, b, c = arr >>> a, b, c (1, 2, 3) Based on PEP `3132 `_, we can use a single asterisk to unpack N elements to the number of variables which is less than N in Python 3. .. code-block:: python >>> arr = [1, 2, 3, 4, 5] >>> a, b, *c, d = arr >>> a, b, d (1, 2, 5) >>> c [3, 4] Iterate with Index Using enumerate() ------------------------------------ ``enumerate`` is a built-in function. It helps us to acquire indexes (or a count) and elements at the same time without using ``range(len(list))``. Further information can be found on `Looping Techniques `_. .. code-block:: python >>> for i, v in enumerate(range(3)): ... print(i, v) ... 0 0 1 1 2 2 >>> for i, v in enumerate(range(3), 1): # start = 1 ... print(i, v) ... 1 0 2 1 3 2 Combine Lists with zip() ------------------------ `zip `_ enables us to iterate over items contained in multiple lists at a time. Iteration stops whenever one of the lists is exhausted. As a result, the length of the iteration is the same as the shortest list. If this behavior is not desired, we can use ``itertools.zip_longest`` in **Python 3** or ``itertools.izip_longest`` in **Python 2**. .. code-block:: python >>> a = [1, 2, 3] >>> b = [4, 5, 6] >>> list(zip(a, b)) [(1, 4), (2, 5), (3, 6)] >>> c = [1] >>> list(zip(a, b, c)) [(1, 4, 1)] >>> from itertools import zip_longest >>> list(zip_longest(a, b, c)) [(1, 4, 1), (2, 5, None), (3, 6, None)] Filter List Items ----------------- `filter `_ is a built-in function to assist us to remove unnecessary items. In **Python 2**, ``filter`` returns a list. However, in **Python 3**, ``filter`` returns an *iterable object*. Note that *list comprehension* or *generator expression* provides a more concise way to remove items. .. code-block:: python >>> [x for x in range(5) if x > 1] [2, 3, 4] >>> l = ['1', '2', 3, 'Hello', 4] >>> f = lambda x: isinstance(x, int) >>> filter(f, l) >>> list(filter(f, l)) [3, 4] >>> list((i for i in l if f(i))) [3, 4] Implement Stack with List ------------------------- There is no need for an additional data structure, stack, in Python because the ``list`` provides ``append`` and ``pop`` methods which enable us use a list as a stack. .. code-block:: python >>> stack = [] >>> stack.append(1) >>> stack.append(2) >>> stack.append(3) >>> stack [1, 2, 3] >>> stack.pop() 3 >>> stack.pop() 2 >>> stack [1] Check Membership with in Operator --------------------------------- We can implement the ``__contains__`` method to make a class do ``in`` operations. It is a common way for a programmer to emulate a membership test operations for custom classes. .. code-block:: python class Stack: def __init__(self): self.__list = [] def push(self, val): self.__list.append(val) def pop(self): return self.__list.pop() def __contains__(self, item): return True if item in self.__list else False stack = Stack() stack.push(1) print(1 in stack) print(0 in stack) Example .. code-block:: bash python stack.py True False Access Items with __getitem__ and __setitem__ --------------------------------------------- Making custom classes perform get and set operations like lists is simple. We can implement a ``__getitem__`` method and a ``__setitem__`` method to enable a class to retrieve and overwrite data by index. In addition, if we want to use the function, ``len``, to calculate the number of elements, we can implement a ``__len__`` method. .. code-block:: python class Stack: def __init__(self): self.__list = [] def push(self, val): self.__list.append(val) def pop(self): return self.__list.pop() def __repr__(self): return "{}".format(self.__list) def __len__(self): return len(self.__list) def __getitem__(self, idx): return self.__list[idx] def __setitem__(self, idx, val): self.__list[idx] = val stack = Stack() stack.push(1) stack.push(2) print("stack:", stack) stack[0] = 3 print("stack:", stack) print("num items:", len(stack)) Example .. code-block:: bash $ python stack.py stack: [1, 2] stack: [3, 2] num items: 2 Delegate Iteration with __iter__ -------------------------------- If a custom container class holds a list and we want iterations to work on the container, we can implement a ``__iter__`` method to delegate iterations to the list. Note that the method, ``__iter__``, should return an *iterator object*, so we cannot return the list directly; otherwise, Python raises a ``TypeError``. .. code-block:: python class Stack: def __init__(self): self.__list = [] def push(self, val): self.__list.append(val) def pop(self): return self.__list.pop() def __iter__(self): return iter(self.__list) stack = Stack() stack.push(1) stack.push(2) for s in stack: print(s) Example .. code-block:: bash $ python stack.py 1 2 Sort Lists with sort() and sorted() ----------------------------------- Python list provides a built-in ``list.sort`` method which sorts a list `in-place `_ without using extra memory. Moreover, the return value of ``list.sort`` is ``None`` in order to avoid confusion with ``sorted`` and the function can only be used for ``list``. .. code-block:: python >>> l = [5, 4, 3, 2, 1] >>> l.sort() >>> l [1, 2, 3, 4, 5] >>> l.sort(reverse=True) >>> l [5, 4, 3, 2, 1] The ``sorted`` function does not modify any iterable object in-place. Instead, it returns a new sorted list. Using ``sorted`` is safer than ``list.sort`` if some list's elements are read-only or immutable. Besides, another difference between ``list.sort`` and ``sorted`` is that ``sorted`` accepts any **iterable object**. .. code-block:: python >>> l = [5, 4, 3, 2, 1] >>> new = sorted(l) >>> new [1, 2, 3, 4, 5] >>> l [5, 4, 3, 2, 1] >>> d = {3: 'andy', 2: 'david', 1: 'amy'} >>> sorted(d) # sort iterable [1, 2, 3] To sort a list with its elements are tuples, using ``operator.itemgetter`` is helpful because it assigns a key function to the ``sorted`` key parameter. Note that the key should be comparable; otherwise, it will raise a ``TypeError``. .. code-block:: python >>> from operator import itemgetter >>> l = [('andy', 10), ('david', 8), ('amy', 3)] >>> l.sort(key=itemgetter(1)) >>> l [('amy', 3), ('david', 8), ('andy', 10)] ``operator.itemgetter`` is useful because the function returns a getter method which can be applied to other objects with a method ``__getitem__``. For example, sorting a list with its elements are dictionary can be achieved by using ``operator.itemgetter`` due to all elements have ``__getitem__``. .. code-block:: python >>> from pprint import pprint >>> from operator import itemgetter >>> l = [ ... {'name': 'andy', 'age': 10}, ... {'name': 'david', 'age': 8}, ... {'name': 'amy', 'age': 3}, ... ] >>> l.sort(key=itemgetter('age')) >>> pprint(l) [{'age': 3, 'name': 'amy'}, {'age': 8, 'name': 'david'}, {'age': 10, 'name': 'andy'}] If it is necessary to sort a list with its elements are neither comparable nor having ``__getitem__`` method, assigning a customized key function is feasible. .. code-block:: python >>> class Node(object): ... def __init__(self, val): ... self.val = val ... def __repr__(self): ... return f"Node({self.val})" ... >>> nodes = [Node(3), Node(2), Node(1)] >>> nodes.sort(key=lambda x: x.val) >>> nodes [Node(1), Node(2), Node(3)] >>> nodes.sort(key=lambda x: x.val, reverse=True) >>> nodes [Node(3), Node(2), Node(1)] The above snippet can be simplified by using ``operator.attrgetter``. The function returns an attribute getter based on the attribute's name. Note that the attribute should be comparable; otherwise, ``sorted`` or ``list.sort`` will raise ``TypeError``. .. code-block:: python >>> from operator import attrgetter >>> class Node(object): ... def __init__(self, val): ... self.val = val ... def __repr__(self): ... return f"Node({self.val})" ... >>> nodes = [Node(3), Node(2), Node(1)] >>> nodes.sort(key=attrgetter('val')) >>> nodes [Node(1), Node(2), Node(3)] If an object has ``__lt__`` method, it means that the object is comparable and ``sorted`` or ``list.sort`` is not necessary to input a key function to its key parameter. A list or an iterable sequence can be sorted directly. .. code-block:: python >>> class Node(object): ... def __init__(self, val): ... self.val = val ... def __repr__(self): ... return f"Node({self.val})" ... def __lt__(self, other): ... return self.val - other.val < 0 ... >>> nodes = [Node(3), Node(2), Node(1)] >>> nodes.sort() >>> nodes [Node(1), Node(2), Node(3)] If an object does not have ``__lt__`` method, it is likely to patch the method after a declaration of the object's class. In other words, after the patching, the object becomes comparable. .. code-block:: python >>> class Node(object): ... def __init__(self, val): ... self.val = val ... def __repr__(self): ... return f"Node({self.val})" ... >>> Node.__lt__ = lambda s, o: s.val < o.val >>> nodes = [Node(3), Node(2), Node(1)] >>> nodes.sort() >>> nodes [Node(1), Node(2), Node(3)] Note that ``sorted`` or ``list.sort`` in Python3 does not support ``cmp`` parameter which is an **ONLY** valid argument in Python2. If it is necessary to use an old comparison function, e.g., some legacy code, ``functools.cmp_to_key`` is useful since it converts a comparison function to a key function. .. code-block:: python >>> from functools import cmp_to_key >>> class Node(object): ... def __init__(self, val): ... self.val = val ... def __repr__(self): ... return f"Node({self.val})" ... >>> nodes = [Node(3), Node(2), Node(1)] >>> nodes.sort(key=cmp_to_key(lambda x,y: x.val - y.val)) >>> nodes [Node(1), Node(2), Node(3)] Maintain Sorted List with bisect -------------------------------- The `bisect `_ module provides functions to maintain a list in sorted order without having to sort the list after each insertion. It uses a binary search algorithm, making insertions efficient for large lists. .. code-block:: python import bisect class Foo(object): def __init__(self, k): self.k = k def __eq__(self, rhs): return self.k == rhs.k def __ne__(self, rhs): return self.k != rhs.k def __lt__(self, rhs): return self.k < rhs.k def __gt__(self, rhs): return self.k > rhs.k def __le__(self, rhs): return self.k <= rhs.k def __ge__(self, rhs): return self.k >= rhs.k def __repr__(self): return f"Foo({self.k})" def __str__(self): return self.__repr__() foo = [Foo(1), Foo(3), Foo(2), Foo(0)] bar = [] for x in foo: bisect.insort(bar, x) print(bar) # [Foo(0), Foo(1), Foo(2), Foo(3)] Create Nested Lists Correctly ----------------------------- When creating nested lists (2D lists or matrices), we should use list comprehension to ensure each inner list is a separate object. The following snippet shows the correct way to create a 2D list. .. code-block:: python # new a list with size = 3 >>> [0] * 3 [0, 0, 0] # new a 2d list with size 3x3 >>> [[0] * 3 for _ in range(3)] [[0, 0, 0], [0, 0, 0], [0, 0, 0]] Note that we should avoid creating a multi-dimension list via the following snippet because all objects in the list point to the same address. .. code-block:: python >>> a = [[0] * 3] * 3 >>> a [[0, 0, 0], [0, 0, 0], [0, 0, 0]] >>> a[1][1] = 2 >>> a [[0, 2, 0], [0, 2, 0], [0, 2, 0]] Implement Circular Buffer with deque ------------------------------------ `collections.deque `_ is a double-ended queue that supports adding and removing elements from both ends efficiently. By setting ``maxlen``, we can create a circular buffer that automatically discards old elements when new ones are added. .. code-block:: python >>> from collections import deque >>> d = deque(maxlen=8) >>> for x in range(9): ... d.append(x) ... >>> d deque([1, 2, 3, 4, 5, 6, 7, 8], maxlen=8) The following example shows how to implement a ``tail`` function similar to the Unix command using ``deque``. .. code-block:: python >>> from collections import deque >>> def tail(path, n=10): ... with open(path) as f: ... return deque(f, n) ... >>> tail("/etc/hosts") Split List into Chunks ---------------------- Sometimes, we need to split a list into smaller chunks of a specific size. The following generator function yields successive chunks from the list. .. code-block:: python >>> def chunk(lst, n): ... for i in range(0, len(lst), n): ... yield lst[i:i+n] ... >>> a = [1, 2, 3, 4, 5, 6, 7, 8] >>> list(chunk(a, 3)) [[1, 2, 3], [4, 5, 6], [7, 8]] Group Consecutive Elements with itertools.groupby ------------------------------------------------- `itertools.groupby `_ groups consecutive elements in an iterable that have the same key. It is useful for run-length encoding or grouping sorted data. .. code-block:: python >>> import itertools >>> s = "AAABBCCCCC" >>> for k, v in itertools.groupby(s): ... print(k, list(v)) ... A ['A', 'A', 'A'] B ['B', 'B'] C ['C', 'C', 'C', 'C', 'C'] # group by key >>> x = [('gp1', 'a'), ('gp2', 'b'), ('gp2', 'c')] >>> for k, v in itertools.groupby(x, lambda x: x[0]): ... print(k, list(v)) ... gp1 [('gp1', 'a')] gp2 [('gp2', 'b'), ('gp2', 'c')] Binary Search in Sorted List ---------------------------- Binary search is an efficient algorithm for finding an item in a sorted list. The following snippet shows how to implement binary search using ``bisect_left``. .. code-block:: python >>> def binary_search(arr, x, lo=0, hi=None): ... if not hi: hi = len(arr) ... pos = bisect_left(arr, x, lo, hi) ... return pos if pos != hi and arr[pos] == x else -1 ... >>> a = [1, 1, 1, 2, 3] >>> binary_search(a, 1) 0 >>> binary_search(a, 2) 3 Find Lower Bound with bisect_left --------------------------------- ``bisect_left`` returns the leftmost position where an element can be inserted to keep the list sorted. This is equivalent to finding the lower bound. .. code-block:: python >>> import bisect >>> a = [1,2,3,3,4,5] >>> bisect.bisect_left(a, 3) 2 >>> bisect.bisect_left(a, 3.5) 4 Find Upper Bound with bisect_right ---------------------------------- ``bisect_right`` (or ``bisect``) returns the rightmost position where an element can be inserted to keep the list sorted. This is equivalent to finding the upper bound. .. code-block:: python >>> import bisect >>> a = [1,2,3,3,4,5] >>> bisect.bisect_right(a, 3) 4 >>> bisect.bisect_right(a, 3.5) 4 Sort Tuples Lexicographically ----------------------------- Python compares tuples and lists lexicographically by default. This means it compares the first elements, and if they are equal, it compares the second elements, and so on. .. code-block:: python # python compare lists lexicographically >>> a = [(1,2), (1,1), (1,0), (2,1)] >>> a.sort() >>> a [(1, 0), (1, 1), (1, 2), (2, 1)] Implement Trie (Prefix Tree) ---------------------------- A `Trie `_ (prefix tree) is a tree data structure used for efficient retrieval of keys in a dataset of strings. The following snippet shows a compact implementation using ``defaultdict``. .. code-block:: python >>> from functools import reduce >>> from collections import defaultdict >>> Trie = lambda: defaultdict(Trie) >>> prefixes = ['abc', 'de', 'g'] >>> trie = Trie() >>> end = True >>> for p in prefixes: ... reduce(dict.__getitem__, p, trie)[end] = p ... # search prefix >>> def find(trie, word): ... curr = trie ... for c in word: ... if c not in curr: ... return False ... curr = curr[c] ... return True ... >>> find(trie, "abcdef") False >>> find(trie, "abc") True >>> find(trie, "ab") True # search word >>> def find(trie, p): ... curr = trie ... for c in p: ... if c not in curr or True in curr: ... break ... curr = curr[c] ... return True if True in curr else False ... >>> find(trie, "abcdef") True >>> find(trie, "abc") True >>> find(trie, "ab") False ================================================ FILE: docs/notes/basic/python-object.rst ================================================ .. meta:: :description lang=en: Python class cheat sheet covering magic methods, property decorators, inheritance, context managers, and OOP design patterns with code examples :keywords: Python, Python3, Python class, Python OOP cheat sheet, magic methods, property decorator, context manager, singleton, abstract class, descriptor, inheritance ===== Class ===== .. contents:: Table of Contents :backlinks: none Python is an object-oriented programming language. This cheat sheet covers class definitions, inheritance, magic methods, property decorators, context managers, and common design patterns. Understanding these concepts is essential for writing clean, maintainable Python code. List Attributes with dir() -------------------------- The ``dir()`` function returns a list of all attributes and methods of an object. This is useful for introspection and discovering what operations are available. .. code-block:: python >>> dir(list) # check all attr of list ['__add__', '__class__', ...] Check Type with isinstance() ---------------------------- Use ``isinstance()`` to check if an object is an instance of a class or its subclasses. This is preferred over ``type()`` comparison because it supports inheritance. .. code-block:: python >>> ex = 10 >>> isinstance(ex, int) True >>> isinstance(ex, (int, float)) # check multiple types True Check Inheritance with issubclass() ----------------------------------- Use ``issubclass()`` to check if a class is a subclass of another class. .. code-block:: python >>> class Animal: pass >>> class Dog(Animal): pass >>> issubclass(Dog, Animal) True >>> issubclass(Dog, object) True Get Class Name -------------- Access the class name through the ``__class__.__name__`` attribute. .. code-block:: python >>> class ExampleClass: ... pass ... >>> ex = ExampleClass() >>> ex.__class__.__name__ 'ExampleClass' Has / Get / Set Attributes -------------------------- Python provides built-in functions to dynamically access and modify object attributes at runtime. .. code-block:: python >>> class Example: ... def __init__(self): ... self.name = "ex" ... >>> ex = Example() >>> hasattr(ex, "name") True >>> getattr(ex, 'name') 'ex' >>> setattr(ex, 'name', 'example') >>> ex.name 'example' >>> getattr(ex, 'missing', 'default') # with default 'default' Declare Class with type() ------------------------- Classes can be created dynamically using ``type()``. This is useful for metaprogramming and creating classes at runtime. .. code-block:: python >>> def greet(self): ... return f"Hello, I'm {self.name}" ... >>> Person = type('Person', (object,), { ... 'name': 'Anonymous', ... 'greet': greet ... }) >>> p = Person() >>> p.greet() "Hello, I'm Anonymous" This is equivalent to: .. code-block:: python >>> class Person: ... name = 'Anonymous' ... def greet(self): ... return f"Hello, I'm {self.name}" __new__ vs __init__ ------------------- ``__new__`` creates the instance, ``__init__`` initializes it. ``__init__`` is only called if ``__new__`` returns an instance of the class. .. code-block:: python >>> class Example: ... def __new__(cls, arg): ... print(f'__new__ {arg}') ... return super().__new__(cls) ... def __init__(self, arg): ... print(f'__init__ {arg}') ... >>> o = Example("Hello") __new__ Hello __init__ Hello __str__ and __repr__ -------------------- ``__str__`` returns a human-readable string, ``__repr__`` returns an unambiguous representation for debugging. When ``__str__`` is not defined, ``__repr__`` is used. .. code-block:: python >>> class Vector: ... def __init__(self, x, y): ... self.x, self.y = x, y ... def __repr__(self): ... return f"Vector({self.x}, {self.y})" ... def __str__(self): ... return f"({self.x}, {self.y})" ... >>> v = Vector(1, 2) >>> repr(v) 'Vector(1, 2)' >>> str(v) '(1, 2)' >>> print(v) (1, 2) Comparison Magic Methods ------------------------ Implement comparison operators by defining magic methods. Use ``functools.total_ordering`` to generate all comparisons from ``__eq__`` and one other. .. code-block:: python >>> from functools import total_ordering >>> @total_ordering ... class Number: ... def __init__(self, val): ... self.val = val ... def __eq__(self, other): ... return self.val == other.val ... def __lt__(self, other): ... return self.val < other.val ... >>> Number(1) < Number(2) True >>> Number(2) >= Number(1) True Arithmetic Magic Methods ------------------------ Implement arithmetic operators to make objects work with ``+``, ``-``, ``*``, etc. .. code-block:: python >>> class Vector: ... def __init__(self, x, y): ... self.x, self.y = x, y ... def __add__(self, other): ... return Vector(self.x + other.x, self.y + other.y) ... def __mul__(self, scalar): ... return Vector(self.x * scalar, self.y * scalar) ... def __repr__(self): ... return f"Vector({self.x}, {self.y})" ... >>> Vector(1, 2) + Vector(3, 4) Vector(4, 6) >>> Vector(1, 2) * 3 Vector(3, 6) Callable with __call__ ---------------------- Implement ``__call__`` to make instances callable like functions. This is useful for creating function-like objects that maintain state. .. code-block:: python >>> class Multiplier: ... def __init__(self, factor): ... self.factor = factor ... def __call__(self, x): ... return x * self.factor ... >>> double = Multiplier(2) >>> double(5) 10 >>> callable(double) True @property Decorator ------------------- Use ``@property`` to define getters, setters, and deleters for managed attributes. This allows attribute access syntax while running custom code. .. code-block:: python >>> class Circle: ... def __init__(self, radius): ... self._radius = radius ... @property ... def radius(self): ... return self._radius ... @radius.setter ... def radius(self, value): ... if value < 0: ... raise ValueError("Radius must be positive") ... self._radius = value ... @property ... def area(self): ... return 3.14159 * self._radius ** 2 ... >>> c = Circle(5) >>> c.area 78.53975 >>> c.radius = 10 >>> c.radius 10 Descriptor Protocol ------------------- Descriptors control attribute access at the class level. They implement ``__get__``, ``__set__``, and/or ``__delete__`` methods. .. code-block:: python >>> class Positive: ... def __init__(self, name): ... self.name = name ... def __get__(self, obj, objtype=None): ... return obj.__dict__[self.name] ... def __set__(self, obj, value): ... if value < 0: ... raise ValueError("Must be positive") ... obj.__dict__[self.name] = value ... >>> class Example: ... x = Positive('x') ... def __init__(self, x): ... self.x = x ... >>> ex = Example(10) >>> ex.x 10 Context Manager Protocol ------------------------ Context managers implement ``__enter__`` and ``__exit__`` to manage resources with the ``with`` statement. This ensures proper cleanup even if exceptions occur. .. code-block:: python class ManagedFile: def __init__(self, filename): self.filename = filename def __enter__(self): self.file = open(self.filename, 'r') return self.file def __exit__(self, exc_type, exc_val, exc_tb): self.file.close() return False # don't suppress exceptions with ManagedFile('example.txt') as f: content = f.read() Using contextlib ---------------- The ``contextlib`` module provides utilities for creating context managers without writing a full class. .. code-block:: python from contextlib import contextmanager @contextmanager def managed_file(filename): f = open(filename, 'r') try: yield f finally: f.close() with managed_file('example.txt') as f: content = f.read() @staticmethod and @classmethod ------------------------------ ``@staticmethod`` defines a method that doesn't access instance or class. ``@classmethod`` receives the class as the first argument, useful for alternative constructors. .. code-block:: python >>> class Date: ... def __init__(self, year, month, day): ... self.year, self.month, self.day = year, month, day ... @classmethod ... def from_string(cls, date_string): ... year, month, day = map(int, date_string.split('-')) ... return cls(year, month, day) ... @staticmethod ... def is_valid(date_string): ... try: ... y, m, d = map(int, date_string.split('-')) ... return 1 <= m <= 12 and 1 <= d <= 31 ... except: ... return False ... >>> d = Date.from_string('2024-01-15') >>> d.year 2024 >>> Date.is_valid('2024-13-01') False Abstract Base Classes with abc ------------------------------ Use ``abc`` module to define abstract base classes that cannot be instantiated and require subclasses to implement certain methods. .. code-block:: python >>> from abc import ABC, abstractmethod >>> class Shape(ABC): ... @abstractmethod ... def area(self): ... pass ... >>> class Rectangle(Shape): ... def __init__(self, width, height): ... self.width, self.height = width, height ... def area(self): ... return self.width * self.height ... >>> r = Rectangle(3, 4) >>> r.area() 12 >>> Shape() # raises TypeError The Diamond Problem (MRO) ------------------------- Python uses Method Resolution Order (MRO) to resolve the diamond problem in multiple inheritance. Use ``ClassName.mro()`` to see the resolution order. .. code-block:: python >>> class A: ... def method(self): ... return "A" ... >>> class B(A): ... def method(self): ... return "B" ... >>> class C(A): ... def method(self): ... return "C" ... >>> class D(B, C): ... pass ... >>> D().method() 'B' >>> D.mro() [, , , , ] Singleton Pattern ----------------- Singleton ensures only one instance of a class exists. Implement using ``__new__`` or a decorator. .. code-block:: python class Singleton: _instance = None def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance a = Singleton() b = Singleton() print(a is b) # True Using __slots__ --------------- ``__slots__`` restricts instance attributes and reduces memory usage by avoiding ``__dict__`` per instance. .. code-block:: python >>> class Point: ... __slots__ = ['x', 'y'] ... def __init__(self, x, y): ... self.x, self.y = x, y ... >>> p = Point(1, 2) >>> p.x 1 >>> p.z = 3 # raises AttributeError Common Magic Methods Reference ------------------------------ .. code-block:: python # Object Creation and Representation __new__(cls, ...) # create instance __init__(self, ...) # initialize instance __del__(self) # destructor __repr__(self) # repr(obj) __str__(self) # str(obj) # Comparison __eq__(self, other) # == __ne__(self, other) # != __lt__(self, other) # < __le__(self, other) # <= __gt__(self, other) # > __ge__(self, other) # >= # Arithmetic __add__(self, other) # + __sub__(self, other) # - __mul__(self, other) # * __truediv__(self, other) # / __floordiv__(self, other)# // __mod__(self, other) # % __pow__(self, other) # ** # Container __len__(self) # len(obj) __getitem__(self, key) # obj[key] __setitem__(self, k, v) # obj[key] = value __delitem__(self, key) # del obj[key] __contains__(self, item) # item in obj __iter__(self) # iter(obj) # Attribute Access __getattr__(self, name) # obj.name (when not found) __setattr__(self, n, v) # obj.name = value __delattr__(self, name) # del obj.name # Callable __call__(self, ...) # obj() # Context Manager __enter__(self) # with obj __exit__(self, ...) # exit with block # Descriptor __get__(self, obj, type) # descriptor access __set__(self, obj, val) # descriptor assignment __delete__(self, obj) # descriptor deletion ================================================ FILE: docs/notes/basic/python-rexp.rst ================================================ .. meta:: :description lang=en: Python regex cheat sheet covering re module, pattern matching, groups, lookahead, lookbehind, substitution, and common regex patterns with code examples :keywords: Python, Python3, Python regex, Python regex cheat sheet, regular expression, re module, pattern matching, findall, search, match, sub, lookahead, lookbehind, named groups ================== Regular Expression ================== .. contents:: Table of Contents :backlinks: none Regular expressions (regex) are powerful tools for pattern matching and text manipulation. Python's ``re`` module provides comprehensive support for regex operations. This cheat sheet covers basic matching, groups, lookaround assertions, substitution, and common patterns for validating emails, URLs, IP addresses, etc. Basic Operations ---------------- The ``re`` module provides several functions for pattern matching. Use ``search()`` to find the first match anywhere in the string, ``match()`` to match at the beginning, and ``fullmatch()`` to match the entire string. .. code-block:: python >>> import re >>> # search - find anywhere in string >>> re.search(r'\d+', 'abc123def') >>> # match - match at beginning only >>> re.match(r'\d+', '123abc') >>> re.match(r'\d+', 'abc123') is None True >>> # fullmatch - match entire string >>> re.fullmatch(r'\d+', '123') >>> re.fullmatch(r'\d+', '123abc') is None True ``re.findall()`` - Find All Matches ----------------------------------- The ``findall()`` function returns all non-overlapping matches as a list of strings. If the pattern has groups, it returns a list of tuples. .. code-block:: python >>> # find all words >>> source = "Hello World Ker HAHA" >>> re.findall(r'[\w]+', source) ['Hello', 'World', 'Ker', 'HAHA'] >>> # find all digits >>> re.findall(r'\d+', 'a1b22c333') ['1', '22', '333'] >>> # with groups - returns tuples >>> re.findall(r'(\w+)=(\d+)', 'a=1 b=2 c=3') [('a', '1'), ('b', '2'), ('c', '3')] ``re.split()`` - Split by Pattern --------------------------------- The ``split()`` function splits a string by pattern occurrences. Use ``maxsplit`` to limit the number of splits. .. code-block:: python >>> re.split(r'\s+', 'a b c') ['a', 'b', 'c'] >>> re.split(r'[,;]', 'a,b;c,d') ['a', 'b', 'c', 'd'] >>> re.split(r'(\s+)', 'a b c') # keep delimiters ['a', ' ', 'b', ' ', 'c'] >>> re.split(r'\s+', 'a b c d', maxsplit=2) ['a', 'b', 'c d'] Group Matching -------------- Parentheses ``(...)`` create capturing groups. Use ``group()`` to access matched groups. Group 0 is the entire match, group 1 is the first parenthesized group, etc. .. code-block:: python >>> m = re.search(r'(\d{4})-(\d{2})-(\d{2})', '2016-01-01') >>> m.groups() ('2016', '01', '01') >>> m.group() # entire match '2016-01-01' >>> m.group(1) # first group '2016' >>> m.group(2, 3) # multiple groups ('01', '01') # Nested groups - numbered left to right by opening parenthesis >>> m = re.search(r'(((\d{4})-\d{2})-\d{2})', '2016-01-01') >>> m.groups() ('2016-01-01', '2016-01', '2016') Non-Capturing Group ``(?:...)`` ------------------------------- Use ``(?:...)`` when you need grouping for alternation or quantifiers but don't need to capture the match. This improves performance and keeps group numbering clean. .. code-block:: python >>> url = 'http://stackoverflow.com/' >>> # non-capturing group for protocol >>> m = re.search(r'(?:http|ftp)://([^/\r\n]+)(/[^\r\n]*)?', url) >>> m.groups() ('stackoverflow.com', '/') >>> # capturing group - protocol is captured >>> m = re.search(r'(http|ftp)://([^/\r\n]+)(/[^\r\n]*)?', url) >>> m.groups() ('http', 'stackoverflow.com', '/') Named Groups ``(?P...)`` ------------------------------ Named groups make patterns more readable and allow access by name instead of number. Use ``(?P...)`` to define and ``(?P=name)`` for back reference. .. code-block:: python >>> pattern = r'(?P\d{4})-(?P\d{2})-(?P\d{2})' >>> m = re.search(pattern, '2016-01-01') >>> m.group('year') '2016' >>> m.group('month') '01' >>> m.groupdict() {'year': '2016', 'month': '01', 'day': '01'} # named back reference >>> re.search(r'^(?P[a-z])(?P=char)', 'aa') >>> re.search(r'^(?P[a-z])(?P=char)', 'ab') is None True Back Reference ``\1``, ``\2`` ----------------------------- Back references match the same text as a previous capturing group. Use ``\1`` for the first group, ``\2`` for the second, etc. .. code-block:: python >>> # match repeated characters >>> re.search(r'([a-z])\1', 'aa') is not None True >>> re.search(r'([a-z])\1', 'ab') is not None False >>> # match HTML tags with matching close tag >>> pattern = r'<([^>]+)>[\s\S]*?' >>> re.search(pattern, 'test') is not None True >>> re.search(pattern, 'test') is not None False Substitute with ``re.sub()`` ---------------------------- The ``sub()`` function replaces pattern matches with a replacement string. Use ``\1``, ``\2`` in the replacement to reference captured groups. .. code-block:: python >>> # basic substitution >>> re.sub(r'[a-z]', ' ', '1a2b3c') '1 2 3 ' >>> # substitute with group reference >>> re.sub(r'(\d{4})-(\d{2})-(\d{2})', r'\2/\3/\1', '2016-01-01') '01/01/2016' >>> # using function as replacement >>> re.sub(r'\d+', lambda m: str(int(m.group()) * 2), 'a1b2c3') 'a2b4c6' >>> # camelCase to snake_case >>> def to_snake(s): ... s = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', s) ... return re.sub(r'([a-z])([A-Z])', r'\1_\2', s).lower() ... >>> to_snake('CamelCase') 'camel_case' >>> to_snake('SimpleHTTPServer') 'simple_http_server' Lookahead and Lookbehind ------------------------ Lookaround assertions match a position without consuming characters. They are useful for matching patterns based on context. +---------------+---------------------+---------------------------+ | Notation | Name | Description | +===============+=====================+===========================+ | ``(?=...)`` | Positive lookahead | Followed by ... | +---------------+---------------------+---------------------------+ | ``(?!...)`` | Negative lookahead | Not followed by ... | +---------------+---------------------+---------------------------+ | ``(?<=...)`` | Positive lookbehind | Preceded by ... | +---------------+---------------------+---------------------------+ | ``(?>> # positive lookahead - find word before @ >>> re.findall(r'\w+(?=@)', 'user@example.com') ['user'] >>> # negative lookahead - find digits not followed by px >>> re.findall(r'\d+(?!px)', '12px 34em 56') ['1', '34', '56'] >>> # positive lookbehind - find digits after $ >>> re.findall(r'(?<=\$)\d+', '$100 $200') ['100', '200'] >>> # negative lookbehind - find digits not after $ >>> re.findall(r'(?>> # insert space before groups of 3 digits from right >>> re.sub(r'(?=(\d{3})+$)', ' ', '12345678') ' 12 345 678' Compile Pattern for Reuse ------------------------- Use ``re.compile()`` to create a reusable pattern object. This improves performance when the same pattern is used multiple times. .. code-block:: python >>> pattern = re.compile(r'\d{4}-\d{2}-\d{2}') >>> pattern.search('Date: 2024-01-15') >>> pattern.findall('2024-01-15 and 2024-02-20') ['2024-01-15', '2024-02-20'] Regex Flags ----------- Flags modify pattern behavior. Common flags include ``re.IGNORECASE`` (``re.I``), ``re.MULTILINE`` (``re.M``), ``re.DOTALL`` (``re.S``), and ``re.VERBOSE`` (``re.X``). .. code-block:: python >>> # case insensitive >>> re.findall(r'[a-z]+', 'Hello World', re.I) ['Hello', 'World'] >>> # multiline - ^ and $ match line boundaries >>> re.findall(r'^\w+', 'line1\nline2', re.M) ['line1', 'line2'] >>> # dotall - . matches newline >>> re.search(r'a.b', 'a\nb', re.S) >>> # verbose - allow comments and whitespace >>> pattern = re.compile(r''' ... \d{4} # year ... - ... \d{2} # month ... - ... \d{2} # day ... ''', re.X) >>> pattern.match('2024-01-15') Compare HTML Tags ----------------- Common patterns for matching different types of HTML tags. +------------+--------------+--------------+ | Tag Type | Pattern | Example | +============+==============+==============+ | All tags | <[^>]+> |
, | +------------+--------------+--------------+ | Open tag | <[^/>][^>]*> | , | +------------+--------------+--------------+ | Close tag | ]+> |

, | +------------+--------------+--------------+ | Self-close | <[^/>]+/> |
| +------------+--------------+--------------+ .. code-block:: python >>> # open tag >>> re.search(r'<[^/>][^>]*>', '
') is not None True >>> re.search(r'<[^/>][^>]*>', '
') is not None False >>> # close tag >>> re.search(r']+>', '') is not None True >>> # self-closing tag >>> re.search(r'<[^/>]+/>', '
') is not None True Match Email Address ------------------- A pattern for validating email addresses. Note that fully RFC-compliant email validation is extremely complex; this covers common cases. .. code-block:: python >>> pattern = re.compile(r'^[\w.+-]+@[\w-]+\.[\w.-]+$') >>> pattern.match('hello.world@example.com') is not None True >>> pattern.match('user+tag@sub.domain.org') is not None True >>> pattern.match('invalid@') is not None False Match URL --------- A pattern for matching URLs with optional protocol, domain, and path. .. code-block:: python >>> pattern = re.compile(r''' ... ^(https?://)? # optional protocol ... ([\da-z.-]+) # domain ... \.([a-z.]{2,6}) # TLD ... ([/\w.-]*)*/?$ # path ... ''', re.X | re.I) >>> pattern.match('https://www.example.com/path') is not None True >>> pattern.match('example.com') is not None True Match IP Address ---------------- A pattern for validating IPv4 addresses (0.0.0.0 to 255.255.255.255). .. code-block:: python >>> pattern = re.compile(r''' ... ^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3} ... (?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$ ... ''', re.X) >>> pattern.match('192.168.1.1') is not None True >>> pattern.match('255.255.255.0') is not None True >>> pattern.match('256.0.0.0') is not None False Match MAC Address ----------------- A pattern for validating MAC addresses in colon-separated format. .. code-block:: python >>> pattern = re.compile(r'^([0-9a-f]{2}:){5}[0-9a-f]{2}$', re.I) >>> pattern.match('3c:38:51:05:03:1e') is not None True >>> pattern.match('AA:BB:CC:DD:EE:FF') is not None True Match Phone Number ------------------ Patterns for common phone number formats. .. code-block:: python >>> # US phone number >>> pattern = re.compile(r'^(\+1)?[-.\s]?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}$') >>> pattern.match('123-456-7890') is not None True >>> pattern.match('(123) 456-7890') is not None True >>> pattern.match('+1 123 456 7890') is not None True Match Password Strength ----------------------- Pattern to validate password with minimum requirements: at least 8 characters, one uppercase, one lowercase, one digit, and one special character. .. code-block:: python >>> pattern = re.compile(r''' ... ^(?=.*[a-z]) # at least one lowercase ... (?=.*[A-Z]) # at least one uppercase ... (?=.*\d) # at least one digit ... (?=.*[@$!%*?&]) # at least one special char ... [A-Za-z\d@$!%*?&]{8,}$ # at least 8 chars ... ''', re.X) >>> pattern.match('Passw0rd!') is not None True >>> pattern.match('weakpass') is not None False Simple Lexer ------------ Using regex to build a simple tokenizer for arithmetic expressions. This demonstrates using named groups and ``scanner()`` for lexical analysis. .. code-block:: python >>> from collections import namedtuple >>> tokens = [ ... r'(?P\d+)', ... r'(?P\+)', ... r'(?P-)', ... r'(?P\*)', ... r'(?P/)', ... r'(?P\s+)' ... ] >>> lex = re.compile('|'.join(tokens)) >>> Token = namedtuple('Token', ['type', 'value']) >>> def tokenize(text): ... scan = lex.scanner(text) ... return (Token(m.lastgroup, m.group()) ... for m in iter(scan.match, None) if m.lastgroup != 'WS') ... >>> list(tokenize('9 + 5 * 2')) [Token(type='NUMBER', value='9'), Token(type='PLUS', value='+'), Token(type='NUMBER', value='5'), Token(type='TIMES', value='*'), Token(type='NUMBER', value='2')] Common Patterns Reference ------------------------- .. code-block:: python # Digits only r'^\d+$' # Alphanumeric r'^[a-zA-Z0-9]+$' # Username (3-16 chars, alphanumeric, underscore, hyphen) r'^[a-zA-Z0-9_-]{3,16}$' # Hex color r'^#?([a-fA-F0-9]{6}|[a-fA-F0-9]{3})$' # Date (YYYY-MM-DD) r'^\d{4}-\d{2}-\d{2}$' # Time (HH:MM:SS) r'^\d{2}:\d{2}:\d{2}$' # Slug (URL-friendly string) r'^[a-z0-9]+(?:-[a-z0-9]+)*$' # Remove HTML tags re.sub(r'<[^>]+>', '', html) # Extract domain from URL re.search(r'https?://([^/]+)', url).group(1) # Find all hashtags re.findall(r'#\w+', text) # Find all @mentions re.findall(r'@\w+', text) ================================================ FILE: docs/notes/basic/python-set.rst ================================================ .. meta:: :description lang=en: Python set cheat sheet covering set comprehensions, set operations (union, intersection, difference), removing duplicates, subsets, supersets, and frozenset with code examples :keywords: Python, Python3, Python set, Python set cheat sheet, set comprehension, set operations, union, intersection, difference, symmetric difference, frozenset, subset, superset === Set === .. contents:: Table of Contents :backlinks: none Sets are unordered collections of unique elements in Python. They provide O(1) average time complexity for membership testing and support mathematical set operations like union, intersection, and difference. This cheat sheet covers set comprehensions, common set operations, uniquifying lists, and the immutable frozenset type. The source code is available on `GitHub `_. References ---------- - `Set Types — set, frozenset `_ - `Sets `_ Create a Set ------------ Create sets using curly braces ``{}`` or the ``set()`` constructor. Note that empty curly braces ``{}`` create a dict, not a set. .. code-block:: python >>> s = {1, 2, 3} >>> s {1, 2, 3} >>> s = set([1, 2, 2, 3]) >>> s {1, 2, 3} >>> empty = set() # not {} >>> type(empty) Create Sets with Set Comprehension ---------------------------------- Like list comprehensions, set comprehensions provide a concise way to create sets. The syntax uses curly braces ``{}`` instead of square brackets. .. code-block:: python >>> a = [1, 2, 5, 6, 6, 6, 7] >>> s = {x for x in a} >>> s {1, 2, 5, 6, 7} >>> s = {x for x in a if x > 3} >>> s {5, 6, 7} >>> s = {x ** 2 for x in range(5)} >>> s {0, 1, 4, 9, 16} Remove Duplicates from a List ----------------------------- Converting a list to a set automatically removes duplicate elements. This is one of the most common use cases for sets. .. code-block:: python >>> a = [1, 2, 2, 2, 3, 4, 5, 5] >>> list(set(a)) [1, 2, 3, 4, 5] To preserve the original order, use ``dict.fromkeys()`` (Python 3.7+): .. code-block:: python >>> a = [3, 1, 2, 1, 3, 2] >>> list(dict.fromkeys(a)) [3, 1, 2] Add Items to a Set ------------------ Use ``add()`` to add a single element, or ``update()`` to add multiple elements. .. code-block:: python >>> s = {1, 2, 3} >>> s.add(4) >>> s {1, 2, 3, 4} >>> s.update([5, 6, 7]) >>> s {1, 2, 3, 4, 5, 6, 7} >>> s |= {8, 9} # same as update >>> s {1, 2, 3, 4, 5, 6, 7, 8, 9} Remove Items from a Set ----------------------- Use ``remove()`` to remove an element (raises KeyError if not found), or ``discard()`` to remove without error. Use ``pop()`` to remove an arbitrary element. .. code-block:: python >>> s = {1, 2, 3, 4, 5} >>> s.remove(3) >>> s {1, 2, 4, 5} >>> s.discard(10) # no error if not found >>> s.pop() # remove arbitrary element 1 >>> s.clear() # remove all >>> s set() Union with ``|`` Operator ------------------------- The union of two sets contains all elements from both sets. Use the ``|`` operator or the ``union()`` method. .. code-block:: python >>> a = {1, 2, 3} >>> b = {3, 4, 5} >>> a | b {1, 2, 3, 4, 5} >>> a.union(b) {1, 2, 3, 4, 5} >>> a | b | {6, 7} # multiple sets {1, 2, 3, 4, 5, 6, 7} Intersection with ``&`` Operator -------------------------------- The intersection of two sets contains only elements that exist in both sets. Use the ``&`` operator or the ``intersection()`` method. .. code-block:: python >>> a = {1, 2, 3, 4} >>> b = {3, 4, 5, 6} >>> a & b {3, 4} >>> a.intersection(b) {3, 4} Find Common Elements Between Lists ---------------------------------- Finding common items between two lists is a practical application of set intersection. .. code-block:: python >>> a = [1, 1, 2, 3] >>> b = [3, 5, 5, 6] >>> list(set(a) & set(b)) [3] Difference with ``-`` Operator ------------------------------ The difference of two sets contains elements that are in the first set but not in the second. Use the ``-`` operator or the ``difference()`` method. .. code-block:: python >>> a = {1, 2, 3, 4} >>> b = {3, 4, 5, 6} >>> a - b {1, 2} >>> b - a {5, 6} Symmetric Difference with ``^`` Operator ---------------------------------------- The symmetric difference contains elements that are in either set, but not in both. Use the ``^`` operator or the ``symmetric_difference()`` method. .. code-block:: python >>> a = {1, 2, 3} >>> b = {3, 4, 5} >>> a ^ b {1, 2, 4, 5} Check Subset with ``<=`` Operator --------------------------------- Use ``<=`` or ``issubset()`` to check if all elements of one set are in another. Use ``<`` for proper subset (subset but not equal). .. code-block:: python >>> a = {1, 2} >>> b = {1, 2, 3, 4} >>> a <= b # a is subset of b True >>> a < b # a is proper subset True >>> a <= a # equal sets True >>> a < a # not proper subset False Check Superset with ``>=`` Operator ----------------------------------- Use ``>=`` or ``issuperset()`` to check if a set contains all elements of another. .. code-block:: python >>> a = {1, 2, 3, 4} >>> b = {1, 2} >>> a >= b # a is superset of b True >>> a > b # a is proper superset True Check Disjoint Sets ------------------- Two sets are disjoint if they have no elements in common. Use ``isdisjoint()`` to check. .. code-block:: python >>> a = {1, 2, 3} >>> b = {4, 5, 6} >>> a.isdisjoint(b) True >>> c = {3, 4, 5} >>> a.isdisjoint(c) False Membership Testing ------------------ Sets provide O(1) average time complexity for membership testing, making them much faster than lists for this operation. .. code-block:: python >>> s = {1, 2, 3, 4, 5} >>> 3 in s True >>> 10 in s False >>> 10 not in s True Frozenset - Immutable Set ------------------------- ``frozenset`` is an immutable version of set. It can be used as a dictionary key or as an element of another set. .. code-block:: python >>> fs = frozenset([1, 2, 3]) >>> fs frozenset({1, 2, 3}) >>> fs.add(4) # raises AttributeError AttributeError: 'frozenset' object has no attribute 'add' Use frozenset as dictionary key: .. code-block:: python >>> d = {frozenset([1, 2]): "a", frozenset([3, 4]): "b"} >>> d[frozenset([1, 2])] 'a' Use frozenset in a set: .. code-block:: python >>> s = {frozenset([1, 2]), frozenset([3, 4])} >>> frozenset([1, 2]) in s True Set Operations Summary ---------------------- .. code-block:: python # Creation s = {1, 2, 3} # literal s = set([1, 2, 3]) # from iterable s = {x for x in range(5)} # comprehension # Add/Remove s.add(x) # add single element s.update([x, y]) # add multiple elements s.remove(x) # remove (KeyError if missing) s.discard(x) # remove (no error if missing) s.pop() # remove arbitrary element s.clear() # remove all # Set Operations a | b # union a & b # intersection a - b # difference a ^ b # symmetric difference # Comparisons a <= b # subset a < b # proper subset a >= b # superset a > b # proper superset a.isdisjoint(b) # no common elements # Membership x in s # O(1) lookup x not in s ================================================ FILE: docs/notes/basic/python-typing.rst ================================================ .. meta:: :description lang=en: Python typing cheat sheet covering type hints, annotations, generics, protocols, TypeVar, and mypy type checking with code examples :keywords: Python, Python3, Python typing, Python type hints cheat sheet, type annotations, generics, Protocol, TypeVar, mypy, static typing ====== Typing ====== .. contents:: Table of Contents :backlinks: none PEP `484 `_, which provides a specification about what a type system should look like in Python3, introduced the concept of type hints. Moreover, to better understand the type hints design philosophy, it is crucial to read PEP `483 `_ that would be helpful to aid a pythoneer to understand reasons why Python introduce a type system. The main goal of this cheat sheet is to show some common usage about type hints in Python3. Without type check ------------------- .. code-block:: python def fib(n): a, b = 0, 1 for _ in range(n): yield a b, a = a + b, b print([n for n in fib(3.6)]) output: .. code-block:: bash # errors will not be detected until runtime $ python fib.py Traceback (most recent call last): File "fib.py", line 8, in print([n for n in fib(3.5)]) File "fib.py", line 8, in print([n for n in fib(3.5)]) File "fib.py", line 3, in fib for _ in range(n): TypeError: 'float' object cannot be interpreted as an integer With type check ---------------- .. code-block:: python # give a type hint from typing import Generator def fib(n: int) -> Generator: a: int = 0 b: int = 1 for _ in range(n): yield a b, a = a + b, b print([n for n in fib(3.6)]) output: .. code-block:: bash # errors will be detected before running $ mypy --strict fib.py fib.py:12: error: Argument 1 to "fib" has incompatible type "float"; expected "int" Basic types ----------- .. code-block:: python import io import re from collections import deque, namedtuple from typing import ( Dict, List, Tuple, Set, Deque, NamedTuple, IO, Pattern, Match, Text, Optional, Sequence, Iterable, Mapping, MutableMapping, Any, ) # without initializing x: int # any type y: Any y = 1 y = "1" # built-in var_int: int = 1 var_str: str = "Hello Typing" var_byte: bytes = b"Hello Typing" var_bool: bool = True var_float: float = 1. var_unicode: Text = u'\u2713' # could be none var_could_be_none: Optional[int] = None var_could_be_none = 1 # collections var_set: Set[int] = {i for i in range(3)} var_dict: Dict[str, str] = {"foo": "Foo"} var_list: List[int] = [i for i in range(3)] var_static_length_Tuple: Tuple[int, int, int] = (1, 2, 3) var_dynamic_length_Tuple: Tuple[int, ...] = (i for i in range(10, 3)) var_deque: Deque = deque([1, 2, 3]) var_nametuple: NamedTuple = namedtuple('P', ['x', 'y']) # io var_io_str: IO[str] = io.StringIO("Hello String") var_io_byte: IO[bytes] = io.BytesIO(b"Hello Bytes") var_io_file_str: IO[str] = open(__file__) var_io_file_byte: IO[bytes] = open(__file__, 'rb') # re p: Pattern = re.compile("(https?)://([^/\r\n]+)(/[^\r\n]*)?") m: Optional[Match] = p.match("https://www.python.org/") # duck types: list-like var_seq_list: Sequence[int] = [1, 2, 3] var_seq_tuple: Sequence[int] = (1, 2, 3) var_iter_list: Iterable[int] = [1, 2, 3] var_iter_tuple: Iterable[int] = (1, 2, 3) # duck types: dict-like var_map_dict: Mapping[str, str] = {"foo": "Foo"} var_mutable_dict: MutableMapping[str, str] = {"bar": "Bar"} Functions ---------- .. code-block:: python from typing import Generator, Callable # function def gcd(a: int, b: int) -> int: while b: a, b = b, a % b return a # callback def fun(cb: Callable[[int, int], int]) -> int: return cb(55, 66) # lambda f: Callable[[int], int] = lambda x: x * 2 Classes -------- .. code-block:: python from typing import ClassVar, Dict, List class Foo: x: int = 1 # instance variable. default = 1 y: ClassVar[str] = "class var" # class variable def __init__(self) -> None: self.i: List[int] = [0] def foo(self, a: int, b: str) -> Dict[int, str]: return {a: b} foo = Foo() foo.x = 123 print(foo.x) print(foo.i) print(Foo.y) print(foo.foo(1, "abc")) Generator ---------- .. code-block:: python from typing import Generator # Generator[YieldType, SendType, ReturnType] def fib(n: int) -> Generator[int, None, None]: a: int = 0 b: int = 1 while n > 0: yield a b, a = a + b, b n -= 1 g: Generator = fib(10) i: Iterator[int] = (x for x in range(3)) Asynchronous Generator ----------------------- .. code-block:: python import asyncio from typing import AsyncGenerator, AsyncIterator async def fib(n: int) -> AsyncGenerator: a: int = 0 b: int = 1 while n > 0: await asyncio.sleep(0.1) yield a b, a = a + b, b n -= 1 async def main() -> None: async for f in fib(10): print(f) ag: AsyncIterator = (f async for f in fib(10)) loop = asyncio.get_event_loop() loop.run_until_complete(main()) Context Manager --------------- .. code-block:: python from typing import ContextManager, Generator, IO from contextlib import contextmanager @contextmanager def open_file(name: str) -> Generator: f = open(name) yield f f.close() cm: ContextManager[IO] = open_file(__file__) with cm as f: print(f.read()) Asynchronous Context Manager ----------------------------- .. code-block:: python import asyncio from typing import AsyncContextManager, AsyncGenerator, IO from contextlib import asynccontextmanager # need python 3.7 or above @asynccontextmanager async def open_file(name: str) -> AsyncGenerator: await asyncio.sleep(0.1) f = open(name) yield f await asyncio.sleep(0.1) f.close() async def main() -> None: acm: AsyncContextManager[IO] = open_file(__file__) async with acm as f: print(f.read()) loop = asyncio.get_event_loop() loop.run_until_complete(main()) Avoid ``None`` access ---------------------- .. code-block:: python import re from typing import Pattern, Dict, Optional # like c++ # std::regex url("(https?)://([^/\r\n]+)(/[^\r\n]*)?"); # std::regex color("^#?([a-f0-9]{6}|[a-f0-9]{3})$"); url: Pattern = re.compile("(https?)://([^/\r\n]+)(/[^\r\n]*)?") color: Pattern = re.compile("^#?([a-f0-9]{6}|[a-f0-9]{3})$") x: Dict[str, Pattern] = {"url": url, "color": color} y: Optional[Pattern] = x.get("baz", None) print(y.match("https://www.python.org/")) output: .. code-block:: bash $ mypy --strict foo.py foo.py:15: error: Item "None" of "Optional[Pattern[Any]]" has no attribute "match" Positional-only arguments -------------------------- .. code-block:: python # define arguments with names beginning with __ def fib(__n: int) -> int: # positional only arg a, b = 0, 1 for _ in range(__n): b, a = a + b, b return a def gcd(*, a: int, b: int) -> int: # keyword only arg while b: a, b = b, a % b return a print(fib(__n=10)) # error print(gcd(10, 5)) # error output: .. code-block:: bash mypy --strict foo.py foo.py:1: note: "fib" defined here foo.py:14: error: Unexpected keyword argument "__n" for "fib" foo.py:15: error: Too many positional arguments for "gcd" Multiple return values ----------------------- .. code-block:: python from typing import Tuple, Iterable, Union def foo(x: int, y: int) -> Tuple[int, int]: return x, y # or def bar(x: int, y: str) -> Iterable[Union[int, str]]: # XXX: not recommend declaring in this way return x, y a: int b: int a, b = foo(1, 2) # ok c, d = bar(3, "bar") # ok Optional Type ---------------------------------- .. code-block:: python from typing import List, Union def first(l: List[Union[int, None]]) -> Union[int, None]: return None if len(l) == 0 else l[0] first([None]) # equal to from typing import List, Optional def first(l: List[Optional[int]]) -> Optional[int]: return None if len(l) == 0 else l[0] first([None]) Be careful of ``Optional`` --------------------------- .. code-block:: python from typing import cast, Optional def fib(n): a, b = 0, 1 for _ in range(n): b, a = a + b, b return a def cal(n: Optional[int]) -> None: print(fib(n)) cal(None) output: .. code-block:: bash # mypy will not detect errors $ mypy foo.py Explicitly declare .. code-block:: python from typing import Optional def fib(n: int) -> int: # declare n to be int a, b = 0, 1 for _ in range(n): b, a = a + b, b return a def cal(n: Optional[int]) -> None: print(fib(n)) output: .. code-block:: bash # mypy can detect errors even we do not check None $ mypy --strict foo.py foo.py:11: error: Argument 1 to "fib" has incompatible type "Optional[int]"; expected "int" Be careful of casting ---------------------- .. code-block:: python from typing import cast, Optional def gcd(a: int, b: int) -> int: while b: a, b = b, a % b return a def cal(a: Optional[int], b: Optional[int]) -> None: # XXX: Avoid casting ca, cb = cast(int, a), cast(int, b) print(gcd(ca, cb)) cal(None, None) output: .. code-block:: bash # mypy will not detect type errors $ mypy --strict foo.py Forward references ------------------- Based on PEP 484, if we want to reference a type before it has been declared, we have to use **string literal** to imply that there is a type of that name later on in the file. .. code-block:: python from typing import Optional class Tree: def __init__( self, data: int, left: Optional["Tree"], # Forward references. right: Optional["Tree"] ) -> None: self.data = data self.left = left self.right = right .. note:: There are some issues that mypy does not complain about Forward References. Get further information from `Issue#948`_. .. _Issue\#948: https://github.com/python/mypy/issues/948 .. code-block:: python class A: def __init__(self, a: A) -> None: # should fail self.a = a output: .. code-block:: bash $ mypy --strict type.py $ echo $? 0 $ python type.py # get runtime fail Traceback (most recent call last): File "type.py", line 1, in class A: File "type.py", line 2, in A def __init__(self, a: A) -> None: # should fail NameError: name 'A' is not defined Postponed Evaluation of Annotations ----------------------------------- **New in Python 3.7** - PEP 563_ - Postponed Evaluation of Annotations .. _563: https://www.python.org/dev/peps/pep-0563/ Before Python 3.7 .. code-block:: python >>> class A: ... def __init__(self, a: A) -> None: ... self._a = a ... Traceback (most recent call last): File "", line 1, in File "", line 2, in A NameError: name 'A' is not defined After Python 3.7 (include 3.7) .. code-block:: python >>> from __future__ import annotations >>> class A: ... def __init__(self, a: A) -> None: ... self._a = a ... .. note:: Annotation can only be used within the scope which names have already existed. Therefore, **forward reference** does not support the case which names are not available in the current scope. **Postponed evaluation of annotations** will become the default behavior in Python 4.0. Type Alias ---------- Like ``typedef`` or ``using`` in c/c++ .. code-block:: cpp #include #include #include #include typedef std::string Url; template using Vector = std::vector; int main(int argc, char *argv[]) { Url url = "https://python.org"; std::regex p("(https?)://([^/\r\n]+)(/[^\r\n]*)?"); bool m = std::regex_match(url, p); Vector v = {1, 2}; std::cout << m << std::endl; for (auto it : v) std::cout << it << std::endl; return 0; } Type aliases are defined by simple variable assignments .. code-block:: python import re from typing import Pattern, List # Like typedef, using in c/c++ # PEP 484 recommend capitalizing alias names Url = str url: Url = "https://www.python.org/" p: Pattern = re.compile("(https?)://([^/\r\n]+)(/[^\r\n]*)?") m = p.match(url) Vector = List[int] v: Vector = [1., 2.] Using NewType --------------------- Unlike alias, ``NewType`` returns a separate type but is identical to the original type at runtime. .. code-block:: python from sqlalchemy import Column, String, Integer from sqlalchemy.ext.declarative import declarative_base from typing import NewType, Any # check mypy #2477 Base: Any = declarative_base() # create a new type Id = NewType('Id', int) # not equal alias, it's a 'new type' class User(Base): __tablename__ = 'User' id = Column(Integer, primary_key=True) age = Column(Integer, nullable=False) name = Column(String, nullable=False) def __init__(self, id: Id, age: int, name: str) -> None: self.id = id self.age = age self.name = name # create users user1 = User(Id(1), 62, "Guido van Rossum") # ok user2 = User(2, 48, "David M. Beazley") # error output: .. code-block:: bash $ python foo.py $ mypy --ignore-missing-imports foo.py foo.py:24: error: Argument 1 to "User" has incompatible type "int"; expected "Id" Further reading: - `Issue\#1284`_ .. _`Issue\#1284`: https://github.com/python/mypy/issues/1284 Using ``TypeVar`` as template ------------------------------ Like c++ ``template `` .. code-block:: cpp #include template T add(T x, T y) { return x + y; } int main(int argc, char *argv[]) { std::cout << add(1, 2) << std::endl; std::cout << add(1., 2.) << std::endl; return 0; } Python using ``TypeVar`` .. code-block:: python from typing import TypeVar T = TypeVar("T") def add(x: T, y: T) -> T: return x + y add(1, 2) add(1., 2.) Using ``TypeVar`` and ``Generic`` as class template ---------------------------------------------------- Like c++ ``template class`` .. code-block:: cpp #include template class Foo { public: Foo(T foo) { foo_ = foo; } T Get() { return foo_; } private: T foo_; }; int main(int argc, char *argv[]) { Foo f(123); std::cout << f.Get() << std::endl; return 0; } Define a generic class in Python .. code-block:: python from typing import Generic, TypeVar T = TypeVar("T") class Foo(Generic[T]): def __init__(self, foo: T) -> None: self.foo = foo def get(self) -> T: return self.foo f: Foo[str] = Foo("Foo") v: int = f.get() output: .. code-block:: bash $ mypy --strict foo.py foo.py:13: error: Incompatible types in assignment (expression has type "str", variable has type "int") Scoping rules for ``TypeVar`` ------------------------------ - ``TypeVar`` used in different generic function will be inferred to be different types. .. code-block:: python from typing import TypeVar T = TypeVar("T") def foo(x: T) -> T: return x def bar(y: T) -> T: return y a: int = foo(1) # ok: T is inferred to be int b: int = bar("2") # error: T is inferred to be str output: .. code-block:: bash $ mypy --strict foo.py foo.py:12: error: Incompatible types in assignment (expression has type "str", variable has type "int") - ``TypeVar`` used in a generic class will be inferred to be same types. .. code-block:: python from typing import TypeVar, Generic T = TypeVar("T") class Foo(Generic[T]): def foo(self, x: T) -> T: return x def bar(self, y: T) -> T: return y f: Foo[int] = Foo() a: int = f.foo(1) # ok: T is inferred to be int b: str = f.bar("2") # error: T is expected to be int output: .. code-block:: bash $ mypy --strict foo.py foo.py:15: error: Incompatible types in assignment (expression has type "int", variable has type "str") foo.py:15: error: Argument 1 to "bar" of "Foo" has incompatible type "str"; expected "int" - ``TypeVar`` used in a method but did not match any parameters which declare in ``Generic`` can be inferred to be different types. .. code-block:: python from typing import TypeVar, Generic T = TypeVar("T") S = TypeVar("S") class Foo(Generic[T]): # S does not match params def foo(self, x: T, y: S) -> S: return y def bar(self, z: S) -> S: return z f: Foo[int] = Foo() a: str = f.foo(1, "foo") # S is inferred to be str b: int = f.bar(12345678) # S is inferred to be int output: .. code-block:: bash $ mypy --strict foo.py - ``TypeVar`` should not appear in body of method/function if it is unbound type. .. code-block:: python from typing import TypeVar, Generic T = TypeVar("T") S = TypeVar("S") def foo(x: T) -> None: a: T = x # ok b: S = 123 # error: invalid type output: .. code-block:: bash $ mypy --strict foo.py foo.py:8: error: Invalid type "foo.S" Restricting to a fixed set of possible types ---------------------------------------------- ``T = TypeVar('T', ClassA, ...)`` means we create a **type variable with a value restriction**. .. code-block:: python from typing import TypeVar # restrict T = int or T = float T = TypeVar("T", int, float) def add(x: T, y: T) -> T: return x + y add(1, 2) add(1., 2.) add("1", 2) add("hello", "world") output: .. code-block:: bash # mypy can detect wrong type $ mypy --strict foo.py foo.py:10: error: Value of type variable "T" of "add" cannot be "object" foo.py:11: error: Value of type variable "T" of "add" cannot be "str" ``TypeVar`` with an upper bound -------------------------------- ``T = TypeVar('T', bound=BaseClass)`` means we create a **type variable with an upper bound**. The concept is similar to **polymorphism** in c++. .. code-block:: cpp #include class Shape { public: Shape(double width, double height) { width_ = width; height_ = height; }; virtual double Area() = 0; protected: double width_; double height_; }; class Rectangle: public Shape { public: Rectangle(double width, double height) :Shape(width, height) {}; double Area() { return width_ * height_; }; }; class Triangle: public Shape { public: Triangle(double width, double height) :Shape(width, height) {}; double Area() { return width_ * height_ / 2; }; }; double Area(Shape &s) { return s.Area(); } int main(int argc, char *argv[]) { Rectangle r(1., 2.); Triangle t(3., 4.); std::cout << Area(r) << std::endl; std::cout << Area(t) << std::endl; return 0; } Like c++, create a base class and ``TypeVar`` which bounds to the base class. Then, static type checker will take every subclass as type of base class. .. code-block:: python from typing import TypeVar class Shape: def __init__(self, width: float, height: float) -> None: self.width = width self.height = height def area(self) -> float: return 0 class Rectangle(Shape): def area(self) -> float: width: float = self.width height: float = self.height return width * height class Triangle(Shape): def area(self) -> float: width: float = self.width height: float = self.height return width * height / 2 S = TypeVar("S", bound=Shape) def area(s: S) -> float: return s.area() r: Rectangle = Rectangle(1, 2) t: Triangle = Triangle(3, 4) i: int = 5566 print(area(r)) print(area(t)) print(area(i)) output: .. code-block:: bash $ mypy --strict foo.py foo.py:40: error: Value of type variable "S" of "area" cannot be "int" @overload ---------- Sometimes, we use ``Union`` to infer that the return of a function has multiple different types. However, type checker cannot distinguish which type do we want. Therefore, following snippet shows that type checker cannot determine which type is correct. .. code-block:: python from typing import List, Union class Array(object): def __init__(self, arr: List[int]) -> None: self.arr = arr def __getitem__(self, i: Union[int, str]) -> Union[int, str]: if isinstance(i, int): return self.arr[i] if isinstance(i, str): return str(self.arr[int(i)]) arr = Array([1, 2, 3, 4, 5]) x:int = arr[1] y:str = arr["2"] output: .. code-block:: bash $ mypy --strict foo.py foo.py:16: error: Incompatible types in assignment (expression has type "Union[int, str]", variable has type "int") foo.py:17: error: Incompatible types in assignment (expression has type "Union[int, str]", variable has type "str") Although we can use ``cast`` to solve the problem, it cannot avoid typo and ``cast`` is not safe. .. code-block:: python from typing import List, Union, cast class Array(object): def __init__(self, arr: List[int]) -> None: self.arr = arr def __getitem__(self, i: Union[int, str]) -> Union[int, str]: if isinstance(i, int): return self.arr[i] if isinstance(i, str): return str(self.arr[int(i)]) arr = Array([1, 2, 3, 4, 5]) x: int = cast(int, arr[1]) y: str = cast(str, arr[2]) # typo. we want to assign arr["2"] output: .. code-block:: bash $ mypy --strict foo.py $ echo $? 0 Using ``@overload`` can solve the problem. We can declare the return type explicitly. .. code-block:: python from typing import Generic, List, Union, overload class Array(object): def __init__(self, arr: List[int]) -> None: self.arr = arr @overload def __getitem__(self, i: str) -> str: ... @overload def __getitem__(self, i: int) -> int: ... def __getitem__(self, i: Union[int, str]) -> Union[int, str]: if isinstance(i, int): return self.arr[i] if isinstance(i, str): return str(self.arr[int(i)]) arr = Array([1, 2, 3, 4, 5]) x: int = arr[1] y: str = arr["2"] output: .. code-block:: bash $ mypy --strict foo.py $ echo $? 0 .. warning:: Based on PEP 484, the ``@overload`` decorator just **for type checker only**, it does not implement the real overloading like c++/java. Thus, we have to implement one exactly non-``@overload`` function. At the runtime, calling the ``@overload`` function will raise ``NotImplementedError``. .. code-block:: python from typing import List, Union, overload class Array(object): def __init__(self, arr: List[int]) -> None: self.arr = arr @overload def __getitem__(self, i: Union[int, str]) -> Union[int, str]: if isinstance(i, int): return self.arr[i] if isinstance(i, str): return str(self.arr[int(i)]) arr = Array([1, 2, 3, 4, 5]) try: x: int = arr[1] except NotImplementedError as e: print("NotImplementedError") output: .. code-block:: bash $ python foo.py NotImplementedError Stub Files ---------- Stub files just like header files which we usually use to define our interfaces in c/c++. In python, we can define our interfaces in the same module directory or ``export MYPYPATH=${stubs}`` First, we need to create a stub file (interface file) for module. .. code-block:: bash $ mkdir fib $ touch fib/__init__.py fib/__init__.pyi Then, define the interface of the function in ``__init__.pyi`` and implement the module. .. code-block:: python # fib/__init__.pyi def fib(n: int) -> int: ... # fib/__init__.py def fib(n): a, b = 0, 1 for _ in range(n): b, a = a + b, b return a Then, write a test.py for testing ``fib`` module. .. code-block:: python # touch test.py import sys from pathlib import Path p = Path(__file__).parent / "fib" sys.path.append(str(p)) from fib import fib print(fib(10.0)) output: .. code-block:: bash $ mypy --strict test.py test.py:10: error: Argument 1 to "fib" has incompatible type "float"; expected "int" ================================================ FILE: docs/notes/basic/python-unicode.rst ================================================ .. meta:: :description lang=en: Python Unicode tutorial covering string encoding, decoding, UTF-8, ASCII, bytes conversion, and character handling in Python 3 :keywords: Python, Python3, Unicode, UTF-8, encoding, decoding, bytes, string, ASCII, character, codec ======= Unicode ======= .. contents:: Table of Contents :backlinks: none The main goal of this cheat sheet is to collect some common snippets which are related to Unicode. In Python 3, strings are represented by Unicode instead of bytes. Further information can be found on PEP `3100 `_ **ASCII** code is the most well-known standard which defines numeric codes for characters. The numeric values only define 128 characters originally, so ASCII only contains control codes, digits, lowercase letters, uppercase letters, etc. However, it is not enough for us to represent characters such as accented characters, Chinese characters, or emoji existed around the world. Therefore, **Unicode** was developed to solve this issue. It defines the *code point* to represent various characters like ASCII but the number of characters is up to 1,111,998. String ------ In Python 2, strings are represented in *bytes*, not *Unicode*. Python provides different types of string such as Unicode string, raw string, and so on. In this case, if we want to declare a Unicode string, we add ``u`` prefix for string literals. .. code-block:: python >>> s = 'Café' # byte string >>> s 'Caf\xc3\xa9' >>> type(s) >>> u = u'Café' # unicode string >>> u u'Caf\xe9' >>> type(u) In Python 3, strings are represented in *Unicode*. If we want to represent a byte string, we add the ``b`` prefix for string literals. Note that the early Python versions (3.0-3.2) do not support the ``u`` prefix. In order to ease the pain to migrate Unicode aware applications from Python 2, Python 3.3 once again supports the ``u`` prefix for string literals. Further information can be found on PEP `414 `_ .. code-block:: python >>> s = 'Café' >>> type(s) >>> s 'Café' >>> s.encode('utf-8') b'Caf\xc3\xa9' >>> s.encode('utf-8').decode('utf-8') 'Café' Characters ---------- Python 2 takes all string characters as bytes. In this case, the length of strings may be not equivalent to the number of characters. For example, the length of ``Café`` is 5, not 4 because ``é`` is encoded as a 2 bytes character. .. code-block:: python >>> s= 'Café' >>> print([_c for _c in s]) ['C', 'a', 'f', '\xc3', '\xa9'] >>> len(s) 5 >>> s = u'Café' >>> print([_c for _c in s]) [u'C', u'a', u'f', u'\xe9'] >>> len(s) 4 Python 3 takes all string characters as Unicode code point. The lenght of a string is always equivalent to the number of characters. .. code-block:: python >>> s = 'Café' >>> print([_c for _c in s]) ['C', 'a', 'f', 'é'] >>> len(s) 4 >>> bs = bytes(s, encoding='utf-8') >>> print(bs) b'Caf\xc3\xa9' >>> len(bs) 5 Porting unicode(s, 'utf-8') --------------------------- The `unicode() `_ built-in function was removed in Python 3 so what is the best way to convert the expression ``unicode(s, 'utf-8')`` so it works in both Python 2 and 3? In Python 2: .. code-block:: python >>> s = 'Café' >>> unicode(s, 'utf-8') u'Caf\xe9' >>> s.decode('utf-8') u'Caf\xe9' >>> unicode(s, 'utf-8') == s.decode('utf-8') True In Python 3: .. code-block:: python >>> s = 'Café' >>> s.decode('utf-8') AttributeError: 'str' object has no attribute 'decode' So, the real answer is... Unicode Code Point ------------------ `ord `_ is a powerful built-in function to get a Unicode code point from a given character. Consequently, If we want to check a Unicode code point of a character, we can use ``ord``. .. code-block:: python >>> s = u'Café' >>> for _c in s: print('U+%04x' % ord(_c)) ... U+0043 U+0061 U+0066 U+00e9 >>> u = '中文' >>> for _c in u: print('U+%04x' % ord(_c)) ... U+4e2d U+6587 Encoding -------- A *Unicode code point* transfers to a *byte string* is called encoding. .. code-block:: python >>> s = u'Café' >>> type(s.encode('utf-8')) Decoding --------- A *byte string* transfers to a *Unicode code point* is called decoding. .. code-block:: python >>> s = bytes('Café', encoding='utf-8') >>> s.decode('utf-8') 'Café' Unicode Normalization --------------------- Some characters can be represented in two similar form. For example, the character, ``é`` can be written as ``e ́`` (Canonical Decomposition) or ``é`` (Canonical Composition). In this case, we may acquire unexpected results when we are comparing two strings even though they look alike. Therefore, we can normalize a Unicode form to solve the issue. .. code-block:: python # python 3 >>> u1 = 'Café' # unicode string >>> u2 = 'Cafe\u0301' >>> u1, u2 ('Café', 'Café') >>> len(u1), len(u2) (4, 5) >>> u1 == u2 False >>> u1.encode('utf-8') # get u1 byte string b'Caf\xc3\xa9' >>> u2.encode('utf-8') # get u2 byte string b'Cafe\xcc\x81' >>> from unicodedata import normalize >>> s1 = normalize('NFC', u1) # get u1 NFC format >>> s2 = normalize('NFC', u2) # get u2 NFC format >>> s1 == s2 True >>> s1.encode('utf-8'), s2.encode('utf-8') (b'Caf\xc3\xa9', b'Caf\xc3\xa9') >>> s1 = normalize('NFD', u1) # get u1 NFD format >>> s2 = normalize('NFD', u2) # get u2 NFD format >>> s1, s2 ('Café', 'Café') >>> s1 == s2 True >>> s1.encode('utf-8'), s2.encode('utf-8') (b'Cafe\xcc\x81', b'Cafe\xcc\x81') Avoid ``UnicodeDecodeError`` ---------------------------- Python raises `UnicodeDecodeError` when byte strings cannot decode to Unicode code points. If we want to avoid this exception, we can pass *replace*, *backslashreplace*, or *ignore* to errors argument in `decode `_. .. code-block:: python >>> u = b"\xff" >>> u.decode('utf-8', 'strict') Traceback (most recent call last): File "", line 1, in UnicodeDecodeError: 'utf-8' codec can't decode byte 0xff in position 0: invalid start byte >>> # use U+FFFD, REPLACEMENT CHARACTER >>> u.decode('utf-8', "replace") '\ufffd' >>> # inserts a \xNN escape sequence >>> u.decode('utf-8', "backslashreplace") '\\xff' >>> # leave the character out of the Unicode result >>> u.decode('utf-8', "ignore") '' Long String ----------- The following snippet shows common ways to declare a multi-line string in Python. .. code-block:: python # original long string s = 'This is a very very very long python string' # Single quote with an escaping backslash s = "This is a very very very " \ "long python string" # Using brackets s = ( "This is a very very very " "long python string" ) # Using ``+`` s = ( "This is a very very very " + "long python string" ) # Using triple-quote with an escaping backslash s = '''This is a very very very \ long python string''' ================================================ FILE: docs/notes/concurrency/index.rst ================================================ .. meta:: :description lang=en: Python concurrency tutorial covering threading, multiprocessing, locks, semaphores, queues, process pools, and concurrent.futures :keywords: Python, Python3, threading, multiprocessing, concurrency, parallel, lock, semaphore, queue, ThreadPoolExecutor, ProcessPoolExecutor, GIL Concurrency =========== Python provides multiple approaches for concurrent execution to handle CPU-bound and I/O-bound tasks efficiently. The ``threading`` module enables lightweight concurrent execution within a single process, while ``multiprocessing`` bypasses the Global Interpreter Lock (GIL) by using separate processes for true parallelism. The ``concurrent.futures`` module offers a high-level interface that abstracts the differences between threads and processes behind a unified API. Understanding when to use each approach is crucial: threads excel at I/O-bound tasks (network requests, file operations) where the GIL is released during waiting, while processes are better for CPU-bound tasks (computation, data processing) where true parallel execution is needed. .. toctree:: :maxdepth: 1 python-threading python-multiprocessing python-futures ================================================ FILE: docs/notes/concurrency/python-futures.rst ================================================ .. meta:: :description lang=en: Python concurrent.futures tutorial covering ThreadPoolExecutor, ProcessPoolExecutor, Future objects, callbacks, and high-level parallel execution patterns :keywords: Python, Python3, concurrent.futures, ThreadPoolExecutor, ProcessPoolExecutor, Future, executor, submit, map, as_completed, parallel ================== concurrent.futures ================== :Source: `src/basic/concurrency_.py `_ .. contents:: Table of Contents :backlinks: none Introduction ------------ The ``concurrent.futures`` module provides a high-level interface for asynchronously executing callables using threads or processes. It abstracts the differences between threading and multiprocessing behind a unified API, making it easy to switch between them. The module introduces two key concepts: **Executors** that manage pools of workers, and **Futures** that represent the eventual result of an asynchronous operation. ThreadPoolExecutor Basics ------------------------- ``ThreadPoolExecutor`` manages a pool of threads that execute tasks concurrently. Use it for I/O-bound tasks like network requests, file operations, or database queries where threads spend time waiting for external resources. .. code-block:: python from concurrent.futures import ThreadPoolExecutor import time def fetch_url(url): """Simulate fetching a URL.""" time.sleep(1) # Simulate network delay return f"Content from {url}" urls = ["http://site1.com", "http://site2.com", "http://site3.com"] # Sequential - takes ~3 seconds start = time.time() results = [fetch_url(url) for url in urls] print(f"Sequential: {time.time() - start:.2f}s") # Concurrent - takes ~1 second start = time.time() with ThreadPoolExecutor(max_workers=3) as executor: results = list(executor.map(fetch_url, urls)) print(f"Concurrent: {time.time() - start:.2f}s") ProcessPoolExecutor Basics -------------------------- ``ProcessPoolExecutor`` manages a pool of processes for true parallel execution. Use it for CPU-bound tasks like data processing, calculations, or image manipulation where you need to utilize multiple CPU cores. .. code-block:: python from concurrent.futures import ProcessPoolExecutor import time def cpu_intensive(n): """CPU-bound computation.""" return sum(i * i for i in range(n)) if __name__ == "__main__": numbers = [10**7] * 4 # Sequential start = time.time() results = [cpu_intensive(n) for n in numbers] print(f"Sequential: {time.time() - start:.2f}s") # Parallel with processes start = time.time() with ProcessPoolExecutor(max_workers=4) as executor: results = list(executor.map(cpu_intensive, numbers)) print(f"Parallel: {time.time() - start:.2f}s") Using submit() and Future Objects --------------------------------- The ``submit()`` method schedules a callable and returns a ``Future`` object immediately. The Future represents the pending result and provides methods to check status, get the result, or cancel the task. This gives more control than ``map()`` for handling individual tasks. .. code-block:: python from concurrent.futures import ThreadPoolExecutor import time def task(name, duration): time.sleep(duration) return f"{name} completed in {duration}s" with ThreadPoolExecutor(max_workers=3) as executor: # Submit tasks - returns Future immediately future1 = executor.submit(task, "Task A", 2) future2 = executor.submit(task, "Task B", 1) future3 = executor.submit(task, "Task C", 3) # Check if done (non-blocking) print(f"Task A done: {future1.done()}") # Get result (blocking) print(future2.result()) # Waits for completion print(future1.result()) print(future3.result()) Processing Results as They Complete ----------------------------------- ``as_completed()`` yields futures as they complete, regardless of submission order. This is useful when you want to process results as soon as they're available rather than waiting for all tasks to finish. .. code-block:: python from concurrent.futures import ThreadPoolExecutor, as_completed import time import random def fetch_data(source_id): delay = random.uniform(0.5, 2.0) time.sleep(delay) return f"Data from source {source_id} (took {delay:.2f}s)" sources = range(5) with ThreadPoolExecutor(max_workers=5) as executor: # Submit all tasks future_to_source = { executor.submit(fetch_data, src): src for src in sources } # Process results as they complete for future in as_completed(future_to_source): source = future_to_source[future] try: result = future.result() print(f"Source {source}: {result}") except Exception as e: print(f"Source {source} failed: {e}") Using wait() for Completion Control ----------------------------------- ``wait()`` blocks until specified futures complete. You can wait for all tasks, the first task, or the first exception. This provides fine-grained control over when to proceed. .. code-block:: python from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED, ALL_COMPLETED import time def task(task_id, duration): time.sleep(duration) return f"Task {task_id} done" with ThreadPoolExecutor(max_workers=3) as executor: futures = [ executor.submit(task, 1, 3), executor.submit(task, 2, 1), executor.submit(task, 3, 2), ] # Wait for first to complete done, not_done = wait(futures, return_when=FIRST_COMPLETED) print(f"First completed: {done.pop().result()}") print(f"Still running: {len(not_done)}") # Wait for all remaining done, not_done = wait(not_done, return_when=ALL_COMPLETED) for f in done: print(f"Completed: {f.result()}") Adding Callbacks to Futures --------------------------- Callbacks are functions that execute automatically when a future completes. They're useful for processing results without blocking the main thread or for chaining operations. The callback receives the future as its argument. .. code-block:: python from concurrent.futures import ThreadPoolExecutor import time def compute(n): time.sleep(1) return n * n def on_complete(future): """Callback executed when future completes.""" try: result = future.result() print(f"Callback: result is {result}") except Exception as e: print(f"Callback: task failed with {e}") with ThreadPoolExecutor(max_workers=3) as executor: for i in range(5): future = executor.submit(compute, i) future.add_done_callback(on_complete) # Main thread continues while callbacks fire print("Main thread: tasks submitted") time.sleep(2) print("Main thread: done waiting") Exception Handling ------------------ Exceptions raised in tasks are captured and re-raised when you call ``result()``. You can also check for exceptions using ``exception()``. Always handle exceptions to prevent silent failures. .. code-block:: python from concurrent.futures import ThreadPoolExecutor, as_completed def risky_task(n): if n == 3: raise ValueError(f"Bad value: {n}") return n * 2 with ThreadPoolExecutor(max_workers=3) as executor: futures = {executor.submit(risky_task, i): i for i in range(5)} for future in as_completed(futures): n = futures[future] try: result = future.result() print(f"Task {n}: {result}") except ValueError as e: print(f"Task {n} failed: {e}") # Alternative: check exception without raising future = executor.submit(risky_task, 3) future.result() # Wait for completion if future.exception() is not None: print(f"Exception occurred: {future.exception()}") Timeout Handling ---------------- Both ``result()`` and ``as_completed()`` accept timeout parameters. If a task doesn't complete within the timeout, a ``TimeoutError`` is raised. This prevents indefinite blocking on slow or stuck tasks. .. code-block:: python from concurrent.futures import ThreadPoolExecutor, TimeoutError, as_completed import time def slow_task(duration): time.sleep(duration) return f"Completed after {duration}s" with ThreadPoolExecutor(max_workers=2) as executor: future = executor.submit(slow_task, 5) try: # Wait max 2 seconds for result result = future.result(timeout=2) print(result) except TimeoutError: print("Task timed out!") # Note: task continues running in background # Timeout with as_completed futures = [executor.submit(slow_task, i) for i in [1, 3, 5]] try: for future in as_completed(futures, timeout=2): print(future.result()) except TimeoutError: print("Some tasks didn't complete in time") Cancelling Tasks ---------------- Tasks can be cancelled before they start executing using ``cancel()``. Once a task has started, it cannot be cancelled. Check ``cancelled()`` to see if cancellation succeeded. .. code-block:: python from concurrent.futures import ThreadPoolExecutor import time def long_task(n): time.sleep(2) return n with ThreadPoolExecutor(max_workers=1) as executor: # Submit multiple tasks to single worker future1 = executor.submit(long_task, 1) future2 = executor.submit(long_task, 2) # Queued, not started future3 = executor.submit(long_task, 3) # Queued, not started time.sleep(0.1) # Let first task start # Try to cancel queued tasks cancelled2 = future2.cancel() cancelled3 = future3.cancel() print(f"Future 2 cancelled: {cancelled2}") # True print(f"Future 3 cancelled: {cancelled3}") # True print(f"Future 1 cancelled: {future1.cancel()}") # False (already running) Executor Context Manager ------------------------ Using executors as context managers (``with`` statement) ensures proper cleanup. When exiting the context, ``shutdown(wait=True)`` is called automatically, which waits for all pending tasks to complete before returning. .. code-block:: python from concurrent.futures import ThreadPoolExecutor import time def task(n): time.sleep(1) return n * 2 # Context manager - automatic cleanup with ThreadPoolExecutor(max_workers=3) as executor: futures = [executor.submit(task, i) for i in range(5)] # Executor waits for all tasks when exiting 'with' block print("All tasks completed") # Manual management (not recommended) executor = ThreadPoolExecutor(max_workers=3) try: futures = [executor.submit(task, i) for i in range(5)] finally: executor.shutdown(wait=True) # Must call explicitly Map with Chunking ----------------- For large iterables, ``map()`` can be more efficient with chunking. The ``chunksize`` parameter groups items together, reducing overhead from inter-process communication when using ``ProcessPoolExecutor``. .. code-block:: python from concurrent.futures import ProcessPoolExecutor import time def process_item(x): return x * x if __name__ == "__main__": items = range(100000) # Without chunking - more IPC overhead start = time.time() with ProcessPoolExecutor(max_workers=4) as executor: results = list(executor.map(process_item, items)) print(f"No chunking: {time.time() - start:.2f}s") # With chunking - less IPC overhead start = time.time() with ProcessPoolExecutor(max_workers=4) as executor: results = list(executor.map(process_item, items, chunksize=1000)) print(f"With chunking: {time.time() - start:.2f}s") Real-World Example: Parallel Downloads -------------------------------------- This example demonstrates a practical use case: downloading multiple files concurrently with progress tracking, error handling, and timeout management. .. code-block:: python from concurrent.futures import ThreadPoolExecutor, as_completed import urllib.request import time def download(url, timeout=10): """Download URL content with timeout.""" try: with urllib.request.urlopen(url, timeout=timeout) as response: content = response.read() return url, len(content), None except Exception as e: return url, 0, str(e) urls = [ "https://www.python.org", "https://www.github.com", "https://www.google.com", "https://httpbin.org/delay/5", # Slow endpoint ] print("Starting downloads...") start = time.time() with ThreadPoolExecutor(max_workers=4) as executor: future_to_url = {executor.submit(download, url): url for url in urls} for future in as_completed(future_to_url, timeout=15): url, size, error = future.result() if error: print(f"FAILED: {url} - {error}") else: print(f"OK: {url} - {size} bytes") print(f"Total time: {time.time() - start:.2f}s") ================================================ FILE: docs/notes/concurrency/python-multiprocessing.rst ================================================ .. meta:: :description lang=en: Python multiprocessing tutorial covering process creation, pools, shared memory, inter-process communication, and parallel CPU-bound task execution :keywords: Python, Python3, multiprocessing, Process, Pool, Queue, Pipe, shared memory, parallel, CPU-bound, GIL bypass =============== Multiprocessing =============== :Source: `src/basic/concurrency_.py `_ .. contents:: Table of Contents :backlinks: none Introduction ------------ The ``multiprocessing`` module enables true parallel execution by spawning separate Python processes, each with its own Python interpreter and memory space. Unlike threads, processes bypass the Global Interpreter Lock (GIL), making multiprocessing ideal for CPU-bound tasks that need to utilize multiple CPU cores. The trade-off is higher overhead for process creation and inter-process communication compared to threads. Creating Processes ------------------ Creating processes is similar to creating threads. Each process runs in its own memory space, so changes to variables in one process don't affect others. Use ``start()`` to begin execution and ``join()`` to wait for completion. .. code-block:: python from multiprocessing import Process import os def worker(name): print(f"Worker {name}, PID: {os.getpid()}") if __name__ == "__main__": processes = [] for i in range(4): p = Process(target=worker, args=(i,)) processes.append(p) p.start() for p in processes: p.join() print(f"Main process PID: {os.getpid()}") Process Pool ------------ A ``Pool`` manages a collection of worker processes and distributes tasks among them. This is more efficient than creating a new process for each task, as processes are reused. The pool provides methods like ``map()``, ``apply()``, and their async variants for different use cases. .. code-block:: python from multiprocessing import Pool import time def cpu_intensive(n): """Simulate CPU-bound work.""" total = 0 for i in range(n): total += i * i return total if __name__ == "__main__": numbers = [10**6, 10**6, 10**6, 10**6] # Sequential execution start = time.time() results = [cpu_intensive(n) for n in numbers] print(f"Sequential: {time.time() - start:.2f}s") # Parallel execution with Pool start = time.time() with Pool(4) as pool: results = pool.map(cpu_intensive, numbers) print(f"Parallel: {time.time() - start:.2f}s") Pool Methods ------------ The Pool class provides several methods for distributing work. ``map()`` applies a function to each item in an iterable and returns results in order. ``apply()`` calls a function with arguments and blocks until complete. The ``_async`` variants return immediately with an ``AsyncResult`` object. .. code-block:: python from multiprocessing import Pool def square(x): return x * x def add(a, b): return a + b if __name__ == "__main__": with Pool(4) as pool: # map - apply function to iterable results = pool.map(square, range(10)) print(f"map: {results}") # starmap - unpack arguments from iterable pairs = [(1, 2), (3, 4), (5, 6)] results = pool.starmap(add, pairs) print(f"starmap: {results}") # apply_async - non-blocking single call result = pool.apply_async(square, (10,)) print(f"apply_async: {result.get()}") # map_async - non-blocking map result = pool.map_async(square, range(5)) print(f"map_async: {result.get()}") Sharing Data with Queue ----------------------- Processes don't share memory by default. ``multiprocessing.Queue`` provides a thread and process-safe way to exchange data between processes. It's the recommended approach for most inter-process communication scenarios. .. code-block:: python from multiprocessing import Process, Queue def producer(q, items): for item in items: q.put(item) print(f"Produced: {item}") q.put(None) # Sentinel def consumer(q): while True: item = q.get() if item is None: break print(f"Consumed: {item}") if __name__ == "__main__": q = Queue() items = list(range(5)) p1 = Process(target=producer, args=(q, items)) p2 = Process(target=consumer, args=(q,)) p1.start() p2.start() p1.join() p2.join() Sharing Data with Pipe ---------------------- A ``Pipe`` creates a two-way communication channel between two processes. It's simpler and faster than Queue for point-to-point communication but only supports two endpoints. Each end can send and receive data. .. code-block:: python from multiprocessing import Process, Pipe def sender(conn): conn.send("Hello from sender") conn.send([1, 2, 3]) response = conn.recv() print(f"Sender received: {response}") conn.close() def receiver(conn): msg = conn.recv() print(f"Receiver got: {msg}") data = conn.recv() print(f"Receiver got: {data}") conn.send("Thanks!") conn.close() if __name__ == "__main__": parent_conn, child_conn = Pipe() p1 = Process(target=sender, args=(parent_conn,)) p2 = Process(target=receiver, args=(child_conn,)) p1.start() p2.start() p1.join() p2.join() Shared Memory with Value and Array ---------------------------------- For simple shared state, ``Value`` and ``Array`` provide shared memory that multiple processes can access. These are faster than Queue/Pipe for frequently accessed data but require careful synchronization to avoid race conditions. .. code-block:: python from multiprocessing import Process, Value, Array def increment(counter, lock_needed=True): for _ in range(10000): with counter.get_lock(): counter.value += 1 def modify_array(arr): for i in range(len(arr)): arr[i] = arr[i] * 2 if __name__ == "__main__": # Shared integer counter = Value('i', 0) # 'i' = signed int processes = [Process(target=increment, args=(counter,)) for _ in range(4)] for p in processes: p.start() for p in processes: p.join() print(f"Counter: {counter.value}") # 40000 # Shared array arr = Array('d', [1.0, 2.0, 3.0, 4.0]) # 'd' = double p = Process(target=modify_array, args=(arr,)) p.start() p.join() print(f"Array: {list(arr)}") # [2.0, 4.0, 6.0, 8.0] Manager for Complex Shared Objects ---------------------------------- A ``Manager`` provides a way to share more complex Python objects (lists, dicts) between processes. The manager runs a server process that holds the actual objects, and other processes access them through proxies. This is slower than Value/Array but supports arbitrary Python objects. .. code-block:: python from multiprocessing import Process, Manager def worker(shared_dict, shared_list, worker_id): shared_dict[worker_id] = worker_id * 10 shared_list.append(worker_id) if __name__ == "__main__": with Manager() as manager: shared_dict = manager.dict() shared_list = manager.list() processes = [] for i in range(4): p = Process(target=worker, args=(shared_dict, shared_list, i)) processes.append(p) p.start() for p in processes: p.join() print(f"Dict: {dict(shared_dict)}") print(f"List: {list(shared_list)}") Process Synchronization ----------------------- Multiprocessing provides the same synchronization primitives as threading: ``Lock``, ``RLock``, ``Semaphore``, ``Event``, ``Condition``, and ``Barrier``. These work across processes instead of threads. .. code-block:: python from multiprocessing import Process, Lock, Value def safe_increment(counter, lock): for _ in range(10000): with lock: counter.value += 1 if __name__ == "__main__": lock = Lock() counter = Value('i', 0) processes = [ Process(target=safe_increment, args=(counter, lock)) for _ in range(4) ] for p in processes: p.start() for p in processes: p.join() print(f"Counter: {counter.value}") # 40000 Daemon Processes ---------------- Like daemon threads, daemon processes are terminated when the main process exits. They're useful for background tasks that shouldn't prevent program termination. Set ``daemon=True`` before calling ``start()``. .. code-block:: python from multiprocessing import Process import time def background_task(): while True: print("Background process running...") time.sleep(1) if __name__ == "__main__": p = Process(target=background_task, daemon=True) p.start() time.sleep(3) print("Main process exiting, daemon will be terminated") Handling Process Termination ---------------------------- Processes can be terminated gracefully using ``terminate()`` or forcefully using ``kill()``. Always clean up resources properly and consider using signals for graceful shutdown in production code. .. code-block:: python from multiprocessing import Process import time import signal def long_running_task(): try: while True: print("Working...") time.sleep(1) except KeyboardInterrupt: print("Graceful shutdown") if __name__ == "__main__": p = Process(target=long_running_task) p.start() time.sleep(3) # Graceful termination (SIGTERM) p.terminate() p.join(timeout=2) # Force kill if still alive if p.is_alive(): p.kill() p.join() print(f"Exit code: {p.exitcode}") ProcessPoolExecutor ------------------- ``concurrent.futures.ProcessPoolExecutor`` provides a higher-level interface for process pools that's consistent with ``ThreadPoolExecutor``. It's often easier to use than ``multiprocessing.Pool`` and integrates well with the futures pattern. .. code-block:: python from concurrent.futures import ProcessPoolExecutor, as_completed def compute(n): return sum(i * i for i in range(n)) if __name__ == "__main__": numbers = [10**6, 10**6, 10**6, 10**6] with ProcessPoolExecutor(max_workers=4) as executor: # Submit individual tasks futures = [executor.submit(compute, n) for n in numbers] # Process results as they complete for future in as_completed(futures): print(f"Result: {future.result()}") # Or use map for ordered results results = list(executor.map(compute, numbers)) print(f"All results: {results}") Comparing Threads vs Processes ------------------------------ Choose threads for I/O-bound tasks (network, file I/O) where the GIL is released during waiting. Choose processes for CPU-bound tasks that need true parallelism. This example demonstrates the performance difference. .. code-block:: python from threading import Thread from multiprocessing import Process, Pool import time def cpu_bound(n): """CPU-intensive task.""" return sum(i * i for i in range(n)) if __name__ == "__main__": n = 10**7 count = 4 # Sequential start = time.time() for _ in range(count): cpu_bound(n) print(f"Sequential: {time.time() - start:.2f}s") # Threads (limited by GIL) start = time.time() threads = [Thread(target=cpu_bound, args=(n,)) for _ in range(count)] for t in threads: t.start() for t in threads: t.join() print(f"Threads: {time.time() - start:.2f}s") # Processes (true parallelism) start = time.time() with Pool(count) as pool: pool.map(cpu_bound, [n] * count) print(f"Processes: {time.time() - start:.2f}s") ================================================ FILE: docs/notes/concurrency/python-threading.rst ================================================ .. meta:: :description lang=en: Python threading tutorial covering thread creation, synchronization primitives, locks, semaphores, events, conditions, and thread-safe data structures :keywords: Python, Python3, threading, Thread, Lock, RLock, Semaphore, Event, Condition, synchronization, GIL, concurrent, parallel ========= Threading ========= :Source: `src/basic/concurrency_.py `_ .. contents:: Table of Contents :backlinks: none Introduction ------------ The ``threading`` module provides a high-level interface for creating and managing threads in Python. Threads are lightweight units of execution that share the same memory space within a process, making them efficient for I/O-bound tasks where the program spends time waiting for external resources. However, due to Python's Global Interpreter Lock (GIL), threads cannot achieve true parallelism for CPU-bound tasks—only one thread can execute Python bytecode at a time. For CPU-intensive work, consider using ``multiprocessing`` instead. Creating Threads ---------------- There are two primary ways to create threads: subclassing ``Thread`` or passing a target function. The function-based approach is more flexible and commonly used, while subclassing is useful when you need to encapsulate thread state and behavior in a class. .. code-block:: python from threading import Thread # Method 1: Subclass Thread class Worker(Thread): def __init__(self, worker_id): super().__init__() self.worker_id = worker_id def run(self): print(f"Worker {self.worker_id} running") # Method 2: Pass target function (preferred) def task(worker_id): print(f"Task {worker_id} running") # Using subclass t1 = Worker(1) t1.start() t1.join() # Using target function t2 = Thread(target=task, args=(2,)) t2.start() t2.join() Thread with Return Value ------------------------ Threads don't directly return values from their target functions. To get results back, you can use shared mutable objects, queues, or store results as instance attributes when subclassing Thread. .. code-block:: python from threading import Thread from queue import Queue def compute(n, results): """Store result in shared dict.""" results[n] = n * n # Using shared dictionary results = {} threads = [] for i in range(5): t = Thread(target=compute, args=(i, results)) threads.append(t) t.start() for t in threads: t.join() print(results) # {0: 0, 1: 1, 2: 4, 3: 9, 4: 16} # Using Queue (thread-safe) def compute_queue(n, q): q.put((n, n * n)) q = Queue() threads = [] for i in range(5): t = Thread(target=compute_queue, args=(i, q)) threads.append(t) t.start() for t in threads: t.join() while not q.empty(): n, result = q.get() print(f"{n}: {result}") Daemon Threads -------------- Daemon threads run in the background and are automatically terminated when all non-daemon threads have finished. They're useful for background tasks that shouldn't prevent the program from exiting, such as monitoring or cleanup tasks. .. code-block:: python from threading import Thread import time def background_task(): while True: print("Background task running...") time.sleep(1) # Daemon thread - won't prevent program exit t = Thread(target=background_task, daemon=True) t.start() # Main thread work time.sleep(3) print("Main thread done, daemon will be killed") Lock - Mutual Exclusion ----------------------- A ``Lock`` is the simplest synchronization primitive that prevents multiple threads from accessing a shared resource simultaneously. Always use locks when modifying shared state to prevent race conditions. The context manager syntax (``with lock:``) is preferred as it guarantees the lock is released even if an exception occurs. .. code-block:: python from threading import Thread, Lock counter = 0 lock = Lock() def increment(n): global counter for _ in range(n): with lock: # Acquire and release automatically counter += 1 threads = [Thread(target=increment, args=(100000,)) for _ in range(5)] for t in threads: t.start() for t in threads: t.join() print(f"Counter: {counter}") # Always 500000 with lock RLock - Reentrant Lock ---------------------- An ``RLock`` (reentrant lock) can be acquired multiple times by the same thread without causing a deadlock. This is essential when a thread needs to call methods that also acquire the same lock, such as in recursive functions or when methods call other methods on the same object. .. code-block:: python from threading import Thread, RLock class Counter: def __init__(self): self.value = 0 self.lock = RLock() def increment(self): with self.lock: self.value += 1 def increment_twice(self): with self.lock: # First acquisition self.increment() # Second acquisition - OK with RLock self.increment() counter = Counter() threads = [Thread(target=counter.increment_twice) for _ in range(100)] for t in threads: t.start() for t in threads: t.join() print(f"Value: {counter.value}") # 200 Semaphore - Resource Limiting ----------------------------- A ``Semaphore`` limits the number of threads that can access a resource concurrently. Unlike a lock which allows only one thread, a semaphore with count N allows up to N threads to proceed. This is useful for connection pools, rate limiting, or controlling access to limited resources. .. code-block:: python from threading import Thread, Semaphore import time # Allow max 3 concurrent connections connection_pool = Semaphore(3) def access_database(thread_id): print(f"Thread {thread_id} waiting for connection...") with connection_pool: print(f"Thread {thread_id} connected") time.sleep(1) # Simulate database work print(f"Thread {thread_id} disconnected") threads = [Thread(target=access_database, args=(i,)) for i in range(10)] for t in threads: t.start() for t in threads: t.join() Event - Thread Signaling ------------------------ An ``Event`` is a simple signaling mechanism that allows one thread to signal other threads that something has happened. Threads can wait for the event to be set, and one thread can set or clear the event. This is useful for coordinating startup, shutdown, or state changes between threads. .. code-block:: python from threading import Thread, Event import time ready = Event() def worker(worker_id): print(f"Worker {worker_id} waiting for signal...") ready.wait() # Block until event is set print(f"Worker {worker_id} starting work") def coordinator(): print("Coordinator preparing...") time.sleep(2) print("Coordinator: All systems go!") ready.set() # Signal all waiting threads threads = [Thread(target=worker, args=(i,)) for i in range(3)] threads.append(Thread(target=coordinator)) for t in threads: t.start() for t in threads: t.join() Condition - Complex Synchronization ----------------------------------- A ``Condition`` combines a lock with the ability to wait for and notify about state changes. It's essential for producer-consumer patterns where threads need to wait for specific conditions (like "buffer not empty" or "buffer not full") before proceeding. .. code-block:: python from threading import Thread, Condition import time items = [] condition = Condition() def producer(): for i in range(5): time.sleep(0.5) with condition: items.append(i) print(f"Produced: {i}") condition.notify() # Wake up one waiting consumer def consumer(): while True: with condition: while not items: # Wait until items available condition.wait() item = items.pop(0) print(f"Consumed: {item}") if item == 4: break t1 = Thread(target=producer) t2 = Thread(target=consumer) t1.start() t2.start() t1.join() t2.join() Barrier - Synchronization Point ------------------------------- A ``Barrier`` blocks a specified number of threads until all of them have reached the barrier point, then releases them all simultaneously. This is useful when you need multiple threads to complete a phase before any can proceed to the next phase. .. code-block:: python from threading import Thread, Barrier import time import random barrier = Barrier(3) def worker(worker_id): # Phase 1: Initialization print(f"Worker {worker_id} initializing...") time.sleep(random.uniform(0.5, 2)) print(f"Worker {worker_id} waiting at barrier") barrier.wait() # Wait for all threads # Phase 2: All threads proceed together print(f"Worker {worker_id} proceeding") threads = [Thread(target=worker, args=(i,)) for i in range(3)] for t in threads: t.start() for t in threads: t.join() Timer - Delayed Execution ------------------------- A ``Timer`` is a thread that executes a function after a specified delay. It can be cancelled before it fires. This is useful for timeouts, delayed cleanup, or scheduling one-time tasks. .. code-block:: python from threading import Timer def delayed_task(): print("Task executed after delay") # Execute after 2 seconds timer = Timer(2.0, delayed_task) timer.start() # Can be cancelled before it fires # timer.cancel() Thread-Local Data ----------------- ``threading.local()`` provides thread-local storage where each thread has its own independent copy of the data. This is useful for storing per-thread state without passing it through function arguments, such as database connections or request context in web applications. .. code-block:: python from threading import Thread, local # Each thread gets its own 'data' attribute thread_data = local() def worker(worker_id): thread_data.value = worker_id process() def process(): # Access thread-local data without passing as argument print(f"Processing with value: {thread_data.value}") threads = [Thread(target=worker, args=(i,)) for i in range(3)] for t in threads: t.start() for t in threads: t.join() Producer-Consumer with Queue ---------------------------- The ``queue.Queue`` class provides a thread-safe FIFO queue that handles all locking internally. This is the recommended way to communicate between threads in a producer-consumer pattern, as it eliminates the need for manual synchronization. .. code-block:: python from threading import Thread from queue import Queue import time def producer(q, items): for item in items: time.sleep(0.1) q.put(item) print(f"Produced: {item}") q.put(None) # Sentinel to signal completion def consumer(q): while True: item = q.get() if item is None: break print(f"Consumed: {item}") q.task_done() q = Queue(maxsize=5) # Bounded queue items = list(range(10)) t1 = Thread(target=producer, args=(q, items)) t2 = Thread(target=consumer, args=(q,)) t1.start() t2.start() t1.join() t2.join() Deadlock Example and Prevention ------------------------------- Deadlock occurs when two or more threads are waiting for each other to release locks, creating a circular dependency. The classic example is when thread A holds lock 1 and waits for lock 2, while thread B holds lock 2 and waits for lock 1. Prevent deadlocks by always acquiring locks in a consistent order. .. code-block:: python from threading import Thread, Lock import time lock1 = Lock() lock2 = Lock() # DEADLOCK EXAMPLE - DON'T DO THIS def task_a_bad(): with lock1: print("Task A acquired lock1") time.sleep(0.1) with lock2: # Waits for lock2 print("Task A acquired lock2") def task_b_bad(): with lock2: print("Task B acquired lock2") time.sleep(0.1) with lock1: # Waits for lock1 - DEADLOCK! print("Task B acquired lock1") # CORRECT - Always acquire locks in same order def task_a_good(): with lock1: with lock2: print("Task A acquired both locks") def task_b_good(): with lock1: # Same order as task_a with lock2: print("Task B acquired both locks") Understanding the GIL --------------------- The Global Interpreter Lock (GIL) is a mutex that protects access to Python objects, preventing multiple threads from executing Python bytecode simultaneously. This means threads don't provide speedup for CPU-bound tasks. However, the GIL is released during I/O operations, making threads effective for I/O-bound tasks. .. code-block:: python from threading import Thread import time def cpu_bound(n): """CPU-bound task - GIL limits parallelism.""" count = 0 for i in range(n): count += i return count def io_bound(seconds): """I/O-bound task - GIL released during sleep.""" time.sleep(seconds) # CPU-bound: threads won't help (may be slower due to GIL contention) start = time.time() threads = [Thread(target=cpu_bound, args=(10**7,)) for _ in range(4)] for t in threads: t.start() for t in threads: t.join() print(f"CPU-bound threaded: {time.time() - start:.2f}s") # I/O-bound: threads help significantly start = time.time() threads = [Thread(target=io_bound, args=(1,)) for _ in range(4)] for t in threads: t.start() for t in threads: t.join() print(f"I/O-bound threaded: {time.time() - start:.2f}s") # ~1s, not 4s ================================================ FILE: docs/notes/database/index.rst ================================================ .. meta:: :description lang=en: Python SQLAlchemy tutorial covering database connections, ORM models, relationships, queries, joins, and advanced query patterns :keywords: Python, Python3, SQLAlchemy, database, SQL, ORM, query, join, relationship, session, model, PostgreSQL, MySQL, SQLite ======== Database ======== Working with databases is a core skill for most Python applications, from web development to data analysis. SQLAlchemy is Python's most popular database toolkit, providing both a low-level SQL expression language (Core) and a high-level Object-Relational Mapper (ORM). This section covers SQLAlchemy from basic connections and queries to advanced patterns like relationships, eager loading, and complex joins. Whether you're building a simple script or a large-scale application, these examples will help you interact with databases efficiently and safely. .. toctree:: :maxdepth: 1 python-sqlalchemy python-sqlalchemy-orm python-sqlalchemy-query ================================================ FILE: docs/notes/database/python-sqlalchemy-orm.rst ================================================ .. meta:: :description lang=en: SQLAlchemy ORM tutorial covering declarative models, sessions, relationships, and object-relational mapping patterns :keywords: Python, SQLAlchemy, ORM, database, model, session, relationship, declarative, object-relational mapping ============== SQLAlchemy ORM ============== .. contents:: Table of Contents :backlinks: none SQLAlchemy's Object-Relational Mapper (ORM) provides a high-level abstraction that allows you to work with database tables as Python classes and rows as objects. The ORM builds on top of SQLAlchemy Core and adds features like identity mapping, unit of work pattern, and relationship management. This approach lets you write database code in a more Pythonic way, focusing on objects and their relationships rather than SQL statements. The ORM is ideal for applications with complex domain models where you want to leverage object-oriented programming patterns. Define Models with Declarative Base ----------------------------------- The declarative system is the most common way to define ORM models in SQLAlchemy. You create a base class using ``declarative_base()`` and then define your models as subclasses. Each model class represents a database table, with class attributes defining columns. The ``__tablename__`` attribute specifies the table name. This approach keeps your model definitions clean and readable while providing full access to SQLAlchemy's features. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String >>> from sqlalchemy.orm import declarative_base >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = "users" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) ... email = Column(String(100)) ... def __repr__(self): ... return f"User(id={self.id}, name='{self.name}')" >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) Session Basics -------------- The ``Session`` is the primary interface for persistence operations in the ORM. It manages a "holding zone" for objects you've loaded or associated with it, and handles the communication with the database. Sessions track changes to objects and synchronize them with the database when you call ``commit()``. The recommended pattern is to use ``sessionmaker`` to create a session factory, then create sessions as needed. Always close sessions when done to release database connections. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String >>> from sqlalchemy.orm import declarative_base, sessionmaker >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = "users" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> try: ... user = User(name="Alice") ... session.add(user) ... session.commit() ... print(f"Created user with id: {user.id}") ... finally: ... session.close() Created user with id: 1 Add and Commit Objects ---------------------- To persist new objects to the database, add them to the session with ``add()`` or ``add_all()`` for multiple objects. Objects remain in a "pending" state until you call ``commit()``, which flushes all pending changes to the database in a transaction. If an error occurs, call ``rollback()`` to undo all changes since the last commit. After commit, auto-generated values like primary keys are available on the objects. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String >>> from sqlalchemy.orm import declarative_base, sessionmaker >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = "users" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> # Add single object >>> user1 = User(name="Alice") >>> session.add(user1) >>> # Add multiple objects >>> users = [User(name="Bob"), User(name="Carol")] >>> session.add_all(users) >>> session.commit() >>> print([u.id for u in [user1] + users]) [1, 2, 3] >>> session.close() Query Objects ------------- SQLAlchemy 2.0 uses ``select()`` with ``session.execute()`` for queries, replacing the legacy ``session.query()`` API. The ``select()`` construct accepts model classes or specific columns. Use ``scalars()`` to get model instances directly, or ``execute()`` for row tuples. The result supports iteration, ``all()`` for a list, ``first()`` for the first result, and ``one()`` when exactly one result is expected. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, select >>> from sqlalchemy.orm import declarative_base, sessionmaker >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = "users" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) ... age = Column(Integer) >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> session.add_all([ ... User(name="Alice", age=30), ... User(name="Bob", age=25), ... User(name="Carol", age=35)]) >>> session.commit() >>> # Get all users >>> users = session.execute(select(User)).scalars().all() >>> print([u.name for u in users]) ['Alice', 'Bob', 'Carol'] >>> # Filter with where() >>> user = session.execute(select(User).where(User.age > 28)).scalars().first() >>> print(user.name) Alice >>> session.close() Filter Queries -------------- The ``where()`` method accepts filter conditions using column comparisons. SQLAlchemy overloads Python operators to generate SQL: ``==`` becomes ``=``, ``!=`` becomes ``<>``, and so on. For complex conditions, use ``and_()``, ``or_()``, and ``not_()`` from SQLAlchemy. Columns also provide methods like ``in_()``, ``like()``, ``between()``, ``is_()``, and ``isnot()`` for SQL-specific operations. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, select, and_, or_ >>> from sqlalchemy.orm import declarative_base, sessionmaker >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = "users" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) ... age = Column(Integer) >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> session.add_all([ ... User(name="Alice", age=30), ... User(name="Bob", age=25), ... User(name="Carol", age=35), ... User(name="Fred", age=30)]) >>> session.commit() >>> # AND condition >>> stmt = select(User).where(and_(User.age >= 30, User.name.like("A%"))) >>> print([u.name for u in session.execute(stmt).scalars()]) ['Alice'] >>> # OR condition >>> stmt = select(User).where(or_(User.name == "Alice", User.name == "Bob")) >>> print([u.name for u in session.execute(stmt).scalars()]) ['Alice', 'Bob'] >>> # IN clause >>> stmt = select(User).where(User.age.in_([25, 35])) >>> print([u.name for u in session.execute(stmt).scalars()]) ['Bob', 'Carol'] >>> session.close() Update Objects -------------- To update objects, simply modify their attributes and call ``commit()``. The session tracks changes to loaded objects automatically through a mechanism called "dirty tracking". When you commit, SQLAlchemy generates UPDATE statements only for changed attributes. You can also use bulk updates with ``update()`` for efficiency when modifying many rows without loading them into memory. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, select, update >>> from sqlalchemy.orm import declarative_base, sessionmaker >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = "users" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> session.add(User(name="Alice")) >>> session.commit() >>> # Update via object modification >>> user = session.execute(select(User).where(User.name == "Alice")).scalars().first() >>> user.name = "Alicia" >>> session.commit() >>> # Verify update >>> user = session.execute(select(User)).scalars().first() >>> print(user.name) Alicia >>> session.close() Delete Objects -------------- To delete objects, use ``session.delete()`` followed by ``commit()``. The session will generate a DELETE statement for the object. For bulk deletes without loading objects, use the ``delete()`` construct with ``session.execute()``. Be careful with cascading deletes when objects have relationships - SQLAlchemy can automatically delete related objects based on your cascade configuration. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, select, delete >>> from sqlalchemy.orm import declarative_base, sessionmaker >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = "users" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> session.add_all([User(name="Alice"), User(name="Bob"), User(name="Carol")]) >>> session.commit() >>> # Delete via object >>> user = session.execute(select(User).where(User.name == "Bob")).scalars().first() >>> session.delete(user) >>> session.commit() >>> # Verify deletion >>> users = session.execute(select(User)).scalars().all() >>> print([u.name for u in users]) ['Alice', 'Carol'] >>> session.close() One-to-Many Relationship ------------------------ Relationships define how tables are connected. A one-to-many relationship means one record in the parent table can have multiple related records in the child table. Use ``relationship()`` on the parent side and ``ForeignKey`` on the child side. The ``back_populates`` parameter creates a bidirectional relationship, allowing navigation from both sides. SQLAlchemy handles the foreign key management automatically. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, ForeignKey, select >>> from sqlalchemy.orm import declarative_base, sessionmaker, relationship >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = "users" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) ... posts = relationship("Post", back_populates="author") >>> class Post(Base): ... __tablename__ = "posts" ... id = Column(Integer, primary_key=True) ... title = Column(String(100)) ... user_id = Column(Integer, ForeignKey("users.id")) ... author = relationship("User", back_populates="posts") >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> user = User(name="Alice") >>> user.posts.append(Post(title="First Post")) >>> user.posts.append(Post(title="Second Post")) >>> session.add(user) >>> session.commit() >>> # Access relationship >>> user = session.execute(select(User)).scalars().first() >>> print([p.title for p in user.posts]) ['First Post', 'Second Post'] >>> session.close() Many-to-Many Relationship ------------------------- Many-to-many relationships require an association table that contains foreign keys to both related tables. Define the association table using ``Table``, then use ``relationship()`` with the ``secondary`` parameter pointing to it. Both sides can have a relationship, and SQLAlchemy manages the association table entries automatically when you add or remove items from the relationship collections. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, ForeignKey, Table, select >>> from sqlalchemy.orm import declarative_base, sessionmaker, relationship >>> Base = declarative_base() >>> # Association table >>> student_course = Table( ... "student_course", Base.metadata, ... Column("student_id", Integer, ForeignKey("students.id"), primary_key=True), ... Column("course_id", Integer, ForeignKey("courses.id"), primary_key=True)) >>> class Student(Base): ... __tablename__ = "students" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) ... courses = relationship("Course", secondary=student_course, back_populates="students") >>> class Course(Base): ... __tablename__ = "courses" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) ... students = relationship("Student", secondary=student_course, back_populates="courses") >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> math = Course(name="Math") >>> physics = Course(name="Physics") >>> alice = Student(name="Alice", courses=[math, physics]) >>> bob = Student(name="Bob", courses=[math]) >>> session.add_all([alice, bob]) >>> session.commit() >>> # Query relationships >>> math = session.execute(select(Course).where(Course.name == "Math")).scalars().first() >>> print([s.name for s in math.students]) ['Alice', 'Bob'] >>> session.close() Self-Referential Relationship ----------------------------- Self-referential relationships connect a table to itself, useful for hierarchical data like organizational charts, categories, or threaded comments. Use ``ForeignKey`` pointing to the same table and ``relationship()`` with ``remote_side`` to indicate which side is the "parent". This pattern allows you to model tree structures where each node can have a parent and multiple children. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, ForeignKey, select >>> from sqlalchemy.orm import declarative_base, sessionmaker, relationship >>> Base = declarative_base() >>> class Employee(Base): ... __tablename__ = "employees" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) ... manager_id = Column(Integer, ForeignKey("employees.id")) ... manager = relationship("Employee", remote_side=[id], backref="subordinates") >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> ceo = Employee(name="CEO") >>> session.add(ceo) >>> session.flush() >>> manager = Employee(name="Manager", manager_id=ceo.id) >>> session.add(manager) >>> session.flush() >>> worker1 = Employee(name="Worker1", manager_id=manager.id) >>> worker2 = Employee(name="Worker2", manager_id=manager.id) >>> session.add_all([worker1, worker2]) >>> session.commit() >>> # Navigate hierarchy >>> mgr = session.execute(select(Employee).where(Employee.name == "Manager")).scalars().first() >>> print(f"Manager: {mgr.name}, Boss: {mgr.manager.name}") Manager: Manager, Boss: CEO >>> print(f"Subordinates: {[e.name for e in mgr.subordinates]}") Subordinates: ['Worker1', 'Worker2'] >>> session.close() Cascade Deletes --------------- Cascade options control what happens to related objects when a parent is deleted or modified. The ``cascade`` parameter on ``relationship()`` accepts a comma-separated string of cascade rules. Common options include ``"all, delete-orphan"`` which deletes children when the parent is deleted and when children are removed from the collection. This ensures referential integrity and prevents orphaned records. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, ForeignKey, select >>> from sqlalchemy.orm import declarative_base, sessionmaker, relationship >>> Base = declarative_base() >>> class Parent(Base): ... __tablename__ = "parents" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) ... children = relationship("Child", back_populates="parent", ... cascade="all, delete-orphan") >>> class Child(Base): ... __tablename__ = "children" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) ... parent_id = Column(Integer, ForeignKey("parents.id")) ... parent = relationship("Parent", back_populates="children") >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> parent = Parent(name="Parent1") >>> parent.children = [Child(name="Child1"), Child(name="Child2")] >>> session.add(parent) >>> session.commit() >>> # Delete parent - children are also deleted >>> session.delete(parent) >>> session.commit() >>> children = session.execute(select(Child)).scalars().all() >>> print(len(children)) 0 >>> session.close() Eager Loading ------------- By default, SQLAlchemy uses lazy loading for relationships, executing a new query when you access related objects. This can cause the "N+1 query problem" when iterating over many objects. Eager loading fetches related objects in the same query using JOIN or subqueries. Use ``joinedload()`` for single objects or small collections, and ``selectinload()`` for larger collections to avoid cartesian products. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, ForeignKey, select >>> from sqlalchemy.orm import declarative_base, sessionmaker, relationship, joinedload >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = "users" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) ... posts = relationship("Post", back_populates="author") >>> class Post(Base): ... __tablename__ = "posts" ... id = Column(Integer, primary_key=True) ... title = Column(String(100)) ... user_id = Column(Integer, ForeignKey("users.id")) ... author = relationship("User", back_populates="posts") >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> user = User(name="Alice") >>> user.posts = [Post(title="Post1"), Post(title="Post2")] >>> session.add(user) >>> session.commit() >>> # Eager load posts with user in single query >>> stmt = select(User).options(joinedload(User.posts)) >>> user = session.execute(stmt).scalars().unique().first() >>> print([p.title for p in user.posts]) # No additional query ['Post1', 'Post2'] >>> session.close() Hybrid Properties ----------------- Hybrid properties allow you to define Python properties that work both at the instance level (in Python) and at the class level (in SQL queries). This is useful for computed attributes that you want to filter or sort by in database queries. Use the ``@hybrid_property`` decorator and optionally ``@property.expression`` to customize the SQL expression. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, select >>> from sqlalchemy.orm import declarative_base, sessionmaker >>> from sqlalchemy.ext.hybrid import hybrid_property >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = "users" ... id = Column(Integer, primary_key=True) ... first_name = Column(String(50)) ... last_name = Column(String(50)) ... ... @hybrid_property ... def full_name(self): ... return f"{self.first_name} {self.last_name}" >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> session.add(User(first_name="Alice", last_name="Smith")) >>> session.commit() >>> user = session.execute(select(User)).scalars().first() >>> print(user.full_name) Alice Smith >>> session.close() Event Hooks ----------- SQLAlchemy provides an event system that lets you hook into various ORM operations like before/after insert, update, or delete. Use ``@event.listens_for()`` decorator to register event handlers. Events are useful for auditing, validation, automatic timestamps, or triggering side effects. Common events include ``before_insert``, ``after_insert``, ``before_update``, ``after_update``, ``before_delete``, and ``after_delete``. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, DateTime, select, event >>> from sqlalchemy.orm import declarative_base, sessionmaker >>> from datetime import datetime >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = "users" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) ... created_at = Column(DateTime) ... updated_at = Column(DateTime) >>> @event.listens_for(User, "before_insert") ... def set_created_at(mapper, connection, target): ... target.created_at = datetime.now() ... target.updated_at = datetime.now() >>> @event.listens_for(User, "before_update") ... def set_updated_at(mapper, connection, target): ... target.updated_at = datetime.now() >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> user = User(name="Alice") >>> session.add(user) >>> session.commit() >>> print(user.created_at is not None) True >>> session.close() ================================================ FILE: docs/notes/database/python-sqlalchemy-query.rst ================================================ .. meta:: :description lang=en: SQLAlchemy advanced query patterns including joins, subqueries, aggregations, window functions, and performance optimization :keywords: Python, SQLAlchemy, query, join, subquery, aggregate, window function, CTE, performance, optimization ======================== SQLAlchemy Query Recipes ======================== .. contents:: Table of Contents :backlinks: none This section covers advanced query patterns and recipes for SQLAlchemy. While the basics of querying are covered in the ORM section, real-world applications often require more sophisticated queries involving joins across multiple tables, subqueries, aggregations, and performance optimizations. These patterns help you write efficient database queries while maintaining readable Python code. Understanding these techniques is essential for building scalable applications that interact with relational databases. Order By -------- The ``order_by()`` method sorts query results by one or more columns. Pass column objects directly, or use ``desc()`` for descending order. You can chain multiple columns for secondary sorting. SQLAlchemy also supports ``nullsfirst()`` and ``nullslast()`` to control how NULL values are sorted, which is particularly useful when dealing with optional fields. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, select, desc >>> from sqlalchemy.orm import declarative_base, sessionmaker >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = "users" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) ... age = Column(Integer) >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> session.add_all([ ... User(name="Alice", age=30), ... User(name="Bob", age=25), ... User(name="Carol", age=30)]) >>> session.commit() >>> # Ascending order >>> stmt = select(User).order_by(User.age) >>> print([u.name for u in session.execute(stmt).scalars()]) ['Bob', 'Alice', 'Carol'] >>> # Descending order >>> stmt = select(User).order_by(desc(User.age), User.name) >>> print([u.name for u in session.execute(stmt).scalars()]) ['Alice', 'Carol', 'Bob'] >>> session.close() Limit and Offset ---------------- Use ``limit()`` to restrict the number of results and ``offset()`` to skip rows, enabling pagination. These methods translate directly to SQL LIMIT and OFFSET clauses. For large datasets, consider using keyset pagination (filtering by the last seen ID) instead of offset-based pagination, as OFFSET can become slow with large offsets since the database must scan and discard rows. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, select >>> from sqlalchemy.orm import declarative_base, sessionmaker >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = "users" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> session.add_all([User(name=f"User{i}") for i in range(10)]) >>> session.commit() >>> # First page (3 items) >>> stmt = select(User).order_by(User.id).limit(3) >>> print([u.name for u in session.execute(stmt).scalars()]) ['User0', 'User1', 'User2'] >>> # Second page >>> stmt = select(User).order_by(User.id).limit(3).offset(3) >>> print([u.name for u in session.execute(stmt).scalars()]) ['User3', 'User4', 'User5'] >>> session.close() Group By and Aggregates ----------------------- Use ``group_by()`` with aggregate functions like ``func.count()``, ``func.sum()``, ``func.avg()``, ``func.min()``, and ``func.max()`` for grouped calculations. The ``having()`` method filters groups after aggregation, similar to SQL's HAVING clause. When selecting both regular columns and aggregates, all non-aggregate columns must be included in the GROUP BY clause. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, select, func >>> from sqlalchemy.orm import declarative_base, sessionmaker >>> Base = declarative_base() >>> class Sale(Base): ... __tablename__ = "sales" ... id = Column(Integer, primary_key=True) ... product = Column(String(50)) ... amount = Column(Integer) >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> session.add_all([ ... Sale(product="A", amount=100), ... Sale(product="A", amount=150), ... Sale(product="B", amount=200), ... Sale(product="B", amount=50)]) >>> session.commit() >>> # Group by with sum >>> stmt = select(Sale.product, func.sum(Sale.amount).label("total"))\ ... .group_by(Sale.product) >>> for row in session.execute(stmt): ... print(f"{row.product}: {row.total}") A: 250 B: 250 >>> # Having clause >>> stmt = select(Sale.product, func.count().label("count"))\ ... .group_by(Sale.product).having(func.count() > 1) >>> print(session.execute(stmt).fetchall()) [('A', 2), ('B', 2)] >>> session.close() Join Queries ------------ SQLAlchemy provides several ways to join tables. For ORM models with relationships, use ``join()`` which automatically uses the foreign key. For explicit join conditions, pass the condition as the second argument. Use ``outerjoin()`` for LEFT OUTER JOIN. The ``select_from()`` method specifies the FROM clause when needed for complex joins. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, ForeignKey, select >>> from sqlalchemy.orm import declarative_base, sessionmaker, relationship >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = "users" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) ... orders = relationship("Order", back_populates="user") >>> class Order(Base): ... __tablename__ = "orders" ... id = Column(Integer, primary_key=True) ... product = Column(String(50)) ... user_id = Column(Integer, ForeignKey("users.id")) ... user = relationship("User", back_populates="orders") >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> alice = User(name="Alice") >>> alice.orders = [Order(product="Book"), Order(product="Pen")] >>> bob = User(name="Bob") >>> session.add_all([alice, bob]) >>> session.commit() >>> # Inner join >>> stmt = select(User.name, Order.product).join(Order) >>> print(session.execute(stmt).fetchall()) [('Alice', 'Book'), ('Alice', 'Pen')] >>> # Left outer join (includes users without orders) >>> stmt = select(User.name, Order.product).outerjoin(Order) >>> print(session.execute(stmt).fetchall()) [('Alice', 'Book'), ('Alice', 'Pen'), ('Bob', None)] >>> session.close() Subqueries ---------- Subqueries are queries nested inside other queries. Use ``subquery()`` to create a subquery that can be used in the FROM clause, or ``scalar_subquery()`` for single-value subqueries in SELECT or WHERE clauses. Subqueries are useful for complex filtering, computing derived values, or when you need to reference aggregated data in the main query. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, select, func >>> from sqlalchemy.orm import declarative_base, sessionmaker >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = "users" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) ... score = Column(Integer) >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> session.add_all([ ... User(name="Alice", score=85), ... User(name="Bob", score=90), ... User(name="Carol", score=75)]) >>> session.commit() >>> # Scalar subquery: users with above-average score >>> avg_score = select(func.avg(User.score)).scalar_subquery() >>> stmt = select(User).where(User.score > avg_score) >>> print([u.name for u in session.execute(stmt).scalars()]) ['Alice', 'Bob'] >>> session.close() Common Table Expressions (CTE) ------------------------------ CTEs (WITH clauses) improve query readability by naming subqueries. They're especially useful for recursive queries and when the same subquery is referenced multiple times. Use ``cte()`` to create a CTE from a select statement. CTEs can reference themselves for recursive queries, which is powerful for hierarchical data like organizational charts or category trees. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, select, func >>> from sqlalchemy.orm import declarative_base, sessionmaker >>> Base = declarative_base() >>> class Sale(Base): ... __tablename__ = "sales" ... id = Column(Integer, primary_key=True) ... region = Column(String(50)) ... amount = Column(Integer) >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> session.add_all([ ... Sale(region="East", amount=100), ... Sale(region="East", amount=200), ... Sale(region="West", amount=150)]) >>> session.commit() >>> # CTE for regional totals >>> regional_totals = select( ... Sale.region, ... func.sum(Sale.amount).label("total") ... ).group_by(Sale.region).cte("regional_totals") >>> # Use CTE in main query >>> stmt = select(regional_totals).where(regional_totals.c.total > 200) >>> print(session.execute(stmt).fetchall()) [('East', 300)] >>> session.close() Exists and Correlated Subqueries -------------------------------- The ``exists()`` construct creates an EXISTS subquery, which returns true if the subquery returns any rows. This is efficient for checking the existence of related records without loading them. Correlated subqueries reference columns from the outer query, allowing row-by-row comparisons. Use ``correlate()`` to explicitly specify which tables the subquery should correlate with. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, ForeignKey, select, exists >>> from sqlalchemy.orm import declarative_base, sessionmaker, relationship >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = "users" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) >>> class Order(Base): ... __tablename__ = "orders" ... id = Column(Integer, primary_key=True) ... user_id = Column(Integer, ForeignKey("users.id")) >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> session.add_all([User(name="Alice"), User(name="Bob")]) >>> session.commit() >>> alice = session.execute(select(User).where(User.name == "Alice")).scalars().first() >>> session.add(Order(user_id=alice.id)) >>> session.commit() >>> # Users with orders >>> has_orders = exists().where(Order.user_id == User.id) >>> stmt = select(User).where(has_orders) >>> print([u.name for u in session.execute(stmt).scalars()]) ['Alice'] >>> # Users without orders >>> stmt = select(User).where(~has_orders) >>> print([u.name for u in session.execute(stmt).scalars()]) ['Bob'] >>> session.close() Union Queries ------------- Use ``union()`` to combine results from multiple SELECT statements, removing duplicates. Use ``union_all()`` to keep all rows including duplicates, which is faster when you know there are no duplicates or don't care about them. The queries must have the same number of columns with compatible types. Unions are useful for combining data from different tables or different filtered views of the same table. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, select, union_all >>> from sqlalchemy.orm import declarative_base, sessionmaker >>> Base = declarative_base() >>> class Customer(Base): ... __tablename__ = "customers" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) >>> class Supplier(Base): ... __tablename__ = "suppliers" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> session.add_all([Customer(name="Alice"), Customer(name="Bob")]) >>> session.add_all([Supplier(name="Acme"), Supplier(name="Bob")]) >>> session.commit() >>> # Union all names >>> stmt = union_all( ... select(Customer.name), ... select(Supplier.name)) >>> print(sorted([row[0] for row in session.execute(stmt)])) ['Acme', 'Alice', 'Bob', 'Bob'] >>> session.close() Case Expressions ---------------- The ``case()`` construct creates SQL CASE expressions for conditional logic in queries. It's useful for computed columns, conditional aggregations, and data transformations. Pass a list of (condition, result) tuples, with an optional ``else_`` for the default value. Case expressions can be used in SELECT, WHERE, ORDER BY, and other clauses. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, select, case >>> from sqlalchemy.orm import declarative_base, sessionmaker >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = "users" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) ... score = Column(Integer) >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> session.add_all([ ... User(name="Alice", score=95), ... User(name="Bob", score=75), ... User(name="Carol", score=55)]) >>> session.commit() >>> # Grade based on score >>> grade = case( ... (User.score >= 90, "A"), ... (User.score >= 70, "B"), ... else_="C") >>> stmt = select(User.name, grade.label("grade")) >>> for row in session.execute(stmt): ... print(f"{row.name}: {row.grade}") Alice: A Bob: B Carol: C >>> session.close() Distinct and Count Distinct --------------------------- Use ``distinct()`` to remove duplicate rows from results. For counting unique values, combine ``func.count()`` with ``distinct()``. The ``distinct()`` method can be applied to the entire query or to specific columns. This is essential for accurate counting when dealing with joined tables that may produce duplicate rows. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, select, func, distinct >>> from sqlalchemy.orm import declarative_base, sessionmaker >>> Base = declarative_base() >>> class Order(Base): ... __tablename__ = "orders" ... id = Column(Integer, primary_key=True) ... customer = Column(String(50)) ... product = Column(String(50)) >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> session.add_all([ ... Order(customer="Alice", product="Book"), ... Order(customer="Alice", product="Pen"), ... Order(customer="Bob", product="Book")]) >>> session.commit() >>> # Distinct customers >>> stmt = select(Order.customer).distinct() >>> print(session.execute(stmt).fetchall()) [('Alice',), ('Bob',)] >>> # Count distinct products >>> stmt = select(func.count(distinct(Order.product))) >>> print(session.execute(stmt).scalar()) 2 >>> session.close() Aliased Tables -------------- Use ``aliased()`` to create aliases for tables, allowing you to reference the same table multiple times in a query with different names. This is essential for self-joins and queries that need to compare rows within the same table. Aliases create independent references that can have different join conditions and filters. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, select >>> from sqlalchemy.orm import declarative_base, sessionmaker, aliased >>> Base = declarative_base() >>> class Employee(Base): ... __tablename__ = "employees" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) ... salary = Column(Integer) >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> session.add_all([ ... Employee(name="Alice", salary=50000), ... Employee(name="Bob", salary=60000), ... Employee(name="Carol", salary=55000)]) >>> session.commit() >>> # Find employees earning more than Alice >>> alice_alias = aliased(Employee, name="alice") >>> stmt = select(Employee.name).select_from(Employee).join( ... alice_alias, alice_alias.name == "Alice" ... ).where(Employee.salary > alice_alias.salary) >>> print([row[0] for row in session.execute(stmt)]) ['Bob', 'Carol'] >>> session.close() Bulk Operations --------------- For inserting or updating many rows, bulk operations are much faster than adding objects one by one. Use ``session.bulk_insert_mappings()`` for inserts and ``session.bulk_update_mappings()`` for updates. These methods bypass the ORM's unit of work pattern for better performance. For even faster inserts, use Core's ``insert()`` with ``execute()`` passing a list of dictionaries. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, select, insert >>> from sqlalchemy.orm import declarative_base, sessionmaker >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = "users" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> # Bulk insert with Core (fastest) >>> data = [{"name": f"User{i}"} for i in range(100)] >>> session.execute(insert(User), data) >>> session.commit() >>> # Verify >>> count = session.execute(select(func.count()).select_from(User)).scalar() >>> print(count) 100 >>> session.close() Raw SQL with Text ----------------- When you need to execute raw SQL that's difficult to express with SQLAlchemy's constructs, use ``text()`` to wrap SQL strings. Always use bound parameters (`:name` syntax) instead of string formatting to prevent SQL injection. The ``text()`` construct can be used with both Core and ORM queries, and results can be mapped to ORM objects using ``from_statement()``. .. code-block:: python >>> from sqlalchemy import create_engine, Column, Integer, String, text, select >>> from sqlalchemy.orm import declarative_base, sessionmaker >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = "users" ... id = Column(Integer, primary_key=True) ... name = Column(String(50)) >>> engine = create_engine("sqlite:///:memory:") >>> Base.metadata.create_all(engine) >>> Session = sessionmaker(bind=engine) >>> session = Session() >>> session.add_all([User(name="Alice"), User(name="Bob")]) >>> session.commit() >>> # Raw SQL with parameters >>> result = session.execute( ... text("SELECT * FROM users WHERE name = :name"), ... {"name": "Alice"}) >>> print(result.fetchall()) [(1, 'Alice')] >>> # Map raw SQL to ORM objects >>> stmt = select(User).from_statement(text("SELECT * FROM users ORDER BY name")) >>> users = session.execute(stmt).scalars().all() >>> print([u.name for u in users]) ['Alice', 'Bob'] >>> session.close() ================================================ FILE: docs/notes/database/python-sqlalchemy.rst ================================================ .. meta:: :description lang=en: SQLAlchemy Core tutorial covering database connections, engine creation, metadata, table definitions, and SQL expression language :keywords: Python, SQLAlchemy, database, SQL, engine, metadata, table, connection, transaction, Core API ================= SQLAlchemy Basics ================= .. contents:: Table of Contents :backlinks: none SQLAlchemy is the most popular database toolkit and Object-Relational Mapping (ORM) library for Python. It provides a full suite of well-known enterprise-level persistence patterns, designed for efficient and high-performing database access. SQLAlchemy is divided into two main components: the Core (low-level SQL abstraction) and the ORM (high-level object mapping). This cheat sheet covers the Core API, which provides a SQL Expression Language that allows you to construct SQL statements in Python code while remaining database-agnostic. The Core is ideal when you need fine-grained control over SQL queries or when working with existing database schemas. Create an Engine ---------------- The ``Engine`` is the starting point for any SQLAlchemy application. It represents the connection pool and dialect for a particular database, managing connectivity and translating Python code into database-specific SQL. The ``create_engine()`` function takes a database URL that specifies the database type, credentials, host, and database name. SQLAlchemy supports many databases including SQLite, PostgreSQL, MySQL, Oracle, and Microsoft SQL Server through different dialects. .. code-block:: python >>> from sqlalchemy import create_engine >>> # SQLite in-memory database (great for testing) >>> engine = create_engine("sqlite:///:memory:") >>> # SQLite file-based database >>> engine = create_engine("sqlite:///mydb.sqlite") >>> # PostgreSQL >>> engine = create_engine("postgresql://user:pass@localhost/dbname") >>> # MySQL >>> engine = create_engine("mysql+pymysql://user:pass@localhost/dbname") Database URL Format ------------------- SQLAlchemy uses RFC-1738 style URLs to specify database connections. The URL format provides a standardized way to specify all connection parameters including the database driver, authentication credentials, host address, port number, and database name. Understanding this format is essential for configuring connections to different database systems. The ``make_url()`` function can parse and construct these URLs programmatically. .. code-block:: python >>> from sqlalchemy import make_url >>> # Format: dialect+driver://username:password@host:port/database >>> url = make_url("postgresql://user:pass@localhost:5432/mydb") >>> url.drivername 'postgresql' >>> url.username 'user' >>> url.host 'localhost' >>> url.database 'mydb' Connect and Execute Raw SQL --------------------------- While SQLAlchemy encourages using its SQL Expression Language, you can also execute raw SQL strings directly. This is useful for complex queries that are difficult to express in SQLAlchemy's API, or when migrating existing SQL code. The ``text()`` function wraps raw SQL strings and allows parameter binding for security. Always use parameter binding instead of string formatting to prevent SQL injection attacks. .. code-block:: python >>> from sqlalchemy import create_engine, text >>> engine = create_engine("sqlite:///:memory:") >>> with engine.connect() as conn: ... conn.execute(text("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)")) ... conn.execute(text("INSERT INTO test (name) VALUES (:name)"), {"name": "Alice"}) ... conn.commit() ... result = conn.execute(text("SELECT * FROM test")) ... print(result.fetchall()) [(1, 'Alice')] Transaction Management ---------------------- Transactions ensure that a series of database operations either all succeed or all fail together, maintaining data integrity. SQLAlchemy provides several ways to manage transactions: implicit transactions with ``begin()``, context managers for automatic commit/rollback, and manual control with ``commit()`` and ``rollback()``. The ``begin()`` method starts a transaction that will automatically rollback on exceptions and commit on successful completion when used as a context manager. .. code-block:: python >>> from sqlalchemy import create_engine, text >>> engine = create_engine("sqlite:///:memory:") >>> # Using begin() for automatic commit/rollback >>> with engine.begin() as conn: ... conn.execute(text("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")) ... conn.execute(text("INSERT INTO users (name) VALUES ('Bob')")) ... # Commits automatically if no exception >>> # Manual transaction control >>> with engine.connect() as conn: ... trans = conn.begin() ... try: ... conn.execute(text("INSERT INTO users (name) VALUES ('Carol')")) ... trans.commit() ... except: ... trans.rollback() ... raise Define Tables with Metadata --------------------------- ``MetaData`` is a container that holds information about database tables and other schema constructs. You can define tables programmatically using the ``Table`` class, specifying columns with their types and constraints. This approach is part of SQLAlchemy Core and gives you explicit control over the table structure. The metadata can then create all defined tables in the database with ``create_all()``, which generates the appropriate DDL statements for your database dialect. .. code-block:: python >>> from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String >>> engine = create_engine("sqlite:///:memory:") >>> metadata = MetaData() >>> users = Table( ... "users", metadata, ... Column("id", Integer, primary_key=True), ... Column("name", String(50)), ... Column("email", String(100)) ... ) >>> metadata.create_all(engine) >>> # Check table columns >>> [c.name for c in users.columns] ['id', 'name', 'email'] Reflect Existing Tables ----------------------- Table reflection allows SQLAlchemy to load table definitions from an existing database schema automatically. This is useful when working with legacy databases or when you want to avoid duplicating schema definitions. The ``reflect()`` method on ``MetaData`` reads the database schema and creates ``Table`` objects for all tables found. You can also reflect individual tables using ``autoload_with`` parameter. .. code-block:: python >>> from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String >>> engine = create_engine("sqlite:///:memory:") >>> # Create a table first >>> with engine.begin() as conn: ... conn.execute(text("CREATE TABLE products (id INTEGER PRIMARY KEY, name TEXT, price REAL)")) >>> # Reflect the table >>> metadata = MetaData() >>> metadata.reflect(bind=engine) >>> list(metadata.tables.keys()) ['products'] >>> products = metadata.tables['products'] >>> [c.name for c in products.columns] ['id', 'name', 'price'] Inspect Database Schema ----------------------- The ``inspect()`` function provides a powerful way to examine database schema details at runtime. The inspector can retrieve information about tables, columns, indexes, foreign keys, and other database objects. This is particularly useful for database administration tasks, schema migrations, and debugging. The inspector works with any database supported by SQLAlchemy and provides a consistent API across different database systems. .. code-block:: python >>> from sqlalchemy import create_engine, inspect, MetaData, Table, Column, Integer, String >>> engine = create_engine("sqlite:///:memory:") >>> metadata = MetaData() >>> users = Table("users", metadata, ... Column("id", Integer, primary_key=True), ... Column("name", String(50))) >>> metadata.create_all(engine) >>> inspector = inspect(engine) >>> inspector.get_table_names() ['users'] >>> inspector.get_columns('users') # doctest: +ELLIPSIS [{'name': 'id', ...}, {'name': 'name', ...}] Insert Data ----------- The ``insert()`` construct creates an INSERT statement for a table. You can specify values using the ``values()`` method or pass them as keyword arguments. For bulk inserts, pass a list of dictionaries to ``execute()``. SQLAlchemy will generate efficient multi-row INSERT statements when possible. The ``returning()`` method can retrieve auto-generated values like primary keys after insertion. .. code-block:: python >>> from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String, insert >>> engine = create_engine("sqlite:///:memory:") >>> metadata = MetaData() >>> users = Table("users", metadata, ... Column("id", Integer, primary_key=True), ... Column("name", String(50))) >>> metadata.create_all(engine) >>> # Single insert >>> with engine.begin() as conn: ... conn.execute(insert(users).values(name="Alice")) ... # Bulk insert ... conn.execute(insert(users), [{"name": "Bob"}, {"name": "Carol"}]) >>> with engine.connect() as conn: ... result = conn.execute(users.select()) ... print(result.fetchall()) [(1, 'Alice'), (2, 'Bob'), (3, 'Carol')] Select Data ----------- The ``select()`` construct creates SELECT statements with a Pythonic API. You can specify which columns to retrieve, add WHERE clauses with ``where()``, order results with ``order_by()``, and limit results with ``limit()`` and ``offset()``. The SQL Expression Language uses Python operators like ``==``, ``!=``, ``>``, ``<`` which are overloaded to generate SQL conditions. This provides type safety and prevents SQL injection. .. code-block:: python >>> from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String, select, insert >>> engine = create_engine("sqlite:///:memory:") >>> metadata = MetaData() >>> users = Table("users", metadata, ... Column("id", Integer, primary_key=True), ... Column("name", String(50)), ... Column("age", Integer)) >>> metadata.create_all(engine) >>> with engine.begin() as conn: ... conn.execute(insert(users), [ ... {"name": "Alice", "age": 30}, ... {"name": "Bob", "age": 25}, ... {"name": "Carol", "age": 35}]) >>> with engine.connect() as conn: ... # Select all ... result = conn.execute(select(users)) ... print(result.fetchall()) ... # Select with condition ... result = conn.execute(select(users).where(users.c.age > 28)) ... print(result.fetchall()) [(1, 'Alice', 30), (2, 'Bob', 25), (3, 'Carol', 35)] [(1, 'Alice', 30), (3, 'Carol', 35)] Update Data ----------- The ``update()`` construct creates UPDATE statements. Use ``where()`` to specify which rows to update and ``values()`` to set new column values. Without a WHERE clause, all rows in the table will be updated. The ``returning()`` method can retrieve the updated values. For bulk updates with different values per row, use ``bindparam()`` to create parameterized statements. .. code-block:: python >>> from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String >>> from sqlalchemy import select, insert, update >>> engine = create_engine("sqlite:///:memory:") >>> metadata = MetaData() >>> users = Table("users", metadata, ... Column("id", Integer, primary_key=True), ... Column("name", String(50))) >>> metadata.create_all(engine) >>> with engine.begin() as conn: ... conn.execute(insert(users), [{"name": "Alice"}, {"name": "Bob"}]) ... conn.execute(update(users).where(users.c.name == "Alice").values(name="Alicia")) >>> with engine.connect() as conn: ... result = conn.execute(select(users)) ... print(result.fetchall()) [(1, 'Alicia'), (2, 'Bob')] Delete Data ----------- The ``delete()`` construct creates DELETE statements. Always use ``where()`` to specify which rows to delete, unless you intend to delete all rows. Like other DML statements, ``delete()`` supports ``returning()`` to retrieve deleted rows. Be careful with DELETE statements as they permanently remove data from the database. .. code-block:: python >>> from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String >>> from sqlalchemy import select, insert, delete >>> engine = create_engine("sqlite:///:memory:") >>> metadata = MetaData() >>> users = Table("users", metadata, ... Column("id", Integer, primary_key=True), ... Column("name", String(50))) >>> metadata.create_all(engine) >>> with engine.begin() as conn: ... conn.execute(insert(users), [{"name": "Alice"}, {"name": "Bob"}, {"name": "Carol"}]) ... conn.execute(delete(users).where(users.c.name == "Bob")) >>> with engine.connect() as conn: ... result = conn.execute(select(users)) ... print(result.fetchall()) [(1, 'Alice'), (3, 'Carol')] SQL Expression Language ----------------------- SQLAlchemy's SQL Expression Language provides a Pythonic way to construct SQL statements. Column objects support comparison operators (``==``, ``!=``, ``>``, ``<``), logical operators (``&`` for AND, ``|`` for OR), and methods like ``in_()``, ``like()``, ``between()``, and ``is_()``. These expressions are composable and can be combined to build complex queries while maintaining readability and type safety. .. code-block:: python >>> from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String >>> from sqlalchemy import select, insert, and_, or_ >>> engine = create_engine("sqlite:///:memory:") >>> metadata = MetaData() >>> users = Table("users", metadata, ... Column("id", Integer, primary_key=True), ... Column("name", String(50)), ... Column("age", Integer)) >>> metadata.create_all(engine) >>> with engine.begin() as conn: ... conn.execute(insert(users), [ ... {"name": "Alice", "age": 30}, ... {"name": "Bob", "age": 25}, ... {"name": "Carol", "age": 35}]) >>> with engine.connect() as conn: ... # AND condition ... stmt = select(users).where(and_(users.c.age > 25, users.c.age < 35)) ... print(conn.execute(stmt).fetchall()) ... # OR condition ... stmt = select(users).where(or_(users.c.name == "Alice", users.c.name == "Bob")) ... print(conn.execute(stmt).fetchall()) ... # IN clause ... stmt = select(users).where(users.c.name.in_(["Alice", "Carol"])) ... print(conn.execute(stmt).fetchall()) [(1, 'Alice', 30)] [(1, 'Alice', 30), (2, 'Bob', 25)] [(1, 'Alice', 30), (3, 'Carol', 35)] Join Tables ----------- The ``join()`` method creates JOIN clauses between tables. SQLAlchemy can automatically determine join conditions based on foreign key relationships, or you can specify them explicitly. Use ``select_from()`` to specify the joined tables in a SELECT statement. SQLAlchemy supports INNER JOIN (default), LEFT OUTER JOIN, RIGHT OUTER JOIN, and FULL OUTER JOIN through the ``isouter`` and ``full`` parameters. .. code-block:: python >>> from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String, ForeignKey >>> from sqlalchemy import select, insert >>> engine = create_engine("sqlite:///:memory:") >>> metadata = MetaData() >>> users = Table("users", metadata, ... Column("id", Integer, primary_key=True), ... Column("name", String(50))) >>> orders = Table("orders", metadata, ... Column("id", Integer, primary_key=True), ... Column("user_id", Integer, ForeignKey("users.id")), ... Column("product", String(50))) >>> metadata.create_all(engine) >>> with engine.begin() as conn: ... conn.execute(insert(users), [{"name": "Alice"}, {"name": "Bob"}]) ... conn.execute(insert(orders), [ ... {"user_id": 1, "product": "Book"}, ... {"user_id": 1, "product": "Pen"}, ... {"user_id": 2, "product": "Laptop"}]) >>> with engine.connect() as conn: ... stmt = select(users.c.name, orders.c.product).select_from( ... users.join(orders)) ... print(conn.execute(stmt).fetchall()) [('Alice', 'Book'), ('Alice', 'Pen'), ('Bob', 'Laptop')] Aggregate Functions ------------------- SQLAlchemy provides functions for SQL aggregates like ``count()``, ``sum()``, ``avg()``, ``min()``, and ``max()`` in the ``sqlalchemy.func`` namespace. These can be used in SELECT statements and combined with ``group_by()`` for grouped aggregations. The ``func`` object is a special namespace that generates SQL function calls for any function name you access on it. .. code-block:: python >>> from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String >>> from sqlalchemy import select, insert, func >>> engine = create_engine("sqlite:///:memory:") >>> metadata = MetaData() >>> sales = Table("sales", metadata, ... Column("id", Integer, primary_key=True), ... Column("product", String(50)), ... Column("amount", Integer)) >>> metadata.create_all(engine) >>> with engine.begin() as conn: ... conn.execute(insert(sales), [ ... {"product": "A", "amount": 100}, ... {"product": "A", "amount": 150}, ... {"product": "B", "amount": 200}]) >>> with engine.connect() as conn: ... # Count all rows ... result = conn.execute(select(func.count()).select_from(sales)) ... print(result.scalar()) ... # Sum with group by ... stmt = select(sales.c.product, func.sum(sales.c.amount)).group_by(sales.c.product) ... print(conn.execute(stmt).fetchall()) 3 [('A', 250), ('B', 200)] Drop Tables ----------- Tables can be dropped using the ``drop()`` method on a ``Table`` object or ``drop_all()`` on ``MetaData`` to drop all tables. The ``checkfirst`` parameter prevents errors if the table doesn't exist. Be careful with these operations in production as they permanently delete data and schema. Always backup your database before dropping tables. .. code-block:: python >>> from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String, inspect >>> engine = create_engine("sqlite:///:memory:") >>> metadata = MetaData() >>> users = Table("users", metadata, Column("id", Integer, primary_key=True)) >>> products = Table("products", metadata, Column("id", Integer, primary_key=True)) >>> metadata.create_all(engine) >>> inspector = inspect(engine) >>> sorted(inspector.get_table_names()) ['products', 'users'] >>> # Drop single table >>> users.drop(engine) >>> sorted(inspect(engine).get_table_names()) ['products'] >>> # Drop all tables >>> metadata.drop_all(engine) >>> inspect(engine).get_table_names() [] ================================================ FILE: docs/notes/extension/cpp-from-python.rst ================================================ .. meta:: :description lang=en: Learn modern C++ syntax from Python - side-by-side comparison of Python and C++ code snippets :keywords: C++, Python, C++11, C++14, C++17, C++20, modern C++, syntax comparison, learn C++ ======================= Learn C++ from Python ======================= .. contents:: Table of Contents :backlinks: none Modern C++ (C++11, C++14, C++17, C++20) has evolved to include features that make it syntactically similar to Python, making the transition easier for Python developers. This comprehensive guide provides side-by-side comparisons and 1-1 mappings between Python and modern C++ code snippets, covering essential programming concepts like variables, data structures, functions, lambdas, classes, and algorithms. Whether you're a Python developer looking to learn C++ for performance optimization, system programming, or expanding your programming skills, this tutorial demonstrates how familiar Python patterns translate to modern C++ syntax. Many popular frameworks like PyTorch, TensorFlow, and NumPy use C++ extensions for performance-critical operations, especially in deep learning, LLM training, and CUDA GPU programming. Understanding C++ enables you to write custom extensions, optimize bottlenecks, and contribute to these high-performance libraries. To learn more about C++ programming, refer to this `C++ cheatsheet `_ for additional reference and best practices. **Complete working examples:** See `cpp_from_py.cpp `_ for runnable code with integrated Google Test suite. Each function includes Doxygen comments showing the equivalent Python code. Hello World ----------- The traditional first program in any language. Both Python and C++ can print text to the console, though C++ requires including the iostream library and a main function. **Python** .. code-block:: python print("Hello, World!") **C++** .. code-block:: cpp #include int main() { std::cout << "Hello, World!" << std::endl; return 0; } Variables --------- Modern C++ supports automatic type inference with the ``auto`` keyword, making variable declarations as concise as Python. The compiler deduces types from initialization values. **Python** .. code-block:: python x = 10 y = 3.14 name = "Alice" is_valid = True **C++** .. code-block:: cpp auto x = 10; auto y = 3.14; auto name = "Alice"; auto is_valid = true; Lists and Vectors ----------------- Python lists and C++ vectors are dynamic arrays that can grow and shrink. Both support indexing, appending elements, and querying size. C++ vectors require specifying the element type, but modern C++ can infer it from initialization. **Python** .. code-block:: python numbers = [1, 2, 3, 4, 5] numbers.append(6) print(numbers[0]) print(len(numbers)) **C++** .. code-block:: cpp #include std::vector numbers = {1, 2, 3, 4, 5}; numbers.push_back(6); std::cout << numbers[0] << std::endl; std::cout << numbers.size() << std::endl; Array Slicing and Access ------------------------- Python supports powerful slicing syntax with negative indices and ranges. C++ doesn't have built-in slicing, but you can use iterators or create subvectors. Negative indexing requires manual calculation from the end. **Python** .. code-block:: python numbers = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] print(numbers[0]) print(numbers[-1]) print(numbers[2:5]) print(numbers[:3]) print(numbers[7:]) print(numbers[::2]) print(numbers[::-1]) **C++** .. code-block:: cpp #include #include std::vector numbers = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; std::cout << numbers[0] << std::endl; std::cout << numbers[numbers.size() - 1] << std::endl; std::vector slice1(numbers.begin() + 2, numbers.begin() + 5); std::vector slice2(numbers.begin(), numbers.begin() + 3); std::vector slice3(numbers.begin() + 7, numbers.end()); std::vector every_second; for (size_t i = 0; i < numbers.size(); i += 2) { every_second.push_back(numbers[i]); } std::vector reversed(numbers.rbegin(), numbers.rend()); Dictionaries and Maps --------------------- Dictionaries in Python and maps in C++ store key-value pairs. Both allow insertion, lookup, and modification using bracket notation. C++ maps keep keys sorted, while Python dicts maintain insertion order (Python 3.7+). **Python** .. code-block:: python ages = {"Alice": 30, "Bob": 25} ages["Charlie"] = 35 print(ages["Alice"]) **C++** .. code-block:: cpp #include #include std::map ages = {{"Alice", 30}, {"Bob", 25}}; ages["Charlie"] = 35; std::cout << ages["Alice"] << std::endl; For Loop -------- Both languages support traditional counting loops and range-based iteration. C++ range-based for loops (C++11) provide syntax similar to Python's for-in loops, making iteration over containers more readable. **Python** .. code-block:: python for i in range(5): print(i) for item in [1, 2, 3]: print(item) **C++** .. code-block:: cpp for (int i = 0; i < 5; i++) { std::cout << i << std::endl; } for (auto item : {1, 2, 3}) { std::cout << item << std::endl; } While Loop ---------- While loops execute as long as a condition is true. The syntax is nearly identical between Python and C++, with C++ requiring parentheses around the condition and braces for the body. **Python** .. code-block:: python i = 0 while i < 5: print(i) i += 1 **C++** .. code-block:: cpp int i = 0; while (i < 5) { std::cout << i << std::endl; i++; } If-Else ------- Conditional statements control program flow based on boolean expressions. C++ requires parentheses around conditions and uses braces for blocks, while Python uses indentation. Both support chained conditions with elif/else if. **Python** .. code-block:: python x = 10 if x > 5: print("Greater") elif x == 5: print("Equal") else: print("Less") **C++** .. code-block:: cpp auto x = 10; if (x > 5) { std::cout << "Greater" << std::endl; } else if (x == 5) { std::cout << "Equal" << std::endl; } else { std::cout << "Less" << std::endl; } Functions --------- Functions encapsulate reusable code. Modern C++ supports trailing return type syntax (-> type) similar to Python's type hints. The auto keyword allows type inference for return types when the function body is simple. **Python** .. code-block:: python def add(a, b): return a + b result = add(3, 5) **C++** .. code-block:: cpp auto add(int a, int b) -> int { return a + b; } auto result = add(3, 5); Lambda Functions ---------------- Lambda functions are anonymous functions that can capture variables from their surrounding scope. Both Python and C++ support lambdas, making functional programming patterns possible. C++ lambdas can specify capture modes (by value, by reference) for more control over variable lifetime and performance. **Python** .. code-block:: python square = lambda x: x * x print(square(5)) numbers = [1, 2, 3, 4] squared = list(map(lambda x: x * x, numbers)) # Capturing variables multiplier = 10 multiply = lambda x: x * multiplier print(multiply(5)) **C++** .. code-block:: cpp #include #include auto square = [](int x) { return x * x; }; std::cout << square(5) << std::endl; std::vector numbers = {1, 2, 3, 4}; std::vector squared; std::transform(numbers.begin(), numbers.end(), std::back_inserter(squared), [](int x) { return x * x; }); // Capturing variables by value [=] or by reference [&] int multiplier = 10; auto multiply = [multiplier](int x) { return x * multiplier; }; std::cout << multiply(5) << std::endl; Lambda Capture Modes -------------------- C++ lambdas provide explicit control over how variables are captured from the enclosing scope. This is more explicit than Python's implicit closure behavior and allows optimization by choosing between copying values or using references. **Python** .. code-block:: python x = 10 y = 20 # Implicitly captures x and y add_xy = lambda z: x + y + z print(add_xy(5)) **C++** .. code-block:: cpp int x = 10; int y = 20; // Capture by value auto add_xy_val = [x, y](int z) { return x + y + z; }; std::cout << add_xy_val(5) << std::endl; // Capture by reference auto add_xy_ref = [&x, &y](int z) { return x + y + z; }; // Capture all by value auto add_all_val = [=](int z) { return x + y + z; }; // Capture all by reference auto add_all_ref = [&](int z) { return x + y + z; }; List Comprehension ------------------ Python's list comprehensions provide concise syntax for creating lists. C++ doesn't have direct syntax for this, but you can achieve similar results using loops or STL algorithms like std::transform and std::copy_if. **Python** .. code-block:: python squares = [x * x for x in range(10)] evens = [x for x in range(10) if x % 2 == 0] **C++** .. code-block:: cpp #include #include std::vector squares; for (int x = 0; x < 10; x++) { squares.push_back(x * x); } std::vector evens; for (int x = 0; x < 10; x++) { if (x % 2 == 0) { evens.push_back(x); } } String Operations ----------------- Both languages provide rich string manipulation capabilities. C++ strings are mutable like Python strings in terms of concatenation, but individual character access works similarly. C++ requires explicit conversion functions for case changes. **Python** .. code-block:: python s = "Hello" s += " World" print(len(s)) print(s[0]) print(s.upper()) **C++** .. code-block:: cpp #include #include std::string s = "Hello"; s += " World"; std::cout << s.size() << std::endl; std::cout << s[0] << std::endl; std::string upper = s; std::transform(upper.begin(), upper.end(), upper.begin(), ::toupper); std::cout << upper << std::endl; Classes ------- Object-oriented programming works similarly in both languages. C++ requires explicit access specifiers (public, private) and constructor initialization lists. Both support member variables and methods, with C++ using :: for scope resolution. **Python** .. code-block:: python class Person: def __init__(self, name, age): self.name = name self.age = age def greet(self): return f"Hello, I'm {self.name}" p = Person("Alice", 30) print(p.greet()) **C++** .. code-block:: cpp #include class Person { public: std::string name; int age; Person(std::string name, int age) : name(name), age(age) {} std::string greet() { return "Hello, I'm " + name; } }; Person p("Alice", 30); std::cout << p.greet() << std::endl; Optional Values --------------- Python uses None to represent missing values, while C++ (C++17+) provides std::optional for type-safe optional values. This prevents null pointer errors and makes the absence of a value explicit in the type system. **Python** .. code-block:: python def find_value(key): data = {"a": 1, "b": 2} return data.get(key) result = find_value("a") if result is not None: print(result) **C++** .. code-block:: cpp #include #include std::optional find_value(std::string key) { std::map data = {{"a", 1}, {"b", 2}}; auto it = data.find(key); if (it != data.end()) { return it->second; } return std::nullopt; } auto result = find_value("a"); if (result.has_value()) { std::cout << result.value() << std::endl; } Smart Pointers -------------- Python handles memory automatically with garbage collection. C++ smart pointers (C++11+) provide automatic memory management through RAII. unique_ptr ensures single ownership, while shared_ptr allows multiple owners with reference counting. **Python** .. code-block:: python class Resource: def __init__(self, name): self.name = name resource = Resource("data") **C++** .. code-block:: cpp #include class Resource { public: std::string name; Resource(std::string name) : name(name) {} }; auto resource = std::make_unique("data"); auto shared = std::make_shared("data"); File I/O -------- **Python** .. code-block:: python with open("file.txt", "r") as f: content = f.read() with open("file.txt", "w") as f: f.write("Hello") **C++** .. code-block:: cpp #include #include std::ifstream file("file.txt"); std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator()); std::ofstream out("file.txt"); out << "Hello"; Exception Handling ------------------ Both languages support try-catch exception handling for error management. C++ uses typed exceptions and requires explicit exception types in catch blocks. **Python** .. code-block:: python try: result = 10 / 0 except ZeroDivisionError as e: print(f"Error: {e}") finally: print("Cleanup") **C++** .. code-block:: cpp #include try { if (divisor == 0) { throw std::runtime_error("Division by zero"); } result = 10 / divisor; } catch (const std::exception& e) { std::cout << "Error: " << e.what() << std::endl; } Tuples ------ Tuples group multiple values together. C++17 introduces structured bindings that allow tuple unpacking similar to Python, making it easy to return and destructure multiple values. **Python** .. code-block:: python point = (10, 20) x, y = point print(x, y) **C++** .. code-block:: cpp #include auto point = std::make_tuple(10, 20); auto [x, y] = point; std::cout << x << " " << y << std::endl; Enumerate --------- Python's enumerate provides index-value pairs during iteration. C++ doesn't have a direct equivalent, but you can achieve the same result with traditional indexed loops. **Python** .. code-block:: python items = ["a", "b", "c"] for i, item in enumerate(items): print(i, item) **C++** .. code-block:: cpp #include #include std::vector items = {"a", "b", "c"}; for (size_t i = 0; i < items.size(); i++) { std::cout << i << " " << items[i] << std::endl; } Filter and Map -------------- Python's filter and map functions apply transformations to sequences. C++ provides equivalent functionality through STL algorithms like ``std::copy_if`` and ``std::transform``. **Python** .. code-block:: python numbers = [1, 2, 3, 4, 5] evens = list(filter(lambda x: x % 2 == 0, numbers)) doubled = list(map(lambda x: x * 2, numbers)) **C++** .. code-block:: cpp #include #include std::vector numbers = {1, 2, 3, 4, 5}; std::vector evens; std::copy_if(numbers.begin(), numbers.end(), std::back_inserter(evens), [](int x) { return x % 2 == 0; }); std::vector doubled; std::transform(numbers.begin(), numbers.end(), std::back_inserter(doubled), [](int x) { return x * 2; }); Any and All ----------- Check if any or all elements in a sequence satisfy a condition. C++ provides ``std::any_of`` and ``std::all_of`` algorithms for these common operations. **Python** .. code-block:: python numbers = [1, 2, 3, 4, 5] has_even = any(x % 2 == 0 for x in numbers) all_positive = all(x > 0 for x in numbers) **C++** .. code-block:: cpp #include #include std::vector numbers = {1, 2, 3, 4, 5}; bool has_even = std::any_of(numbers.begin(), numbers.end(), [](int x) { return x % 2 == 0; }); bool all_positive = std::all_of(numbers.begin(), numbers.end(), [](int x) { return x > 0; }); Sorting ------- Sort sequences in ascending or descending order. Both languages provide in-place sorting and the ability to create sorted copies with custom comparison functions. **Python** .. code-block:: python numbers = [3, 1, 4, 1, 5] numbers.sort() sorted_nums = sorted(numbers, reverse=True) **C++** .. code-block:: cpp #include #include std::vector numbers = {3, 1, 4, 1, 5}; std::sort(numbers.begin(), numbers.end()); std::vector sorted_nums = numbers; std::sort(sorted_nums.begin(), sorted_nums.end(), std::greater()); Min and Max ----------- Find the minimum and maximum values in a sequence. C++ uses iterator-based algorithms that return iterators, requiring dereferencing to get the actual values. **Python** .. code-block:: python numbers = [3, 1, 4, 1, 5] print(min(numbers)) print(max(numbers)) **C++** .. code-block:: cpp #include #include std::vector numbers = {3, 1, 4, 1, 5}; std::cout << *std::min_element(numbers.begin(), numbers.end()) << std::endl; std::cout << *std::max_element(numbers.begin(), numbers.end()) << std::endl; Sum --- Calculate the sum of all elements in a sequence. C++ uses ``std::accumulate`` from the numeric library, which can also perform other reduction operations. **Python** .. code-block:: python numbers = [1, 2, 3, 4, 5] total = sum(numbers) **C++** .. code-block:: cpp #include #include std::vector numbers = {1, 2, 3, 4, 5}; int total = std::accumulate(numbers.begin(), numbers.end(), 0); Zip --- Iterate over multiple sequences in parallel. Python's zip is built-in, while C++ requires manual index-based iteration to achieve the same result. **Python** .. code-block:: python names = ["Alice", "Bob"] ages = [30, 25] for name, age in zip(names, ages): print(name, age) **C++** .. code-block:: cpp #include #include std::vector names = {"Alice", "Bob"}; std::vector ages = {30, 25}; for (size_t i = 0; i < std::min(names.size(), ages.size()); i++) { std::cout << names[i] << " " << ages[i] << std::endl; } Default Arguments ----------------- Functions can have default parameter values that are used when arguments aren't provided. Both languages support this feature with similar syntax. **Python** .. code-block:: python def greet(name, greeting="Hello"): return f"{greeting}, {name}" print(greet("Alice")) print(greet("Bob", "Hi")) **C++** .. code-block:: cpp #include std::string greet(std::string name, std::string greeting = "Hello") { return greeting + ", " + name; } std::cout << greet("Alice") << std::endl; std::cout << greet("Bob", "Hi") << std::endl; ================================================ FILE: docs/notes/extension/index.rst ================================================ .. meta:: :description lang=en: Python C/C++ extension tutorial covering pybind11, ctypes, cffi, Cython, and the Python C API for high-performance native code :keywords: Python, Python3, pybind11, ctypes, cffi, Cython, C extension, C API, native code, performance, NumPy, shared library Extension ========= Python's flexibility comes at a performance cost. When you need speed for numerical computing, system interfaces, or wrapping existing C/C++ libraries, native extensions bridge the gap. This section covers multiple approaches: - **ctypes** - Standard library FFI for calling C functions without compilation - **Python C API** - Traditional approach with maximum control (legacy) - **pybind11** (recommended for C++) - Clean C++11 syntax with automatic type conversions, used by PyTorch, TensorFlow, and SciPy - **cffi** - Cleaner alternative to ctypes with PyPy compatibility - **Cython** - Python-like syntax that compiles to C for gradual optimization For most new projects wrapping C++ code, pybind11 is the recommended choice. For calling existing C libraries without a build step, use ctypes or cffi. .. toctree:: :maxdepth: 1 python-ctypes python-capi python-cext-modern cpp-from-python ================================================ FILE: docs/notes/extension/python-capi.rst ================================================ .. meta:: :description lang=en: Comprehensive Python C API tutorial covering native C extension development, module creation, argument parsing, reference counting, GIL management, exception handling, and Python type manipulation including lists, dictionaries, tuples, sets, and strings :keywords: Python, C API, C Extension, PyObject, reference counting, GIL, module, CPython, native extension, PyList, PyDict, PyTuple, PySet, PyUnicode, memory management ============ Python C API ============ .. contents:: Table of Contents :backlinks: none The Python C API is the traditional and most powerful way to write native extensions for CPython. It provides direct access to Python's internals, giving developers complete control over memory management, object creation, and interpreter interaction. While more verbose than modern alternatives like pybind11 or cffi, the C API remains essential for maintaining legacy extensions, understanding how Python works internally, implementing performance-critical code paths, or accessing low-level features not exposed by higher-level tools. The C API is also the foundation upon which tools like Cython and pybind11 are built. .. warning:: The C extension interface is specific to CPython and may not work on alternative Python implementations like PyPy, Jython, or GraalPython. Additionally, the API can change between Python versions (especially between Python 2 and Python 3), requiring careful version handling and conditional compilation for compatibility. Simple setup.py --------------- Building C extensions requires a ``setup.py`` file that tells Python how to compile your C code. The ``distutils`` module (or its modern replacement ``setuptools``) handles cross-platform compilation, linking against the Python library, and generating the correct shared library format (.so on Linux, .dylib on macOS, .pyd on Windows). This minimal example compiles a single C file into a Python-importable module. .. code-block:: python from distutils.core import setup, Extension ext = Extension('foo', sources=['foo.c']) setup(name="Foo", version="1.0", ext_modules=[ext]) Build and install: .. code-block:: bash $ python setup.py build $ python setup.py install Customize CFLAGS ---------------- For production-quality extensions, you'll want to customize compiler flags to enable warnings, optimizations, or debugging symbols. The ``extra_compile_args`` parameter passes flags directly to the C compiler (gcc, clang, or MSVC). Common flags include ``-Wall`` and ``-Wextra`` for comprehensive warnings, ``-Werror`` to treat warnings as errors, ``-O3`` for aggressive optimization, and ``-g`` for debug symbols. .. code-block:: python import sysconfig from distutils.core import setup, Extension cflags = sysconfig.get_config_var("CFLAGS") extra_compile_args = cflags.split() extra_compile_args += ["-Wextra", "-Wall", "-Werror"] ext = Extension( "foo", ["foo.c"], extra_compile_args=extra_compile_args ) setup(name="foo", version="1.0", ext_modules=[ext]) Simple C Extension ------------------ :Source: `src/cext/capi/simple.c `_ Every Python C extension follows a standard structure with three essential components: a module definition (``PyModuleDef``) that describes the module's name and methods, a method table (``PyMethodDef``) that maps Python function names to C functions, and an initialization function (``PyInit_``) that Python calls when importing the module. The ``PyDoc_STRVAR`` macro creates docstrings that appear in Python's ``help()`` system. This example demonstrates a minimal working extension. **foo.c:** .. code-block:: c #include PyDoc_STRVAR(doc_mod, "Module document\n"); PyDoc_STRVAR(doc_foo, "foo() -> None\n\nPrint 'foo' to stdout."); static PyObject* foo(PyObject* self) { PyObject* s = PyUnicode_FromString("foo"); PyObject_Print(s, stdout, 0); Py_DECREF(s); Py_RETURN_NONE; } static PyMethodDef methods[] = { {"foo", (PyCFunction)foo, METH_NOARGS, doc_foo}, {NULL, NULL, 0, NULL} }; static struct PyModuleDef module = { PyModuleDef_HEAD_INIT, "foo", /* module name */ doc_mod, /* docstring */ -1, /* size of per-interpreter state (-1 = global) */ methods }; PyMODINIT_FUNC PyInit_foo(void) { return PyModule_Create(&module); } Output: .. code-block:: bash $ python setup.py -q build && python setup.py -q install $ python -c "import foo; foo.foo()" 'foo' Parse Arguments --------------- :Source: `src/cext/capi/args.c `_ C extension functions receive arguments as ``PyObject*`` pointers and must parse them into C types. The ``PyArg_ParseTuple()`` function handles positional arguments using format codes: ``i`` for int, ``l`` for long, ``d`` for double, ``s`` for string, ``O`` for any Python object. For keyword arguments, use ``PyArg_ParseTupleAndKeywords()``. The method flags (``METH_NOARGS``, ``METH_O``, ``METH_VARARGS``, ``METH_KEYWORDS``) tell Python how to call your function and must match your implementation. .. code-block:: c #include // No arguments: METH_NOARGS static PyObject * foo(PyObject *self) { Py_RETURN_NONE; } // Single object argument: METH_O static PyObject * bar(PyObject *self, PyObject *arg) { return Py_BuildValue("O", arg); } // Multiple positional arguments: METH_VARARGS static PyObject * baz(PyObject *self, PyObject *args) { PyObject *x = NULL, *y = NULL; if (!PyArg_ParseTuple(args, "OO", &x, &y)) { return NULL; } return Py_BuildValue("OO", x, y); } // Keyword arguments: METH_VARARGS | METH_KEYWORDS static PyObject * qux(PyObject *self, PyObject *args, PyObject *kwargs) { static char *keywords[] = {"x", "y", NULL}; PyObject *x = NULL, *y = NULL; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O", keywords, &x, &y)) { return NULL; } if (!y) { y = Py_None; } return Py_BuildValue("OO", x, y); } static PyMethodDef methods[] = { {"foo", (PyCFunction)foo, METH_NOARGS, NULL}, {"bar", (PyCFunction)bar, METH_O, NULL}, {"baz", (PyCFunction)baz, METH_VARARGS, NULL}, {"qux", (PyCFunction)qux, METH_VARARGS | METH_KEYWORDS, NULL}, {NULL, NULL, 0, NULL} }; Output: .. code-block:: bash >>> import foo >>> foo.foo() >>> foo.bar(3.7) 3.7 >>> foo.baz(3, 7) (3, 7) >>> foo.qux(x=3, y=7) (3, 7) Release the GIL --------------- :Source: `src/cext/capi/gil.c `_ The Global Interpreter Lock (GIL) is a mutex that protects access to Python objects, preventing multiple native threads from executing Python bytecode simultaneously. While the GIL simplifies CPython's implementation, it can become a bottleneck for CPU-bound multi-threaded code. For long-running C operations that don't access Python objects (file I/O, network calls, heavy computation), release the GIL using ``Py_BEGIN_ALLOW_THREADS`` and ``Py_END_ALLOW_THREADS`` macros. This allows other Python threads to run concurrently, dramatically improving multi-threaded performance. .. code-block:: c #include static PyObject* foo(PyObject* self) { Py_BEGIN_ALLOW_THREADS sleep(3); // Blocking operation - other threads can run Py_END_ALLOW_THREADS Py_RETURN_NONE; } **With GIL released** (threads run concurrently): .. code-block:: bash >>> import threading, foo >>> from datetime import datetime >>> def f(n): ... print(f'{datetime.now()}: thread {n}') ... foo.foo() >>> ts = [threading.Thread(target=f, args=(n,)) for n in range(3)] >>> [t.start() for t in ts]; [t.join() for t in ts] 2018-11-04 20:15:34.860454: thread 0 2018-11-04 20:15:34.860592: thread 1 # Same time! 2018-11-04 20:15:34.860705: thread 2 **Without GIL release** (threads run sequentially): .. code-block:: bash 2018-11-04 20:16:44.055932: thread 0 2018-11-04 20:16:47.059718: thread 1 # 3 seconds later 2018-11-04 20:16:50.063579: thread 2 # 3 seconds later .. warning:: Never call Python C API functions between ``Py_BEGIN_ALLOW_THREADS`` and ``Py_END_ALLOW_THREADS``. The GIL must be held to safely access Python objects. Acquire the GIL --------------- :Source: `src/cext/capi/gil.c `_ When threads are created from C code (using pthreads or platform threading APIs), they don't automatically hold the GIL. Before these threads can safely call any Python C API function or access Python objects, they must acquire the GIL using ``PyGILState_Ensure()``. After completing Python operations, release the GIL with ``PyGILState_Release()`` to allow other threads to run. Failing to acquire the GIL before accessing Python objects leads to crashes, data corruption, or undefined behavior. .. code-block:: c void *worker_thread(void *arg) { PyObject *result = NULL; PyObject *callback = (PyObject *)arg; // Do C work here (no GIL needed) do_heavy_computation(); // Acquire GIL before calling Python PyGILState_STATE state = PyGILState_Ensure(); result = PyObject_CallFunction(callback, "s", "Done!"); Py_XDECREF(result); // Release GIL PyGILState_Release(state); return NULL; } Raise Exception --------------- :Source: `src/cext/capi/errors.c `_ Error handling in C extensions follows a simple pattern: set an exception using ``PyErr_SetString()`` or ``PyErr_Format()``, then return ``NULL`` to signal failure. Python provides built-in exception types as global variables: ``PyExc_ValueError``, ``PyExc_TypeError``, ``PyExc_RuntimeError``, ``PyExc_KeyError``, ``PyExc_IndexError``, and many others. Always check return values from C API functions and propagate errors by returning ``NULL`` when an exception is already set. .. code-block:: c static PyObject* foo(PyObject* self) { PyErr_SetString(PyExc_NotImplementedError, "Not implemented"); return NULL; } Output: .. code-block:: bash >>> import foo; foo.foo() Traceback (most recent call last): File "", line 1, in NotImplementedError: Not implemented Custom Exception ---------------- :Source: `src/cext/capi/errors.c `_ For domain-specific errors, create custom exception classes using ``PyErr_NewException()``. The first argument is the fully-qualified name (``"module.ExceptionName"``), the second is the base class (``NULL`` defaults to ``Exception``), and the third is an optional dictionary of class attributes. Register the exception as a module attribute so Python code can catch it with ``except module.ExceptionName``. Remember to ``Py_INCREF()`` the exception object before adding it to the module to prevent premature garbage collection. .. code-block:: c static PyObject *FooError; static PyObject * foo(PyObject *self) { PyErr_SetString(FooError, "Something went wrong"); return NULL; } PyMODINIT_FUNC PyInit_foo(void) { PyObject *m = PyModule_Create(&module); if (!m) return NULL; FooError = PyErr_NewException("foo.FooError", NULL, NULL); Py_INCREF(FooError); PyModule_AddObject(m, "FooError", FooError); return m; } Output: .. code-block:: bash >>> import foo; foo.foo() Traceback (most recent call last): File "", line 1, in foo.FooError: Something went wrong Reference Counting ------------------ Python uses reference counting as its primary memory management strategy. Every ``PyObject*`` maintains a count of how many references point to it. When the count reaches zero, the object is deallocated. Use ``Py_INCREF()`` when storing a new reference to an object and ``Py_DECREF()`` when you're done with it. The variant ``Py_XDECREF()`` safely handles ``NULL`` pointers. Understanding reference counting is crucial for avoiding memory leaks (forgetting to decref) and use-after-free bugs (decrefing too early). Functions that return "new references" transfer ownership to the caller, while "borrowed references" should not be decrefed. .. code-block:: c static PyObject * getrefcount(PyObject *self, PyObject *a) { return PyLong_FromSsize_t(Py_REFCNT(a)); } Output: .. code-block:: bash >>> import sys, foo >>> l = [1, 2, 3] >>> sys.getrefcount(l[0]) 104 >>> foo.getrefcount(l[0]) 104 >>> i = l[0] # New reference >>> foo.getrefcount(l[0]) 105 Iterate a List -------------- :Source: `src/cext/capi/types_demo.c `_ The Python C API provides two approaches for iterating over sequences. The generic iterator protocol using ``PyObject_GetIter()`` and ``PyIter_Next()`` works with any iterable object (lists, tuples, generators, custom iterables). For lists specifically, you can use ``PyList_Size()`` and ``PyList_GetItem()`` for direct indexed access, which is slightly faster but less flexible. Note that ``PyList_GetItem()`` returns a borrowed reference, while ``PyIter_Next()`` returns a new reference that must be decrefed after use. .. code-block:: c static PyObject *iter_list(PyObject *self, PyObject *args) { PyObject *list, *iter, *item; if (!PyArg_ParseTuple(args, "O", &list)) { return NULL; } iter = PyObject_GetIter(list); if (!iter) return NULL; PyObject *result = PyList_New(0); while ((item = PyIter_Next(iter)) != NULL) { PyObject *doubled = PyLong_FromLong(PyLong_AsLong(item) * 2); PyList_Append(result, doubled); Py_DECREF(doubled); Py_DECREF(item); } Py_DECREF(iter); return result; } Output: .. code-block:: bash >>> import types_demo >>> types_demo.iter_list([1, 2, 3]) [2, 4, 6] Iterate a Dictionary -------------------- :Source: `src/cext/capi/types_demo.c `_ Dictionary iteration in C uses ``PyDict_Next()`` with a position variable that tracks the iteration state. Initialize ``pos`` to 0 before the loop, and the function updates it automatically. The key and value pointers receive borrowed references to the current item on each iteration—do not decref them unless you incref first. This function is safe to use even if the dictionary is modified during iteration (though modifications may cause items to be skipped or visited twice). For read-only iteration, this is the most efficient approach. .. code-block:: c static PyObject *iter_dict(PyObject *self, PyObject *args) { PyObject *dict; if (!PyArg_ParseTuple(args, "O!", &PyDict_Type, &dict)) { return NULL; } PyObject *result = PyList_New(0); PyObject *key, *value; Py_ssize_t pos = 0; while (PyDict_Next(dict, &pos, &key, &value)) { PyObject *pair = PyTuple_Pack(2, key, value); PyList_Append(result, pair); Py_DECREF(pair); } return result; } Output: .. code-block:: bash >>> import types_demo >>> types_demo.iter_dict({"a": 1, "b": 2}) [('a', 1), ('b', 2)] Create a List ------------- :Source: `src/cext/capi/types_demo.c `_ Creating Python lists from C involves ``PyList_New()`` to allocate the list and ``PyList_Append()`` or ``PyList_SetItem()`` to populate it. When using ``PyList_New(n)`` with a non-zero size, you must initialize all slots with ``PyList_SetItem()`` before the list is used. ``PyList_SetItem()`` steals a reference to the item, while ``PyList_Append()`` increments the reference count. For building lists dynamically, start with ``PyList_New(0)`` and use ``PyList_Append()``. .. code-block:: c static PyObject *list_demo(PyObject *self) { PyObject *list = PyList_New(0); PyList_Append(list, PyLong_FromLong(1)); PyList_Append(list, PyLong_FromLong(2)); PyList_Append(list, PyLong_FromLong(3)); return list; } Output: .. code-block:: bash >>> import types_demo >>> types_demo.list_demo() [1, 2, 3] Create a Dictionary ------------------- :Source: `src/cext/capi/types_demo.c `_ Python dictionaries are created with ``PyDict_New()`` and populated using ``PyDict_SetItem()`` for PyObject keys or ``PyDict_SetItemString()`` for C string keys. Both functions increment the reference count of the value, so you may need to decref temporary objects after insertion. For retrieving values, ``PyDict_GetItem()`` and ``PyDict_GetItemString()`` return borrowed references (or NULL if the key doesn't exist), while ``PyDict_GetItemWithError()`` distinguishes between missing keys and errors. .. code-block:: c static PyObject *dict_demo(PyObject *self) { PyObject *dict = PyDict_New(); PyDict_SetItemString(dict, "name", PyUnicode_FromString("Python")); PyDict_SetItemString(dict, "version", PyLong_FromLong(3)); return dict; } Output: .. code-block:: bash >>> import types_demo >>> types_demo.dict_demo() {'name': 'Python', 'version': 3} Create a Tuple -------------- :Source: `src/cext/capi/types_demo.c `_ Tuples are immutable sequences, and the C API reflects this: once created and populated, tuple contents cannot be changed. Use ``PyTuple_New(n)`` to create a tuple of size n, then ``PyTuple_SetItem()`` to fill each slot (this steals a reference). Alternatively, ``Py_BuildValue()`` with parentheses format creates tuples directly from C values: ``"(isd)"`` creates a tuple of (int, string, double). ``PyTuple_Pack(n, ...)`` is another convenient way to create tuples from existing PyObject pointers. .. code-block:: c static PyObject *tuple_demo(PyObject *self) { return Py_BuildValue("(isd)", 1, "hello", 3.14); } Output: .. code-block:: bash >>> import types_demo >>> types_demo.tuple_demo() (1, 'hello', 3.14) Create a Set ------------ :Source: `src/cext/capi/types_demo.c `_ Sets are unordered collections of unique hashable objects. Create an empty set with ``PySet_New(NULL)`` or initialize from an iterable with ``PySet_New(iterable)``. Add elements with ``PySet_Add()``, which returns 0 on success or -1 on error (e.g., if the object is unhashable). Check membership with ``PySet_Contains()``, remove elements with ``PySet_Discard()`` (no error if missing) or ``PySet_Pop()`` to remove and return an arbitrary element. Duplicate additions are silently ignored. .. code-block:: c static PyObject *set_demo(PyObject *self) { PyObject *set = PySet_New(NULL); PySet_Add(set, PyLong_FromLong(1)); PySet_Add(set, PyLong_FromLong(2)); PySet_Add(set, PyLong_FromLong(2)); /* duplicate ignored */ PySet_Add(set, PyLong_FromLong(3)); return set; } Output: .. code-block:: bash >>> import types_demo >>> types_demo.set_demo() {1, 2, 3} String Operations ----------------- :Source: `src/cext/capi/types_demo.c `_ Python 3 strings are Unicode objects, created with ``PyUnicode_FromString()`` for UTF-8 encoded C strings or ``PyUnicode_FromFormat()`` for printf-style formatting. Concatenate strings with ``PyUnicode_Concat()``, which returns a new string object. For extracting C strings, use ``PyUnicode_AsUTF8()`` (returns a borrowed pointer valid only while the object exists) or ``PyUnicode_AsUTF8AndSize()`` to also get the length. The format function supports ``%s`` for C strings, ``%S`` for Python objects (calls str()), ``%R`` for repr(), and ``%d``, ``%u``, ``%ld`` for integers. .. code-block:: c static PyObject *str_demo(PyObject *self) { PyObject *s1 = PyUnicode_FromString("Hello"); PyObject *s2 = PyUnicode_FromString(" World"); PyObject *result = PyUnicode_Concat(s1, s2); Py_DECREF(s1); Py_DECREF(s2); return result; } static PyObject *str_format(PyObject *self, PyObject *args) { const char *name; int age; if (!PyArg_ParseTuple(args, "si", &name, &age)) { return NULL; } return PyUnicode_FromFormat("%s is %d years old", name, age); } Output: .. code-block:: bash >>> import types_demo >>> types_demo.str_demo() 'Hello World' >>> types_demo.str_format("Alice", 30) 'Alice is 30 years old' Bytes Operations ---------------- :Source: `src/cext/capi/types_demo.c `_ Bytes objects represent immutable sequences of bytes, essential for binary data, file I/O, and network protocols. Create bytes from C strings with ``PyBytes_FromString()`` (copies until null terminator) or ``PyBytes_FromStringAndSize()`` for binary data with embedded nulls. Access the internal buffer with ``PyBytes_AsString()`` (borrowed pointer) and get the length with ``PyBytes_Size()``. For mutable byte sequences, use ``PyByteArray_*`` functions instead. When parsing arguments, use format code ``s#`` for bytes with length or ``S`` to accept only bytes objects. .. code-block:: c static PyObject *bytes_demo(PyObject *self) { return PyBytes_FromString("hello bytes"); } static PyObject *bytes_len(PyObject *self, PyObject *args) { PyObject *bytes; if (!PyArg_ParseTuple(args, "S", &bytes)) { return NULL; } return PyLong_FromSsize_t(PyBytes_Size(bytes)); } Output: .. code-block:: bash >>> import types_demo >>> types_demo.bytes_demo() b'hello bytes' >>> types_demo.bytes_len(b"hello") 5 Simple Class ------------ Defining custom Python types in C requires creating a ``PyTypeObject`` structure that describes the type's behavior, memory layout, and methods. The minimal type needs ``tp_name`` (fully qualified name like ``"module.ClassName"``), ``tp_basicsize`` (size of the instance struct), and ``tp_new`` (allocation function, often ``PyType_GenericNew``). Call ``PyType_Ready()`` to finalize the type before use, then add it to your module with ``PyModule_AddObject()``. The ``Py_TPFLAGS_DEFAULT`` flag enables standard type features. .. code-block:: c typedef struct { PyObject_HEAD } FooObject; static PyTypeObject FooType = { PyVarObject_HEAD_INIT(NULL, 0) .tp_name = "foo.Foo", .tp_doc = "Foo objects", .tp_basicsize = sizeof(FooObject), .tp_itemsize = 0, .tp_flags = Py_TPFLAGS_DEFAULT, .tp_new = PyType_GenericNew }; PyMODINIT_FUNC PyInit_foo(void) { PyObject *m = NULL; if (PyType_Ready(&FooType) < 0) return NULL; if ((m = PyModule_Create(&module)) == NULL) return NULL; Py_INCREF(&FooType); PyModule_AddObject(m, "Foo", (PyObject *)&FooType); return m; } Class with Members and Methods ------------------------------ Full-featured Python classes in C require implementing several type slots. Use ``PyMemberDef`` to expose C struct fields as Python attributes (with automatic type conversion), and ``PyMethodDef`` for instance methods. Implement ``tp_new`` for memory allocation (called before ``__init__``), ``tp_init`` for initialization (the ``__init__`` method), and ``tp_dealloc`` for cleanup (must decref all owned PyObject members and call ``tp_free``). The ``Py_TPFLAGS_BASETYPE`` flag allows the type to be subclassed in Python. .. code-block:: c #include #include typedef struct { PyObject_HEAD PyObject *foo; PyObject *bar; } FooObject; static void Foo_dealloc(FooObject *self) { Py_XDECREF(self->foo); Py_XDECREF(self->bar); Py_TYPE(self)->tp_free((PyObject *)self); } static PyObject * Foo_new(PyTypeObject *type, PyObject *args, PyObject *kw) { FooObject *self = (FooObject *)type->tp_alloc(type, 0); if (self) { self->foo = PyUnicode_FromString(""); self->bar = PyUnicode_FromString(""); } return (PyObject *)self; } static int Foo_init(FooObject *self, PyObject *args, PyObject *kw) { static char *keywords[] = {"foo", "bar", NULL}; PyObject *foo = NULL, *bar = NULL; if (!PyArg_ParseTupleAndKeywords(args, kw, "|OO", keywords, &foo, &bar)) return -1; if (foo) { Py_INCREF(foo); Py_XDECREF(self->foo); self->foo = foo; } if (bar) { Py_INCREF(bar); Py_XDECREF(self->bar); self->bar = bar; } return 0; } static PyMemberDef Foo_members[] = { {"foo", T_OBJECT_EX, offsetof(FooObject, foo), 0, "foo attribute"}, {"bar", T_OBJECT_EX, offsetof(FooObject, bar), 0, "bar attribute"}, {NULL} }; Output: .. code-block:: bash >>> import foo >>> o = foo.Foo('hello', 'world') >>> o.foo 'hello' >>> o.bar 'world' Properties (Getter/Setter) -------------------------- For computed attributes or attributes requiring validation, use ``PyGetSetDef`` to define properties with custom getter and setter functions. The getter receives the instance and an optional closure pointer, returning a new reference to the attribute value. The setter receives the instance, new value (or NULL for deletion), and closure, returning 0 on success or -1 on error. This provides the same functionality as Python's ``@property`` decorator but with C-level control over attribute access. .. code-block:: c static PyObject * Foo_getfoo(FooObject *self, void *closure) { Py_INCREF(self->foo); return self->foo; } static int Foo_setfoo(FooObject *self, PyObject *value, void *closure) { if (!value || !PyUnicode_Check(value)) { PyErr_SetString(PyExc_TypeError, "value must be a string"); return -1; } Py_INCREF(value); Py_XDECREF(self->foo); self->foo = value; return 0; } static PyGetSetDef Foo_getsetters[] = { {"foo", (getter)Foo_getfoo, (setter)Foo_setfoo, "foo property", NULL}, {NULL} }; Calling Python from C --------------------- C extensions often need to call back into Python code—invoking callbacks, calling methods on objects, or using Python library functions. Use ``PyObject_CallFunction()`` for calling with C-style format arguments, ``PyObject_CallObject()`` with a tuple of arguments, or ``PyObject_CallMethod()`` to call a method by name. Always check if the callable is valid with ``PyCallable_Check()`` before calling, and check the return value for NULL (indicating an exception was raised). The GIL must be held when calling Python functions. .. code-block:: c static PyObject * call_callback(PyObject *self, PyObject *args) { PyObject *callback = NULL; PyObject *result = NULL; if (!PyArg_ParseTuple(args, "O:callback", &callback)) return NULL; if (!PyCallable_Check(callback)) { PyErr_SetString(PyExc_TypeError, "argument must be callable"); return NULL; } // Call: callback("Hello from C!") result = PyObject_CallFunction(callback, "s", "Hello from C!"); return result; } Output: .. code-block:: bash >>> import foo >>> foo.call_callback(print) Hello from C! Performance Comparison ---------------------- :Source: `src/cext/capi/simple.c `_ C extensions provide dramatic speedups for CPU-bound operations by eliminating Python's interpreter overhead. This recursive Fibonacci benchmark demonstrates typical performance gains of 50-100x compared to pure Python. The speedup comes from avoiding Python object creation, method dispatch, and bytecode interpretation on each function call. For numerical code, the gains can be even larger when combined with SIMD instructions or multi-threading (with GIL released). However, the overhead of crossing the Python/C boundary means C extensions are most beneficial for compute-intensive inner loops rather than simple operations. .. code-block:: c static unsigned long fib(unsigned long n) { if (n < 2) return n; return fib(n - 1) + fib(n - 2); } static PyObject * py_fib(PyObject *self, PyObject *args) { unsigned long n = 0; if (!PyArg_ParseTuple(args, "k", &n)) return NULL; return PyLong_FromUnsignedLong(fib(n)); } .. code-block:: python >>> from time import time >>> def py_fib(n): ... if n < 2: return n ... return py_fib(n-1) + py_fib(n-2) ... >>> s = time(); _ = py_fib(35); e = time(); e - s 4.953313112258911 >>> import foo >>> s = time(); _ = foo.fib(35); e = time(); e - s 0.04628586769104004 ================================================ FILE: docs/notes/extension/python-cext-modern.rst ================================================ .. meta:: :description lang=en: Comprehensive guide to modern Python C/C++ extensions covering pybind11, ctypes, cffi, and Cython with practical examples for building high-performance Python modules, NumPy integration, GIL management, and class bindings :keywords: Python, Python3, pybind11, C Extension, C++, ctypes, cffi, Cython, NumPy, Performance, GIL, Native Code, PyTorch, TensorFlow, SciPy, shared library ========================= Modern C/C++ Extensions ========================= .. contents:: Table of Contents :backlinks: none Python's flexibility and ease of use come at a performance cost compared to compiled languages. When you need maximum speed for numerical computing, real-time processing, system interfaces, or wrapping existing C/C++ libraries, native extensions bridge the gap between Python's productivity and C/C++'s performance. This guide covers modern approaches to building Python extensions: **pybind11** (the recommended choice for C++ projects), **ctypes/cffi** (for calling C libraries without compilation), and **Cython** (for Python-like syntax that compiles to C). We compare each approach using the same Fibonacci benchmark to help you choose the right tool for your specific use case. .. note:: For most C++ projects, **pybind11** is the recommended choice. It's used by major machine learning frameworks like PyTorch, TensorFlow, and scientific computing libraries like SciPy. pybind11 provides clean C++11 syntax, automatic type conversions between Python and C++ types, excellent NumPy integration for numerical computing, and seamless exception handling across language boundaries. Comparison of Approaches ------------------------ :: ┌─────────────────────────────────────────────────────────────────────────┐ │ C/C++ EXTENSION APPROACHES │ ├─────────────────────────────────────────────────────────────────────────┤ │ APPROACH │ PROS │ CONS │ ├────────────────┼───────────────────────────┼────────────────────────────┤ │ pybind11 │ Clean C++11 syntax │ Requires C++ compiler │ │ (recommended) │ Automatic type conversion │ Compile step needed │ │ │ NumPy support built-in │ C++ only (not C) │ │ │ Used by PyTorch, SciPy │ │ ├────────────────┼───────────────────────────┼────────────────────────────┤ │ ctypes │ No compilation needed │ Manual type declarations │ │ │ Standard library │ Error-prone │ │ │ Works with any C library │ No C++ support │ ├────────────────┼───────────────────────────┼────────────────────────────┤ │ cffi │ No compilation needed │ Extra dependency │ │ │ Cleaner than ctypes │ No C++ support │ │ │ PyPy compatible │ │ ├────────────────┼───────────────────────────┼────────────────────────────┤ │ Cython │ Python-like syntax │ New language to learn │ │ │ Gradual optimization │ Build complexity │ │ │ Good NumPy integration │ Debugging harder │ ├────────────────┼───────────────────────────┼────────────────────────────┤ │ Python C API │ Maximum control │ Very verbose │ │ (legacy) │ No dependencies │ Manual refcounting │ │ │ │ Error-prone │ └────────────────┴───────────────────────────┴────────────────────────────┘ When to use what: - Wrapping existing C++ library → pybind11 - Wrapping existing C library → ctypes or cffi - Writing new high-perf code → pybind11 or Cython - Need PyPy compatibility → cffi or Cython - Quick prototype → ctypes pybind11: Getting Started ------------------------- :Source: `src/cext/example.cpp `_ pybind11 is a lightweight header-only C++ library that creates Python bindings for existing C++ code. Unlike the traditional Python C API, pybind11 uses modern C++11 features like variadic templates and lambda expressions to provide a clean, intuitive syntax. Installation is simple with ``pip install pybind11``, and it requires only a C++11 compatible compiler (GCC 4.8+, Clang 3.3+, MSVC 2015+). The ``PYBIND11_MODULE`` macro defines the module entry point, and ``m.def()`` binds C++ functions to Python with automatic type conversion for common types like int, float, string, and STL containers. **Simple function binding:** .. code-block:: cpp // example.cpp #include int add(int a, int b) { return a + b; } // Fibonacci for performance comparison unsigned long fib(unsigned long n) { if (n < 2) return n; return fib(n - 1) + fib(n - 2); } PYBIND11_MODULE(example, m) { m.doc() = "Example pybind11 module"; m.def("add", &add, "Add two integers", pybind11::arg("a"), pybind11::arg("b")); m.def("fib", &fib, "Compute Fibonacci number"); } **Build with setup.py:** .. code-block:: python # setup.py from pybind11.setup_helpers import Pybind11Extension, build_ext from setuptools import setup ext_modules = [ Pybind11Extension( "example", ["example.cpp"], ), ] setup( name="example", ext_modules=ext_modules, cmdclass={"build_ext": build_ext}, ) .. code-block:: bash $ pip install pybind11 $ python setup.py build_ext --inplace $ python -c "import example; print(example.add(1, 2))" 3 $ python -c "import example; print(example.fib(35))" 9227465 pybind11: Classes ----------------- :Source: `src/cext/vector.cpp `_ pybind11 makes binding C++ classes to Python straightforward with the ``py::class_`` template. You can expose constructors with ``def(py::init<...>())``, member variables with ``def_readwrite()`` or ``def_readonly()``, and methods with ``def()``. Python special methods like ``__repr__``, ``__add__``, ``__eq__`` are bound by name, enabling natural Python syntax for your C++ objects. Default argument values are supported via ``py::arg()``, and you can add docstrings to improve the Python help() experience. .. code-block:: cpp // vector.cpp #include #include namespace py = pybind11; class Vector2D { public: double x, y; Vector2D(double x = 0, double y = 0) : x(x), y(y) {} double length() const { return std::sqrt(x * x + y * y); } Vector2D operator+(const Vector2D& other) const { return Vector2D(x + other.x, y + other.y); } std::string repr() const { return "Vector2D(" + std::to_string(x) + ", " + std::to_string(y) + ")"; } }; PYBIND11_MODULE(vector, m) { py::class_(m, "Vector2D") .def(py::init(), py::arg("x") = 0, py::arg("y") = 0) .def_readwrite("x", &Vector2D::x) .def_readwrite("y", &Vector2D::y) .def("length", &Vector2D::length) .def("__add__", &Vector2D::operator+) .def("__repr__", &Vector2D::repr); } .. code-block:: python >>> from vector import Vector2D >>> v1 = Vector2D(3, 4) >>> v1.length() 5.0 >>> v2 = Vector2D(1, 2) >>> v3 = v1 + v2 >>> v3 Vector2D(4.0, 6.0) pybind11: NumPy Integration --------------------------- :Source: `src/cext/numpy_example.cpp `_ pybind11 provides first-class NumPy support through the ``pybind11/numpy.h`` header, enabling high-performance numerical computing without data copying. The ``py::array_t`` template wraps NumPy arrays with type safety, and ``unchecked()`` provides fast, bounds-check-free access for performance-critical inner loops. Arrays can be modified in-place using ``mutable_unchecked()``, or new arrays can be created and returned. This zero-copy approach is essential for scientific computing where large datasets would be expensive to duplicate. .. code-block:: cpp // numpy_example.cpp #include #include namespace py = pybind11; // Element-wise multiply (modifies in place) void multiply_inplace(py::array_t arr, double factor) { auto buf = arr.mutable_unchecked<1>(); for (py::ssize_t i = 0; i < buf.shape(0); i++) { buf(i) *= factor; } } // Return new array py::array_t add_arrays(py::array_t a, py::array_t b) { auto buf_a = a.unchecked<1>(); auto buf_b = b.unchecked<1>(); if (buf_a.shape(0) != buf_b.shape(0)) { throw std::runtime_error("Arrays must have same length"); } auto result = py::array_t(buf_a.shape(0)); auto buf_r = result.mutable_unchecked<1>(); for (py::ssize_t i = 0; i < buf_a.shape(0); i++) { buf_r(i) = buf_a(i) + buf_b(i); } return result; } PYBIND11_MODULE(numpy_example, m) { m.def("multiply_inplace", &multiply_inplace); m.def("add_arrays", &add_arrays); } .. code-block:: python >>> import numpy as np >>> from numpy_example import multiply_inplace, add_arrays >>> arr = np.array([1.0, 2.0, 3.0]) >>> multiply_inplace(arr, 2.0) >>> arr array([2., 4., 6.]) >>> a = np.array([1.0, 2.0, 3.0]) >>> b = np.array([4.0, 5.0, 6.0]) >>> add_arrays(a, b) array([5., 7., 9.]) pybind11: Releasing the GIL --------------------------- :Source: `src/cext/gil_example.cpp `_ Python's Global Interpreter Lock (GIL) prevents true parallel execution of Python code across threads. For CPU-intensive C++ operations or blocking I/O that doesn't need Python objects, releasing the GIL allows other Python threads to run concurrently. pybind11 provides ``py::gil_scoped_release`` for RAII-style GIL management—the GIL is released when the object is created and automatically reacquired when it goes out of scope. This pattern is essential for multi-threaded applications where C++ code performs heavy computation while Python threads handle other tasks like UI updates or network I/O. .. code-block:: cpp #include #include #include namespace py = pybind11; // Slow operation that releases GIL void slow_operation(int seconds) { // Release GIL while doing CPU work py::gil_scoped_release release; // Simulate slow work std::this_thread::sleep_for(std::chrono::seconds(seconds)); } // CPU-intensive work unsigned long fib_nogil(unsigned long n) { py::gil_scoped_release release; std::function fib_impl; fib_impl = [&](unsigned long n) -> unsigned long { if (n < 2) return n; return fib_impl(n - 1) + fib_impl(n - 2); }; return fib_impl(n); } PYBIND11_MODULE(gil_example, m) { m.def("slow_operation", &slow_operation); m.def("fib_nogil", &fib_nogil); } .. code-block:: python import threading from datetime import datetime from gil_example import slow_operation def worker(n): print(f"{datetime.now()}: Thread {n} starting") slow_operation(1) print(f"{datetime.now()}: Thread {n} done") # Threads run in parallel because GIL is released threads = [threading.Thread(target=worker, args=(i,)) for i in range(3)] for t in threads: t.start() for t in threads: t.join() ctypes: Quick C Library Access ------------------------------ :Source: `src/cext/fib.c `_ ctypes is Python's built-in foreign function interface (FFI) that lets you call C functions from shared libraries (.so, .dylib, .dll) without writing any C wrapper code or compiling Python extensions. It's ideal for quick prototyping, accessing system libraries, or wrapping existing C code when you don't want a build step. The key requirement is declaring function signatures with ``argtypes`` and ``restype``—without these declarations, ctypes assumes all arguments and return values are C ``int``, which causes silent bugs or crashes with other types like ``double`` or pointers. .. code-block:: c // fib.c - Compile: gcc -shared -fPIC -o libfib.so fib.c unsigned long fib(unsigned long n) { if (n < 2) return n; return fib(n - 1) + fib(n - 2); } double add_doubles(double a, double b) { return a + b; } .. code-block:: python import ctypes from ctypes import c_ulong, c_double # Load shared library # Linux: libfib.so, macOS: libfib.dylib, Windows: fib.dll lib = ctypes.CDLL("./libfib.so") # Declare function signatures (important for non-int types!) lib.fib.argtypes = [c_ulong] lib.fib.restype = c_ulong lib.add_doubles.argtypes = [c_double, c_double] lib.add_doubles.restype = c_double # Call functions print(lib.fib(35)) # 9227465 print(lib.add_doubles(1.5, 2.5)) # 4.0 ctypes: Structures and Pointers ------------------------------- Working with C structures and pointers in ctypes requires careful type declarations that mirror the C memory layout exactly. Define structures by subclassing ``Structure`` and specifying ``_fields_`` as a list of (name, type) tuples in the same order as the C struct. Use ``POINTER(Type)`` to create pointer types and ``byref(obj)`` to pass objects by reference (equivalent to ``&obj`` in C). This approach is more error-prone than pybind11 but works without any compilation step. .. code-block:: c // point.c typedef struct { double x; double y; } Point; double distance(Point* p1, Point* p2) { double dx = p2->x - p1->x; double dy = p2->y - p1->y; return sqrt(dx*dx + dy*dy); } void scale_point(Point* p, double factor) { p->x *= factor; p->y *= factor; } .. code-block:: python import ctypes from ctypes import Structure, c_double, POINTER, byref import math class Point(Structure): _fields_ = [("x", c_double), ("y", c_double)] lib = ctypes.CDLL("./libpoint.so") lib.distance.argtypes = [POINTER(Point), POINTER(Point)] lib.distance.restype = c_double lib.scale_point.argtypes = [POINTER(Point), c_double] lib.scale_point.restype = None # Create points p1 = Point(0, 0) p2 = Point(3, 4) # Pass by reference dist = lib.distance(byref(p1), byref(p2)) print(f"Distance: {dist}") # 5.0 # Modify in place lib.scale_point(byref(p2), 2.0) print(f"Scaled: ({p2.x}, {p2.y})") # (6.0, 8.0) cffi: Cleaner Foreign Function Interface ---------------------------------------- cffi (C Foreign Function Interface) provides a cleaner, more Pythonic API than ctypes for calling C code. Instead of Python type objects, you declare C function signatures using actual C syntax in ``ffi.cdef()``, which can often be copied directly from header files. cffi handles type conversions automatically and provides better error messages. It's also the recommended FFI for PyPy, where it runs significantly faster than ctypes. Install with ``pip install cffi``. .. code-block:: python from cffi import FFI ffi = FFI() # Declare C functions (copy from header file) ffi.cdef(""" unsigned long fib(unsigned long n); double add_doubles(double a, double b); """) # Load library lib = ffi.dlopen("./libfib.so") # Call functions - types are automatic! print(lib.fib(35)) # 9227465 print(lib.add_doubles(1.5, 2.5)) # 4.0 **cffi with inline C code (ABI mode):** .. code-block:: python from cffi import FFI ffi = FFI() ffi.cdef("unsigned long fib(unsigned long n);") # Compile C code inline ffi.set_source("_fib_cffi", """ unsigned long fib(unsigned long n) { if (n < 2) return n; return fib(n - 1) + fib(n - 2); } """) ffi.compile() # Now import and use from _fib_cffi import lib print(lib.fib(35)) Cython: Python-like Syntax -------------------------- Cython is a programming language that combines Python syntax with C data types, compiling to efficient C code. It's excellent for gradual optimization: start with pure Python code, then add type declarations to critical sections for dramatic speedups. Cython supports three levels of optimization: pure Python (``def``), typed Python (``def`` with type hints), and pure C functions (``cdef``). The ``cdef`` functions run at C speed but can only be called from other Cython code, so you typically wrap them with a ``def`` function for Python access. Install with ``pip install cython``. .. code-block:: cython # fib.pyx def fib_py(n): """Pure Python - slow""" if n < 2: return n return fib_py(n - 1) + fib_py(n - 2) def fib_typed(long n): """With type hints - faster""" if n < 2: return n return fib_typed(n - 1) + fib_typed(n - 2) cdef unsigned long _fib_c(unsigned long n): """C function - fastest""" if n < 2: return n return _fib_c(n - 1) + _fib_c(n - 2) def fib_c(unsigned long n): """Python wrapper for C function""" return _fib_c(n) .. code-block:: python # setup.py from setuptools import setup from Cython.Build import cythonize setup( ext_modules=cythonize("fib.pyx"), ) .. code-block:: bash $ python setup.py build_ext --inplace $ python -c "from fib import fib_c; print(fib_c(35))" 9227465 Performance Comparison ---------------------- Understanding the performance characteristics of each approach helps you choose the right tool. This benchmark compares all approaches using recursive Fibonacci (n=35), a CPU-bound task that highlights the overhead of Python's interpreter. Native code achieves 50-100x speedups by eliminating Python object creation, method dispatch, and bytecode interpretation. The actual speedup varies by workload—numerical code with NumPy integration can see even larger gains, while I/O-bound code benefits less. .. code-block:: python from time import time def benchmark(func, n=35, runs=3): times = [] for _ in range(runs): start = time() result = func(n) times.append(time() - start) return min(times), result # Pure Python def fib_python(n): if n < 2: return n return fib_python(n - 1) + fib_python(n - 2) # Results (approximate, varies by system): # # | Approach | Time (s) | Speedup | # |-------------------|----------|---------| # | Pure Python | 2.50 | 1x | # | Cython (typed) | 0.08 | 31x | # | Cython (cdef) | 0.05 | 50x | # | ctypes | 0.05 | 50x | # | cffi | 0.05 | 50x | # | pybind11 | 0.04 | 62x | # | pybind11 (no GIL) | 0.04 | 62x | Best Practices -------------- Following these guidelines will help you write efficient, maintainable, and safe native extensions. The most common mistakes are holding the GIL during long operations (blocking other threads), copying large arrays unnecessarily (killing performance), and ignoring error handling (causing crashes instead of Python exceptions). **Do:** - Use pybind11 for new C++ bindings - Release the GIL for CPU-intensive operations - Use NumPy arrays for numerical data (zero-copy with pybind11) - Declare types in ctypes/cffi (avoid silent bugs) - Profile before optimizing—find the real bottleneck **Don't:** - Write Python C API code for new projects (use pybind11) - Hold the GIL during blocking I/O or long computations - Copy large arrays between Python and C (use views) - Ignore error handling in C code - Optimize prematurely—Python is often fast enough **Error handling:** .. code-block:: cpp // pybind11 - exceptions automatically convert to Python double divide(double a, double b) { if (b == 0) { throw std::runtime_error("Division by zero"); } return a / b; } // In Python: raises RuntimeError **Memory management:** .. code-block:: cpp // pybind11 handles Python object refcounting automatically // For raw pointers, use return value policies: py::class_(m, "Parent") .def("get_child", &Parent::get_child, py::return_value_policy::reference_internal); // Child tied to Parent lifetime ================================================ FILE: docs/notes/extension/python-ctypes.rst ================================================ .. meta:: :description lang=en: Python ctypes tutorial for loading shared libraries, calling C functions, handling pointers, structures, and error handling :keywords: Python, ctypes, shared library, .so, .dylib, .dll, C library, FFI, foreign function interface, pointers, structures ====== ctypes ====== :Source: `src/basic/cext_.py `_ .. contents:: Table of Contents :backlinks: none ctypes is Python's built-in foreign function interface (FFI) library that allows calling functions in shared libraries (.so on Linux, .dylib on macOS, .dll on Windows) without writing any C code or compiling extensions. It's ideal for quick prototyping, accessing system libraries, or wrapping existing C libraries when you don't want a compilation step. However, ctypes requires manual type declarations and careful memory management, making it more error-prone than alternatives like cffi or pybind11 for complex use cases. Loading Shared Libraries ------------------------ ctypes provides platform-specific loaders for shared libraries. Use ``CDLL`` for standard C calling convention or ``WinDLL`` on Windows for stdcall convention. The library search follows system conventions: ``LD_LIBRARY_PATH`` on Linux, ``DYLD_LIBRARY_PATH`` on macOS, and ``PATH`` on Windows. .. code-block:: python import platform from ctypes import CDLL # Load platform-specific C library if platform.system() == "Darwin": libc = CDLL("libc.dylib") elif platform.system() == "Linux": libc = CDLL("libc.so.6") else: from ctypes import windll libc = windll.msvcrt # Call printf libc.printf(b"Hello from C: %d\n", 42) Loading Custom Libraries ------------------------ For your own compiled C libraries, provide the full path or ensure the library is in the system's library search path. The ``use_errno=True`` parameter enables proper errno handling for error detection. .. code-block:: python from ctypes import CDLL from ctypes.util import find_library import os # Load from current directory lib_path = os.path.join(os.path.dirname(__file__), "libfoo.so") lib = CDLL(lib_path, use_errno=True) # Or use find_library for system libraries libm_path = find_library("m") # finds libm.so or libm.dylib if libm_path: libm = CDLL(libm_path) Basic Type Mapping ------------------ ctypes provides Python equivalents for C types. Always declare argument types (``argtypes``) and return types (``restype``) explicitly to avoid crashes and ensure correct data conversion. Without these declarations, ctypes assumes all arguments and return values are C ``int``. .. code-block:: python import platform from ctypes import CDLL, c_double, c_char_p if platform.system() == "Darwin": libc = CDLL("libc.dylib") else: libc = CDLL("libc.so.6") # Declare function signature: double atof(const char *) libc.atof.argtypes = [c_char_p] libc.atof.restype = c_double result = libc.atof(b"3.14159") print(result) # 3.14159 # Common type mappings: # c_int -> int # c_long -> long # c_double -> double # c_char_p -> char* (bytes in Python) # c_void_p -> void* # c_bool -> _Bool Calling strlen and abs ---------------------- Simple examples calling standard C library functions with proper type declarations. .. code-block:: python import platform from ctypes import CDLL, c_char_p, c_size_t if platform.system() == "Darwin": libc = CDLL("libc.dylib") else: libc = CDLL("libc.so.6") # strlen libc.strlen.argtypes = [c_char_p] libc.strlen.restype = c_size_t assert libc.strlen(b"hello") == 5 # abs (default int types work) assert libc.abs(-42) == 42 Calling sqrt from libm ---------------------- .. code-block:: python import platform from ctypes import CDLL, c_double if platform.system() == "Darwin": libm = CDLL("libm.dylib") else: libm = CDLL("libm.so.6") libm.sqrt.argtypes = [c_double] libm.sqrt.restype = c_double result = libm.sqrt(16.0) assert abs(result - 4.0) < 1e-10 Calling C Functions ------------------- This example shows a complete workflow: compile a C library, load it with ctypes, and call functions with proper type declarations. The Fibonacci function demonstrates the performance benefit of C code called from Python. **C source (fib.c):** .. code-block:: c // Compile: // Linux: gcc -shared -fPIC -o libfib.so fib.c // macOS: clang -shared -fPIC -o libfib.dylib fib.c unsigned long fib(unsigned long n) { if (n < 2) return n; return fib(n - 1) + fib(n - 2); } **Python usage:** .. code-block:: python import platform from ctypes import CDLL, c_ulong # Load the library if platform.system() == "Darwin": lib = CDLL("./libfib.dylib") else: lib = CDLL("./libfib.so") # Declare types lib.fib.argtypes = [c_ulong] lib.fib.restype = c_ulong # Call the function print(lib.fib(35)) # 9227465 **Performance comparison:** .. code-block:: python >>> from time import time >>> def py_fib(n): ... if n < 2: return n ... return py_fib(n - 1) + py_fib(n - 2) ... >>> s = time(); _ = py_fib(35); e = time(); e - s 4.918856859207153 >>> s = time(); _ = lib.fib(35); e = time(); e - s 0.07283687591552734 Pointers and byref ------------------ Use ``byref()`` to pass arguments by reference (like ``&var`` in C) and ``POINTER()`` to create pointer types. ``byref()`` is more efficient than ``pointer()`` when you only need to pass a reference to a function. .. code-block:: python from ctypes import c_int, byref, pointer, POINTER # Pointer to integer value = c_int(42) ptr = pointer(value) assert ptr.contents.value == 42 # Modify through pointer ptr.contents.value = 100 assert value.value == 100 # byref creates a lightweight pointer for passing to C functions ref = byref(value) # Create pointer type and array IntPtr = POINTER(c_int) arr = (c_int * 3)(1, 2, 3) Structures ---------- Define C structures by subclassing ``Structure`` and specifying ``_fields_``. Field order must match the C struct exactly. Use ``_pack_`` to control alignment if needed (e.g., ``_pack_ = 1`` for packed structs). .. code-block:: python import ctypes import math class Point(ctypes.Structure): _fields_ = [ ("x", ctypes.c_double), ("y", ctypes.c_double), ] # Create and use p = Point(3.0, 4.0) assert p.x == 3.0 assert p.y == 4.0 # Calculate distance distance = math.sqrt(p.x ** 2 + p.y ** 2) assert abs(distance - 5.0) < 1e-10 Nested Structures ----------------- .. code-block:: python import ctypes class Point(ctypes.Structure): _fields_ = [ ("x", ctypes.c_double), ("y", ctypes.c_double), ] class Rectangle(ctypes.Structure): _fields_ = [ ("top_left", Point), ("bottom_right", Point), ] rect = Rectangle(Point(0, 10), Point(10, 0)) assert rect.top_left.x == 0 assert rect.top_left.y == 10 assert rect.bottom_right.x == 10 Arrays ------ Create C arrays using the multiplication syntax ``type * size``. Arrays can be initialized with values and accessed like Python lists. They automatically convert to pointers when passed to C functions. .. code-block:: python from ctypes import c_int, c_double, c_char # Integer array IntArray5 = c_int * 5 arr = IntArray5(1, 2, 3, 4, 5) assert arr[0] == 1 assert arr[4] == 5 # Modify elements arr[0] = 100 assert list(arr) == [100, 2, 3, 4, 5] # Character array (C string buffer) buf = (c_char * 256)() buf.value = b"Hello" assert buf.value == b"Hello" # Double array for numerical work data = (c_double * 3)(1.1, 2.2, 3.3) assert abs(sum(data) - 6.6) < 1e-10 Array in Structure ------------------ Structures can contain fixed-size arrays as members. .. code-block:: python import ctypes class Data(ctypes.Structure): _fields_ = [ ("values", ctypes.c_int * 5), ("count", ctypes.c_int) ] d = Data() d.count = 5 for i in range(5): d.values[i] = i * 10 assert list(d.values) == [0, 10, 20, 30, 40] assert d.count == 5 Using cffi ---------- cffi is a cleaner alternative to ctypes with better PyPy compatibility. It uses C-like declarations instead of Python type objects. .. code-block:: python import platform from cffi import FFI ffi = FFI() ffi.cdef(""" int abs(int x); size_t strlen(const char *s); double sqrt(double x); """) if platform.system() == "Darwin": libc = ffi.dlopen("libc.dylib") libm = ffi.dlopen("libm.dylib") else: libc = ffi.dlopen("libc.so.6") libm = ffi.dlopen("libm.so.6") assert libc.abs(-42) == 42 assert libc.strlen(b"hello") == 5 assert abs(libm.sqrt(16.0) - 4.0) < 1e-10 Error Handling -------------- When calling C functions that set errno on failure, use ``use_errno=True`` when loading the library and ``get_errno()`` to retrieve the error code. This is essential for proper error handling with system calls. .. code-block:: python import os import platform from ctypes import CDLL, get_errno if platform.system() == "Darwin": libc = CDLL("libc.dylib", use_errno=True) else: libc = CDLL("libc.so.6", use_errno=True) # Try to open a non-existent file fd = libc.open(b"/nonexistent/path", 0) if fd == -1: errno = get_errno() errmsg = f"open failed: {os.strerror(errno)}" print(errmsg) # open failed: No such file or directory Callbacks --------- ctypes can create C-callable function pointers from Python functions using ``CFUNCTYPE``. This is useful for C libraries that accept callback functions, such as ``qsort()`` or event handlers. .. code-block:: python import platform from ctypes import CDLL, CFUNCTYPE, POINTER, c_int, c_void_p, cast, sizeof if platform.system() == "Darwin": libc = CDLL("libc.dylib") else: libc = CDLL("libc.so.6") # Define callback type: int (*compare)(const void*, const void*) CMPFUNC = CFUNCTYPE(c_int, c_void_p, c_void_p) def py_compare(a, b): """Compare function for qsort""" a_val = cast(a, POINTER(c_int)).contents.value b_val = cast(b, POINTER(c_int)).contents.value return a_val - b_val # Create C callback from Python function c_compare = CMPFUNC(py_compare) # Use with qsort arr = (c_int * 5)(5, 2, 8, 1, 9) libc.qsort(arr, len(arr), sizeof(c_int), c_compare) print(list(arr)) # [1, 2, 5, 8, 9] String Handling --------------- C strings require careful handling in ctypes. Use ``c_char_p`` for immutable strings and ``create_string_buffer()`` for mutable buffers. Always use bytes (``b"string"``) not str when passing to C functions. .. code-block:: python import platform from ctypes import CDLL, c_char_p, c_int, create_string_buffer if platform.system() == "Darwin": libc = CDLL("libc.dylib") else: libc = CDLL("libc.so.6") # Immutable string (c_char_p) libc.puts.argtypes = [c_char_p] libc.puts(b"Hello, World!") # Mutable buffer for functions that modify strings buf = create_string_buffer(100) libc.strcpy(buf, b"Hello") libc.strcat(buf, b", World!") print(buf.value) # b'Hello, World!' # Get string length libc.strlen.argtypes = [c_char_p] libc.strlen.restype = c_int print(libc.strlen(b"Hello")) # 5 ================================================ FILE: docs/notes/hpc/index.rst ================================================ .. meta:: :description lang=en: High-Performance Computing (HPC) cheat sheet covering Slurm job scheduling, cluster management, distributed computing, and GPU workloads :keywords: HPC, High-Performance Computing, Slurm, job scheduler, cluster computing, distributed training, GPU cluster, supercomputing, workload management HPC === High-Performance Computing (HPC) enables large-scale computational workloads across clusters of powerful machines. This section covers essential HPC tools and workflows, with a focus on Slurm—the most widely used job scheduler in HPC environments. Whether you're running distributed machine learning training, scientific simulations, or batch processing jobs, these guides will help you efficiently manage and schedule workloads on HPC clusters. .. toctree:: :maxdepth: 1 slurm ================================================ FILE: docs/notes/hpc/slurm.rst ================================================ .. meta:: :description lang=en: Slurm cheat sheet for HPC job scheduling, batch jobs, distributed training, MPI, Enroot containers, and cluster management commands :keywords: Slurm, HPC, job scheduler, sbatch, srun, salloc, distributed training, MPI, Enroot, Pyxis, cluster computing, workload manager, GPU cluster, machine learning, PyTorch distributed ===== Slurm ===== .. contents:: Table of Contents :backlinks: none Slurm (Simple Linux Utility for Resource Management) is an open-source job scheduling and workload management system widely used in high-performance computing (HPC) clusters. It is designed to efficiently allocate resources, manage queues, and dispatch jobs across large numbers of compute nodes. Slurm is the de facto standard for HPC job scheduling, powering many of the world's largest supercomputers and GPU clusters used for scientific computing, machine learning, and AI research. For machine learning engineers, Slurm provides a straightforward way to launch distributed training jobs for large language models (LLMs) and deep learning workloads sharded across multiple nodes. Unlike container orchestration systems like Kubernetes—which often require additional components such as Kubeflow for ML workload scheduling—Slurm provides a simpler, HPC-focused workflow. Users can submit and manage jobs directly with commands like ``srun``, ``sbatch``, and ``squeue``, without needing to configure complex orchestration layers. This cheat sheet covers essential Slurm commands and workflows for submitting jobs, managing resources, running distributed training with PyTorch, launching MPI applications, and using containers with Enroot and Pyxis. Slurm Info ---------- ``sinfo`` is a command used to display general information about a Slurm-managed cluster, such as the number of available nodes and partitions. It also allows users to check the status of nodes, including identifying nodes that are down or in an error state. .. code-block:: bash # show slurm general info sinfo # show partition info sinfo -s sinfo --summarize # show partition info PARTITION=dev sinfo -p ${PARTITION} # show nodes in idle state sinfo --state=idle # show nodes in specific format sinfo -o "%n %P %t %C" # node, partition, state, CPUs # show GPU info (if configured) sinfo -o "%n %G" Node Info --------- ``scontrol show node`` provides detailed information about specific nodes in the cluster, including CPU count, memory, GPU resources, and current state. This is useful for debugging node issues or verifying hardware configurations. .. code-block:: bash # show all nodes info scontrol show nodes # show specific node info scontrol show node compute-01 # show node in parseable format scontrol show node compute-01 --oneliner # list all hostnames scontrol show hostnames # expand node range to list scontrol show hostnames compute-[0-5] # show nodes in idle state sinfo --state=idle Submit Jobs ----------- Launching a job across multiple nodes in the foreground is straightforward with ``srun``. For example, running ``srun hostname`` will execute the ``hostname`` command on multiple allocated nodes and wait for all nodes to return results. With ``srun``, users can easily specify: * Number of nodes to run the job on (``--nodes``) * Partition or queue to submit the job to (``--partition``) * Time limit for the job (``--time``), ensuring compute resources are automatically released when the job finishes or reaches its time limit By default, ``srun`` runs interactively in the foreground, making it ideal for quick tests or debugging. For longer or batch jobs, users typically pair srun with job scripts submitted via ``sbatch``. .. code-block:: bash # Submit a job to a compute node srun -N1 hostname # Submit a job on specific nodes srun --nodelist=compute-[0-5] hostname # Submit a job to a specific partition PARTITION=dev srun -p ${PARTITION} --nodelist=compute-[0-5] hostname # Submit a job via srun on 2 nodes (using dd to simulate a high CPU consume job) srun -N2 dd if=/dev/zero of=/dev/null # Submit a job with time constrain. # - minute # - minute:second # - hours:minutes:seconds # - days-hours # - days-hours:minutes # - days-hours:minutes:seconds # # ex: The following job will be timeout after 1m30s srun -N2 --time=01:30 dd if=/dev/zero of=/dev/null # login to a node srun -N 1 --pty /bin/bash Alloc Nodes ----------- In some scenarios, users may need exclusive, interactive access to specific nodes for experiments or testing. For instance, a researcher running benchmarking tests might require all benchmarks to execute on the same fixed nodes to ensure consistent and reproducible results. The salloc command is used to request and allocate resources interactively. By using ``salloc``, users can reserve a specific number of nodes, ensuring that no other jobs are scheduled on them during the experiment. This isolation helps avoid resource contention that could affect benchmarking or performance measurements. For example, the following command allocates 2 nodes for an interactive session: .. code-block:: bash # Allocte 2 nodes and submit a job on those allocated nodes salloc -N 2 srun hostname exit # release allocated nodes # Allocate nodes on a specific partition PARTITION=dev salloc -N 2 -p ${PARTITION} .. image:: images/salloc.svg .. note:: ``salloc`` is particularly useful for: * Interactive debugging * Benchmarking and performance testing * Running exploratory workloads without writing a full job script Cancel Jobs ----------- Users may occasionally need to cancel their jobs for various reasons. For example, a cluster administrator may announce maintenance (such as upgrading system libraries), requiring users to terminate running jobs. In other cases, a job might hang or consume compute resources unnecessarily, making cancellation necessary. Slurm provides the ``scancel`` command to terminate jobs cleanly. Example usage: .. code-block:: bash # cancel a job scancel "${jobid}" # cancel a job and disable warnings scancel -q "${jobid}" # cancel all jobs which are belong to an account scancel --account="${account}" # cancel all jobs which are belong to a partition scancel --partition="${partition}" # cancel all pending jobs scancel --state="PENDING" # cancel all running jobs scancel --state="RUNNING" # cancel all jobs squeue -l | awk '{ print $ 1}' | grep '[[:digit:]].*' | xargs scancel # cancel all jobs (using state option) for s in "RUNNING" "PENDING" "SUSPAND"; do scancel --state="$s"; done Submit Batch Jobs ----------------- ``sbatch`` is a Slurm command used to submit batch jobs for execution on a cluster. Unlike ``srun``, which typically runs jobs interactively in the foreground, ``sbatch`` is designed for running long, non-interactive workloads in the background. This allows users to submit jobs without maintaining an active SSH session to the cluster's head node, making it ideal for large-scale or time-consuming tasks. A typical workflow involves writing a Slurm job script containing job specifications (such as the number of nodes, time limits, and partitions) and one or more srun commands to execute programs. Submitting this script with sbatch queues the job, and Slurm automatically schedules it based on available resources. Example sbatch script: .. code-block:: bash #!/bin/bash #SBATCH --nodelist=compute-[0-1] #SBATCH --output=logs/%x_%j.out #SBATCH --error=logs/%x_%j.out #SBATCH --ntasks-per-node=8 master_addr="$(scontrol show hostnames | sort | head -n 1)" srun hostname srun torchrun \ --nproc-per-node="$SLURM_NPROCS" \ --nnodes="$SLURM_NNODES" --master-addr="${master_addr}" \ --master-port=29500 \ ${PWD}/train.py # sbatch job.sh Submit mpirun ------------- In some HPC environments, users may not be able to load the MPI module directly on the head (login) node due to security restrictions, minimal software installations, or site policies that restrict heavy workloads on login nodes. In such cases, the workflow is to use Slurm to allocate compute nodes and launch ``mpirun`` from within one of those nodes. From there, mpirun orchestrates the execution of the MPI program across all allocated nodes. .. image:: images/mpirun.svg .. code-block:: bash #!/bin/bash # Usage: # # rank_per_node=8 # salloc -N 4 # ./mpirun.sh ${rank_per_node} ${binary} launch() { local rank_per_node="${1}" local args=("${@:2}") local arr local hosts local cmd mapfile -t arr < <(scontrol show hostnames | sort) OLDIFS="${IFS}" IFS="," hosts="${arr[*]}" IFS="${OLDIFS}" cmd="$(cat <`_ provides a lightweight alternative to traditional container runtimes. It allows users to run isolated filesystem in an HPC setting with minimal overhead, similar to ``chroot``, while still granting direct access to system hardware (e.g., GPUs, interconnects). This makes it ideal for ML and HPC workflows that require fine-tuned performance. Building on Enroot, `Pyxis `_ is a Slurm plugin that enables launching jobs inside Enroot containers without writing additional wrapper scripts. Users can specify Enroot squash file and runtime options directly in their sbatch or srun commands, integrating container workflows seamlessly into Slurm job submission. The following snippet shows serveral to launch a job through Enroot and Pyxis. .. code-block:: bash # build an enroot sqsh file $ enroot import -o "${output_sqsh}" "dockerd://${image}" # submit a job with enroot srun --container-image "${output_sqsh}" \ --container-mounts "/fsx:/fsx,/nfs:/nfs" \ --ntasks-per-node=8 \ ${cmd} # submit a mpi job with enroot srun --container-image "${output_sqsh}" \ --container-mounts "/fsx:/fsx,/nfs:/nfs" \ --ntasks-per-node=8 \ --mpi=pmix \ ${cmd} Job Status ---------- To monitor the status of jobs in a Slurm-managed cluster, users can use the ``squeue`` command. This tool shows essential details about submitted jobs, such as job IDs, job names, partitions, allocated nodes, and job states. Common job states include: * RUNNING – The job is actively running on allocated resources. * PENDING – The job is waiting in the queue for resources to become available. * FAILED – The job has failed due to errors or unmet conditions. If a job is stuck, fails, or behaves unexpectedly, you can terminate it with the ``scancel`` command and resubmit after fixing the issue. .. code-block:: bash # check all Slurm jobs status squeue # check user's job status squeue --user=${USER} Reservation ----------- From an administrator’s perspective, it may be necessary to reserve specific nodes to prevent Slurm from scheduling jobs on them. For example, nodes experiencing hardware or software issues—such as network failures or disk errors—should be reserved to avoid job failures. Reserving nodes allows administrators to troubleshoot, repair, or perform maintenance without interfering with active workloads. The following snippet demonstrates how to create reservations through ``scontrol`` for nodes and check their reservation status. .. code-block:: bash # reserve nodes for a user to test # - minute # - minute:second # - hours:minutes:seconds # - days-hours # - days-hours:minutes # - days-hours:minutes:seconds # # ex: reserve all nodes 120m for maintenance scontrol create reservation ReservationName=maintenance \ starttime=now duration=120 user=root flags=maint,ignore_jobs nodes=ALL # must specify reservation; otherwise, the job will not run srun --reservation=maintain ping 8.8.8.8 2>&1 > /dev/null # show reservations scontrol show res # delete a reservation scontrol delete ReservationName=maintain # drain nodes for maintenance. ex: nodes=compute-[01-02],compute-08 scontrol update NodeName=compute-[01-02],compute-08 State=DOWN Reason=”maintenance” # resume nodes scontrol update NodeName=compute-[01-02],compute-08 State=Resume Accounting ---------- Slurm includes a powerful accounting and resource management system that allows administrators to control how computing resources are allocated and ensure fair usage across all users. Through this system, administrators can configure fairshare scheduling, job priority policies, and resource limits to prevent individual users or groups from monopolizing cluster resources for extended periods. With ``fairshare``, Slurm dynamically adjusts job priorities based on historical resource usage, ensuring that users who have consumed fewer resources get higher priority in the job queue, while heavy users may experience lower priority until usage balances out. This helps maintain equitable access in multi-user HPC environments. Administrators manage these policies through Slurm’s database-backed accounting system (``slurmdbd``) and commands like: .. code-block:: bash # create a cluster (the clustername should be identical to ClusterName in slurm.conf) sacctmgr add cluster clustername # create an account sacctmgr -i add account worker description="worker account" Organization="your.org" # create an user and add to an account sacctmgr create user name=worker DefaultAccount=default # create an user and add to additional accounts sacctmgr -i create user "worker" account="worker" adminlevel="None" # modify user fairshare configuration sacctmgr modify user where name="worker" account="worker" set fairshare=0 # remove an user from an account sacctmgr remove user "worker" where account="worker" # show all users sacctmgr show account # show all users with associations sacctmgr show account -s Job History ----------- ``sacct`` displays accounting data for completed and running jobs. This is essential for analyzing job performance, debugging failed jobs, and tracking resource usage over time. Unlike ``squeue`` which only shows active jobs, ``sacct`` can retrieve historical job information from the Slurm accounting database. .. code-block:: bash # show recent jobs for current user sacct # show jobs from specific date range sacct --starttime=2024-01-01 --endtime=2024-01-31 # show specific job details sacct -j ${jobid} --format=JobID,JobName,Partition,State,ExitCode,Elapsed # show detailed resource usage sacct -j ${jobid} --format=JobID,MaxRSS,MaxVMSize,AveRSS,AveCPU # show all fields sacct -j ${jobid} --format=ALL # show jobs with specific state sacct --state=FAILED --starttime=2024-01-01 # common format for debugging sacct -j ${jobid} --format=JobID,JobName,State,ExitCode,DerivedExitCode,Comment Environment Variables --------------------- Slurm sets various environment variables when a job runs, providing information about the job's allocation and configuration. These variables are essential for writing portable job scripts that adapt to different resource allocations. .. code-block:: bash # Common Slurm environment variables echo $SLURM_JOB_ID # Job ID echo $SLURM_JOB_NAME # Job name echo $SLURM_JOB_NODELIST # List of allocated nodes echo $SLURM_JOB_NUM_NODES # Number of nodes allocated echo $SLURM_NNODES # Same as SLURM_JOB_NUM_NODES echo $SLURM_NTASKS # Total number of tasks echo $SLURM_NTASKS_PER_NODE # Tasks per node echo $SLURM_CPUS_PER_TASK # CPUs per task echo $SLURM_PROCID # MPI rank (global) echo $SLURM_LOCALID # Local task ID on node echo $SLURM_NODEID # Node ID in allocation echo $SLURM_SUBMIT_DIR # Directory where job was submitted echo $SLURM_GPUS # Number of GPUs (if allocated) echo $SLURM_GPUS_PER_NODE # GPUs per node GPU Jobs -------- For machine learning and deep learning workloads, requesting GPU resources is essential. Slurm supports GPU scheduling through the Generic Resource (GRES) plugin. Users can request specific numbers of GPUs, GPU types, or GPUs per task. .. code-block:: bash # request 1 GPU srun --gres=gpu:1 nvidia-smi # request 4 GPUs srun --gres=gpu:4 python train.py # request specific GPU type (if configured) srun --gres=gpu:a100:2 python train.py # request GPUs per task srun --ntasks=4 --gres=gpu:4 --gpus-per-task=1 python train.py # sbatch example with GPUs #!/bin/bash #SBATCH --nodes=2 #SBATCH --ntasks-per-node=8 #SBATCH --gres=gpu:8 #SBATCH --cpus-per-task=12 srun python train.py PyTorch Distributed Training ---------------------------- Launching distributed PyTorch training jobs on Slurm requires coordinating multiple processes across nodes. The ``torchrun`` launcher simplifies this by handling process spawning and environment setup. Here's a complete example for multi-node distributed training. .. code-block:: bash #!/bin/bash #SBATCH --job-name=distributed-train #SBATCH --nodes=4 #SBATCH --ntasks-per-node=1 #SBATCH --gpus-per-node=8 #SBATCH --cpus-per-task=96 #SBATCH --output=logs/%x_%j.out #SBATCH --error=logs/%x_%j.err # Get master node address MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) MASTER_PORT=29500 # Set environment variables for distributed training export NCCL_DEBUG=INFO export NCCL_IB_DISABLE=0 export NCCL_NET_GDR_LEVEL=2 srun torchrun \ --nnodes=$SLURM_NNODES \ --nproc_per_node=8 \ --rdzv_id=$SLURM_JOB_ID \ --rdzv_backend=c10d \ --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ train.py --batch-size=32 --epochs=100 For containerized environments using Enroot and Pyxis, add ``--container-image`` to run training inside a custom container with specific CUDA, PyTorch, or NCCL versions. Use ``--container-env`` to pass environment variables into the container: .. code-block:: bash #!/bin/bash #SBATCH --job-name=distributed-train #SBATCH --nodes=4 #SBATCH --ntasks-per-node=1 #SBATCH --gpus-per-node=8 #SBATCH --cpus-per-task=96 #SBATCH --output=logs/%x_%j.out MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) MASTER_PORT=29500 srun --container-image=/path/to/pytorch-24.01.sqsh \ --container-mounts="/data:/data,${PWD}:/workspace" \ --container-workdir=/workspace \ --container-env="NCCL_DEBUG=INFO,NCCL_IB_DISABLE=0,NCCL_NET_GDR_LEVEL=2" \ torchrun \ --nnodes=$SLURM_NNODES \ --nproc_per_node=8 \ --rdzv_id=$SLURM_JOB_ID \ --rdzv_backend=c10d \ --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ train.py --batch-size=32 --epochs=100 Array Jobs ---------- Array jobs allow submitting multiple similar jobs with a single ``sbatch`` command. Each job in the array runs independently with a unique ``SLURM_ARRAY_TASK_ID``, making it ideal for parameter sweeps, hyperparameter tuning, or processing multiple datasets. .. code-block:: bash #!/bin/bash #SBATCH --job-name=array-job #SBATCH --array=0-9 #SBATCH --output=logs/array_%A_%a.out # SLURM_ARRAY_JOB_ID - Job array's master job ID # SLURM_ARRAY_TASK_ID - Job array index (0-9 in this case) echo "Array task ID: $SLURM_ARRAY_TASK_ID" python train.py --seed=$SLURM_ARRAY_TASK_ID # Submit array job # sbatch array_job.sh # Submit with step size (0, 2, 4, 6, 8) #SBATCH --array=0-9:2 # Submit with max concurrent tasks #SBATCH --array=0-99%10 # max 10 running at once # Cancel specific array tasks scancel ${jobid}_5 # cancel task 5 scancel ${jobid}_[1-3] # cancel tasks 1-3 Job Dependencies ---------------- Job dependencies allow you to control the execution order of jobs. A job can wait for another job to complete, succeed, or fail before starting. This is useful for creating pipelines where preprocessing must finish before training. .. code-block:: bash # Submit first job JOB1=$(sbatch --parsable preprocess.sh) # Submit second job after first completes successfully JOB2=$(sbatch --parsable --dependency=afterok:$JOB1 train.sh) # Submit third job after second completes (regardless of status) sbatch --dependency=afterany:$JOB2 postprocess.sh # Dependency types: # after:jobid - start after job begins # afterok:jobid - start after job completes successfully # afternotok:jobid - start after job fails # afterany:jobid - start after job completes (any status) # singleton - only one job with same name runs at a time # Multiple dependencies sbatch --dependency=afterok:$JOB1:$JOB2 final.sh Resource Limits --------------- Setting appropriate resource limits helps ensure fair cluster usage and prevents jobs from consuming excessive resources. Slurm allows specifying memory, CPU, and time limits per job. .. code-block:: bash # Memory per node srun --mem=64G python train.py # Memory per CPU srun --mem-per-cpu=4G python train.py # CPU limit srun --cpus-per-task=8 python train.py # Time limit (job killed if exceeded) srun --time=24:00:00 python train.py # Exclusive node access (no sharing) srun --exclusive python train.py # sbatch example with limits #!/bin/bash #SBATCH --mem=128G #SBATCH --cpus-per-task=32 #SBATCH --time=48:00:00 #SBATCH --exclusive Debugging Failed Jobs --------------------- When jobs fail, Slurm provides several tools to diagnose the issue. Common problems include out-of-memory errors, time limits exceeded, and node failures. .. code-block:: bash # Check job exit code and state sacct -j ${jobid} --format=JobID,State,ExitCode,DerivedExitCode # Common exit codes: # 0 - Success # 1 - General error # 137 - OOM killed (128 + 9 SIGKILL) # 143 - Time limit (128 + 15 SIGTERM) # Check job's stderr/stdout cat slurm-${jobid}.out # Check why job is pending squeue -j ${jobid} --format="%i %r" # Common pending reasons: # Resources - Waiting for resources # Priority - Lower priority than other jobs # Dependency - Waiting for dependent job # QOSMaxJobsPerUserLimit - User job limit reached # Show detailed job info scontrol show job ${jobid} # Check node health where job ran scontrol show node ${nodename} ================================================ FILE: docs/notes/llm/index.rst ================================================ .. meta:: :description lang=en: Large Language Models (LLM) cheat sheet — PyTorch, distributed training, vLLM/SGLang serving, and benchmarking for GPU clusters. :keywords: LLM, Large Language Models, PyTorch, vLLM, SGLang, distributed training, model inference, model serving, GPU optimization, CUDA, transformer models, LLM tutorial, LLM cheat sheet LLM === Large Language Models (LLM) training, inference, and optimization. Covers PyTorch for model development, distributed training across GPUs, vLLM and SGLang for high-performance LLM inference and serving, and benchmarking tools for measuring serving performance. .. toctree:: :maxdepth: 1 pytorch megatron llm-serving llm-bench ================================================ FILE: docs/notes/llm/llm-bench.rst ================================================ .. meta:: :description lang=en: LLM benchmark suite — measure throughput, TTFT, ITL, latency for vLLM, SGLang, and TensorRT-LLM serving performance. :keywords: LLM benchmark, vLLM benchmark, SGLang benchmark, TensorRT-LLM benchmark, serving benchmark, throughput, latency, TTFT, time to first token, ITL, inter-token latency, prefill, decode, concurrency, ShareGPT, GPU benchmark, tokens per second ============= LLM Benchmark ============= .. contents:: Table of Contents :backlinks: none Benchmark suites for measuring LLM serving performance with vLLM, SGLang, and TensorRT-LLM. All use similar methodology — same test categories, workloads, and metrics — for easy comparison between the three inference engines. - **vLLM:** ``vllm bench serve`` via `bench.sh `__ - **SGLang:** ``python -m sglang.bench_serving`` via `bench.sh `__ - **TensorRT-LLM:** ``python -m tensorrt_llm.serve.scripts.benchmark_serving`` via `bench.sh `__ The scripts handle Docker image loading and container management automatically. If the CLI is not available on the host, they load the Docker image and re-execute inside the container. When running under a SLURM allocation, they use ``srun`` to dispatch to the compute node. Quick Start ----------- Launch a server in one terminal, then run benchmarks from another. The benchmark script auto-detects the model from the server. **vLLM:** .. code-block:: bash # Terminal 1: start server vllm serve Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0 --port 8000 # Terminal 2: run benchmarks bash bench.sh -H localhost -i vllm-serve:latest bash bench.sh -H localhost --type throughput,prefill **SGLang:** .. code-block:: bash # Terminal 1: start server python -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0 --port 30000 # Terminal 2: run benchmarks bash bench.sh -H localhost -i sglang-serve:latest bash bench.sh -H localhost --type throughput,prefill **TensorRT-LLM:** .. code-block:: bash # Terminal 1: start server trtllm-serve /path/to/Qwen2.5-7B-Instruct --host 0.0.0.0 --port 8000 # Terminal 2: run benchmarks (requires -m for tokenizer loading) bash bench.sh -H localhost -m /path/to/Qwen2.5-7B-Instruct -i tensorrt-llm-serve:latest bash bench.sh -H localhost -m /path/to/Qwen2.5-7B-Instruct --type throughput,prefill Multi-Node with Slurm --------------------- For benchmarking larger models (e.g., DeepSeek-V3, Llama-3.1-405B) that cannot fit on a single node, refer to `Distributed Serving on SLURM `__ for how to deploy multi-node serving with different parallelism strategies. Once the server is running, benchmark using ``bench.sh`` as shown in the Quick Start above. Throughput ---------- Measures peak output tokens/sec by saturating the server with requests. Uses ``request-rate=inf`` to send all prompts immediately, forcing the scheduler to batch aggressively. This reveals the server's maximum throughput under full load. ``512in/256out`` is a moderate workload that exercises both the prefill phase (processing the input) and the decode phase (generating tokens). .. code-block:: bash # vLLM vllm bench serve --dataset-name random --random-input-len 512 --random-output-len 256 \ --num-prompts 100 --request-rate inf # SGLang python -m sglang.bench_serving --dataset-name random --random-input 512 --random-output 256 \ --num-prompts 100 --request-rate inf # TensorRT-LLM python -m tensorrt_llm.serve.scripts.benchmark_serving \ --dataset-name random --random-input-len 512 --random-output-len 256 \ --num-prompts 100 --max-concurrency 100 Prefill (TTFT) -------------- Measures Time to First Token — how fast the server processes the input prompt before generating the first output token. ``output-len=1`` isolates prefill from decode since nearly all compute goes to processing the input. Sweeping input length (128→16K) reveals how TTFT scales with context size. Prefill compute is O(n) per layer, so TTFT should grow roughly linearly. ``rate=4`` keeps the server lightly loaded so TTFT reflects compute time, not queueing delay. .. code-block:: bash # vLLM - sweeps input length: 128, 512, 2048, 4096, 16384 vllm bench serve --dataset-name random --random-input-len $LEN --random-output-len 1 \ --num-prompts 100 --request-rate 4 # SGLang python -m sglang.bench_serving --dataset-name random --random-input $LEN --random-output 1 \ --num-prompts 100 --request-rate 4 # TensorRT-LLM python -m tensorrt_llm.serve.scripts.benchmark_serving \ --dataset-name random --random-input-len $LEN --random-output-len 1 \ --num-prompts 100 --max-concurrency 4 Decode (ITL) ------------ Measures Inter-Token Latency — the time between consecutive output tokens during autoregressive generation. ``input-len=128`` keeps prefill minimal so the benchmark focuses on the decode phase. Sweeping output length (128→1024) reveals how ITL changes as the KV cache grows. Longer sequences increase memory pressure and may trigger PagedAttention block allocation or preemption. ``rate=4`` avoids batching interference so ITL reflects single-request decode speed. .. code-block:: bash # vLLM - sweeps output length: 128, 256, 512, 1024 vllm bench serve --dataset-name random --random-input-len 128 --random-output-len $LEN \ --num-prompts 100 --request-rate 4 # SGLang python -m sglang.bench_serving --dataset-name random --random-input 128 --random-output $LEN \ --num-prompts 100 --request-rate 4 # TensorRT-LLM python -m tensorrt_llm.serve.scripts.benchmark_serving \ --dataset-name random --random-input-len 128 --random-output-len $LEN \ --num-prompts 100 --max-concurrency 4 Latency (E2E) ------------- Measures end-to-end request latency under minimal load — the "single user" experience. ``rate=1`` ensures requests are mostly processed alone with no batching, giving the baseline best-case latency (similar to ChatGPT-style usage where one user waits for a complete response). Three size classes (short/medium/long) show how total latency scales with request size. E2E latency = TTFT + (output_tokens × ITL). .. code-block:: bash # vLLM - tests short (128/128), medium (512/256), long (4096/512) vllm bench serve --dataset-name random --random-input-len $IN --random-output-len $OUT \ --num-prompts 100 --request-rate 1 # SGLang python -m sglang.bench_serving --dataset-name random --random-input $IN --random-output $OUT \ --num-prompts 100 --request-rate 1 # TensorRT-LLM python -m tensorrt_llm.serve.scripts.benchmark_serving \ --dataset-name random --random-input-len $IN --random-output-len $OUT \ --num-prompts 100 --max-concurrency 1 Concurrency ----------- Finds the server's saturation point by sweeping the number of concurrent requests. ``request-rate=inf`` with ``max-concurrency=N`` caps how many requests run in parallel, decoupling arrival rate from concurrency. At low concurrency (1–4), latency is good but throughput is low (GPU underutilized). At high concurrency (64–256), throughput plateaus and latency degrades (queueing). The "knee" where throughput stops improving is the optimal operating point — it tells you how many concurrent users the server can handle before quality degrades. .. code-block:: bash # vLLM - sweeps concurrency: 1, 4, 16, 64, 256 vllm bench serve --dataset-name random --random-input-len 512 --random-output-len 256 \ --num-prompts 100 --request-rate inf --max-concurrency $C # SGLang python -m sglang.bench_serving --dataset-name random --random-input 512 --random-output 256 \ --num-prompts 100 --request-rate inf --max-concurrency $C # TensorRT-LLM python -m tensorrt_llm.serve.scripts.benchmark_serving \ --dataset-name random --random-input-len 512 --random-output-len 256 \ --num-prompts 100 --max-concurrency $C ShareGPT -------- Realistic conversational workload from real user conversations with variable input/output lengths. Unlike random datasets with fixed lengths, ShareGPT captures the natural distribution of short and long prompts from actual ChatGPT conversations, making it the best proxy for production chat traffic. ShareGPT is the standard dataset used by vLLM CI, GPUStack perf lab, and most published benchmarks. The dataset is auto-downloaded from HuggingFace if not present locally. .. code-block:: bash # vLLM - throughput mode vllm bench serve --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \ --num-prompts 100 --request-rate inf # vLLM - realistic load vllm bench serve --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \ --num-prompts 100 --request-rate 4 # SGLang - throughput mode python -m sglang.bench_serving --dataset-name sharegpt \ --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \ --num-prompts 100 --request-rate inf # TensorRT-LLM - throughput mode python -m tensorrt_llm.serve.scripts.benchmark_serving \ --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \ --num-prompts 100 --max-concurrency 100 Sonnet (Prefix Caching) ----------------------- The sonnet dataset uses Shakespeare's sonnets with a shared prefix across all prompts. This tests prefix caching — if enabled, the shared prefix KV cache is computed once and reused across requests, reducing TTFT. .. code-block:: bash # Download sonnet dataset wget -q https://raw.githubusercontent.com/vllm-project/vllm/main/benchmarks/sonnet.txt # vLLM - prefill-heavy (short output isolates prefill) vllm bench serve --dataset-name sonnet --dataset-path sonnet.txt \ --sonnet-input-len 550 --sonnet-output-len 150 --sonnet-prefix-len 200 \ --num-prompts 100 --request-rate inf # vLLM - realistic load vllm bench serve --dataset-name sonnet --dataset-path sonnet.txt \ --sonnet-input-len 550 --sonnet-output-len 150 --sonnet-prefix-len 200 \ --num-prompts 100 --request-rate 4 Both ShareGPT and sonnet are used by the vLLM team's `v0.6.0 performance blog `__ to benchmark serving engines. To learn more about the methodology, see the `reproduction steps `__ and SGLang's `counter-benchmark `__, which uses ``sglang.bench_serving`` to compare both engines: .. code-block:: bash # Launch servers # vLLM with multi-step scheduling python -m vllm.entrypoints.openai.api_server \ --model meta-llama/Llama-3.1-8B-Instruct \ --disable-log-requests --num-scheduler-steps 10 --max_model_len 4096 # SGLang python -m sglang.launch_server \ --model-path meta-llama/Llama-3.1-8B-Instruct \ --enable-torch-compile --disable-radix-cache # Online benchmark (realistic load) python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt \ --num-prompts 1200 --request-rate 4 python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt \ --num-prompts 1200 --request-rate 4 # Offline benchmark (max throughput) python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt \ --num-prompts 5000 python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt \ --num-prompts 5000 Key Metrics ----------- - **TTFT** (Time to First Token): Time from request arrival to first generated token. Dominated by prefill compute. Lower is better for interactive use. - **ITL** (Inter-Token Latency): Time between consecutive tokens. Reflects decode speed and consistency. - **TPOT** (Time Per Output Token): Average time per generated token. Similar to ITL but averaged across all tokens. - **E2E Latency**: Total time from request to completion. E2E ≈ TTFT + (tokens × ITL). - **Throughput**: Output tokens/sec across all requests. Higher is better for batch workloads. CLI Differences --------------- .. list-table:: :widths: 20 27 27 26 :header-rows: 1 * - Parameter - vLLM - SGLang - TensorRT-LLM * - Input length - ``--random-input-len`` - ``--random-input`` - ``--random-input-len`` * - Output length - ``--random-output-len`` - ``--random-output`` - ``--random-output-len`` * - Max rate - ``--request-rate inf`` - ``--request-rate inf`` - ``--max-concurrency`` * - Random dataset - (works by default) - (works by default) - ``--random-ids --random-prefix-len 0`` * - Model flag - auto-detected - auto-detected - ``-m`` required (tokenizer) * - Results - ``--result-dir ./results`` - ``--output-file ./results/out.json`` - ``--result-dir ./results`` Profiling --------- Benchmark runs can be combined with profiling to correlate performance metrics with GPU-level traces. Two profiling approaches are available for vLLM: **PyTorch profiler** — vLLM's built-in profiler triggered via REST endpoints. Start the server with ``--profile`` (or ``--profiler-config``), then pass ``--profile`` to the benchmark client to call ``/start_profile`` and ``/stop_profile`` around the workload. .. code-block:: bash # Server — start with profiling enabled bash run.sbatch --profile \ Qwen/Qwen3-30B-A3B-FP8 \ --tensor-parallel-size 8 # Client — benchmark with profiling bash bench.sh -H --type throughput --profile View traces at https://ui.perfetto.dev/ (supports ``.gz`` files directly). **Nsight Systems** — wraps ``vllm serve`` with ``nsys profile`` for CUDA kernel, NVTX, and memory tracing. Combine with ``--profiler-config '{"profiler": "cuda"}'`` to also capture vLLM's internal CUDA profiler markers. .. code-block:: bash # Server — enable nsys + CUDA profiler (terminal 0) bash run.sbatch --nsys \ Qwen/Qwen3-30B-A3B-FP8 \ --tensor-parallel-size 8 \ --enable-expert-parallel \ --profiler-config '{"profiler": "cuda"}' # Client — benchmark with profiling (terminal 1) bash bench.sh -H --type throughput --profile # Stop server with Ctrl+C (terminal 0) # Nsys finalizes profiles (~30s) # Profile files: nsys-vllm/profile-node*.nsys-rep Open ``.nsys-rep`` files with `Nsight Systems `_. See the `vLLM Serving Guide `_ for full ``run.sbatch`` flag reference and the `vLLM Profiling Guide `_ for more details. Offline Benchmarking -------------------- vLLM also supports offline benchmarking to measure raw inference performance without API server overhead. This is useful for: - Measuring peak throughput without network/serialization overhead - Multi-node distributed inference testing - Profiling with Nsight Systems or PyTorch profiler - Testing with custom datasets (ShareGPT, random prompts) For complete offline benchmarking documentation, see the `vLLM Offline Benchmark Guide `_. ================================================ FILE: docs/notes/llm/llm-serving.rst ================================================ .. meta:: :description lang=en: LLM serving guide — vLLM, SGLang, and TensorRT-LLM for single-node, multi-node SLURM, tensor/pipeline/data/expert parallelism on GPU clusters. :keywords: vLLM, SGLang, TensorRT-LLM, LLM serving, LLM inference, model serving, distributed inference, tensor parallelism, pipeline parallelism, data parallelism, expert parallelism, MoE serving, GPU inference, OpenAI compatible API, multi-node GPU, SLURM, HPC, EFA, NCCL, PagedAttention, RadixAttention, continuous batching, Docker, Qwen, Llama, DeepSeek =========== LLM Serving =========== .. contents:: Table of Contents :backlinks: none This guide covers LLM inference serving with three high-performance engines: - **vLLM** — High-throughput inference engine with PagedAttention for efficient KV cache memory management, continuous batching for maximizing GPU utilization, and optimized CUDA kernels. Provides an OpenAI-compatible API as a drop-in replacement for OpenAI services. - **SGLang** — Fast inference engine with RadixAttention for efficient prefix caching across requests with shared prompts. Optimized for multi-turn conversations and workloads with common system prompts. - **TensorRT-LLM** — NVIDIA's inference engine with PyTorch backend, optimized CUDA kernels, and FP8/INT4 quantization. Uses ``trtllm-serve`` for OpenAI-compatible serving. Supports TP, PP, EP, and attention DP via YAML config. All support distributed inference across multiple GPUs and nodes with tensor parallelism (TP), pipeline parallelism (PP), data parallelism (DP), and expert parallelism (EP) for Mixture-of-Experts (MoE) models. This guide covers everything from basic single-GPU deployment to advanced multi-node distributed serving on SLURM clusters. Scripts and examples: - vLLM: `src/llm/vllm/ `_ - SGLang: `src/llm/sglang/ `_ - TensorRT-LLM: `src/llm/tensorrt-llm/ `_ Quick Start ----------- Get started in minutes. Install the package, launch a server, and query it with standard HTTP requests. Both engines expose OpenAI-compatible ``/v1/chat/completions`` and ``/v1/completions`` endpoints. **vLLM** (port 8000): .. code-block:: bash pip install vllm vllm serve Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0 --port 8000 curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{"model": "Qwen/Qwen2.5-7B-Instruct", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50}' **SGLang** (port 30000): .. code-block:: bash pip install "sglang[all]" python -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0 --port 30000 curl -X POST http://localhost:30000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{"model": "Qwen/Qwen2.5-7B-Instruct", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50}' **TensorRT-LLM** (port 8000): .. code-block:: bash pip install tensorrt-llm trtllm-serve Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0 --port 8000 curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{"model": "Qwen/Qwen2.5-7B-Instruct", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50}' Tensor Parallel (TP) -------------------- Tensor parallelism splits individual model layers across multiple GPUs, with each GPU holding a portion of the weight matrices. All GPUs participate in every forward pass, communicating via all-reduce operations. Essential for models that don't fit in a single GPU's memory. **Use when:** Model doesn't fit on a single GPU, or you need to reduce per-GPU memory. .. code-block:: bash # vLLM vllm serve Qwen/Qwen2.5-14B-Instruct --tensor-parallel-size 8 # SGLang python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --tp 8 # TensorRT-LLM trtllm-serve Qwen/Qwen2.5-14B-Instruct --tp_size 8 Pipeline Parallel (PP) ---------------------- Pipeline parallelism divides the model into sequential stages, with each stage assigned to different GPUs. Unlike tensor parallelism where all GPUs work on every layer, pipeline parallelism processes different parts of the model on different GPUs. This reduces communication overhead since GPUs only pass activations between stages. **Use when:** You want to reduce inter-GPU communication or scale across nodes with slower interconnects. .. code-block:: bash # vLLM: PP=2 splits model into 2 stages, TP=4 within each vllm serve Qwen/Qwen2.5-14B-Instruct --tensor-parallel-size 4 --pipeline-parallel-size 2 # SGLang python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --tp 4 --pp 2 # TensorRT-LLM trtllm-serve Qwen/Qwen2.5-14B-Instruct --tp_size 4 --pp_size 2 Data Parallel (DP) ------------------ Data parallelism creates multiple independent replicas of the model, each processing different requests simultaneously. This is the most effective way to increase throughput when you have sufficient GPU memory for multiple model copies. Each replica can use tensor parallelism internally. **Use when:** You need higher request throughput and have enough GPUs to replicate the model. .. code-block:: bash # vLLM: 2 replicas, each using 8 GPUs vllm serve Qwen/Qwen2.5-14B-Instruct --tensor-parallel-size 8 --data-parallel-size 2 # SGLang: multi-node DP requires --enable-dp-attention python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --tp 8 --dp 2 --enable-dp-attention # TensorRT-LLM: DP via multi-node with TP=8 per node trtllm-serve Qwen/Qwen2.5-14B-Instruct --tp_size 8 Expert Parallel (EP) -------------------- Expert parallelism is specifically designed for Mixture-of-Experts (MoE) models, where the model contains multiple expert sub-networks and a gating mechanism routes tokens to different experts. EP shards the experts across GPUs, allowing each GPU to hold a subset of experts. **vLLM:** EP is computed automatically (``EP = DP × TP``). **SGLang:** EP is a subdivision of TP. With ``--tp 8 --ep 2``, the 8 TP GPUs split into 2 expert groups of 4 GPUs each. **TensorRT-LLM:** EP subdivides TP (same as SGLang). With ``--tp_size 8 --ep_size 2``, experts are sharded across 2 groups of 4 GPUs. .. code-block:: bash # vLLM: EP auto-computed vllm serve Qwen/Qwen3-30B-A3B-FP8 --tensor-parallel-size 8 --enable-expert-parallel # SGLang: EP subdivides TP python -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --tp 8 --ep 2 # TensorRT-LLM: EP subdivides TP trtllm-serve Qwen/Qwen1.5-MoE-A2.7B --tp_size 8 --ep_size 2 Parallelism Formulas -------------------- Both engines use the same formula for computing total GPU requirements: .. code-block:: text Total GPUs = TP × PP × DP Expert parallelism (EP) is handled differently: - **vLLM**: EP is auto-computed (``EP = TP × DP``) when ``--enable-expert-parallel`` is set. All GPUs in the world participate in expert parallelism automatically. - **SGLang**: EP explicitly subdivides TP. For example, ``--tp 8 --ep 2`` splits the 8 TP GPUs into 2 expert groups of 4 GPUs each. Each group handles different experts while all 8 GPUs still perform tensor parallelism for non-expert layers. - **TensorRT-LLM**: EP subdivides TP (same as SGLang). ``--tp_size 8 --ep_size 2`` splits experts across 2 groups. Constraint: ``moe_tp × ep = tp_size``. .. _distributed-serving-on-slurm: Distributed Serving on SLURM ---------------------------- Some large models (e.g., DeepSeek-V3, Llama-3.1-405B) may not fit into a single node. All three engines support serving across multiple nodes with different parallelism strategies (TP, PP, EP, DP). Multi-node deployment can be tricky at the beginning — the ``run.sbatch`` examples below show how to use ``salloc`` with each engine to get started quickly on Slurm. The scripts handle Docker image distribution to all nodes, container launch with EFA/GPU passthrough, worker coordination, and health checking. The server runs until you stop it with ``Ctrl+C`` or ``scancel``. **vLLM:** .. code-block:: bash salloc -N 2 --gpus-per-node=8 --exclusive # MoE with expert parallelism bash run.sbatch Qwen/Qwen3-30B-A3B-FP8 --tensor-parallel-size 8 --enable-expert-parallel # Dense with pipeline parallelism bash run.sbatch deepseek-ai/DeepSeek-V2-Lite --tensor-parallel-size 8 --pipeline-parallel-size 2 **SGLang:** .. code-block:: bash salloc -N 2 --gpus-per-node=8 --exclusive # Large model with TP=8 (uses 8 GPUs on first node) bash run.sbatch --model-path Qwen/Qwen2.5-72B-Instruct --tp 8 # MoE with expert parallelism (TP=8, EP=2 across 2 nodes) bash run.sbatch --model-path Qwen/Qwen1.5-MoE-A2.7B --tp 8 --ep 2 **TensorRT-LLM:** .. code-block:: bash salloc -N 2 --gpus-per-node=8 --exclusive # MoE with expert parallelism bash run.sbatch /path/to/Qwen1.5-MoE-A2.7B --tp_size 8 --ep_size 2 # Dense model bash run.sbatch /path/to/Qwen2.5-14B-Instruct --tp_size 8 See the READMEs for full script options: - `vLLM README `_ - `SGLang README `_ - `TensorRT-LLM README `_ ================================================ FILE: docs/notes/llm/megatron.rst ================================================ .. meta:: :description lang=en: Megatron Bridge cheat sheet — pretrain recipes, Nsys profiling, and distributed training on SLURM with EFA. :keywords: Megatron, Megatron-LM, Megatron Bridge, distributed training, pretrain, Nsys, profiling, EFA, SLURM, DeepSeek, MoE, NCCL, GPU, LLM =========== Megatron-LM =========== .. contents:: Table of Contents :backlinks: none `Megatron-LM `_ is NVIDIA's framework for training and fine-tuning large transformer models with tensor, pipeline, and expert parallelism. `Megatron Bridge `_ sits on top of Megatron-LM and provides a recipe-based interface — instead of passing dozens of CLI flags, you write a short Python recipe that returns a config object. Recipes can load `HuggingFace `_ pretrained weights directly via the ``hf_path`` parameter, so you can start from any checkpoint without manual conversion. In this note, we demonstrate how to use Megatron Bridge to load HuggingFace pretrained weights and train with Megatron-LM. The examples use the scripts in `src/megatron `__. How to Use Megatron Bridge -------------------------- The container image is built with Docker and exported as an `enroot `_ squashfs (``.sqsh``) file. Enroot is a lightweight container runtime designed for HPC — it converts Docker images into unprivileged sandboxes that integrate with SLURM via the `pyxis `_ plugin. When ``srun.sh`` runs, it passes ``--container-image`` and ``--container-mounts`` to ``srun``, and pyxis handles importing the ``.sqsh`` and launching each task inside the container. .. code-block:: bash # Build Docker image and export to enroot sqsh make build # This produces megatron-lm+latest.sqsh in the current directory. # srun.sh picks it up via the SQSH env var (default: ./megatron-lm+latest.sqsh). You can override the image path and container mounts with environment variables: .. list-table:: :header-rows: 1 * - Variable - Default - Description * - ``SQSH`` - ``./megatron-lm+latest.sqsh`` - Path to enroot image * - ``MOUNT`` - ``.:/workspace/megatron,/fsx:/fsx`` - Container mounts * - ``GPUS_PER_NODE`` - ``8`` - GPUs per node A recipe is a Python file that calls a Megatron Bridge config function and returns a ``ConfigContainer``. For example, `deepseek_v2_lite_pretrain.py `__: .. code-block:: python from megatron.bridge.recipes.deepseek.deepseek_v2 import ( deepseek_v2_lite_pretrain_config, ) def configure(hf_path=None, moe_token_dispatcher_type=None): cfg = deepseek_v2_lite_pretrain_config( **({"hf_path": hf_path} if hf_path else {}), tensor_model_parallel_size=8, pipeline_model_parallel_size=1, expert_model_parallel_size=2, sequence_parallel=True, seq_length=4096, train_iters=500, global_batch_size=64, micro_batch_size=1, eval_interval=100, lr_warmup_iters=50, save_interval=0, ) cfg.model.moe_permute_fusion = False if moe_token_dispatcher_type == "deepep": cfg.model.moe_token_dispatcher_type = "flex" cfg.model.moe_flex_dispatcher_backend = "deepep" cfg.model.moe_enable_deepep = True cfg.model.moe_shared_expert_overlap = False return cfg Launch a pretrain job with ``srun.sh``, which wraps ``srun`` + pyxis to run inside the enroot container: .. code-block:: bash # 2-node DeepSeek V2 Lite pretrain salloc -N 2 ./srun.sh recipes/deepseek_v2_lite_pretrain.py \ hf_path=/fsx/models/deepseek-ai/DeepSeek-V2-Lite # Override config with Hydra-style args ./srun.sh recipes/deepseek_v2_lite_pretrain.py \ train.train_iters=1000 The `entrypoint.py `__ loads the recipe, applies CLI overrides, and calls ``pretrain()``. The ``srun.sh`` script sets up EFA environment variables (``FI_PROVIDER=efa``, ``NCCL_NET_PLUGIN``, etc.) and maps ``SLURM_PROCID`` / ``SLURM_LOCALID`` to ``RANK`` / ``LOCAL_RANK`` so that Megatron's distributed init works without ``torchrun``. How to Enable Nsys Profiling ---------------------------- Pass ``--nsys`` to ``srun.sh`` and set the profiling overrides in the recipe: .. code-block:: bash ./srun.sh --nsys recipes/deepseek_v2_lite_pretrain.py \ hf_path=/fsx/models/deepseek-ai/DeepSeek-V2-Lite \ profiling.use_nsys_profiler=true \ profiling.profile_step_start=10 \ profiling.profile_step_end=15 \ profiling.profile_ranks=[0] When ``--nsys`` is enabled, ``srun.sh`` prepends ``nsys profile`` with ``--capture-range=cudaProfilerApi`` so that only the steps between ``profile_step_start`` and ``profile_step_end`` are captured. The output ``.nsys-rep`` files are written to ``nsys-megatron/`` inside the container mount. On AWS, if you want to monitor EFA network traffic in the Nsys timeline, add ``--enable efa_metrics`` to the nsys command (already included in ``srun.sh``). For a detailed walkthrough on monitoring EFA with NCCL GIN and Nsys, refer to :doc:`/notes/appendix/megatron-efa-monitoring`. Why ``python`` Instead of ``torchrun`` -------------------------------------- The ``srun.sh`` script launches ``python3 entrypoint.py`` directly rather than using ``torchrun``. This is intentional. ``torchrun`` spawns worker processes via ``multiprocessing``, and the spawn boundary can interfere with Nsys profiling — the profiler sometimes fails to capture the ``cudaProfilerStart`` and ``cudaProfilerStop`` calls issued by the child process, resulting in empty or incomplete traces. By running ``python`` directly under ``srun --ntasks-per-node=``, each GPU gets its own process managed by SLURM. The ``RANK``, ``LOCAL_RANK``, and ``WORLD_SIZE`` environment variables are derived from ``SLURM_PROCID``, ``SLURM_LOCALID``, and ``SLURM_NTASKS`` respectively. This avoids the spawn layer entirely, giving Nsys (and other profilers like VizTracer) a clean, single-process view of each rank. Custom Profilers (VizTracer) ---------------------------- Megatron Bridge's profiling hooks can be extended to support custom profilers. The `viztracer_plugin.py `__ shows how to do this by monkey-patching ``megatron.bridge.training.profiling``: 1. Add a new field (e.g. ``use_viztracer``) to ``ProfilingConfig`` via ``dataclasses.field`` so OmegaConf recognizes it as a valid override. 2. Patch ``handle_profiling_step`` to start the profiler at ``profile_step_start``. 3. Patch ``handle_profiling_stop`` to stop and save at ``profile_step_end``. The plugin is loaded in `entrypoint.py `__ before any Megatron imports: .. code-block:: python import viztracer_plugin viztracer_plugin.install() Then pass the override on the command line: .. code-block:: bash ./srun.sh recipes/deepseek_v2_lite_pretrain.py \ hf_path=/fsx/models/deepseek-ai/DeepSeek-V2-Lite \ profiling.use_viztracer=true \ profiling.profile_step_start=10 \ profiling.profile_step_end=15 \ profiling.profile_ranks=[0] The same pattern works for any profiler — implement ``start()`` / ``stop()`` / ``save()`` in the patched hooks and register a config field so it can be toggled via Hydra overrides. ================================================ FILE: docs/notes/llm/pytorch.rst ================================================ .. meta:: :description lang=en: PyTorch cheat sheet covering tensors, operations, gradients, GPU usage, neural networks, and NumPy interoperability :keywords: Python, Python3, PyTorch, tensor, deep learning, GPU, CUDA, gradient, autograd, numpy, neural network ======= PyTorch ======= .. contents:: Table of Contents :backlinks: none PyTorch is an open-source machine learning framework developed by Meta AI. It provides a flexible and intuitive interface for building and training neural networks with strong GPU acceleration support. PyTorch uses dynamic computation graphs, making it easier to debug and experiment with different model architectures compared to static graph frameworks. Check CUDA ---------- Before running GPU-accelerated computations, verify that CUDA is properly installed and accessible. These commands help check GPU availability and configure device settings. .. code-block:: python >>> import torch >>> torch.cuda.is_available() True >>> torch.cuda.nccl.version() (2, 26, 2) >>> torch.cuda.device_count() 2 >>> torch.cuda.get_device_name(0) 'NVIDIA GeForce RTX 3090' >>> torch.cuda.set_device(0) >>> torch.cuda.current_device() 0 Check Device ------------ Determine where a tensor is stored (CPU or GPU) to ensure computations run on the intended device. .. code-block:: python >>> tensor = torch.tensor([0, 1, 2, 3, 4, 5]) >>> tensor.device device(type='cpu') >>> tensor = tensor.to('cuda') >>> tensor.device device(type='cuda', index=0) Create Tensors -------------- Tensors are the fundamental data structure in PyTorch, similar to NumPy arrays but with GPU acceleration support. You can create tensors from Python lists, with specific values, or using various initialization methods. .. code-block:: python >>> x = torch.tensor([0, 1, 2, 3, 4, 5]) >>> x = torch.empty(2, 2) >>> x = torch.rand(2, 2) >>> x = torch.randn(2, 2) >>> x = torch.ones(2, 2) >>> x = torch.zeros(2, 2) >>> x = torch.arange(0, 10, 2) >>> x = torch.linspace(0, 1, 5) >>> device = torch.device("cuda:0") >>> x = torch.tensor([1, 2, 3, 4, 5], device=device) >>> x tensor([1, 2, 3, 4, 5], device='cuda:0') Tensor Properties ----------------- Understanding tensor properties like shape, data type, and device location is crucial for debugging and ensuring compatibility between operations. .. code-block:: python >>> x = torch.randn(2, 3, 4) >>> x.shape torch.Size([2, 3, 4]) >>> x.size() torch.Size([2, 3, 4]) >>> x.dtype torch.float32 >>> x.device device(type='cpu') >>> x.numel() 24 Contiguous Tensors ------------------ Tensors must be stored in contiguous memory blocks for certain operations. After operations like ``transpose`` or ``permute``, tensors may not be contiguous. Use ``contiguous()`` to create a contiguous copy when needed. .. code-block:: python >>> x = torch.randn(2, 3) >>> x.is_contiguous() True >>> y = x.transpose(0, 1) >>> y.is_contiguous() False >>> z = y.contiguous() >>> z.is_contiguous() True View vs Reshape --------------- ``view()`` requires tensors to be contiguous and returns a view sharing the same memory. ``reshape()`` works on both contiguous and non-contiguous tensors, creating a copy if needed. Use ``view()`` when you know the tensor is contiguous for better performance. .. code-block:: python >>> x = torch.randn(2, 3, 4) >>> x.is_contiguous() True >>> x.view(2, 12).shape torch.Size([2, 12]) >>> y = x.transpose(0, 1) >>> y.is_contiguous() False >>> y.view(3, 8) RuntimeError: view size is not compatible with input tensor's size and stride >>> y.reshape(3, 8).shape torch.Size([3, 8]) >>> y.contiguous().view(3, 8).shape torch.Size([3, 8]) Reshape Tensors --------------- Common reshaping operations for preparing data for neural network layers. .. code-block:: python x = torch.randn(2, 3, 4) x.reshape(6, 4) x.flatten() x.squeeze() x.unsqueeze(0) Move Tensors ------------ Transfer tensors between CPU and GPU, or between different GPU devices. This is necessary when working with models and data on different devices. .. code-block:: python x = torch.randn(2, 3) x_gpu = x.to('cuda') x_gpu = x.cuda() x_cpu = x_gpu.to('cpu') x_cpu = x_gpu.cpu() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') x = x.to(device) Arithmetic ---------- PyTorch supports element-wise operations and matrix operations. Most operations have both functional and method forms, and support broadcasting for tensors of different shapes. .. code-block:: python >>> x = torch.rand(2, 2, device=0) >>> y = torch.rand(2, 2, device=0) >>> x tensor([[0.2171, 0.7797], [0.7265, 0.6759]], device='cuda:0') >>> y tensor([[0.6766, 0.1862], [0.2438, 0.0076]], device='cuda:0') >>> x + y tensor([[0.8937, 0.9659], [0.9703, 0.6834]], device='cuda:0') >>> torch.add(x, y) tensor([[0.8937, 0.9659], [0.9703, 0.6834]], device='cuda:0') >>> x - y tensor([[-0.4594, 0.5935], [ 0.4827, 0.6683]], device='cuda:0') >>> x * y tensor([[0.1469, 0.1452], [0.1771, 0.0051]], device='cuda:0') >>> x / y tensor([[ 0.3209, 4.1880], [ 2.9796, 89.4011]], device='cuda:0') >>> x ** 2 tensor([[0.0471, 0.6079], [0.5278, 0.4568]], device='cuda:0') >>> x @ y tensor([[0.3370, 0.0463], [0.6563, 0.1404]], device='cuda:0') >>> torch.matmul(x, y) tensor([[0.3370, 0.0463], [0.6563, 0.1404]], device='cuda:0') In-place Operations ------------------- Operations ending with ``_`` modify tensors directly without creating new tensors, saving memory. Use these carefully as they can affect gradient computation. .. code-block:: python x = torch.rand(2, 2, device=0) y = torch.rand(2, 2, device=0) y.add_(x) y.sub_(x) y.mul_(x) y.div_(x) y.zero_() y.fill_(5) Transpose --------- Swap dimensions of multi-dimensional tensors. This is commonly used when preparing data for different neural network layers that expect specific input shapes. .. code-block:: python >>> x = torch.randn(2, 6, 2, device=0) >>> x.shape torch.Size([2, 6, 2]) >>> y = x.transpose(1, 2) >>> y.shape torch.Size([2, 2, 6]) >>> x.T.shape torch.Size([2, 2, 6]) >>> x.permute(2, 0, 1).shape torch.Size([2, 2, 6]) Matrix Multiplication --------------------- Perform batch matrix operations on high-dimensional tensors. This is fundamental for neural network computations where you process multiple samples simultaneously. .. code-block:: python >>> x = torch.randn(1, 2, 3, 4, device=0) >>> x @ x.transpose(2, 3) tensor([[[[ 1.0950, -0.1160, -1.4840], [-0.1160, 2.4025, 2.8093], [-1.4840, 2.8093, 5.9335]], [[ 3.2498, -0.3049, -0.5129], [-0.3049, 0.1591, -0.2596], [-0.5129, -0.2596, 2.9863]]]], device='cuda:0') >>> torch.bmm(x, x.transpose(2, 3)) >>> torch.einsum('bijk,bikl->bijl', x, x.transpose(2, 3)) Aggregation ----------- Reduce tensors along specified dimensions using aggregation functions. These operations are essential for computing statistics and reducing dimensionality. .. code-block:: python >>> x = torch.randn(2, 3, 4) >>> x.sum() tensor(5.2341) >>> x.sum(dim=0).shape torch.Size([3, 4]) >>> x.sum(dim=-1).shape torch.Size([2, 3]) >>> x.mean() tensor(0.2181) >>> x.std() tensor(1.0234) >>> x.max() tensor(2.3456) >>> x.min() tensor(-1.8765) >>> x.argmax() tensor(7) >>> x.argmin() tensor(15) >>> x.max(dim=1) torch.return_types.max( values=tensor([[1.2345, 0.9876, 1.5432, 0.8765], [0.7654, 1.3456, 0.6543, 1.2345]]), indices=tensor([[1, 2, 0, 1], [2, 0, 1, 2]])) Slicing ------- Extract parts of tensors using NumPy-style indexing. Slicing is essential for accessing specific elements, rows, columns, or sub-tensors without copying data. .. code-block:: python >>> x = torch.randn(2, 3, device=0) >>> x tensor([[-1.3921, 0.0475, 0.7571], [-0.1469, -0.3882, 0.2149]], device='cuda:0') >>> x[:, 1] tensor([ 0.0475, -0.3882], device='cuda:0') >>> x[1, :] tensor([-0.1469, -0.3882, 0.2149], device='cuda:0') >>> x[1, 1].item() -0.3882044851779938 >>> x = torch.triu(torch.ones(5, 5)) >>> x[:3, :3] tensor([[1., 1., 1.], [0., 1., 1.], [0., 0., 1.]]) Advanced Indexing ----------------- Use boolean masks and fancy indexing to select elements based on conditions or specific index arrays. .. code-block:: python >>> x = torch.randn(3, 4) >>> mask = x > 0 >>> x[mask] tensor([0.5234, 1.2345, 0.8765, 0.3456, 1.5678]) x[x > 0] = 0 indices = torch.tensor([0, 2]) x[indices] x[[0, 1], [1, 2]] Concatenation ------------- Combine multiple tensors along existing or new dimensions. This is useful for building batches or combining features from different sources. .. code-block:: python >>> x = torch.randn(2, 3) >>> y = torch.randn(2, 3) >>> torch.cat([x, y], dim=0).shape torch.Size([4, 3]) >>> torch.cat([x, y], dim=1).shape torch.Size([2, 6]) >>> torch.stack([x, y], dim=0).shape torch.Size([2, 2, 3]) >>> torch.vstack([x, y]).shape torch.Size([4, 3]) >>> torch.hstack([x, y]).shape torch.Size([2, 6]) Splitting --------- Split tensors into multiple chunks or along specific dimensions. Useful for distributing data across multiple GPUs or processing in smaller batches. .. code-block:: python >>> x = torch.randn(6, 4) >>> chunks = torch.split(x, 2, dim=0) >>> len(chunks) 3 >>> chunks = torch.chunk(x, 3, dim=0) >>> len(chunks) 3 >>> tensors = x.unbind(dim=0) >>> len(tensors) 6 Gradient -------- Automatic differentiation is PyTorch's core feature for training neural networks. Enable gradient tracking on tensors to compute derivatives during backpropagation. .. code-block:: python >>> x = torch.randn(3, requires_grad=True, device=0) >>> x tensor([-1.1442, -0.8709, -0.2581], device='cuda:0', requires_grad=True) >>> y = x.detach() >>> y tensor([-1.1442, -0.8709, -0.2581], device='cuda:0') >>> x.requires_grad_(False) tensor([-1.1442, -0.8709, -0.2581], device='cuda:0') Disable Gradient ---------------- Temporarily disable gradient computation for inference or when you don't need gradients. This reduces memory usage and speeds up computations. .. code-block:: python >>> x = torch.randn(3, requires_grad=True, device=0) >>> with torch.no_grad(): ... y = x + 1 ... print(y) ... tensor([1.2969, 1.5251, 0.7915], device='cuda:0') with torch.inference_mode(): y = x * 2 @torch.no_grad() def predict(x): return model(x) Backpropagation --------------- Compute gradients using automatic differentiation. PyTorch builds a computation graph during the forward pass and computes gradients during the backward pass using the chain rule. .. code-block:: python >>> x = torch.randn(3, requires_grad=True) >>> y = x + 1 >>> z = y * y * 3 >>> z = z.mean() >>> z.backward() >>> x.grad tensor([1.2036, 5.0103, 0.5143]) x.grad.zero_() z.backward() Gradient Accumulation --------------------- Accumulate gradients over multiple backward passes. This is useful for simulating larger batch sizes when GPU memory is limited. .. code-block:: python optimizer.zero_grad() for i, (inputs, targets) in enumerate(dataloader): outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() if (i + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() Neural Network Module --------------------- Define neural networks by subclassing ``nn.Module``. This provides a clean interface for building complex models with automatic parameter management. .. code-block:: python import torch.nn as nn class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(784, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = torch.relu(self.fc1(x)) x = self.fc2(x) return x model = Net() model.to('cuda') Common Layers ------------- PyTorch provides a wide variety of pre-built layers for constructing neural networks. .. code-block:: python import torch.nn as nn linear = nn.Linear(10, 5) conv2d = nn.Conv2d(3, 64, kernel_size=3, padding=1) maxpool = nn.MaxPool2d(2, 2) dropout = nn.Dropout(0.5) batchnorm = nn.BatchNorm2d(64) relu = nn.ReLU() sigmoid = nn.Sigmoid() softmax = nn.Softmax(dim=1) lstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=2) gru = nn.GRU(input_size=10, hidden_size=20) embedding = nn.Embedding(1000, 128) Loss Functions -------------- Loss functions measure how well the model's predictions match the target values. Choose the appropriate loss function based on your task. .. code-block:: python import torch.nn as nn mse_loss = nn.MSELoss() cross_entropy = nn.CrossEntropyLoss() bce_loss = nn.BCELoss() bce_with_logits = nn.BCEWithLogitsLoss() l1_loss = nn.L1Loss() nll_loss = nn.NLLLoss() outputs = model(inputs) loss = cross_entropy(outputs, targets) Optimizers ---------- Optimizers update model parameters based on computed gradients. Different optimizers use different strategies for parameter updates. .. code-block:: python import torch.optim as optim optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) optimizer = optim.Adam(model.parameters(), lr=0.001) optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01) optimizer = optim.RMSprop(model.parameters(), lr=0.01) optimizer.zero_grad() loss.backward() optimizer.step() Learning Rate Scheduler ----------------------- Adjust the learning rate during training to improve convergence and model performance. .. code-block:: python from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau scheduler = StepLR(optimizer, step_size=30, gamma=0.1) scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=10) for epoch in range(num_epochs): train(...) val_loss = validate(...) scheduler.step(val_loss) Training Loop ------------- A typical training loop involves forward pass, loss computation, backward pass, and parameter updates. .. code-block:: python model.train() for epoch in range(num_epochs): for inputs, targets in train_loader: inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() print(f'Loss: {loss.item():.4f}') Evaluation Mode --------------- Switch between training and evaluation modes to control behavior of layers like dropout and batch normalization. .. code-block:: python model.eval() with torch.no_grad(): for inputs, targets in test_loader: inputs = inputs.to(device) outputs = model(inputs) predictions = outputs.argmax(dim=1) Save and Load Models -------------------- Save trained models to disk and load them later for inference or continued training. .. code-block:: python torch.save(model.state_dict(), 'model.pth') model.load_state_dict(torch.load('model.pth')) torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, }, 'checkpoint.pth') checkpoint = torch.load('checkpoint.pth') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) DataLoader ---------- Efficiently load and batch data for training. DataLoader handles shuffling, batching, and parallel data loading. .. code-block:: python from torch.utils.data import DataLoader, TensorDataset dataset = TensorDataset(x_train, y_train) loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4) for batch_x, batch_y in loader: outputs = model(batch_x) loss = criterion(outputs, batch_y) Custom Dataset -------------- Create custom datasets by subclassing ``Dataset`` for loading data from various sources. .. code-block:: python from torch.utils.data import Dataset class CustomDataset(Dataset): def __init__(self, data, labels): self.data = data self.labels = labels def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx], self.labels[idx] dataset = CustomDataset(x_train, y_train) loader = DataLoader(dataset, batch_size=32) NumPy Conversion ---------------- Convert between PyTorch tensors and NumPy arrays for interoperability with other libraries. Note that tensors on GPU must be moved to CPU first. .. code-block:: python >>> x = torch.randn([1, 2, 3], device=0) >>> y = x.cpu().numpy() >>> y array([[[-0.11979043, 0.13762406, -1.2633433 ], [-0.380241 , 1.5320604 , -1.0828359 ]]], dtype=float32) import numpy as np arr = np.array([[1, 2], [3, 4]]) tensor = torch.from_numpy(arr) Shared Memory ------------- CPU tensors and NumPy arrays can share the same underlying memory, so modifications to one will affect the other. .. code-block:: python >>> x = torch.randn(1, 2, 3) >>> y = x.numpy() >>> x.add_(1) tensor([[[ 1.8195, 3.0259, 0.6733], [ 2.6539, 1.1562, -0.9821]]]) >>> y array([[[ 1.8194908 , 3.0258512 , 0.67326605], [ 2.6539469 , 1.1561831 , -0.98211455]]], dtype=float32) Random Seed ----------- Set random seeds for reproducibility across different runs. This ensures consistent results when debugging or comparing experiments. .. code-block:: python torch.manual_seed(42) torch.cuda.manual_seed(42) torch.cuda.manual_seed_all(42) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False import numpy as np import random np.random.seed(42) random.seed(42) GPU Memory Management --------------------- Monitor and manage GPU memory usage to avoid out-of-memory errors and optimize performance. .. code-block:: python >>> torch.cuda.empty_cache() >>> torch.cuda.memory_allocated() 1073741824 >>> torch.cuda.memory_reserved() 2147483648 >>> torch.cuda.max_memory_allocated() 3221225472 torch.cuda.reset_peak_memory_stats() print(torch.cuda.memory_summary()) Distributed Training -------------------- NCCL (NVIDIA Collective Communication Library) enables efficient multi-GPU and multi-node training. Initialize the distributed process group and check its status. .. code-block:: python >>> import torch.distributed as dist >>> dist.is_available() True >>> dist.is_nccl_available() True >>> dist.is_initialized() False dist.init_process_group(backend='nccl') >>> dist.is_initialized() True >>> dist.get_rank() 0 >>> dist.get_world_size() 4 >>> dist.get_backend() 'nccl' Launch with torchrun -------------------- Use ``torchrun`` to launch distributed training across multiple GPUs or nodes. It automatically sets up environment variables for distributed training. .. code-block:: bash # single node, multiple GPUs torchrun --nproc_per_node=4 train.py # multiple nodes torchrun --nproc_per_node=4 \ --nnodes=2 \ --node_rank=0 \ --master_addr="192.168.1.1" \ --master_port=29500 \ train.py Launch with Slurm ----------------- Submit distributed training jobs to Slurm clusters. Slurm manages resource allocation and node assignment. See :doc:`../hpc/slurm` for more Slurm examples. .. code-block:: bash #!/bin/bash #SBATCH --job-name=pytorch_ddp #SBATCH --nodes=2 #SBATCH --ntasks-per-node=4 #SBATCH --gres=gpu:4 #SBATCH --time=24:00:00 export MASTER_ADDR=$(scontrol show hostname $SLURM_NODELIST | head -n 1) export MASTER_PORT=29500 srun torchrun --nproc_per_node=4 \ --nnodes=$SLURM_NNODES \ --node_rank=$SLURM_NODEID \ --master_addr=$MASTER_ADDR \ --master_port=$MASTER_PORT \ train.py DistributedDataParallel ------------------------ Wrap your model with DDP for multi-GPU training. DDP uses NCCL for efficient gradient synchronization across GPUs. .. code-block:: python import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP model = Net().to(device) model = DDP(model, device_ids=[local_rank]) for inputs, targets in train_loader: outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() Collective Operations --------------------- NCCL provides collective communication operations for distributed training. .. code-block:: python import torch.distributed as dist tensor = torch.randn(2, 3).cuda() dist.all_reduce(tensor, op=dist.ReduceOp.SUM) dist.broadcast(tensor, src=0) dist.all_gather(tensor_list, tensor) dist.reduce(tensor, dst=0, op=dist.ReduceOp.SUM) dist.barrier() ================================================ FILE: docs/notes/network/index.rst ================================================ .. meta:: :description lang=en: Python network programming tutorial covering sockets, TCP/UDP servers, async I/O, SSL/TLS, packet sniffing, and SSH tunneling :keywords: Python, Python3, socket, networking, TCP, UDP, SSL, TLS, async, select, epoll, SSH, packet sniffing, server, client ======= Network ======= Network programming is fundamental to modern software development, enabling communication between processes across machines, data centers, and the internet. This section covers Python's networking capabilities from low-level socket programming to secure communications. You'll find practical examples for building TCP/UDP servers, handling multiple connections with async I/O (select, poll, epoll), implementing TLS/SSL encryption, packet sniffing for network analysis, and SSH for secure remote access and tunneling. Whether you're building web services, IoT applications, network tools, or automation scripts, these cheat sheets provide the essential patterns and code snippets you need. .. toctree:: :maxdepth: 1 python-socket python-socket-server python-socket-async python-socket-ssl python-socket-sniffer python-ssh ================================================ FILE: docs/notes/network/python-socket-async.rst ================================================ .. meta:: :description lang=en: Python asynchronous socket programming tutorial with select, poll, epoll, kqueue, and selectors module for building high-performance non-blocking servers :keywords: Python, socket, async, select, poll, epoll, kqueue, selectors, non-blocking, event-driven, I/O multiplexing, concurrent connections, high-performance server ================= Async Socket I/O ================= .. contents:: Table of Contents :backlinks: none Asynchronous I/O is essential for building high-performance network servers that can handle thousands of concurrent connections efficiently. Traditional blocking I/O requires one thread per connection, which doesn't scale well due to memory overhead and context switching costs. Asynchronous I/O solves this by allowing a single thread to handle multiple connections using I/O multiplexing—the program monitors multiple sockets simultaneously and processes whichever ones are ready for reading or writing. This section covers the evolution of I/O multiplexing in Python, from the classic ``select()`` system call (portable but limited) to modern high-performance mechanisms like ``epoll`` (Linux) and ``kqueue`` (BSD/macOS) that can efficiently handle tens of thousands of connections. We also cover the ``selectors`` module, which provides a high-level, platform-independent interface that automatically uses the best available mechanism. Understanding these primitives is valuable even if you use higher-level frameworks like ``asyncio``, as they build upon these same concepts. Async TCP Server - select ------------------------- ``select()`` is the oldest and most portable I/O multiplexing mechanism, available on virtually all platforms including Windows, Linux, and macOS. It monitors file descriptors for three conditions: readability (data available to read), writability (buffer space available to write), and exceptional conditions (errors). While portable, ``select()`` has limitations: it typically supports only up to 1024 file descriptors and has O(n) performance as it must scan all monitored descriptors on each call. .. code-block:: python from select import select import socket host = ('localhost', 5566) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(host) sock.listen(5) read_list = [sock] write_list = [] messages = {} try: while True: readable, writable, _ = select(read_list, write_list, []) for s in readable: if s == sock: conn, addr = sock.accept() read_list.append(conn) else: msg = s.recv(1024) if msg: messages[s.fileno()] = msg write_list.append(s) else: read_list.remove(s) s.close() for s in writable: msg = messages.pop(s.fileno(), None) if msg: s.send(msg) write_list.remove(s) except KeyboardInterrupt: sock.close() Async TCP Server - poll ----------------------- ``poll()`` is similar to ``select()`` but more efficient for large numbers of file descriptors. Available on Unix systems. .. code-block:: python import socket import select import contextlib host = 'localhost' port = 5566 connections = {} requests = {} responses = {} @contextlib.contextmanager def create_server(host, port): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.setblocking(False) s.bind((host, port)) s.listen(10) try: yield s finally: s.close() def accept(server, poll): conn, addr = server.accept() conn.setblocking(False) fd = conn.fileno() poll.register(fd, select.POLLIN) requests[fd] = conn connections[fd] = conn def recv(fd, poll): conn = requests.pop(fd, None) if not conn: return msg = conn.recv(1024) if msg: responses[fd] = msg poll.modify(fd, select.POLLOUT) else: poll.unregister(fd) conn.close() connections.pop(fd, None) def send(fd, poll): conn = connections.get(fd) msg = responses.pop(fd, None) if conn and msg: conn.send(msg) requests[fd] = conn poll.modify(fd, select.POLLIN) with create_server(host, port) as server: poll = select.poll() poll.register(server.fileno(), select.POLLIN) try: while True: events = poll.poll(1000) for fd, event in events: if fd == server.fileno(): accept(server, poll) elif event & (select.POLLIN | select.POLLPRI): recv(fd, poll) elif event & select.POLLOUT: send(fd, poll) except KeyboardInterrupt: pass Async TCP Server - epoll ------------------------ ``epoll`` is Linux-specific and the most efficient for handling thousands of connections. It uses edge-triggered or level-triggered notifications. .. code-block:: python import socket import select import contextlib host = 'localhost' port = 5566 connections = {} requests = {} responses = {} @contextlib.contextmanager def create_server(host, port): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.setblocking(False) s.bind((host, port)) s.listen(10) try: yield s finally: s.close() def accept(server, epoll): conn, addr = server.accept() conn.setblocking(False) fd = conn.fileno() epoll.register(fd, select.EPOLLIN) requests[fd] = conn connections[fd] = conn def recv(fd, epoll): conn = requests.pop(fd, None) if not conn: return msg = conn.recv(1024) if msg: responses[fd] = msg epoll.modify(fd, select.EPOLLOUT) else: epoll.unregister(fd) conn.close() connections.pop(fd, None) def send(fd, epoll): conn = connections.get(fd) msg = responses.pop(fd, None) if conn and msg: conn.send(msg) requests[fd] = conn epoll.modify(fd, select.EPOLLIN) with create_server(host, port) as server: epoll = select.epoll() epoll.register(server.fileno(), select.EPOLLIN) try: while True: events = epoll.poll(1) for fd, event in events: if fd == server.fileno(): accept(server, epoll) elif event & select.EPOLLIN: recv(fd, epoll) elif event & select.EPOLLOUT: send(fd, epoll) except KeyboardInterrupt: pass finally: epoll.close() Async TCP Server - kqueue ------------------------- ``kqueue`` is the BSD/macOS equivalent of epoll, providing efficient event notification for large numbers of file descriptors. .. code-block:: python import socket import select import contextlib if not hasattr(select, 'kqueue'): print("kqueue not supported on this platform") exit(1) host = 'localhost' port = 5566 connections = {} requests = {} responses = {} @contextlib.contextmanager def create_server(host, port): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.setblocking(False) s.bind((host, port)) s.listen(10) try: yield s finally: s.close() def accept(server, kq): conn, addr = server.accept() conn.setblocking(False) fd = conn.fileno() ke = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) kq.control([ke], 0) requests[fd] = conn connections[fd] = conn def recv(fd, kq): conn = requests.pop(fd, None) if not conn: return msg = conn.recv(1024) if msg: responses[fd] = msg # Switch from read to write ke_del = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) ke_add = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) kq.control([ke_del, ke_add], 0) requests[fd] = conn else: ke = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) kq.control([ke], 0) conn.close() connections.pop(fd, None) def send(fd, kq): conn = connections.get(fd) msg = responses.pop(fd, None) if conn and msg: conn.send(msg) # Switch from write to read ke_del = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) ke_add = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) kq.control([ke_del, ke_add], 0) requests[fd] = conn with create_server(host, port) as server: kq = select.kqueue() ke = select.kevent(server.fileno(), select.KQ_FILTER_READ, select.KQ_EV_ADD) kq.control([ke], 0) try: while True: events = kq.control(None, 1024, 1) for e in events: fd = e.ident if fd == server.fileno(): accept(server, kq) elif e.filter == select.KQ_FILTER_READ: recv(fd, kq) elif e.filter == select.KQ_FILTER_WRITE: send(fd, kq) except KeyboardInterrupt: pass finally: kq.close() High-Level API - selectors -------------------------- The ``selectors`` module (Python 3.4+) provides a high-level, platform-independent interface that automatically uses the best available mechanism (epoll, kqueue, etc.). .. code-block:: python import selectors import socket import contextlib @contextlib.contextmanager def create_server(host, port): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind((host, port)) s.listen(10) sel = selectors.DefaultSelector() try: yield s, sel finally: s.close() sel.close() def accept_handler(sock, sel): conn, addr = sock.accept() sel.register(conn, selectors.EVENT_READ, read_handler) def read_handler(conn, sel): msg = conn.recv(1024) if msg: conn.send(msg) else: sel.unregister(conn) conn.close() host = 'localhost' port = 5566 with create_server(host, port) as (sock, sel): sel.register(sock, selectors.EVENT_READ, accept_handler) try: while True: events = sel.select() for key, mask in events: handler = key.data handler(key.fileobj, sel) except KeyboardInterrupt: pass Comparison of I/O Multiplexing Methods -------------------------------------- +------------+------------+------------------+---------------------------+ | Method | Platform | Scalability | Notes | +============+============+==================+===========================+ | select | All | O(n) - Limited | Max ~1024 FDs | +------------+------------+------------------+---------------------------+ | poll | Unix | O(n) - Better | No FD limit | +------------+------------+------------------+---------------------------+ | epoll | Linux | O(1) - Excellent | Edge/level triggered | +------------+------------+------------------+---------------------------+ | kqueue | BSD/macOS | O(1) - Excellent | Similar to epoll | +------------+------------+------------------+---------------------------+ | selectors | All | Best available | Recommended for new code | +------------+------------+------------------+---------------------------+ .. note:: For new code, use the ``selectors`` module or ``asyncio`` for async I/O. The low-level APIs (select, poll, epoll, kqueue) are mainly useful for understanding how async I/O works or when you need fine-grained control. ================================================ FILE: docs/notes/network/python-socket-server.rst ================================================ .. meta:: :description lang=en: Python TCP and UDP server tutorial with examples for echo servers, IPv6 dual-stack, Unix domain sockets, SocketServer module, threaded servers, and zero-copy sendfile :keywords: Python, socket, TCP server, UDP server, echo server, IPv6, dual-stack, Unix domain socket, SocketServer, network programming, threaded server, sendfile, IPC ============== Socket Servers ============== .. contents:: Table of Contents :backlinks: none Building network servers is one of the most common applications of socket programming. A server listens for incoming connections on a specific port, accepts client connections, and processes requests. Python's socket module provides all the primitives needed to build robust TCP and UDP servers, from simple single-threaded echo servers to complex multi-client applications. This section covers building TCP and UDP servers in Python using the socket module, including simple echo servers for learning the basics, IPv6 and dual-stack servers for modern network compatibility, Unix domain sockets for high-performance local IPC, the SocketServer module for rapid development, threaded servers for handling multiple clients, and zero-copy file transfer with sendfile. These patterns form the foundation for building production-ready network services. Simple TCP Echo Server ---------------------- A basic TCP echo server demonstrates the fundamental server pattern: create a socket, bind to an address, listen for connections, accept clients, and process data. This example uses Python's context manager protocol for proper resource cleanup, ensuring the socket is closed even if an exception occurs. The ``SO_REUSEADDR`` option allows the server to restart immediately without waiting for the TIME_WAIT state to expire. .. code-block:: python import socket class Server: def __init__(self, host, port): self._host = host self._port = port def __enter__(self): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind((self._host, self._port)) sock.listen(10) self._sock = sock return self._sock def __exit__(self, *exc_info): if exc_info[0]: import traceback traceback.print_exception(*exc_info) self._sock.close() if __name__ == '__main__': with Server('localhost', 5566) as s: while True: conn, addr = s.accept() msg = conn.recv(1024) conn.send(msg) conn.close() Output: .. code-block:: console $ nc localhost 5566 Hello World Hello World TCP Echo Server via IPv6 ------------------------ IPv6 server using ``AF_INET6`` address family. Note the different address tuple format which includes flow info and scope ID. .. code-block:: python import contextlib import socket host = "::1" port = 5566 @contextlib.contextmanager def server(host, port): s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM, 0) try: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind((host, port)) s.listen(10) yield s finally: s.close() with server(host, port) as s: try: while True: conn, addr = s.accept() msg = conn.recv(1024) if msg: conn.send(msg) conn.close() except KeyboardInterrupt: pass Output: .. code-block:: bash $ python3 ipv6.py & $ nc -6 ::1 5566 Hello IPv6 Hello IPv6 Dual-Stack Server (IPv4 and IPv6) --------------------------------- A server that accepts both IPv4 and IPv6 connections by binding to ``::`` and disabling ``IPV6_V6ONLY``. IPv4 clients appear as IPv4-mapped IPv6 addresses. .. code-block:: python import contextlib import socket host = "::" port = 5566 @contextlib.contextmanager def server(host, port): s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM, 0) try: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) s.bind((host, port)) s.listen(10) yield s finally: s.close() with server(host, port) as s: try: while True: conn, addr = s.accept() print(f"Connection from: {conn.getpeername()}") msg = conn.recv(1024) if msg: conn.send(msg) conn.close() except KeyboardInterrupt: pass Output: .. code-block:: bash $ python3 dual_stack.py & $ nc -4 127.0.0.1 5566 Connection from: ('::ffff:127.0.0.1', 42604, 0, 0) Hello IPv4 Hello IPv4 $ nc -6 ::1 5566 Connection from: ('::1', 50882, 0, 0) Hello IPv6 Hello IPv6 TCP Server via SocketServer --------------------------- The ``socketserver`` module provides a higher-level interface for building servers. It handles the socket setup and connection management automatically. .. code-block:: python import socketserver class EchoHandler(socketserver.BaseRequestHandler): def handle(self): data = self.request.recv(1024) print(f"Client: {self.client_address}") self.request.sendall(data) if __name__ == '__main__': with socketserver.TCPServer(('localhost', 5566), EchoHandler) as server: server.serve_forever() Simple UDP Echo Server ---------------------- UDP servers use ``SOCK_DGRAM`` and don't require connection handling. Each ``recvfrom()`` returns the data and sender address. .. code-block:: python import socket class UDPServer: def __init__(self, host, port): self._host = host self._port = port def __enter__(self): sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.bind((self._host, self._port)) self._sock = sock return sock def __exit__(self, *exc_info): if exc_info[0]: import traceback traceback.print_exception(*exc_info) self._sock.close() if __name__ == '__main__': with UDPServer('localhost', 5566) as s: while True: msg, addr = s.recvfrom(1024) s.sendto(msg, addr) Output: .. code-block:: console $ nc -u localhost 5566 Hello World Hello World UDP Server via SocketServer --------------------------- .. code-block:: python import socketserver class UDPHandler(socketserver.BaseRequestHandler): def handle(self): data, sock = self.request print(f"Client: {self.client_address}") sock.sendto(data, self.client_address) if __name__ == '__main__': with socketserver.UDPServer(('localhost', 5566), UDPHandler) as server: server.serve_forever() UDP Client - Sender ------------------- Simple UDP client that sends periodic messages. .. code-block:: python import socket import time sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) host = ('localhost', 5566) while True: sock.sendto(b"Hello\n", host) time.sleep(5) Broadcast UDP Packets --------------------- Send UDP packets to all hosts on the local network using broadcast address. .. code-block:: python import socket import time sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.bind(('', 0)) sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) while True: msg = f'{time.time()}\n'.encode() sock.sendto(msg, ('', 5566)) time.sleep(5) Output: .. code-block:: console $ nc -k -w 1 -ul 5566 1431473025.72 Unix Domain Socket ------------------ Unix domain sockets provide inter-process communication on the same machine, faster than TCP/IP loopback as they bypass the network stack. .. code-block:: python import socket import contextlib import os @contextlib.contextmanager def domain_server(addr): try: if os.path.exists(addr): os.unlink(addr) sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.bind(addr) sock.listen(10) yield sock finally: sock.close() if os.path.exists(addr): os.unlink(addr) addr = "./domain.sock" with domain_server(addr) as sock: while True: conn, _ = sock.accept() msg = conn.recv(1024) conn.send(msg) conn.close() Output: .. code-block:: console $ nc -U ./domain.sock Hello Hello Socket Pair for IPC ------------------- ``socketpair()`` creates a pair of connected sockets, useful for bidirectional communication between parent and child processes. .. code-block:: python import os import socket child_sock, parent_sock = socket.socketpair() pid = os.fork() try: if pid == 0: # Child process parent_sock.close() child_sock.send(b'Hello Parent!') msg = child_sock.recv(1024) print(f'Parent says: {msg}') else: # Parent process child_sock.close() msg = parent_sock.recv(1024) print(f'Child says: {msg}') parent_sock.send(b'Hello Child!') os.wait() finally: child_sock.close() parent_sock.close() Output: .. code-block:: bash $ python socketpair_demo.py Child says: b'Hello Parent!' Parent says: b'Hello Child!' Threaded TCP Server ------------------- Handle multiple clients concurrently using threads. .. code-block:: python from threading import Thread import socket def handle_client(conn): while True: msg = conn.recv(1024) if not msg: break conn.send(msg) conn.close() sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(('localhost', 5566)) sock.listen(5) while True: conn, addr = sock.accept() t = Thread(target=handle_client, args=(conn,)) t.daemon = True t.start() Using sendfile for Zero-Copy Transfer ------------------------------------- ``os.sendfile()`` efficiently transfers file data to a socket without copying through user space (zero-copy). .. code-block:: python import os import socket import contextlib @contextlib.contextmanager def server(host, port): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind((host, port)) s.listen(10) try: yield s finally: s.close() def send_file(conn, filepath): with open(filepath, 'rb') as f: fd = f.fileno() size = os.fstat(fd).st_size offset = 0 while size > 0: sent = os.sendfile(conn.fileno(), fd, offset, min(size, 65536)) offset += sent size -= sent # Server that sends a file to each client with server('localhost', 5566) as s: while True: conn, addr = s.accept() send_file(conn, 'large_file.bin') conn.close() ================================================ FILE: docs/notes/network/python-socket-sniffer.rst ================================================ .. meta:: :description lang=en: Python packet sniffing and raw socket programming tutorial for capturing and parsing IP, TCP, UDP, ARP packets, network analysis, and Linux kernel crypto API :keywords: Python, socket, raw socket, packet sniffer, IP header, TCP header, ARP, network analysis, ctypes, struct, Wireshark, tcpdump, AF_ALG, kernel crypto, scapy ================ Packet Sniffing ================ .. contents:: Table of Contents :backlinks: none Raw sockets provide direct access to network packets at the IP layer and below, bypassing the normal TCP/UDP protocol stack. This low-level access is essential for building network analysis tools, intrusion detection systems, custom protocol implementations, and security research tools. While libraries like Wireshark and tcpdump are commonly used for packet capture, understanding how to parse packets in Python gives you the flexibility to build custom analysis tools tailored to your specific needs. This section covers capturing and parsing network packets using Python's raw socket interface, the ``ctypes`` module for defining C-compatible data structures, and the ``struct`` module for binary data parsing. You'll learn to decode IP headers to extract source and destination addresses, parse TCP headers to analyze connection states and flags, capture ARP packets to monitor address resolution, and use the Linux kernel's AF_ALG interface for hardware-accelerated cryptography. These techniques form the foundation for building tools like network monitors, protocol analyzers, and security scanners. .. warning:: Raw socket operations typically require root/administrator privileges on most operating systems. Use these techniques responsibly and only on networks you own or have explicit permission to analyze. Unauthorized packet sniffing may violate laws and regulations in your jurisdiction. Sniffer IP Packets ------------------ Capturing IP packets requires creating a raw socket with ``SOCK_RAW`` and specifying the protocol to capture (e.g., ``IPPROTO_ICMP`` for ICMP packets). The IP header is a 20-byte structure (without options) containing version, header length, type of service, total length, identification, flags, fragment offset, TTL, protocol, checksum, and source/destination addresses. Using ``ctypes.Structure``, we can define a Python class that maps directly to this binary layout for easy field access. .. code-block:: python from ctypes import Structure, c_ubyte, c_uint8, c_uint16, c_uint32 import socket import struct # IP protocol numbers PROTO_MAP = { 1: "ICMP", 2: "IGMP", 6: "TCP", 17: "UDP", 27: "RDP" } class IP(Structure): """IP header structure (20 bytes).""" _fields_ = [ ("ip_hl", c_ubyte, 4), # Header length ("ip_v", c_ubyte, 4), # Version ("ip_tos", c_uint8), # Type of service ("ip_len", c_uint16), # Total length ("ip_id", c_uint16), # Identification ("ip_off", c_uint16), # Fragment offset ("ip_ttl", c_uint8), # Time to live ("ip_p", c_uint8), # Protocol ("ip_sum", c_uint16), # Checksum ("ip_src", c_uint32), # Source address ("ip_dst", c_uint32), # Destination address ] def __new__(cls, buf=None): return cls.from_buffer_copy(buf) def __init__(self, buf=None): src = struct.pack(" {ip_header.dst}') except KeyboardInterrupt: s.close() Output: .. code-block:: console $ sudo python sniffer.py Sniffer start... ICMP: 127.0.0.1 -> 127.0.0.1 ICMP: 127.0.0.1 -> 127.0.0.1 Sniffer TCP Packets ------------------- Parse TCP headers to extract port numbers, sequence numbers, and flags. .. code-block:: python import socket import platform from struct import unpack from contextlib import contextmanager if platform.system() != "Linux": print("This example requires Linux") exit(1) @contextmanager def create_socket(): s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_TCP) try: yield s finally: s.close() def parse_tcp_packet(pkt): # IP header (first 20 bytes, variable length) iphdr = unpack('!BBHHHBBH4s4s', pkt[0:20]) iplen = (iphdr[0] & 0xf) * 4 # TCP header (next 20 bytes minimum) tcphdr = unpack('!HHLLBBHHH', pkt[iplen:iplen+20]) return { 'src_port': tcphdr[0], 'dst_port': tcphdr[1], 'seq': tcphdr[2], 'ack': tcphdr[3], 'data_offset': tcphdr[4] >> 4, 'flags': { 'FIN': tcphdr[5] & 0x01, 'SYN': tcphdr[5] & 0x02, 'RST': tcphdr[5] & 0x04, 'PSH': tcphdr[5] & 0x08, 'ACK': tcphdr[5] & 0x10, 'URG': tcphdr[5] & 0x20, }, 'window': tcphdr[6], 'checksum': tcphdr[7], } try: with create_socket() as s: print("TCP Sniffer started...") while True: pkt, addr = s.recvfrom(65535) tcp = parse_tcp_packet(pkt) # Skip packets without data iplen = (pkt[0] & 0xf) * 4 tcplen = tcp['data_offset'] * 4 data = pkt[iplen + tcplen:] if not data: continue flags = [k for k, v in tcp['flags'].items() if v] print(f"Port {tcp['src_port']} -> {tcp['dst_port']} " f"[{','.join(flags)}] Seq={tcp['seq']}") print(f"Data: {data[:50]}...") except KeyboardInterrupt: pass Sniffer ARP Packets ------------------- Capture ARP (Address Resolution Protocol) packets to see MAC-to-IP mappings. .. code-block:: python import socket import struct import binascii # Create raw socket for all packets rawSocket = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, socket.htons(0x0003)) print("ARP Sniffer started...") while True: packet = rawSocket.recvfrom(2048) # Ethernet header (14 bytes) ethhdr = packet[0][0:14] eth = struct.unpack("!6s6s2s", ethhdr) # Check if ARP packet (0x0806) if eth[2] != b'\x08\x06': continue # ARP header (28 bytes) arphdr = packet[0][14:42] arp = struct.unpack("2s2s1s1s2s6s4s6s4s", arphdr) print("=" * 50) print("ETHERNET FRAME") print(f" Dest MAC: {binascii.hexlify(eth[0]).decode()}") print(f" Source MAC: {binascii.hexlify(eth[1]).decode()}") print("ARP HEADER") print(f" Hardware: {binascii.hexlify(arp[0]).decode()}") print(f" Protocol: {binascii.hexlify(arp[1]).decode()}") print(f" Opcode: {binascii.hexlify(arp[4]).decode()} " f"({'Request' if arp[4] == b'\\x00\\x01' else 'Reply'})") print(f" Sender MAC: {binascii.hexlify(arp[5]).decode()}") print(f" Sender IP: {socket.inet_ntoa(arp[6])}") print(f" Target MAC: {binascii.hexlify(arp[7]).decode()}") print(f" Target IP: {socket.inet_ntoa(arp[8])}") Parse Packet with struct ------------------------ Using ``struct`` module for flexible packet parsing. .. code-block:: python import struct import socket def parse_ip_header(data): """Parse IP header from raw bytes.""" # ! = network byte order (big-endian) # B = unsigned char, H = unsigned short, 4s = 4-byte string fields = struct.unpack('!BBHHHBBH4s4s', data[:20]) return { 'version': fields[0] >> 4, 'ihl': fields[0] & 0x0F, 'tos': fields[1], 'total_length': fields[2], 'identification': fields[3], 'flags': fields[4] >> 13, 'fragment_offset': fields[4] & 0x1FFF, 'ttl': fields[5], 'protocol': fields[6], 'checksum': fields[7], 'src_ip': socket.inet_ntoa(fields[8]), 'dst_ip': socket.inet_ntoa(fields[9]), } def parse_udp_header(data): """Parse UDP header from raw bytes.""" fields = struct.unpack('!HHHH', data[:8]) return { 'src_port': fields[0], 'dst_port': fields[1], 'length': fields[2], 'checksum': fields[3], } # Example usage with captured packet # ip_data = ... (raw IP packet bytes) # ip = parse_ip_header(ip_data) # if ip['protocol'] == 17: # UDP # udp = parse_udp_header(ip_data[ip['ihl']*4:]) Linux Kernel Crypto API (AF_ALG) -------------------------------- Use Linux kernel's cryptographic API through sockets for hardware-accelerated encryption. Requires Linux 2.6.38+ and Python 3.6+. .. code-block:: python import socket import hashlib import contextlib @contextlib.contextmanager def create_alg(typ, name): s = socket.socket(socket.AF_ALG, socket.SOCK_SEQPACKET, 0) try: s.bind((typ, name)) yield s finally: s.close() # SHA-256 hash using kernel crypto msg = b'Python is awesome!' with create_alg('hash', 'sha256') as algo: op, _ = algo.accept() with op: op.sendall(msg) digest = op.recv(512) print(f"AF_ALG SHA256: {digest.hex()}") # Verify against hashlib expected = hashlib.sha256(msg).digest() assert digest == expected AES-CBC Encryption via AF_ALG ----------------------------- .. code-block:: python import socket import os BS = 16 # Block size pad = lambda s: s + (BS - len(s) % BS) * bytes([BS - len(s) % BS]) unpad = lambda s: s[:-s[-1]] def aes_cbc_encrypt(plaintext, key, iv): with socket.socket(socket.AF_ALG, socket.SOCK_SEQPACKET, 0) as algo: algo.bind(('skcipher', 'cbc(aes)')) algo.setsockopt(socket.SOL_ALG, socket.ALG_SET_KEY, key) op, _ = algo.accept() with op: plaintext = pad(plaintext) op.sendmsg_afalg([plaintext], op=socket.ALG_OP_ENCRYPT, iv=iv) return op.recv(len(plaintext)) def aes_cbc_decrypt(ciphertext, key, iv): with socket.socket(socket.AF_ALG, socket.SOCK_SEQPACKET, 0) as algo: algo.bind(('skcipher', 'cbc(aes)')) algo.setsockopt(socket.SOL_ALG, socket.ALG_SET_KEY, key) op, _ = algo.accept() with op: op.sendmsg_afalg([ciphertext], op=socket.ALG_OP_DECRYPT, iv=iv) return unpad(op.recv(len(ciphertext))) # Example key = os.urandom(32) # AES-256 iv = os.urandom(16) plaintext = b"Secret message!" ciphertext = aes_cbc_encrypt(plaintext, key, iv) decrypted = aes_cbc_decrypt(ciphertext, key, iv) print(f"Ciphertext: {ciphertext.hex()}") print(f"Decrypted: {decrypted}") Useful Tools for Packet Analysis -------------------------------- While raw sockets are educational, consider these tools for production use: .. code-block:: python # Scapy - powerful packet manipulation library # pip install scapy from scapy.all import sniff, IP, TCP def packet_callback(pkt): if IP in pkt and TCP in pkt: print(f"{pkt[IP].src}:{pkt[TCP].sport} -> " f"{pkt[IP].dst}:{pkt[TCP].dport}") # Sniff 10 TCP packets sniff(filter="tcp", prn=packet_callback, count=10) # dpkt - fast packet parsing # pip install dpkt import dpkt with open('capture.pcap', 'rb') as f: pcap = dpkt.pcap.Reader(f) for ts, buf in pcap: eth = dpkt.ethernet.Ethernet(buf) if isinstance(eth.data, dpkt.ip.IP): ip = eth.data print(f"{dpkt.utils.inet_to_str(ip.src)} -> " f"{dpkt.utils.inet_to_str(ip.dst)}") ================================================ FILE: docs/notes/network/python-socket-ssl.rst ================================================ .. meta:: :description lang=en: Python TLS/SSL socket programming tutorial covering secure servers, certificate handling, cipher configuration, mutual TLS (mTLS), and non-blocking SSL for HTTPS and encrypted communication :keywords: Python, socket, SSL, TLS, secure socket, certificate, cipher, HTTPS, encryption, network security, mTLS, mutual TLS, X.509, OpenSSL, cryptography ================== SSL/TLS Sockets ================== .. contents:: Table of Contents :backlinks: none Transport Layer Security (TLS), formerly known as Secure Sockets Layer (SSL), is the standard protocol for encrypting network communication. TLS provides three essential security properties: confidentiality (data is encrypted and cannot be read by eavesdroppers), integrity (data cannot be modified in transit without detection), and authentication (parties can verify each other's identity using certificates). Every HTTPS connection, secure email, and VPN uses TLS under the hood. Python's ``ssl`` module provides a comprehensive interface for TLS, allowing you to wrap regular sockets with encryption. This section covers creating TLS-enabled servers and clients, configuring cipher suites for security compliance, handling X.509 certificates, implementing mutual TLS (mTLS) for client authentication, and building non-blocking TLS servers for high-performance applications. Whether you're building a secure API server, implementing certificate pinning, or debugging TLS handshake issues, these examples provide the foundation you need. Simple TLS Echo Server ---------------------- A basic TLS server wraps accepted TCP connections with an SSL context to provide encryption. The server requires a certificate (public key) and private key, which can be self-signed for testing or obtained from a Certificate Authority (CA) for production use. The ``SSLContext`` object manages all TLS settings including protocol version, cipher suites, and certificate verification options. .. code-block:: python import socket import ssl sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(('localhost', 5566)) sock.listen(10) sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslctx.load_cert_chain(certfile='cert.pem', keyfile='key.pem') try: while True: conn, addr = sock.accept() sslconn = sslctx.wrap_socket(conn, server_side=True) msg = sslconn.recv(1024) if msg: sslconn.send(msg) sslconn.close() finally: sock.close() Generate self-signed certificate and test: .. code-block:: bash # Generate private key and self-signed certificate $ openssl genrsa -out key.pem 2048 $ openssl req -x509 -new -nodes -key key.pem -days 365 -out cert.pem # Run server $ python3 ssl_server.py & # Test with openssl client $ openssl s_client -connect localhost:5566 Hello SSL Hello SSL TLS Server with Cipher Configuration ------------------------------------ Configure specific cipher suites for security compliance or compatibility. .. code-block:: python import socket import ssl import json sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(('localhost', 5566)) sock.listen(10) sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslctx.load_cert_chain(certfile='cert.pem', keyfile='key.pem') # Set specific ciphers (TLS 1.2) sslctx.set_ciphers('ECDHE+AESGCM:DHE+AESGCM') # Print configured ciphers print(json.dumps(sslctx.get_ciphers(), indent=2)) try: while True: conn, addr = sock.accept() sslconn = sslctx.wrap_socket(conn, server_side=True) print(f"Cipher: {sslconn.cipher()}") msg = sslconn.recv(1024) if msg: sslconn.send(msg) sslconn.close() finally: sock.close() TLS Client ---------- Connect to a TLS server with certificate verification. .. code-block:: python import socket import ssl hostname = 'www.google.com' port = 443 # Create default SSL context (verifies certificates) context = ssl.create_default_context() with socket.create_connection((hostname, port)) as sock: with context.wrap_socket(sock, server_hostname=hostname) as ssock: print(f"TLS version: {ssock.version()}") print(f"Cipher: {ssock.cipher()}") # Send HTTP request ssock.send(b"GET / HTTP/1.1\r\nHost: www.google.com\r\n\r\n") response = ssock.recv(4096) print(response[:200]) TLS Client with Custom CA ------------------------- Verify server certificate against a custom Certificate Authority. .. code-block:: python import socket import ssl hostname = 'localhost' port = 5566 context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.load_verify_locations('ca-cert.pem') # CA certificate context.check_hostname = True context.verify_mode = ssl.CERT_REQUIRED with socket.create_connection((hostname, port)) as sock: with context.wrap_socket(sock, server_hostname=hostname) as ssock: ssock.send(b"Hello") print(ssock.recv(1024)) Mutual TLS (mTLS) ----------------- Both client and server present certificates for mutual authentication. Server: .. code-block:: python import socket import ssl sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(('localhost', 5566)) sock.listen(10) context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) context.load_cert_chain('server-cert.pem', 'server-key.pem') context.load_verify_locations('ca-cert.pem') context.verify_mode = ssl.CERT_REQUIRED # Require client cert try: while True: conn, addr = sock.accept() sslconn = context.wrap_socket(conn, server_side=True) # Get client certificate info cert = sslconn.getpeercert() print(f"Client: {cert.get('subject')}") msg = sslconn.recv(1024) if msg: sslconn.send(msg) sslconn.close() finally: sock.close() Client: .. code-block:: python import socket import ssl context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.load_cert_chain('client-cert.pem', 'client-key.pem') context.load_verify_locations('ca-cert.pem') with socket.create_connection(('localhost', 5566)) as sock: with context.wrap_socket(sock, server_hostname='localhost') as ssock: ssock.send(b"Hello mTLS") print(ssock.recv(1024)) Non-blocking TLS with selectors ------------------------------- Handle TLS handshake and I/O asynchronously using the selectors module. .. code-block:: python import socket import selectors import ssl from functools import partial sslctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) sslctx.load_cert_chain(certfile="cert.pem", keyfile="key.pem") def accept(sock, sel): conn, addr = sock.accept() sslconn = sslctx.wrap_socket(conn, server_side=True, do_handshake_on_connect=False) sel.register(sslconn, selectors.EVENT_READ, do_handshake) def do_handshake(sslconn, sel): try: sslconn.do_handshake() sel.modify(sslconn, selectors.EVENT_READ, read) except ssl.SSLWantReadError: pass # Need more data, wait for next event except ssl.SSLWantWriteError: sel.modify(sslconn, selectors.EVENT_WRITE, do_handshake) def read(sslconn, sel): try: msg = sslconn.recv(1024) if msg: sel.modify(sslconn, selectors.EVENT_WRITE, partial(write, msg=msg)) else: sel.unregister(sslconn) sslconn.close() except ssl.SSLWantReadError: pass def write(sslconn, sel, msg=None): try: if msg: sslconn.send(msg) sel.modify(sslconn, selectors.EVENT_READ, read) except ssl.SSLWantWriteError: pass # Main server loop sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(('localhost', 5566)) sock.listen(10) sel = selectors.DefaultSelector() sel.register(sock, selectors.EVENT_READ, accept) try: while True: events = sel.select() for key, mask in events: handler = key.data handler(key.fileobj, sel) except KeyboardInterrupt: pass finally: sock.close() sel.close() Get Certificate Information --------------------------- Retrieve and inspect server certificate details. .. code-block:: python import socket import ssl import pprint hostname = 'www.google.com' port = 443 context = ssl.create_default_context() with socket.create_connection((hostname, port)) as sock: with context.wrap_socket(sock, server_hostname=hostname) as ssock: cert = ssock.getpeercert() pprint.pprint(cert) # Get specific fields print(f"Subject: {dict(x[0] for x in cert['subject'])}") print(f"Issuer: {dict(x[0] for x in cert['issuer'])}") print(f"Not Before: {cert['notBefore']}") print(f"Not After: {cert['notAfter']}") # Get certificate in DER format der_cert = ssock.getpeercert(binary_form=True) print(f"Certificate size: {len(der_cert)} bytes") TLS Version and Security Settings --------------------------------- Configure minimum TLS version and security options. .. code-block:: python import ssl context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) # Set minimum TLS version (TLS 1.2+) context.minimum_version = ssl.TLSVersion.TLSv1_2 # Disable older protocols explicitly context.options |= ssl.OP_NO_SSLv2 context.options |= ssl.OP_NO_SSLv3 context.options |= ssl.OP_NO_TLSv1 context.options |= ssl.OP_NO_TLSv1_1 # Disable compression (CRIME attack mitigation) context.options |= ssl.OP_NO_COMPRESSION # Use server's cipher preference context.options |= ssl.OP_CIPHER_SERVER_PREFERENCE # Load certificate context.load_cert_chain('cert.pem', 'key.pem') # Set strong ciphers only context.set_ciphers('ECDHE+AESGCM:DHE+AESGCM:!aNULL:!MD5:!DSS') print(f"Min version: {context.minimum_version}") print(f"Ciphers: {len(context.get_ciphers())}") ================================================ FILE: docs/notes/network/python-socket.rst ================================================ .. meta:: :description lang=en: Python socket programming tutorial with examples for DNS resolution, TCP/UDP clients, IP address conversion, network byte order, timeouts, multicast, and SOCKS proxy :keywords: Python, socket, networking, TCP, UDP, DNS, IP address, hostname, getaddrinfo, inet_aton, inet_ntoa, network programming, timeout, multicast, SOCKS proxy, HTTP client ============== Socket Basics ============== .. contents:: Table of Contents :backlinks: none Socket programming is the foundation of network communication in Python and virtually all networked applications. A socket is an endpoint for sending and receiving data across a network, providing a bidirectional communication channel between processes on the same machine or across different machines over the Internet. While Python provides high-level networking interfaces like ``urllib``, ``requests``, and ``asyncio``, understanding low-level socket operations is essential for building custom protocols, debugging network issues, implementing network tools, and interfacing with system-level networking APIs. This cheat sheet covers the fundamentals of socket programming in Python, including hostname and DNS resolution, IP address manipulation, network byte order conversion, timeout handling, multicast communication, and proxy support. Whether you're building a simple client-server application, implementing a custom protocol, or troubleshooting network connectivity issues, these examples provide the building blocks you need. Get Hostname ------------ The ``socket.gethostname()`` function returns the current machine's hostname as configured in the operating system, while ``socket.gethostbyname()`` performs a DNS lookup to resolve a hostname to its IPv4 address. These functions are the most basic building blocks for network programming, allowing your application to identify itself and resolve other hosts on the network. Note that ``gethostbyname()`` only returns IPv4 addresses; use ``getaddrinfo()`` for IPv6 support. .. code-block:: python >>> import socket >>> socket.gethostname() 'MacBookPro-4380.local' >>> hostname = socket.gethostname() >>> socket.gethostbyname(hostname) '172.20.10.4' >>> socket.gethostbyname('localhost') '127.0.0.1' Get Address Info (DNS Resolution) --------------------------------- ``socket.getaddrinfo()`` is the most versatile and recommended function for DNS resolution in modern Python code. Unlike ``gethostbyname()``, it supports both IPv4 and IPv6 addresses, returns multiple results when available, and provides complete information including address family, socket type, protocol, canonical name, and socket address. This function is essential for writing protocol-agnostic code that works seamlessly with both IPv4 and IPv6 networks. .. code-block:: python import socket import sys try: for res in socket.getaddrinfo(sys.argv[1], None, proto=socket.IPPROTO_TCP): family = res[0] sockaddr = res[4] print(family, sockaddr) except socket.gaierror: print("Invalid") Output: .. code-block:: console $ python gai.py 192.0.2.244 AddressFamily.AF_INET ('192.0.2.244', 0) $ python gai.py 2001:db8:f00d::1:d AddressFamily.AF_INET6 ('2001:db8:f00d::1:d', 0, 0, 0) $ python gai.py www.google.com AddressFamily.AF_INET6 ('2607:f8b0:4006:818::2004', 0, 0, 0) AddressFamily.AF_INET ('172.217.10.132', 0) It handles unusual cases, valid and invalid: .. code-block:: console $ python gai.py 10.0.0.256 # octet overflow Invalid $ python gai.py not-exist.example.com # unresolvable Invalid $ python gai.py fe80::1%eth0 # scoped AddressFamily.AF_INET6 ('fe80::1%eth0', 0, 0, 2) $ python gai.py ::ffff:192.0.2.128 # IPv4-Mapped AddressFamily.AF_INET6 ('::ffff:192.0.2.128', 0, 0, 0) $ python gai.py 0xc000027b # IPv4 in hex AddressFamily.AF_INET ('192.0.2.123', 0) Advanced DNS Queries -------------------- While ``socket.getaddrinfo()`` handles basic hostname resolution, many applications require more advanced DNS operations like querying specific record types. MX records identify mail servers for a domain, TXT records store SPF and DKIM data for email authentication, NS records list authoritative name servers, and SRV records enable service discovery. The ``dnspython`` library provides a comprehensive DNS toolkit that supports all record types, custom nameservers, DNSSEC validation, and zone transfers. This is essential for building email validation systems, service discovery mechanisms, and DNS monitoring tools. .. code-block:: python # pip install dnspython import dns.resolver # Query MX records (mail servers) answers = dns.resolver.resolve('google.com', 'MX') for rdata in answers: print(f"MX: {rdata.exchange} (priority: {rdata.preference})") # Query TXT records (SPF, DKIM, etc.) answers = dns.resolver.resolve('google.com', 'TXT') for rdata in answers: print(f"TXT: {rdata}") # Query NS records (name servers) answers = dns.resolver.resolve('google.com', 'NS') for rdata in answers: print(f"NS: {rdata}") # Query A records with custom nameserver resolver = dns.resolver.Resolver() resolver.nameservers = ['8.8.8.8'] # Use Google DNS answers = resolver.resolve('example.com', 'A') for rdata in answers: print(f"A: {rdata}") Reverse DNS Lookup ------------------ Reverse DNS (rDNS) lookup converts an IP address back to its associated hostname by querying PTR records in the in-addr.arpa (IPv4) or ip6.arpa (IPv6) domains. This is commonly used for logging to make IP addresses human-readable, security analysis to verify that a connecting client's IP matches its claimed hostname, spam filtering to check if mail servers have valid reverse DNS, and network troubleshooting to identify devices on a network. .. code-block:: python >>> import socket >>> # Reverse lookup returns (hostname, aliases, addresses) >>> socket.gethostbyaddr('8.8.8.8') ('dns.google', [], ['8.8.8.8']) >>> socket.gethostbyaddr('140.82.112.4') ('github.com', [], ['140.82.112.4']) >>> # Using getfqdn for fully qualified domain name >>> socket.getfqdn('8.8.8.8') 'dns.google' Network Byte Order Conversion ----------------------------- Network protocols universally use big-endian byte order (most significant byte first), also called "network byte order," while most modern CPUs (x86, ARM in little-endian mode) use little-endian (least significant byte first). When sending multi-byte integers over the network, you must convert from host byte order to network byte order, and vice versa when receiving. The ``htons``/``htonl`` functions convert host to network order for short (16-bit) and long (32-bit) integers, while ``ntohs``/``ntohl`` convert network to host order. Failing to perform these conversions causes subtle bugs where values appear corrupted on machines with different endianness. .. code-block:: python # little-endian machine >>> import socket >>> a = 1 # host endian >>> socket.htons(a) # host to network short (16-bit) 256 >>> socket.htonl(a) # host to network long (32-bit) 16777216 >>> socket.ntohs(256) # network to host short 1 >>> socket.ntohl(16777216) # network to host long 1 IP Address String/Binary Conversion ----------------------------------- IP addresses are typically displayed as human-readable strings (dotted-quad for IPv4 like "192.168.1.1", or colon-hex for IPv6 like "2001:db8::1"), but network protocols transmit them as binary data (4 bytes for IPv4, 16 bytes for IPv6). The ``inet_aton`` and ``inet_ntoa`` functions convert between string and binary formats for IPv4 only. For code that needs to support both IPv4 and IPv6, use ``inet_pton`` (presentation to network) and ``inet_ntop`` (network to presentation), which take an address family parameter to specify the IP version. .. code-block:: python >>> import socket >>> # IPv4: string to binary >>> addr = socket.inet_aton('127.0.0.1') >>> addr b'\x7f\x00\x00\x01' >>> # IPv4: binary to string >>> socket.inet_ntoa(addr) '127.0.0.1' >>> # IPv4/IPv6: use inet_pton/inet_ntop >>> socket.inet_pton(socket.AF_INET, '192.168.1.1') b'\xc0\xa8\x01\x01' >>> socket.inet_pton(socket.AF_INET6, '::1') b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01' >>> socket.inet_ntop(socket.AF_INET6, b'\x00' * 15 + b'\x01') '::1' MAC Address Conversion ---------------------- MAC (Media Access Control) addresses are 48-bit hardware identifiers assigned to network interface cards, typically displayed as six colon-separated hexadecimal pairs like "00:11:22:33:44:55". When working with raw Ethernet frames or ARP packets, you need to convert between this human-readable format and the 6-byte binary format used in network protocols. The ``binascii`` module provides ``hexlify`` and ``unhexlify`` functions for this conversion. .. code-block:: python >>> import binascii >>> mac = '00:11:32:3c:c3:0b' >>> byte = binascii.unhexlify(mac.replace(':', '')) >>> byte b'\x00\x112<\xc3\x0b' >>> binascii.hexlify(byte) b'0011323cc30b' >>> # Format back to colon-separated >>> ':'.join(f'{b:02x}' for b in byte) '00:11:32:3c:c3:0b' Check Port Availability ----------------------- Before starting a server, you often need to verify that the desired port is available for binding. Similarly, network monitoring tools need to check if remote services are reachable. The ``is_port_open`` function attempts a TCP connection to test remote service availability, while ``is_port_available`` tries to bind locally to check if a port is free. These checks are essential for service health monitoring, port scanning, and avoiding "Address already in use" errors when starting servers. .. code-block:: python import socket def is_port_open(host, port, timeout=3): """Check if a port is open on a remote host.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(timeout) try: sock.connect((host, port)) return True except (socket.timeout, ConnectionRefusedError, OSError): return False finally: sock.close() def is_port_available(port): """Check if a local port is available for binding.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: sock.bind(('', port)) return True except OSError: return False finally: sock.close() # Usage print(is_port_open('google.com', 443)) # True print(is_port_available(8080)) # True if not in use Get Network Interfaces ---------------------- Multi-homed servers (machines with multiple network interfaces) need to discover their available IP addresses to bind to specific interfaces or advertise their addresses to clients. The basic approach uses ``getaddrinfo`` on the hostname, but for detailed interface information including interface names, netmasks, and broadcast addresses, the ``netifaces`` library provides a cross-platform solution. This is useful for network configuration tools, service discovery, and building applications that need to select specific network interfaces. .. code-block:: python import socket def get_local_ips(): """Get all local IP addresses.""" ips = [] hostname = socket.gethostname() try: # Get all addresses for hostname for info in socket.getaddrinfo(hostname, None): ip = info[4][0] if ip not in ips: ips.append(ip) except socket.gaierror: pass return ips # For more detailed interface info, use netifaces # pip install netifaces import netifaces for iface in netifaces.interfaces(): addrs = netifaces.ifaddresses(iface) if netifaces.AF_INET in addrs: for addr in addrs[netifaces.AF_INET]: print(f"{iface}: {addr['addr']}") Socket Options -------------- Socket options control low-level socket behavior and are essential for building robust network applications. ``SO_REUSEADDR`` allows immediate restart of servers without waiting for TIME_WAIT to expire. ``SO_REUSEPORT`` enables multiple processes to bind to the same port for load balancing. Buffer size options (``SO_SNDBUF``, ``SO_RCVBUF``) tune throughput for high-bandwidth applications. ``SO_KEEPALIVE`` detects dead connections by sending periodic probes. Understanding these options helps you optimize performance and handle edge cases in production systems. .. code-block:: python import socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # Reuse address (avoid "Address already in use" error) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Reuse port (multiple processes can bind to same port) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) # Set send/receive buffer sizes sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 65536) sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 65536) # Enable TCP keepalive sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) # Set timeout (seconds) sock.settimeout(10.0) # Non-blocking mode sock.setblocking(False) # Get current option value print(sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF)) Troubleshooting: Connection Refused ----------------------------------- "Connection refused" is one of the most common network errors, but its cause isn't always obvious. It can mean the target port has no listening service, a firewall is actively rejecting connections, or the service crashed. Other errors like "Connection timed out" suggest the host is unreachable or a firewall is silently dropping packets, while "Network unreachable" indicates routing problems. This diagnostic function categorizes different error types to help identify the root cause, which is essential for debugging network connectivity issues in development and production environments. .. code-block:: python import socket import errno def diagnose_connection(host, port): """Diagnose connection issues.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(5) try: sock.connect((host, port)) print(f"✓ Connected to {host}:{port}") except socket.timeout: print(f"✗ Timeout - host may be unreachable or firewalled") except ConnectionRefusedError: print(f"✗ Connection refused - no service on {port}") except socket.gaierror as e: print(f"✗ DNS error - cannot resolve {host}: {e}") except OSError as e: if e.errno == errno.ENETUNREACH: print(f"✗ Network unreachable") elif e.errno == errno.EHOSTUNREACH: print(f"✗ Host unreachable") else: print(f"✗ OS error: {e}") finally: sock.close() diagnose_connection('localhost', 8080) Timeout Handling ---------------- Network operations can block indefinitely if a remote host becomes unresponsive, a network path fails, or packets are lost. Without proper timeout handling, your application may hang forever waiting for data that will never arrive. Python sockets support timeouts at multiple levels: a global timeout via ``settimeout()`` that applies to all operations, or per-operation timeouts using ``select()`` for more precise control. Always set appropriate timeouts based on your application's requirements—too short causes false failures, too long delays error detection. .. code-block:: python import socket import errno def connect_with_timeout(host, port, timeout=5): """Connect with timeout and proper error handling.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(timeout) try: sock.connect((host, port)) return sock except socket.timeout: print(f"Connection to {host}:{port} timed out") sock.close() return None except OSError as e: print(f"Connection failed: {e}") sock.close() return None def recv_with_timeout(sock, bufsize=4096, timeout=10): """Receive data with timeout.""" sock.settimeout(timeout) try: return sock.recv(bufsize) except socket.timeout: return None # Timeout, no data # Per-operation timeout using select import select def recv_timeout(sock, bufsize, timeout): """Receive with timeout using select (more precise).""" ready, _, _ = select.select([sock], [], [], timeout) if ready: return sock.recv(bufsize) raise socket.timeout("recv timed out") Graceful Shutdown ----------------- Simply calling ``close()`` on a socket may lose data still in transit. The TCP protocol requires a proper four-way handshake (FIN-ACK sequence) to ensure both sides have finished sending. The ``shutdown()`` method provides fine-grained control: ``SHUT_WR`` sends a FIN packet signaling "I'm done sending" while still allowing reads, ``SHUT_RD`` stops receiving, and ``SHUT_RDWR`` does both. For clean termination, call ``shutdown(SHUT_WR)`` first, drain any remaining incoming data, then ``close()``. This pattern is especially important for protocols where the server waits for client EOF before sending its final response. .. code-block:: python import socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('example.com', 80)) sock.send(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') # Shutdown write side - signals EOF to server sock.shutdown(socket.SHUT_WR) # Read remaining data response = b'' while True: data = sock.recv(4096) if not data: break response += data # Now close the socket sock.close() # shutdown options: # SHUT_RD - no more reads # SHUT_WR - no more writes (sends FIN) # SHUT_RDWR - no more reads or writes Multicast UDP ------------- Multicast is a one-to-many communication model where a single packet is delivered to multiple receivers simultaneously. Unlike broadcast (which floods the entire network) or unicast (one sender, one receiver), multicast uses special IP addresses (224.0.0.0 to 239.255.255.255) and IGMP protocol to efficiently route packets only to interested receivers. Receivers must explicitly join a multicast group to receive traffic. The TTL (Time To Live) controls how far packets travel—TTL=1 stays on the local subnet, higher values cross routers. Multicast is ideal for streaming media, real-time data feeds, and service discovery where the same data goes to many clients. .. code-block:: python import socket import struct MCAST_GROUP = '224.1.1.1' MCAST_PORT = 5007 # Sender def multicast_sender(): sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 2) sock.sendto(b'Hello Multicast!', (MCAST_GROUP, MCAST_PORT)) sock.close() # Receiver def multicast_receiver(): sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(('', MCAST_PORT)) # Join multicast group mreq = struct.pack('4sl', socket.inet_aton(MCAST_GROUP), socket.INADDR_ANY) sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq) data, addr = sock.recvfrom(1024) print(f"Received: {data} from {addr}") sock.close() HTTP Client with Sockets ------------------------ While high-level libraries like ``urllib`` and ``requests`` handle HTTP elegantly, understanding raw HTTP over sockets is invaluable for debugging, implementing custom protocols, or working in constrained environments. HTTP/1.1 is a text-based protocol: you send a request line (``GET /path HTTP/1.1``), headers (key-value pairs), a blank line, and optionally a body. The server responds similarly. Key headers include ``Host`` (required in HTTP/1.1), ``Connection: close`` (to signal single request), and ``Content-Length`` for bodies. This low-level approach reveals exactly what's happening on the wire, making it easier to diagnose issues like malformed headers, encoding problems, or TLS handshake failures. .. code-block:: python import socket import ssl def http_get(host, path='/', port=80, use_ssl=False): """Simple HTTP GET using raw sockets.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) if use_ssl: context = ssl.create_default_context() sock = context.wrap_socket(sock, server_hostname=host) port = 443 sock.connect((host, port)) request = f"GET {path} HTTP/1.1\r\nHost: {host}\r\nConnection: close\r\n\r\n" sock.send(request.encode()) response = b'' while True: data = sock.recv(4096) if not data: break response += data sock.close() # Split headers and body header_end = response.find(b'\r\n\r\n') headers = response[:header_end].decode() body = response[header_end + 4:] return headers, body # Usage headers, body = http_get('example.com', '/', use_ssl=True) print(headers) SOCKS Proxy Support ------------------- SOCKS (Socket Secure) is a protocol that routes network traffic through a proxy server, providing anonymity and the ability to bypass firewalls or geographic restrictions. Unlike HTTP proxies that only handle HTTP traffic, SOCKS operates at a lower level and can proxy any TCP (and with SOCKS5, UDP) traffic. SOCKS5 adds authentication and IPv6 support. Common use cases include routing traffic through Tor (which uses SOCKS5 on port 9050), accessing internal networks via SSH tunnels (``ssh -D``), or corporate proxy requirements. The ``PySocks`` library makes it easy to route Python socket connections through SOCKS proxies, either globally (patching all sockets) or per-connection. .. code-block:: python # pip install PySocks import socks import socket # Method 1: Patch all sockets globally socks.set_default_proxy(socks.SOCKS5, "localhost", 9050) socket.socket = socks.socksocket # Now all socket connections go through the proxy s = socket.socket() s.connect(("example.com", 80)) # Method 2: Create proxy socket directly s = socks.socksocket() s.set_proxy(socks.SOCKS5, "localhost", 9050) s.connect(("example.com", 80)) # Method 3: With authentication s = socks.socksocket() s.set_proxy(socks.SOCKS5, "proxy.example.com", 1080, username="user", password="pass") ================================================ FILE: docs/notes/network/python-ssh.rst ================================================ .. meta:: :description lang=en: Comprehensive SSH cheat sheet for Python developers covering Paramiko library, SSH tunneling (local, reverse, dynamic), port forwarding, SFTP file transfers, jump hosts, and key management with practical examples :keywords: Python, Python3, SSH, Paramiko, SFTP, SSH Tunnel, Port Forwarding, Reverse Tunnel, Jump Host, ProxyJump, Bastion Host, SOCKS Proxy, SSH Agent, Key Authentication, Remote Execution ====================== SSH and Secure Tunnels ====================== .. contents:: Table of Contents :backlinks: none SSH (Secure Shell) is the standard protocol for secure remote access, providing encrypted communication between machines for command execution, file transfer, and network tunneling. Originally developed as a secure replacement for telnet and rsh, SSH has become essential infrastructure for system administration, deployment automation, and secure network access. Python's ``paramiko`` library provides a complete implementation of SSHv2 protocol, enabling programmatic SSH connections, SFTP file transfers, and sophisticated port forwarding scenarios. This cheat sheet covers the full spectrum of SSH operations—from basic password and key authentication to advanced tunneling techniques like reverse tunnels for NAT traversal, jump hosts for accessing isolated networks, and dynamic SOCKS proxies for routing arbitrary traffic through secure channels. Basic SSH Connection -------------------- The foundation of SSH is establishing a secure, authenticated connection to a remote host. The ``SSHClient`` class in Paramiko manages the entire connection lifecycle including TCP connection, cryptographic handshake, host key verification, user authentication, and channel multiplexing. Once connected, you can execute commands, open interactive shells, or establish SFTP sessions. The context manager pattern (``with`` statement) ensures connections are properly closed even if exceptions occur, preventing resource leaks in long-running applications. .. code-block:: python from paramiko.client import SSHClient # Basic password authentication with SSHClient() as ssh: ssh.load_system_host_keys() ssh.connect("example.com", username="user", password="secret") stdin, stdout, stderr = ssh.exec_command("uname -a") print(stdout.read().decode()) .. code-block:: python # Connect on non-standard port with SSHClient() as ssh: ssh.load_system_host_keys() ssh.connect("example.com", port=2222, username="user", password="secret") stdin, stdout, stderr = ssh.exec_command("hostname") print(stdout.read().decode()) Host Key Verification --------------------- SSH's security model relies on verifying the server's identity before sending credentials. Each SSH server has a unique host key pair, and clients store known host public keys in ``~/.ssh/known_hosts``. On first connection, SSH warns about unknown hosts—this is the "fingerprint" prompt you see. Blindly accepting unknown keys defeats this protection and enables man-in-the-middle attacks where an attacker intercepts your connection. For automation, ``AutoAddPolicy`` is convenient but should only be used in trusted networks or with additional verification. In production, pre-populate known_hosts or use certificate-based host authentication. .. code-block:: python import paramiko from paramiko.client import SSHClient # Auto-add unknown host keys (use cautiously) # Equivalent to: ssh -o StrictHostKeyChecking=no with SSHClient() as ssh: ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.connect("example.com", username="user", password="secret") stdin, stdout, stderr = ssh.exec_command("whoami") print(stdout.read().decode()) # Reject unknown hosts (default, most secure) with SSHClient() as ssh: ssh.set_missing_host_key_policy(paramiko.RejectPolicy()) ssh.load_system_host_keys() # Load ~/.ssh/known_hosts ssh.connect("example.com", username="user", password="secret") Key-Based Authentication ------------------------ SSH key pairs provide significantly stronger security than passwords while enabling passwordless automation. A key pair consists of a private key (kept secret on your machine) and a public key (copied to servers you want to access). Authentication works by proving you possess the private key without transmitting it. Modern best practice recommends Ed25519 keys for their security and performance, though RSA (4096-bit) remains widely compatible. Protect private keys with a passphrase for defense-in-depth—if the key file is stolen, the passphrase provides an additional barrier. Use ``ssh-agent`` to cache decrypted keys in memory, avoiding repeated passphrase entry. .. code-block:: python from paramiko.client import SSHClient # Using private key file # ssh-keygen -t ed25519 -f mykey # ssh-copy-id -i mykey.pub user@example.com with SSHClient() as ssh: ssh.load_system_host_keys() ssh.connect("example.com", username="user", key_filename="mykey") stdin, stdout, stderr = ssh.exec_command("id") print(stdout.read().decode()) .. code-block:: python # Key with passphrase with SSHClient() as ssh: ssh.load_system_host_keys() ssh.connect( "example.com", username="user", key_filename="mykey", passphrase="my-key-passphrase" ) .. code-block:: python # Using RSAKey object directly from paramiko import RSAKey pkey = RSAKey.from_private_key_file("mykey", password="passphrase") with SSHClient() as ssh: ssh.load_system_host_keys() ssh.connect("example.com", username="user", pkey=pkey) SFTP File Transfer ------------------ SFTP (SSH File Transfer Protocol) runs over an SSH connection, providing secure, encrypted file operations without requiring a separate service or port. Unlike FTP which sends credentials in plaintext and requires complex firewall rules for passive mode, SFTP tunnels everything through the existing SSH connection on port 22. Paramiko's SFTP client supports the full range of file operations: uploading, downloading, directory listing, file metadata, permissions, and remote file manipulation. For large transfers, SFTP handles resume and provides progress callbacks. It's the standard choice for automated file transfers in deployment scripts, backup systems, and data pipelines where security is required. .. code-block:: python from paramiko.client import SSHClient with SSHClient() as ssh: ssh.load_system_host_keys() ssh.connect("example.com", username="user", key_filename="mykey") sftp = ssh.open_sftp() # Upload file sftp.put("local_file.txt", "/remote/path/file.txt") # Download file sftp.get("/remote/path/file.txt", "downloaded.txt") # List directory for entry in sftp.listdir("/remote/path"): print(entry) # File operations sftp.mkdir("/remote/newdir") sftp.rename("/remote/old.txt", "/remote/new.txt") sftp.remove("/remote/unwanted.txt") # Get file stats stat = sftp.stat("/remote/file.txt") print(f"Size: {stat.st_size}, Modified: {stat.st_mtime}") sftp.close() SSH Tunneling Overview ---------------------- SSH tunneling (port forwarding) is one of SSH's most powerful features, allowing you to securely route network traffic through an encrypted SSH connection. This enables accessing services behind firewalls, encrypting otherwise insecure protocols, and bypassing network restrictions. There are three types: local forwarding brings a remote service to your machine, remote (reverse) forwarding exposes your local service to the remote network, and dynamic forwarding creates a SOCKS proxy for routing arbitrary traffic. Understanding these patterns is essential for secure access to databases, internal web applications, and services in private networks. The diagrams below illustrate the traffic flow for each type. :: ┌─────────────────────────────────────────────────────────────────┐ │ SSH TUNNEL TYPES │ ├─────────────────────────────────────────────────────────────────┤ │ │ │ LOCAL FORWARDING (-L) │ │ Access remote service through local port │ │ │ │ [You] ──► localhost:8080 ══SSH══► [Server] ──► db:5432 │ │ │ │ ssh -L 8080:database.internal:5432 user@server │ │ Then connect to localhost:8080 to reach database │ │ │ ├─────────────────────────────────────────────────────────────────┤ │ │ │ REMOTE/REVERSE FORWARDING (-R) │ │ Expose local service to remote server │ │ │ │ [You:3000] ◄── [Server]:9000 ◄══SSH══◄ [You initiate] │ │ │ │ ssh -R 9000:localhost:3000 user@server │ │ Server's port 9000 forwards to your localhost:3000 │ │ │ ├─────────────────────────────────────────────────────────────────┤ │ │ │ DYNAMIC FORWARDING (-D) - SOCKS Proxy │ │ Route any traffic through SSH server │ │ │ │ [You] ──► localhost:1080 ══SSH══► [Server] ──► anywhere │ │ │ │ ssh -D 1080 user@server │ │ Configure browser/app to use SOCKS5 proxy localhost:1080 │ │ │ └─────────────────────────────────────────────────────────────────┘ Local Port Forwarding --------------------- Local forwarding (``-L``) is the most common tunnel type, binding a port on your local machine that forwards traffic through the SSH server to a destination host. This is invaluable for accessing services in private networks—databases, internal web applications, admin interfaces—that aren't exposed to the internet. The SSH server acts as a relay: your local application connects to ``localhost:port``, SSH encrypts and forwards the traffic to the server, which then connects to the final destination. The destination doesn't need to be the SSH server itself; it can be any host reachable from the server, making this perfect for bastion/jump host scenarios where you SSH to a gateway machine to reach internal resources. :: Scenario: Access internal database through bastion host ┌──────────┐ ┌──────────────┐ ┌────────────────┐ │ Your │ SSH │ Bastion │ │ Database │ │ Machine │─────►│ Server │─────►│ (internal) │ │ │ │ │ │ db:5432 │ └──────────┘ └──────────────┘ └────────────────┘ │ │ Connect to localhost:5432 │ Traffic tunneled to db:5432 ▼ Command: ssh -L 5432:db.internal:5432 user@bastion Then: psql -h localhost -p 5432 mydb .. code-block:: python # Local port forwarding with Paramiko import paramiko from paramiko.client import SSHClient import socket import select import threading def forward_tunnel(local_port, remote_host, remote_port, ssh_client): """Forward local_port to remote_host:remote_port via SSH.""" transport = ssh_client.get_transport() # Create local listening socket server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) server.bind(('127.0.0.1', local_port)) server.listen(5) print(f"Forwarding localhost:{local_port} -> {remote_host}:{remote_port}") while True: client, addr = server.accept() # Open channel to remote destination channel = transport.open_channel( 'direct-tcpip', (remote_host, remote_port), client.getpeername() ) if channel is None: client.close() continue # Bidirectional forwarding in thread threading.Thread( target=_forward_data, args=(client, channel), daemon=True ).start() def _forward_data(sock, channel): """Forward data between socket and SSH channel.""" while True: r, w, x = select.select([sock, channel], [], []) if sock in r: data = sock.recv(4096) if not data: break channel.send(data) if channel in r: data = channel.recv(4096) if not data: break sock.send(data) sock.close() channel.close() # Usage with SSHClient() as ssh: ssh.load_system_host_keys() ssh.connect("bastion.example.com", username="user", key_filename="mykey") # Forward localhost:5432 to internal-db:5432 forward_tunnel(5432, "internal-db.local", 5432, ssh) Reverse Port Forwarding ----------------------- Reverse forwarding (``-R``) solves the opposite problem: exposing a service on your local machine to the remote server's network. This is essential when you're behind NAT, a corporate firewall, or any network that blocks incoming connections. You initiate an outbound SSH connection (which firewalls typically allow), and the SSH server opens a listening port that tunnels back to your machine. Common use cases include sharing a local development server with remote colleagues, providing temporary access to a local service for debugging, or creating a "callback" channel when direct inbound connections are impossible. Note that by default, the server only binds to ``127.0.0.1``; to allow external access, the server's ``sshd_config`` needs ``GatewayPorts yes``. :: Scenario: Expose local dev server to public server ┌──────────────┐ ┌──────────────┐ │ Your Machine│ SSH Connection │ Public Server│ │ (behind NAT)│═══════════════════►│ │ │ │ You initiate │ │ │ localhost │ │ 0.0.0.0 │ │ :3000 │◄───────────────────│ :9000 │ │ (your app) │ Tunnel back │ (exposed) │ └──────────────┘ └──────────────┘ Command: ssh -R 9000:localhost:3000 user@public-server Result: Anyone connecting to public-server:9000 reaches your localhost:3000 Note: Server needs "GatewayPorts yes" in sshd_config to allow binding to 0.0.0.0 (not just 127.0.0.1) .. code-block:: python # Reverse tunnel with Paramiko import paramiko from paramiko.client import SSHClient import socket import select import threading def reverse_tunnel(server_port, local_host, local_port, ssh_client): """Expose local_host:local_port on SSH server's server_port.""" transport = ssh_client.get_transport() # Request remote port forwarding transport.request_port_forward('', server_port) print(f"Reverse tunnel: server:{server_port} -> {local_host}:{local_port}") while True: # Accept forwarded connection from server channel = transport.accept(1000) if channel is None: continue # Connect to local service sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: sock.connect((local_host, local_port)) except Exception as e: print(f"Local connection failed: {e}") channel.close() continue # Bidirectional forwarding threading.Thread( target=_forward_data, args=(sock, channel), daemon=True ).start() # Usage: Expose local web server on remote port 9000 with SSHClient() as ssh: ssh.load_system_host_keys() ssh.connect("public-server.com", username="user", key_filename="mykey") reverse_tunnel(9000, "localhost", 3000, ssh) Dynamic Port Forwarding (SOCKS Proxy) ------------------------------------- Dynamic forwarding (``-D``) creates a local SOCKS proxy server that routes traffic through the SSH connection. Unlike local forwarding where you specify a fixed destination, dynamic forwarding lets applications connect to any host reachable from the SSH server—the destination is determined per-connection by the SOCKS protocol. This is incredibly versatile: configure your browser to use the SOCKS proxy and all web traffic flows through the SSH server, effectively browsing from that server's network location. Use cases include accessing geo-restricted content, browsing internal websites from outside the office, or encrypting traffic on untrusted networks (coffee shop WiFi). SOCKS5 supports both TCP and UDP, plus authentication, making it more capable than HTTP proxies. :: ┌──────────┐ ┌──────────────┐ ┌─────────────┐ │ Your │ SSH │ SSH │ │ Any │ │ Machine │═════►│ Server │─────►│ Destination│ │ │ │ │ │ │ └──────────┘ └──────────────┘ └─────────────┘ │ │ SOCKS5 proxy on localhost:1080 │ Browser/apps route traffic through it ▼ Command: ssh -D 1080 user@server Config: Set browser proxy to SOCKS5 localhost:1080 All browsing now goes through SSH server .. code-block:: bash # Command line usage ssh -D 1080 -N -f user@server # -D 1080: Dynamic forwarding on port 1080 # -N: No remote command (just forwarding) # -f: Background after authentication # Use with curl curl --socks5 localhost:1080 http://internal-site.local Jump Hosts (ProxyJump) ---------------------- Jump hosts (also called bastion hosts or gateway servers) are hardened machines that provide the only entry point into a private network. Instead of exposing internal servers directly to the internet, organizations route all SSH access through a jump host that can be heavily monitored and secured. SSH's ``ProxyJump`` (``-J``) option makes this seamless—you specify the jump host, and SSH automatically chains the connections, authenticating to each hop. The connection to the final destination is end-to-end encrypted; the jump host only sees encrypted traffic passing through. You can chain multiple jump hosts for deeply segmented networks. This pattern is fundamental to secure infrastructure access in cloud environments where production servers should never have public IP addresses. :: ┌──────────┐ ┌──────────────┐ ┌──────────────┐ │ Your │ SSH │ Bastion │ SSH │ Internal │ │ Machine │═════►│ (jump) │═════►│ Server │ │ │ │ │ │ │ └──────────┘ └──────────────┘ └──────────────┘ Command: ssh -J user@bastion user@internal-server Or in ~/.ssh/config: Host internal-server HostName 10.0.0.50 User admin ProxyJump user@bastion.example.com .. code-block:: python # Jump host with Paramiko from paramiko.client import SSHClient # Connect to bastion first bastion = SSHClient() bastion.load_system_host_keys() bastion.connect("bastion.example.com", username="user", key_filename="mykey") # Get transport and open channel to internal host bastion_transport = bastion.get_transport() dest_addr = ("internal-server.local", 22) local_addr = ("127.0.0.1", 0) channel = bastion_transport.open_channel("direct-tcpip", dest_addr, local_addr) # Connect to internal server through the channel internal = SSHClient() internal.load_system_host_keys() internal.connect( "internal-server.local", username="admin", key_filename="mykey", sock=channel # Use bastion channel as socket ) # Execute command on internal server stdin, stdout, stderr = internal.exec_command("hostname") print(stdout.read().decode()) internal.close() bastion.close() SSH Config File --------------- The SSH config file (``~/.ssh/config``) eliminates repetitive command-line options by defining per-host settings. Instead of typing ``ssh -i ~/.ssh/mykey -p 2222 user@long.hostname.example.com``, you define a host alias and simply type ``ssh myserver``. The config file supports wildcards, allowing you to set defaults for groups of hosts (all ``*.internal`` hosts use a specific jump server). You can also define automatic port forwarding, so connecting to a host automatically sets up your database tunnels. For teams, a shared config file ensures everyone uses consistent, secure settings. The config is processed top-to-bottom with first match winning, so put specific hosts before wildcards. .. code-block:: text # ~/.ssh/config # Default settings for all hosts Host * ServerAliveInterval 60 ServerAliveCountMax 3 AddKeysToAgent yes # Simple host alias Host myserver HostName server.example.com User admin Port 2222 IdentityFile ~/.ssh/mykey # Jump host configuration Host bastion HostName bastion.example.com User jumpuser IdentityFile ~/.ssh/bastion_key Host internal-* ProxyJump bastion User admin IdentityFile ~/.ssh/internal_key Host internal-db HostName 10.0.0.50 Host internal-web HostName 10.0.0.51 # Local port forwarding on connect Host db-tunnel HostName bastion.example.com User admin LocalForward 5432 db.internal:5432 LocalForward 6379 redis.internal:6379 SSH Agent Forwarding -------------------- SSH agent forwarding lets you use your local private keys on remote servers without copying the keys there. When you SSH to a server with agent forwarding enabled (``-A``), the remote server can request signatures from your local ssh-agent for subsequent SSH connections. This is essential for workflows like cloning private git repositories from a server or hopping through multiple machines. However, agent forwarding has security implications: anyone with root access on the remote server can use your forwarded agent to authenticate as you to other systems while you're connected. For sensitive environments, consider ``ProxyJump`` instead, which keeps your keys local, or use per-host deploy keys. .. code-block:: bash # Enable agent forwarding ssh -A user@server # On server, your local keys are available git clone git@github.com:user/repo.git # Uses forwarded key .. code-block:: python # Agent forwarding with Paramiko from paramiko.client import SSHClient from paramiko.agent import Agent # Get keys from local SSH agent agent = Agent() agent_keys = agent.get_keys() with SSHClient() as ssh: ssh.load_system_host_keys() # Connect using agent key ssh.connect("server.example.com", username="user", pkey=agent_keys[0]) # Enable agent forwarding for this session transport = ssh.get_transport() paramiko.agent.AgentRequestHandler(transport.open_session()) Keepalive and Connection Stability ---------------------------------- SSH connections can silently die due to network issues, NAT gateway timeouts, or stateful firewalls that drop idle connections. Without keepalives, you won't know the connection is dead until you try to use it—resulting in hung terminals or failed operations. SSH provides two keepalive mechanisms: ``TCPKeepAlive`` uses TCP-level keepalive packets (can be blocked by some firewalls), while ``ServerAliveInterval`` sends SSH-protocol messages through the encrypted channel (more reliable). ``ServerAliveCountMax`` determines how many missed responses trigger disconnect. For reliable long-running connections—tunnels, interactive sessions, or automation—configure both client and server keepalives. A 30-60 second interval works well for most NAT environments. .. code-block:: python from paramiko.client import SSHClient with SSHClient() as ssh: ssh.load_system_host_keys() ssh.connect("server.example.com", username="user", key_filename="mykey") # Configure keepalive transport = ssh.get_transport() transport.set_keepalive(30) # Send keepalive every 30 seconds # Long-running operations... .. code-block:: text # In ~/.ssh/config Host * ServerAliveInterval 30 ServerAliveCountMax 3 TCPKeepAlive yes Common SSH Commands Reference ----------------------------- A quick reference for essential SSH commands covering connections, tunneling, key management, and file transfer. These commands form the foundation of secure remote administration and are worth committing to memory. The verbose flags (``-v`` to ``-vvv``) are invaluable for debugging connection issues, showing the authentication methods tried, key exchanges, and where failures occur. .. code-block:: bash # Basic connection ssh user@host ssh -p 2222 user@host # Custom port ssh -i ~/.ssh/mykey user@host # Specific key # Tunneling ssh -L 8080:localhost:80 user@host # Local forward ssh -R 9000:localhost:3000 user@host # Remote forward ssh -D 1080 user@host # SOCKS proxy # Jump hosts ssh -J jump@bastion user@internal # ProxyJump ssh -o ProxyCommand="ssh -W %h:%p jump@bastion" user@internal # Background tunnels ssh -N -f -L 5432:db:5432 user@host # -N no command, -f background # Key management ssh-keygen -t ed25519 -C "comment" # Generate key (ed25519 recommended) ssh-keygen -t rsa -b 4096 # RSA 4096-bit ssh-copy-id -i ~/.ssh/mykey user@host # Copy public key to server ssh-add ~/.ssh/mykey # Add key to agent # Debugging ssh -v user@host # Verbose ssh -vvv user@host # Very verbose # File transfer scp local.txt user@host:/path/ # Copy to remote scp user@host:/path/file.txt . # Copy from remote scp -r dir/ user@host:/path/ # Recursive copy rsync -avz -e ssh dir/ user@host:/path/ # Efficient sync ================================================ FILE: docs/notes/os/index.rst ================================================ .. meta:: :description lang=en: Python system programming tutorial covering file operations, datetime, process management, environment variables, and path manipulation :keywords: Python, Python3, os, file, directory, datetime, subprocess, pathlib, environment, process, system, platform ====== System ====== Python provides powerful modules for interacting with the operating system, making it an excellent choice for system administration, automation, and scripting tasks. The ``os`` module offers portable access to file systems, processes, and environment variables, while ``datetime`` handles time and date operations. The ``pathlib`` module provides an object-oriented interface for filesystem paths, and ``subprocess`` enables running external commands. This section covers essential system operations including file manipulation, directory traversal, process management, and working with dates and times across different platforms. .. toctree:: :maxdepth: 1 python-date python-os python-io ================================================ FILE: docs/notes/os/python-date.rst ================================================ .. meta:: :description lang=en: Python datetime tutorial covering timestamps, date formatting, parsing, timezones, timedelta calculations, calendar operations, and time arithmetic :keywords: Python, Python3, datetime, date, time, timestamp, timezone, timedelta, strftime, strptime, calendar, UTC, ISO 8601, dateutil, zoneinfo ======== Datetime ======== :Source: `src/basic/datetime_.py `_ .. contents:: Table of Contents :backlinks: none Introduction ------------ Python's ``datetime`` module provides classes for manipulating dates and times. The module includes ``date`` for calendar dates, ``time`` for clock times, ``datetime`` for combined date and time, ``timedelta`` for durations, and ``timezone`` for UTC offsets. Python 3.9+ also includes ``zoneinfo`` for IANA timezone support. Understanding these classes is essential for logging, scheduling, data analysis, and any application that works with temporal data. Current Date and Time --------------------- Getting the current date and time is one of the most common operations. Use ``datetime.now()`` for local time or ``datetime.utcnow()`` for UTC. In Python 3.11+, prefer ``datetime.now(timezone.utc)`` over ``utcnow()`` which is deprecated. .. code-block:: python from datetime import datetime, date, time, timezone # Current local datetime now = datetime.now() print(now) # 2024-01-15 10:30:45.123456 # Current UTC datetime (Python 3.11+ preferred) utc_now = datetime.now(timezone.utc) print(utc_now) # 2024-01-15 02:30:45.123456+00:00 # Current date only today = date.today() print(today) # 2024-01-15 # Current time only current_time = datetime.now().time() print(current_time) # 10:30:45.123456 Creating Datetime Objects ------------------------- You can create datetime objects by specifying year, month, day, and optionally hour, minute, second, and microsecond. The ``date`` and ``time`` classes work similarly for date-only or time-only values. .. code-block:: python from datetime import datetime, date, time # Create specific datetime dt = datetime(2024, 1, 15, 10, 30, 45) print(dt) # 2024-01-15 10:30:45 # Create date only d = date(2024, 1, 15) print(d) # 2024-01-15 # Create time only t = time(10, 30, 45) print(t) # 10:30:45 # Combine date and time combined = datetime.combine(d, t) print(combined) # 2024-01-15 10:30:45 # Get date or time from datetime print(dt.date()) # 2024-01-15 print(dt.time()) # 10:30:45 Timestamps ---------- Unix timestamps represent seconds since January 1, 1970 (the Unix epoch). Converting between timestamps and datetime objects is common when working with APIs, databases, and log files. .. code-block:: python import time from datetime import datetime, timezone # Current timestamp ts = time.time() print(ts) # 1705312245.123456 # Timestamp to datetime (local time) dt = datetime.fromtimestamp(ts) print(dt) # 2024-01-15 10:30:45.123456 # Timestamp to datetime (UTC) dt_utc = datetime.fromtimestamp(ts, tz=timezone.utc) print(dt_utc) # 2024-01-15 02:30:45.123456+00:00 # Datetime to timestamp ts_back = dt.timestamp() print(ts_back) # 1705312245.123456 # Millisecond timestamp (common in JavaScript/APIs) ts_ms = int(ts * 1000) print(ts_ms) # 1705312245123 Formatting Dates (strftime) --------------------------- The ``strftime()`` method formats datetime objects as strings using format codes. This is essential for displaying dates to users, generating filenames, or formatting data for APIs. .. code-block:: python from datetime import datetime dt = datetime(2024, 1, 15, 14, 30, 45) # Common formats print(dt.strftime("%Y-%m-%d")) # 2024-01-15 print(dt.strftime("%d/%m/%Y")) # 15/01/2024 print(dt.strftime("%B %d, %Y")) # January 15, 2024 print(dt.strftime("%Y-%m-%d %H:%M:%S")) # 2024-01-15 14:30:45 print(dt.strftime("%I:%M %p")) # 02:30 PM print(dt.strftime("%A, %B %d")) # Monday, January 15 # ISO 8601 format print(dt.isoformat()) # 2024-01-15T14:30:45 # For filenames (no special characters) print(dt.strftime("%Y%m%d_%H%M%S")) # 20240115_143045 Common format codes: - ``%Y`` - 4-digit year (2024) - ``%m`` - Month as zero-padded number (01-12) - ``%d`` - Day as zero-padded number (01-31) - ``%H`` - Hour 24-hour format (00-23) - ``%I`` - Hour 12-hour format (01-12) - ``%M`` - Minute (00-59) - ``%S`` - Second (00-59) - ``%p`` - AM/PM - ``%A`` - Full weekday name - ``%B`` - Full month name - ``%z`` - UTC offset (+0000) - ``%Z`` - Timezone name Parsing Dates (strptime) ------------------------ The ``strptime()`` method parses strings into datetime objects. The format string must match the input exactly. This is commonly used when reading dates from files, user input, or APIs. .. code-block:: python from datetime import datetime # Parse various formats dt1 = datetime.strptime("2024-01-15", "%Y-%m-%d") dt2 = datetime.strptime("15/01/2024", "%d/%m/%Y") dt3 = datetime.strptime("January 15, 2024", "%B %d, %Y") dt4 = datetime.strptime("2024-01-15 14:30:45", "%Y-%m-%d %H:%M:%S") print(dt1) # 2024-01-15 00:00:00 print(dt4) # 2024-01-15 14:30:45 # Parse ISO 8601 format dt5 = datetime.fromisoformat("2024-01-15T14:30:45") print(dt5) # 2024-01-15 14:30:45 # Parse with timezone (Python 3.11+) dt6 = datetime.fromisoformat("2024-01-15T14:30:45+00:00") print(dt6) # 2024-01-15 14:30:45+00:00 Date Arithmetic with timedelta ------------------------------ The ``timedelta`` class represents a duration—the difference between two dates or times. Use it to add or subtract time from datetime objects, or to calculate the difference between two dates. .. code-block:: python from datetime import datetime, timedelta now = datetime.now() # Add time tomorrow = now + timedelta(days=1) next_week = now + timedelta(weeks=1) in_2_hours = now + timedelta(hours=2) in_90_minutes = now + timedelta(minutes=90) # Subtract time yesterday = now - timedelta(days=1) last_month = now - timedelta(days=30) # Combine units future = now + timedelta(days=5, hours=3, minutes=30) # Calculate difference between dates date1 = datetime(2024, 1, 1) date2 = datetime(2024, 12, 31) diff = date2 - date1 print(diff.days) # 365 print(diff.total_seconds()) # 31536000.0 # Compare dates print(date2 > date1) # True Timezones --------- Working with timezones correctly is crucial for applications serving users across different regions. Python 3.9+ includes ``zoneinfo`` for IANA timezone support. For earlier versions, use ``pytz`` or ``dateutil``. .. code-block:: python from datetime import datetime, timezone, timedelta # UTC timezone utc = timezone.utc dt_utc = datetime.now(utc) print(dt_utc) # 2024-01-15 02:30:45.123456+00:00 # Fixed offset timezone pst = timezone(timedelta(hours=-8)) dt_pst = datetime.now(pst) print(dt_pst) # 2024-01-14 18:30:45.123456-08:00 # Convert between timezones dt_converted = dt_utc.astimezone(pst) print(dt_converted) # Python 3.9+ with zoneinfo from zoneinfo import ZoneInfo eastern = ZoneInfo("America/New_York") tokyo = ZoneInfo("Asia/Tokyo") dt_eastern = datetime.now(eastern) dt_tokyo = dt_eastern.astimezone(tokyo) print(dt_tokyo) # Make naive datetime timezone-aware naive = datetime(2024, 1, 15, 10, 30) aware = naive.replace(tzinfo=utc) Comparing Dates --------------- Datetime objects support comparison operators. When comparing timezone-aware and naive datetimes, Python raises a TypeError to prevent subtle bugs. .. code-block:: python from datetime import datetime, date, timedelta dt1 = datetime(2024, 1, 15, 10, 0) dt2 = datetime(2024, 1, 15, 14, 0) dt3 = datetime(2024, 1, 16, 10, 0) # Comparisons print(dt1 < dt2) # True print(dt1 == dt2) # False print(dt3 > dt2) # True # Check if date is in range start = datetime(2024, 1, 1) end = datetime(2024, 12, 31) check = datetime(2024, 6, 15) print(start <= check <= end) # True # Check if date is today today = date.today() some_date = date(2024, 1, 15) print(some_date == today) # Days until a date future = date(2024, 12, 25) days_until = (future - today).days print(f"Days until: {days_until}") Working with Weeks ------------------ Getting week numbers, weekdays, and working with ISO week dates is common for reporting and scheduling applications. .. code-block:: python from datetime import datetime, date, timedelta dt = datetime(2024, 1, 15) # Day of week (0=Monday, 6=Sunday) print(dt.weekday()) # 0 (Monday) print(dt.isoweekday()) # 1 (Monday, ISO format 1-7) # Week number print(dt.isocalendar()) # (2024, 3, 1) - year, week, weekday # Get start of week (Monday) start_of_week = dt - timedelta(days=dt.weekday()) print(start_of_week) # 2024-01-15 00:00:00 # Get end of week (Sunday) end_of_week = start_of_week + timedelta(days=6) print(end_of_week) # 2024-01-21 00:00:00 # Check if weekend is_weekend = dt.weekday() >= 5 print(is_weekend) # False Start and End of Day/Month/Year ------------------------------- Getting the start or end of a time period is useful for date range queries and reporting. .. code-block:: python from datetime import datetime, date, time, timedelta import calendar dt = datetime(2024, 1, 15, 14, 30, 45) # Start of day start_of_day = datetime.combine(dt.date(), time.min) print(start_of_day) # 2024-01-15 00:00:00 # End of day end_of_day = datetime.combine(dt.date(), time.max) print(end_of_day) # 2024-01-15 23:59:59.999999 # Start of month start_of_month = dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0) print(start_of_month) # 2024-01-01 00:00:00 # End of month last_day = calendar.monthrange(dt.year, dt.month)[1] end_of_month = dt.replace(day=last_day, hour=23, minute=59, second=59) print(end_of_month) # 2024-01-31 23:59:59 # Start of year start_of_year = dt.replace(month=1, day=1, hour=0, minute=0, second=0, microsecond=0) print(start_of_year) # 2024-01-01 00:00:00 Calendar Operations ------------------- The ``calendar`` module provides functions for working with calendars, including checking leap years, getting month ranges, and generating calendar displays. .. code-block:: python import calendar from datetime import date # Check leap year print(calendar.isleap(2024)) # True print(calendar.isleap(2023)) # False # Days in month print(calendar.monthrange(2024, 2)) # (3, 29) - weekday of 1st, days in month days_in_feb = calendar.monthrange(2024, 2)[1] print(days_in_feb) # 29 # Generate month calendar print(calendar.month(2024, 1)) # Iterate through month days cal = calendar.Calendar() for day in cal.itermonthdays(2024, 1): if day != 0: print(day, end=" ") # 1 2 3 ... 31 Date Ranges ----------- Generating sequences of dates is useful for reports, charts, and scheduling. .. code-block:: python from datetime import datetime, date, timedelta def date_range(start, end, step=timedelta(days=1)): """Generate dates from start to end.""" current = start while current <= end: yield current current += step # Daily dates start = date(2024, 1, 1) end = date(2024, 1, 7) for d in date_range(start, end): print(d) # Weekly dates for d in date_range(start, date(2024, 1, 31), timedelta(weeks=1)): print(d) # Business days (skip weekends) def business_days(start, end): current = start while current <= end: if current.weekday() < 5: # Monday=0 to Friday=4 yield current current += timedelta(days=1) Age Calculation --------------- Calculating age from a birthdate requires handling the edge case where the birthday hasn't occurred yet this year. .. code-block:: python from datetime import date def calculate_age(birthdate): """Calculate age in years from birthdate.""" today = date.today() age = today.year - birthdate.year # Subtract 1 if birthday hasn't occurred this year if (today.month, today.day) < (birthdate.month, birthdate.day): age -= 1 return age birthdate = date(1990, 6, 15) age = calculate_age(birthdate) print(f"Age: {age}") # Days until next birthday def days_until_birthday(birthdate): today = date.today() next_birthday = birthdate.replace(year=today.year) if next_birthday < today: next_birthday = next_birthday.replace(year=today.year + 1) return (next_birthday - today).days Human Readable Time Differences ------------------------------- Converting timedelta to human-readable strings like "2 hours ago" or "in 3 days" improves user experience. .. code-block:: python from datetime import datetime, timedelta def time_ago(dt): """Convert datetime to human-readable relative time.""" now = datetime.now() diff = now - dt seconds = diff.total_seconds() if seconds < 60: return "just now" elif seconds < 3600: minutes = int(seconds // 60) return f"{minutes} minute{'s' if minutes != 1 else ''} ago" elif seconds < 86400: hours = int(seconds // 3600) return f"{hours} hour{'s' if hours != 1 else ''} ago" elif seconds < 604800: days = int(seconds // 86400) return f"{days} day{'s' if days != 1 else ''} ago" else: return dt.strftime("%B %d, %Y") # Examples print(time_ago(datetime.now() - timedelta(minutes=5))) # 5 minutes ago print(time_ago(datetime.now() - timedelta(hours=2))) # 2 hours ago print(time_ago(datetime.now() - timedelta(days=3))) # 3 days ago Using dateutil for Flexible Parsing ----------------------------------- The ``python-dateutil`` library provides powerful parsing that handles many date formats automatically, plus relative delta calculations. .. code-block:: python # pip install python-dateutil from dateutil import parser from dateutil.relativedelta import relativedelta from datetime import datetime # Flexible parsing - handles many formats automatically dt1 = parser.parse("January 15, 2024") dt2 = parser.parse("15/01/2024") dt3 = parser.parse("2024-01-15T14:30:45Z") dt4 = parser.parse("Jan 15 2024 2:30 PM") # Relative delta - handles months and years correctly now = datetime.now() # Add 1 month (handles varying month lengths) next_month = now + relativedelta(months=1) # Add 1 year next_year = now + relativedelta(years=1) # Last day of next month last_of_next_month = now + relativedelta(months=1, day=31) # Complex relative: 2 months and 3 days ago past = now - relativedelta(months=2, days=3) ================================================ FILE: docs/notes/os/python-io.rst ================================================ .. meta:: :description lang=en: Python file I/O tutorial covering reading, writing, binary files, pathlib, context managers, file modes, temporary files, and efficient file handling patterns :keywords: Python, Python3, file, I/O, read, write, binary, text, pathlib, Path, context manager, open, with statement, encoding, shutil, tempfile, glob ============= Files and I/O ============= :Source: `src/basic/fileio_.py `_ .. contents:: Table of Contents :backlinks: none Introduction ------------ Python provides comprehensive support for file operations and filesystem manipulation through several built-in modules. The ``open()`` function is the foundation for reading and writing files, supporting text and binary modes with configurable encoding. The ``pathlib`` module (Python 3.4+) offers a modern, object-oriented interface for path manipulation that works consistently across operating systems. For high-level operations like copying directory trees or moving files across filesystems, the ``shutil`` module provides convenient functions. The ``tempfile`` module handles creation of temporary files and directories with automatic cleanup, essential for secure handling of intermediate data. Together, these modules cover virtually all file I/O needs from simple text processing to complex filesystem operations. Reading Files ------------- The ``open()`` function returns a file object that supports multiple read methods. Always use the ``with`` statement (context manager) to ensure files are properly closed even if an exception occurs. The ``read()`` method loads the entire file into memory, while iterating over the file object processes one line at a time, which is more memory-efficient for large files. Always specify ``encoding="utf-8"`` explicitly to avoid platform-dependent behavior. .. code-block:: python # Read entire file as string with open("example.txt", encoding="utf-8") as f: content = f.read() # Read all lines as list with open("example.txt", encoding="utf-8") as f: lines = f.readlines() # Iterate line by line (memory efficient) with open("example.txt", encoding="utf-8") as f: for line in f: print(line.rstrip()) # Read specific number of characters with open("example.txt", encoding="utf-8") as f: first_100 = f.read(100) # Read single line with open("example.txt", encoding="utf-8") as f: first_line = f.readline() Writing Files ------------- Python offers several modes for writing files. Mode ``"w"`` creates a new file or truncates an existing one, ``"a"`` appends to the end without truncating, and ``"x"`` creates exclusively (raising ``FileExistsError`` if the file already exists). The ``write()`` method writes a single string, while ``writelines()`` writes an iterable of strings. Note that neither method adds newlines automatically—you must include ``\n`` in your strings. You can also redirect ``print()`` output to a file using the ``file`` parameter. .. code-block:: python # Write string to file (overwrites) with open("output.txt", "w", encoding="utf-8") as f: f.write("Hello, World!\n") # Write multiple lines lines = ["line 1\n", "line 2\n", "line 3\n"] with open("output.txt", "w", encoding="utf-8") as f: f.writelines(lines) # Append to file with open("output.txt", "a", encoding="utf-8") as f: f.write("Appended line\n") # Create new file (fails if exists) with open("new_file.txt", "x", encoding="utf-8") as f: f.write("New content") # Print to file with open("output.txt", "w", encoding="utf-8") as f: print("Hello", "World", sep=", ", file=f) Binary Files ------------ Binary mode (``"rb"``, ``"wb"``) reads and writes raw bytes without any encoding or newline translation. This is essential for non-text files like images, PDFs, executables, or any file where byte-level accuracy matters. Binary data is represented as ``bytes`` objects in Python. When processing large binary files, read in chunks to avoid loading the entire file into memory at once. .. code-block:: python # Read binary file with open("image.png", "rb") as f: data = f.read() print(type(data)) # # Write binary file with open("copy.png", "wb") as f: f.write(data) # Read binary in chunks chunk_size = 8192 with open("large_file.bin", "rb") as f: while chunk := f.read(chunk_size): process(chunk) File Modes ---------- The ``open()`` function accepts a mode string that controls how the file is opened. The mode combines access type (read, write, append) with content type (text or binary). Text mode performs encoding/decoding and newline translation, while binary mode works with raw bytes. The ``+`` modifier enables both reading and writing on the same file handle. Common file modes: - ``"r"`` - Read text (default) - ``"w"`` - Write text (truncates) - ``"a"`` - Append text - ``"x"`` - Exclusive create (fails if exists) - ``"rb"`` - Read binary - ``"wb"`` - Write binary - ``"r+"`` - Read and write - ``"w+"`` - Write and read (truncates) Reading File in Chunks ---------------------- When processing files larger than available memory, reading in chunks prevents memory exhaustion. A generator function that yields chunks is memory-efficient and works well with streaming processing. The walrus operator (``:=``) provides a clean way to read until the file is exhausted. The ``iter()`` function with a sentinel value offers an alternative pattern for chunk-based reading. .. code-block:: python def read_chunks(filepath, chunk_size=8192): """Read file in chunks.""" with open(filepath, "rb") as f: while chunk := f.read(chunk_size): yield chunk # Process large file for chunk in read_chunks("large_file.bin"): process(chunk) # Using iter with sentinel with open("file.txt", encoding="utf-8") as f: for chunk in iter(lambda: f.read(1024), ""): print(chunk, end="") pathlib Basics -------------- The ``pathlib`` module, introduced in Python 3.4, provides an object-oriented approach to filesystem paths. Unlike string-based path manipulation, ``Path`` objects handle platform differences automatically (forward slashes on Unix, backslashes on Windows). The ``/`` operator joins path components intuitively, and methods like ``resolve()`` convert relative paths to absolute. ``Path`` objects are the recommended way to work with filesystem paths in modern Python. .. code-block:: python from pathlib import Path # Create path objects p = Path("folder/file.txt") p = Path.home() / "Documents" / "file.txt" # Current and home directories cwd = Path.cwd() home = Path.home() # Absolute path abs_path = Path("file.txt").resolve() Path Properties --------------- ``Path`` objects expose various properties to extract components of a path. The ``name`` property returns the final component, ``stem`` returns the name without the suffix, and ``suffix`` returns the file extension including the dot. The ``parent`` property returns the directory containing the path, and ``parts`` returns a tuple of all path components. Methods like ``with_suffix()`` and ``with_name()`` create new paths with modified components without affecting the original. .. code-block:: python from pathlib import Path p = Path("/home/user/documents/report.pdf") print(p.name) # report.pdf print(p.stem) # report print(p.suffix) # .pdf print(p.parent) # /home/user/documents print(p.parts) # ('/', 'home', 'user', 'documents', 'report.pdf') print(p.anchor) # / # Multiple suffixes p2 = Path("archive.tar.gz") print(p2.suffixes) # ['.tar', '.gz'] # Change suffix p3 = p.with_suffix(".txt") print(p3) # /home/user/documents/report.txt # Change name p4 = p.with_name("summary.pdf") print(p4) # /home/user/documents/summary.pdf Path Operations --------------- ``Path`` objects provide methods to check file existence and type, retrieve file metadata, and perform read/write operations. The ``exists()``, ``is_file()``, and ``is_dir()`` methods test path status without raising exceptions. The ``stat()`` method returns detailed file information including size and modification time. For simple file operations, ``read_text()``, ``write_text()``, ``read_bytes()``, and ``write_bytes()`` provide convenient one-liner alternatives to the ``open()`` context manager pattern. .. code-block:: python from pathlib import Path p = Path("example.txt") # Check existence and type p.exists() # True/False p.is_file() # True if regular file p.is_dir() # True if directory p.is_symlink() # True if symbolic link # File stats stat = p.stat() print(stat.st_size) # File size in bytes print(stat.st_mtime) # Modification time # Read and write with pathlib content = p.read_text(encoding="utf-8") p.write_text("New content", encoding="utf-8") # Binary read/write data = p.read_bytes() p.write_bytes(b"binary data") Listing Directories ------------------- Python offers several ways to list directory contents, each with different trade-offs. The ``pathlib`` method ``iterdir()`` returns an iterator of ``Path`` objects, allowing you to check file types and access properties directly. The ``os.scandir()`` function is highly efficient because it retrieves file type information during directory iteration without additional system calls. The simpler ``os.listdir()`` returns just filenames as strings, requiring additional calls to get file information. .. code-block:: python from pathlib import Path import os # pathlib - iterate directory p = Path(".") for item in p.iterdir(): print(item.name, "dir" if item.is_dir() else "file") # pathlib - glob patterns for py_file in Path(".").glob("*.py"): print(py_file) # Recursive glob for py_file in Path(".").rglob("*.py"): print(py_file) # os.scandir (efficient, returns DirEntry) with os.scandir(".") as entries: for entry in entries: print(entry.name, entry.is_file()) # os.listdir (simple list) files = os.listdir(".") Glob Patterns ------------- Glob patterns provide a shell-like syntax for matching multiple files. The ``*`` wildcard matches any characters except path separators, ``?`` matches a single character, and ``**`` matches any number of directories recursively. The ``pathlib`` methods ``glob()`` and ``rglob()`` (recursive glob) return iterators of matching ``Path`` objects. Note that ``pathlib`` glob doesn't support brace expansion like ``{py,txt}``—use multiple glob calls or the ``glob`` module for complex patterns. .. code-block:: python from pathlib import Path # All Python files in current directory list(Path(".").glob("*.py")) # All Python files recursively list(Path(".").rglob("*.py")) # Multiple extensions list(Path(".").glob("*.{py,txt}")) # Won't work # Use instead: py_files = list(Path(".").glob("*.py")) txt_files = list(Path(".").glob("*.txt")) # Single character wildcard list(Path(".").glob("file?.txt")) # file1.txt, file2.txt # Using glob module import glob glob.glob("**/*.py", recursive=True) Creating Directories -------------------- Creating directories is straightforward with both ``pathlib`` and ``os``. The ``mkdir()`` method creates a single directory, raising ``FileExistsError`` if it already exists. The ``parents=True`` parameter creates all intermediate directories (like ``mkdir -p`` in Unix), and ``exist_ok=True`` suppresses the error if the directory already exists. These options together make directory creation idempotent and safe for concurrent execution. .. code-block:: python from pathlib import Path import os # pathlib - create directory Path("new_dir").mkdir() # Create with parents (like mkdir -p) Path("path/to/nested/dir").mkdir(parents=True, exist_ok=True) # os.makedirs os.makedirs("path/to/dir", exist_ok=True) shutil - High-Level File Operations ------------------------------------ The ``shutil`` module provides high-level operations for copying, moving, and removing files and directory trees. **Copying Files:** .. code-block:: python import shutil # Copy file (content only) shutil.copy("source.txt", "dest.txt") # Copy file preserving metadata (timestamps, permissions) shutil.copy2("source.txt", "dest.txt") # Copy to directory (keeps original filename) shutil.copy("file.txt", "backup/") # -> backup/file.txt # Copy file object to file object with open("src.txt", "rb") as src, open("dst.txt", "wb") as dst: shutil.copyfileobj(src, dst) # Copy only file content (no metadata) shutil.copyfile("source.txt", "dest.txt") **Copying Directory Trees:** .. code-block:: python import shutil # Copy entire directory tree shutil.copytree("source_dir", "dest_dir") # Copy with ignore patterns shutil.copytree( "source", "dest", ignore=shutil.ignore_patterns("*.pyc", "__pycache__", ".git") ) # Copy into existing directory (Python 3.8+) shutil.copytree("source", "existing_dest", dirs_exist_ok=True) # Custom ignore function def ignore_large_files(directory, files): """Ignore files larger than 1MB.""" ignored = [] for f in files: path = os.path.join(directory, f) if os.path.isfile(path) and os.path.getsize(path) > 1_000_000: ignored.append(f) return ignored shutil.copytree("source", "dest", ignore=ignore_large_files) # Copy with symlinks preserved shutil.copytree("source", "dest", symlinks=True) **Moving Files and Directories:** .. code-block:: python import shutil from pathlib import Path # Move file (works across filesystems) shutil.move("old_name.txt", "new_name.txt") # Move file to directory shutil.move("file.txt", "archive/") # -> archive/file.txt # Move entire directory shutil.move("old_dir", "new_dir") # pathlib rename (same filesystem only) Path("old.txt").rename("new.txt") # pathlib replace (overwrites destination) Path("source.txt").replace("dest.txt") **Removing Files and Directories:** .. code-block:: python import shutil from pathlib import Path # Delete entire directory tree shutil.rmtree("dir_with_contents") # Delete with error handler def on_error(func, path, exc_info): print(f"Error deleting {path}: {exc_info[1]}") shutil.rmtree("dir", onerror=on_error) # Delete ignoring errors shutil.rmtree("dir", ignore_errors=True) # Delete file Path("file.txt").unlink() Path("file.txt").unlink(missing_ok=True) # No error if missing # Delete empty directory Path("empty_dir").rmdir() **Disk Usage:** .. code-block:: python import shutil # Get disk usage statistics usage = shutil.disk_usage("/") print(f"Total: {usage.total // (1024**3)} GB") print(f"Used: {usage.used // (1024**3)} GB") print(f"Free: {usage.free // (1024**3)} GB") **Finding Executables:** .. code-block:: python import shutil # Find executable in PATH python_path = shutil.which("python") print(python_path) # /usr/bin/python # Returns None if not found result = shutil.which("nonexistent") print(result) # None **Archiving:** .. code-block:: python import shutil # Create archive (zip, tar, gztar, bztar, xztar) shutil.make_archive("backup", "zip", "source_dir") # -> backup.zip shutil.make_archive("backup", "gztar", "source_dir") # -> backup.tar.gz # Extract archive shutil.unpack_archive("backup.zip", "extract_dir") shutil.unpack_archive("backup.tar.gz", "extract_dir") # List supported formats print(shutil.get_archive_formats()) print(shutil.get_unpack_formats()) Temporary Files --------------- The ``tempfile`` module creates temporary files and directories with unique names in a secure manner. ``NamedTemporaryFile`` creates a file that is automatically deleted when closed (unless ``delete=False``). The ``suffix`` parameter adds a file extension, useful when other programs need to identify the file type. ``TemporaryDirectory`` creates a directory that is recursively deleted when the context manager exits, perfect for test fixtures or intermediate processing. These functions use the system's temp directory by default, which you can query with ``gettempdir()``. .. code-block:: python import tempfile from pathlib import Path # Temporary file (auto-deleted when closed) with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=True) as f: f.write("temporary content") print(f.name) # /tmp/tmpXXXXXX.txt # Temporary file that persists with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: temp_path = f.name f.write("persistent temp") # Clean up manually later Path(temp_path).unlink() # Temporary directory with tempfile.TemporaryDirectory() as tmpdir: temp_file = Path(tmpdir) / "file.txt" temp_file.write_text("content") # Directory deleted when exiting context # Get temp directory path print(tempfile.gettempdir()) # /tmp Symbolic Links -------------- Symbolic links (symlinks) are special files that point to another file or directory. They're useful for creating shortcuts, managing multiple versions, or organizing files without duplication. The ``symlink_to()`` method creates a symlink pointing to the specified target. The ``is_symlink()`` method checks if a path is a symlink, and ``readlink()`` returns the target path. The ``resolve()`` method follows all symlinks to return the canonical absolute path. .. code-block:: python from pathlib import Path import os # Create symlink with pathlib Path("link_name").symlink_to("target_file") # Create symlink with os os.symlink("target", "link") # Read symlink target target = Path("link_name").readlink() target = os.readlink("link") # Check if symlink Path("link_name").is_symlink() # Resolve symlink to real path real_path = Path("link_name").resolve() File Permissions ---------------- Unix-like systems use permission bits to control file access. The ``stat()`` method returns file metadata including the permission mode. The ``chmod()`` method modifies permissions using octal notation (e.g., ``0o644`` for owner read/write, group/other read-only) or by combining ``stat`` module constants. The ``os.access()`` function checks if the current process has specific permissions on a file, useful for pre-flight checks before attempting operations. .. code-block:: python from pathlib import Path import os import stat p = Path("script.sh") # Get permissions mode = p.stat().st_mode print(oct(mode)) # 0o100644 # Make executable p.chmod(p.stat().st_mode | stat.S_IXUSR) # Set specific permissions (owner rw, group r, other r) p.chmod(0o644) # Check if readable/writable os.access("file.txt", os.R_OK) # Readable os.access("file.txt", os.W_OK) # Writable os.access("file.txt", os.X_OK) # Executable Working with CSV Files ---------------------- CSV (Comma-Separated Values) is a common format for tabular data exchange. Python's ``csv`` module handles the complexities of CSV parsing, including quoted fields, different delimiters, and proper escaping. The ``writer`` object writes rows as lists, while ``DictWriter`` writes dictionaries using column headers as keys. Similarly, ``reader`` yields rows as lists, and ``DictReader`` yields dictionaries. Always open CSV files with ``newline=""`` to let the csv module handle line endings correctly across platforms. .. code-block:: python import csv # Write CSV with open("data.csv", "w", newline="", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow(["name", "age", "city"]) writer.writerow(["Alice", 30, "NYC"]) writer.writerows([["Bob", 25, "LA"], ["Carol", 35, "Chicago"]]) # Read CSV with open("data.csv", newline="", encoding="utf-8") as f: reader = csv.reader(f) header = next(reader) for row in reader: print(row) # DictReader/DictWriter with open("data.csv", newline="", encoding="utf-8") as f: reader = csv.DictReader(f) for row in reader: print(row["name"], row["age"]) Working with JSON Files ----------------------- JSON (JavaScript Object Notation) is the standard format for data interchange in web APIs and configuration files. Python's ``json`` module serializes Python objects (dicts, lists, strings, numbers, booleans, None) to JSON strings and deserializes JSON back to Python objects. The ``dump()`` and ``load()`` functions work directly with file objects, while ``dumps()`` and ``loads()`` work with strings. The ``indent`` parameter produces human-readable formatted output. .. code-block:: python import json from pathlib import Path data = {"name": "Alice", "scores": [95, 87, 92]} # Write JSON with open("data.json", "w", encoding="utf-8") as f: json.dump(data, f, indent=2) # Read JSON with open("data.json", encoding="utf-8") as f: loaded = json.load(f) # pathlib shorthand Path("data.json").write_text(json.dumps(data, indent=2)) loaded = json.loads(Path("data.json").read_text()) Compressed Files ---------------- Python supports several compression formats through dedicated modules. The ``gzip`` module handles gzip compression, commonly used for log files and web content. Use ``"rt"`` and ``"wt"`` modes for text, ``"rb"`` and ``"wb"`` for binary. The ``zipfile`` module creates and extracts ZIP archives, supporting multiple files in a single archive. The ``writestr()`` method adds content directly from strings without creating temporary files. Both modules integrate seamlessly with Python's file handling patterns. .. code-block:: python import gzip import zipfile from pathlib import Path # Write gzip file with gzip.open("file.txt.gz", "wt", encoding="utf-8") as f: f.write("compressed content") # Read gzip file with gzip.open("file.txt.gz", "rt", encoding="utf-8") as f: content = f.read() # Create zip archive with zipfile.ZipFile("archive.zip", "w") as zf: zf.write("file1.txt") zf.write("file2.txt") zf.writestr("new.txt", "content from string") # Extract zip archive with zipfile.ZipFile("archive.zip", "r") as zf: zf.extractall("output_dir") # Extract single file zf.extract("file1.txt", "output_dir") # List zip contents with zipfile.ZipFile("archive.zip", "r") as zf: print(zf.namelist()) File Locking ------------ File locking prevents data corruption when multiple processes access the same file. On Unix systems, ``fcntl.flock()`` provides advisory locking—processes must cooperatively check locks. ``LOCK_EX`` requests an exclusive lock for writing, while ``LOCK_SH`` allows shared read access. The ``LOCK_NB`` flag makes the call non-blocking, raising ``BlockingIOError`` if the lock isn't immediately available. Always release locks in a ``finally`` block to prevent deadlocks. Note that Windows uses different locking mechanisms (``msvcrt``). .. code-block:: python import fcntl import time # Exclusive lock (Unix) with open("data.txt", "w") as f: fcntl.flock(f.fileno(), fcntl.LOCK_EX) try: f.write("exclusive write") time.sleep(1) # Simulate work finally: fcntl.flock(f.fileno(), fcntl.LOCK_UN) # Non-blocking lock attempt with open("data.txt", "w") as f: try: fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) f.write("got lock") except BlockingIOError: print("File is locked by another process") Watching File Changes (inotify) ------------------------------- The Linux inotify API provides efficient filesystem event monitoring without polling. Applications can watch directories for file creation, deletion, modification, and other events. This is useful for auto-reloading configuration files, triggering builds on source changes, or synchronizing directories. The example below demonstrates direct inotify access via ``ctypes``; for production use, consider the ``watchdog`` library which provides a cross-platform abstraction. .. code-block:: python import ctypes import os import struct import selectors from ctypes.util import find_library from pathlib import Path # inotify constants IN_CREATE = 0x00000100 IN_DELETE = 0x00000200 IN_MODIFY = 0x00000002 libc = ctypes.CDLL(find_library("c")) class Inotify: def __init__(self, path, mask=IN_CREATE | IN_DELETE | IN_MODIFY): self.path = path self.mask = mask self.fd = None self.wd = None def __enter__(self): self.fd = libc.inotify_init() path_bytes = str(self.path).encode("utf-8") self.wd = libc.inotify_add_watch(self.fd, path_bytes, self.mask) return self def __exit__(self, *args): libc.inotify_rm_watch(self.fd, self.wd) os.close(self.fd) def read_events(self): data = os.read(self.fd, 4096) offset = 0 while offset < len(data): wd, mask, cookie, length = struct.unpack_from("iIII", data, offset) offset += 16 name = data[offset:offset + length].rstrip(b"\0").decode("utf-8") offset += length yield mask, name # Usage with Inotify(Path("/tmp")) as inotify: for mask, filename in inotify.read_events(): print(f"Event {mask}: {filename}") ================================================ FILE: docs/notes/os/python-os.rst ================================================ .. meta:: :description lang=en: Python operating system interface tutorial covering file operations, process management, environment variables, path manipulation, system information, and cross-platform OS interactions :keywords: Python, os module, operating system, file system, process, environment variables, path, directory, subprocess, platform, system info, CPU, memory ================ Operating System ================ :Source: `src/basic/os_.py `_ .. contents:: Table of Contents :backlinks: none Python's ``os`` module provides a portable way to interact with the operating system, abstracting platform-specific details behind a consistent API. Whether you're managing files and directories, spawning processes, reading environment variables, or querying system information, the ``os`` module handles the differences between Windows, Linux, and macOS. For path manipulation, the ``os.path`` submodule (or the modern ``pathlib``) provides cross-platform path handling. This cheat sheet covers common OS operations with practical examples. Get System Information ---------------------- Retrieve basic information about the operating system, platform, and current process. These functions are useful for writing cross-platform code that adapts to the runtime environment. .. code-block:: python import os import platform # Operating system name os.name # 'posix' (Linux/macOS) or 'nt' (Windows) # Platform details platform.system() # 'Linux', 'Darwin', 'Windows' platform.release() # '5.15.0-generic', '22.1.0', '10' platform.machine() # 'x86_64', 'arm64', 'AMD64' platform.processor() # 'x86_64', 'arm', '' platform.python_version() # '3.12.0' # Current process os.getpid() # Process ID os.getppid() # Parent process ID os.getcwd() # Current working directory os.getlogin() # Current username Get Number of CPUs ------------------ Determine the number of CPU cores available for parallel processing. This is essential for configuring thread pools, multiprocessing workers, or understanding system capacity. Note that ``cpu_count()`` returns logical cores (including hyperthreading), not physical cores. .. code-block:: python import os # Number of logical CPUs cpu_count = os.cpu_count() print(f"CPUs: {cpu_count}") # CPUs: 8 # For physical cores (Linux only) # cat /proc/cpuinfo | grep "cpu cores" | uniq Set CPU Affinity ---------------- CPU affinity binds a process to specific CPU cores, useful for performance optimization, reducing cache misses, or isolating workloads. This feature is Linux-specific and not available on macOS or Windows through the ``os`` module. .. code-block:: python import os # Linux only - set process to run on specific CPUs pid = os.getpid() affinity = {0, 1} # Run on CPU 0 and 1 only os.sched_setaffinity(pid, affinity) # Get current affinity current = os.sched_getaffinity(pid) print(f"Running on CPUs: {current}") # Running on CPUs: {0, 1} Environment Variables --------------------- Environment variables store configuration that persists across process invocations. Use ``os.environ`` as a dictionary to read, set, or delete variables. Changes only affect the current process and its children, not the parent shell. .. code-block:: python import os # Read environment variable home = os.environ.get('HOME') # Returns None if not set path = os.environ['PATH'] # Raises KeyError if not set debug = os.getenv('DEBUG', 'false') # With default value # Set environment variable os.environ['MY_VAR'] = 'my_value' # Delete environment variable del os.environ['MY_VAR'] os.unsetenv('MY_VAR') # Alternative # List all environment variables for key, value in os.environ.items(): print(f"{key}={value}") Path Operations --------------- Path manipulation is one of the most common OS tasks. The ``os.path`` module provides cross-platform functions that handle path separators (``/`` vs ``\``) automatically. For modern Python (3.4+), consider using ``pathlib`` for an object-oriented approach. .. code-block:: python import os # Join paths (handles separators automatically) path = os.path.join('/home', 'user', 'file.txt') # Linux: '/home/user/file.txt' # Windows: '\\home\\user\\file.txt' # Split path components dirname = os.path.dirname('/home/user/file.txt') # '/home/user' basename = os.path.basename('/home/user/file.txt') # 'file.txt' name, ext = os.path.splitext('file.txt') # ('file', '.txt') # Absolute and relative paths abs_path = os.path.abspath('.') # Full path rel_path = os.path.relpath('/home/user') # Relative to cwd real_path = os.path.realpath('link') # Resolve symlinks # Path checks os.path.exists('/path/to/file') # True if exists os.path.isfile('/path/to/file') # True if regular file os.path.isdir('/path/to/dir') # True if directory os.path.islink('/path/to/link') # True if symbolic link # Path info os.path.getsize('/path/to/file') # Size in bytes os.path.getmtime('/path/to/file') # Modification time (timestamp) Directory Operations -------------------- Create, remove, and navigate directories. These operations are fundamental for file management, build systems, and data processing pipelines. .. code-block:: python import os # Create directory os.mkdir('new_dir') # Single directory os.makedirs('path/to/new_dir') # Create parent dirs too os.makedirs('path/to/dir', exist_ok=True) # Don't error if exists # Remove directory os.rmdir('empty_dir') # Must be empty # For non-empty: use shutil.rmtree() # Change directory os.chdir('/path/to/dir') print(os.getcwd()) # Print current directory # List directory contents entries = os.listdir('.') # List of names for entry in entries: print(entry) # Walk directory tree for root, dirs, files in os.walk('.'): for file in files: path = os.path.join(root, file) print(path) File Operations --------------- Low-level file operations using file descriptors. For most use cases, Python's built-in ``open()`` function is preferred, but ``os`` functions are useful for special cases like non-blocking I/O or when you need precise control over file descriptors. .. code-block:: python import os # Rename/move file os.rename('old_name.txt', 'new_name.txt') os.replace('src.txt', 'dst.txt') # Atomic, overwrites dst # Remove file os.remove('file.txt') os.unlink('file.txt') # Same as remove # File permissions (Unix) os.chmod('file.txt', 0o644) # rw-r--r-- os.chown('file.txt', uid, gid) # Change owner # Create symbolic link os.symlink('target', 'link_name') # Low-level file operations fd = os.open('file.txt', os.O_RDONLY) data = os.read(fd, 1024) # Read up to 1024 bytes os.close(fd) Execute Commands ---------------- Run external commands and programs. For simple cases, ``os.system()`` works, but ``subprocess`` module is recommended for more control over input/output and error handling. .. code-block:: python import os import subprocess # Simple command (returns exit code) exit_code = os.system('ls -la') # Better: use subprocess result = subprocess.run(['ls', '-la'], capture_output=True, text=True) print(result.stdout) print(result.returncode) # Run and capture output output = subprocess.check_output(['date'], text=True) print(output.strip()) # Run with input result = subprocess.run( ['grep', 'pattern'], input='line1\npattern here\nline3', capture_output=True, text=True ) Process Management ------------------ Create and manage child processes. The ``os.fork()`` function is Unix-specific; for cross-platform process creation, use the ``multiprocessing`` module instead. .. code-block:: python import os # Fork process (Unix only) pid = os.fork() if pid == 0: # Child process print(f"Child PID: {os.getpid()}") os._exit(0) else: # Parent process print(f"Parent PID: {os.getpid()}, Child PID: {pid}") os.waitpid(pid, 0) # Wait for child # Execute new program (replaces current process) # os.execv('/bin/ls', ['ls', '-la']) # Send signal to process os.kill(pid, signal.SIGTERM) Temporary Files --------------- Create temporary files and directories that are automatically cleaned up. The ``tempfile`` module provides secure, cross-platform temporary file handling. .. code-block:: python import tempfile import os # Temporary file (auto-deleted when closed) with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: f.write('temporary data') temp_path = f.name print(f"Temp file: {temp_path}") os.unlink(temp_path) # Manual cleanup if delete=False # Temporary directory with tempfile.TemporaryDirectory() as tmpdir: temp_file = os.path.join(tmpdir, 'file.txt') with open(temp_file, 'w') as f: f.write('data') # Directory and contents deleted after with block # Get temp directory path print(tempfile.gettempdir()) # /tmp or C:\Users\...\Temp Using pathlib (Modern Alternative) ---------------------------------- The ``pathlib`` module (Python 3.4+) provides an object-oriented interface to paths, making code more readable and less error-prone than string-based ``os.path`` operations. .. code-block:: python from pathlib import Path # Create path objects p = Path('/home/user/file.txt') p = Path.home() / 'documents' / 'file.txt' # Use / operator # Path components p.name # 'file.txt' p.stem # 'file' p.suffix # '.txt' p.parent # Path('/home/user') p.parts # ('/', 'home', 'user', 'file.txt') # Path checks p.exists() p.is_file() p.is_dir() # Read/write files content = p.read_text() p.write_text('new content') data = p.read_bytes() # Directory operations Path('new_dir').mkdir(parents=True, exist_ok=True) for child in Path('.').iterdir(): print(child) # Glob patterns for py_file in Path('.').glob('**/*.py'): print(py_file) System Monitoring with psutil ----------------------------- The ``psutil`` (process and system utilities) library provides cross-platform access to system monitoring data that isn't available through the standard ``os`` module. It covers CPU, memory, disk, network, and process information. Install with ``pip install psutil``. .. code-block:: python import psutil # CPU information psutil.cpu_count() # Logical CPUs psutil.cpu_count(logical=False) # Physical cores psutil.cpu_percent(interval=1) # CPU usage % psutil.cpu_percent(percpu=True) # Per-CPU usage psutil.cpu_freq() # CPU frequency # Memory information mem = psutil.virtual_memory() mem.total # Total RAM in bytes mem.available # Available RAM mem.percent # Usage percentage mem.used # Used RAM # Swap memory swap = psutil.swap_memory() swap.total, swap.used, swap.percent # Disk information psutil.disk_partitions() # List partitions usage = psutil.disk_usage('/') usage.total, usage.used, usage.percent # Network information psutil.net_io_counters() # Bytes sent/received psutil.net_connections() # Active connections psutil.net_if_addrs() # Interface addresses # Process information for proc in psutil.process_iter(['pid', 'name', 'cpu_percent']): print(proc.info) # Current process p = psutil.Process() p.pid p.name() p.cpu_percent() p.memory_info() p.num_threads() # System boot time import datetime boot = datetime.datetime.fromtimestamp(psutil.boot_time()) ================================================ FILE: docs/notes/python-new-py3.rst ================================================ .. meta:: :description lang=en: Python 3 new features guide covering f-strings, walrus operator, dataclasses, async/await, type hints, and major improvements from Python 2 :keywords: Python, Python3, new features, f-strings, walrus operator, dataclasses, async await, type hints, pathlib, migration ============== New in Python3 ============== .. contents:: Table of Contents :backlinks: none The source code can be found on `py3.py `_. Type Parameter Syntax ---------------------- **New in Python 3.12** - PEP 695_ - Type Parameter Syntax Python 3.12 introduces a cleaner, more intuitive syntax for defining generic classes and functions. Instead of importing ``TypeVar`` and ``Generic`` from the typing module, you can now use the ``[]`` bracket notation directly in class and function definitions. This makes generic code more readable and reduces boilerplate significantly. .. code-block:: python # Old way (before Python 3.12) from typing import TypeVar, Generic T = TypeVar('T') class Box(Generic[T]): def __init__(self, item: T) -> None: self.item = item # New way (Python 3.12+) class Box[T]: def __init__(self, item: T) -> None: self.item = item # Generic functions def first[T](items: list[T]) -> T: return items[0] f-string Improvements ---------------------- **New in Python 3.12** - PEP 701_ - Syntactic formalization of f-strings F-strings have been significantly improved in Python 3.12. They now support nested quotes of the same type, backslash escapes, and multi-line expressions without any limitations. This makes f-strings much more flexible and eliminates many edge cases that previously required workarounds. .. code-block:: python >>> songs = ["Take me back to Eden", "&", "Satellite"] >>> f"This is the playlist: {", ".join(songs)}" 'This is the playlist: Take me back to Eden, &, Satellite' # Nested quotes now work >>> f"He said {"hello"}" 'He said hello' Exception Groups ----------------- **New in Python 3.11** - PEP 654_ - Exception Groups and except* Exception groups allow you to raise and handle multiple unrelated exceptions simultaneously. This is particularly useful for concurrent operations where multiple tasks might fail independently. The new ``except*`` syntax lets you handle specific exception types from a group while letting others propagate. .. code-block:: python >>> def raise_multiple(): ... raise ExceptionGroup("multiple errors", [ ... ValueError("invalid value"), ... TypeError("wrong type"), ... ]) ... >>> try: ... raise_multiple() ... except* ValueError as e: ... print(f"ValueError: {e.exceptions}") ... except* TypeError as e: ... print(f"TypeError: {e.exceptions}") ... ValueError: (ValueError('invalid value'),) TypeError: (TypeError('wrong type'),) Structural Pattern Matching ---------------------------- **New in Python 3.10** - PEP 634_ - Structural Pattern Matching: Specification - PEP 635_ - Structural Pattern Matching: Motivation and Rationale Pattern matching provides a powerful way to destructure and match complex data structures. It's similar to switch statements in other languages but far more expressive, supporting sequence patterns, mapping patterns, class patterns, and guards. The wildcard ``_`` matches anything and serves as a default case. .. code-block:: python >>> def http_status(status): ... match status: ... case 200: ... return "OK" ... case 404: ... return "Not Found" ... case 500: ... return "Internal Server Error" ... case _: ... return "Unknown" ... >>> http_status(200) 'OK' >>> http_status(404) 'Not Found' # Pattern matching with destructuring >>> def describe_point(point): ... match point: ... case (0, 0): ... return "Origin" ... case (x, 0): ... return f"On x-axis at {x}" ... case (0, y): ... return f"On y-axis at {y}" ... case (x, y): ... return f"Point at ({x}, {y})" ... >>> describe_point((0, 0)) 'Origin' >>> describe_point((5, 0)) 'On x-axis at 5' Dictionary Merge ---------------- **New in Python 3.9** - PEP 584_ - Add Union Operators To dict The ``|`` operator provides a cleaner, more intuitive way to merge dictionaries. The ``|=`` operator updates a dictionary in place. This is more readable than using ``{**a, **b}`` or ``dict.update()`` and follows the pattern of set operations. .. code-block:: python >>> a = {"foo": "Foo"} >>> b = {"bar": "Bar"} # old way >>> {**a, **b} {'foo': 'Foo', 'bar': 'Bar'} >>> a.update(b) >>> a {'foo': 'Foo', 'bar': 'Bar'} # new way >>> a = {"foo": "Foo"} >>> a | b {'foo': 'Foo', 'bar': 'Bar'} >>> a |= b >>> a {'foo': 'Foo', 'bar': 'Bar'} Positional-only parameters --------------------------- **New in Python 3.8** - PEP 570_ - Python Positional-Only Parameters Parameters before the ``/`` marker must be passed positionally and cannot be used as keyword arguments. This gives library authors more flexibility in API design and allows parameter names to be changed without breaking backward compatibility. .. code-block:: python >>> def f(a, b, /, c, d): ... print(a, b, c, d) ... >>> f(1, 2, 3, 4) 1 2 3 4 >>> f(1, 2, c=3, d=4) 1 2 3 4 >>> f(1, b=2, c=3, d=4) Traceback (most recent call last): File "", line 1, in TypeError: f() got some positional-only arguments passed as keyword arguments: 'b' The walrus operator -------------------- **New in Python 3.8** - PEP 572_ - Assignment Expressions The walrus operator ``:=`` allows you to assign values to variables as part of an expression. This reduces code duplication when you need to both compute a value and use it in a condition. After completing PEP 572, Guido van Rossum, commonly known as BDFL, decided to resign as Python's dictator. .. code-block:: python >>> f = (0, 1) >>> [(f := (f[1], sum(f)))[0] for i in range(10)] [1, 1, 2, 3, 5, 8, 13, 21, 34, 55] # Useful in while loops >>> while (line := input("Enter: ")) != "quit": ... print(f"You entered: {line}") # Useful in if statements >>> if (n := len("hello")) > 3: ... print(f"Length {n} is greater than 3") Data Classes ------------- **New in Python 3.7** - PEP 557_ - Data Classes Dataclasses automatically generate boilerplate code like ``__init__``, ``__repr__``, ``__eq__``, and optionally ``__hash__`` for classes that primarily store data. This reduces repetitive code and makes data-holding classes more concise and readable. Mutable Data Class .. code-block:: python >>> from dataclasses import dataclass >>> @dataclass ... class DCls(object): ... x: str ... y: str ... >>> d = DCls("foo", "bar") >>> d DCls(x='foo', y='bar') >>> d = DCls(x="foo", y="baz") >>> d DCls(x='foo', y='baz') >>> d.z = "bar" Immutable Data Class .. code-block:: python >>> from dataclasses import dataclass >>> from dataclasses import FrozenInstanceError >>> @dataclass(frozen=True) ... class DCls(object): ... x: str ... y: str ... >>> try: ... d.x = "baz" ... except FrozenInstanceError as e: ... print(e) ... cannot assign to field 'x' >>> try: ... d.z = "baz" ... except FrozenInstanceError as e: ... print(e) ... cannot assign to field 'z' Built-in ``breakpoint()`` -------------------------- **New in Python 3.7** - PEP 553_ - Built-in breakpoint() The ``breakpoint()`` function provides a convenient way to drop into the debugger. It respects the ``PYTHONBREAKPOINT`` environment variable, allowing you to customize or disable debugging behavior without modifying code. .. code-block:: python >>> for x in range(3): ... print(x) ... breakpoint() ... 0 > (1)()->None (Pdb) c 1 > (1)()->None (Pdb) c 2 > (1)()->None (Pdb) c Core support for typing module and generic types ------------------------------------------------- **New in Python 3.7** - PEP 560_ - Core support for typing module and generic types Python 3.7 added core support for the typing module, making generic types faster and enabling classes to customize how they're subscripted via ``__class_getitem__``. Before Python 3.7 .. code-block:: python >>> from typing import Generic, TypeVar >>> from typing import Iterable >>> T = TypeVar('T') >>> class C(Generic[T]): ... ... >>> def func(l: Iterable[C[int]]) -> None: ... for i in l: ... print(i) ... >>> func([1,2,3]) 1 2 3 Python 3.7 or above .. code-block:: python >>> from typing import Iterable >>> class C: ... def __class_getitem__(cls, item): ... return f"{cls.__name__}[{item.__name__}]" ... >>> def func(l: Iterable[C[int]]) -> None: ... for i in l: ... print(i) ... >>> func([1,2,3]) 1 2 3 Variable annotations -------------------- **New in Python 3.6** - PEP 526_ - Syntax for Variable Annotations Variables can now be annotated with types using the ``:`` syntax, even without immediate assignment. This enables better static analysis and IDE support. .. code-block:: python >>> from typing import List >>> x: List[int] = [1, 2, 3] >>> x [1, 2, 3] >>> from typing import List, Dict >>> class Cls(object): ... x: List[int] = [1, 2, 3] ... y: Dict[str, str] = {"foo": "bar"} ... >>> o = Cls() >>> o.x [1, 2, 3] >>> o.y {'foo': 'bar'} f-string --------- **New in Python 3.6** - PEP 498_ - Literal String Interpolation F-strings (formatted string literals) provide a concise and readable way to embed expressions inside string literals. They are faster than ``%`` formatting and ``str.format()`` because they are evaluated at runtime. .. code-block:: python >>> py = "Python3" >>> f'Awesome {py}' 'Awesome Python3' >>> x = [1, 2, 3, 4, 5] >>> f'{x}' '[1, 2, 3, 4, 5]' >>> def foo(x:int) -> int: ... return x + 1 ... >>> f'{foo(0)}' '1' >>> f'{123.567:1.3}' '1.24e+02' Asynchronous generators ------------------------ **New in Python 3.6** - PEP 525_ - Asynchronous Generators Asynchronous generators combine the power of generators with async/await syntax, allowing you to yield values asynchronously. This is useful for streaming data from async sources. .. code-block:: python >>> import asyncio >>> async def fib(n: int): ... a, b = 0, 1 ... for _ in range(n): ... await asyncio.sleep(1) ... yield a ... b, a = a + b , b ... >>> async def coro(n: int): ... ag = fib(n) ... f = await ag.asend(None) ... print(f) ... f = await ag.asend(None) ... print(f) ... >>> loop = asyncio.get_event_loop() >>> loop.run_until_complete(coro(5)) 0 1 Asynchronous comprehensions ---------------------------- **New in Python 3.6** - PEP 530_ - Asynchronous Comprehensions Async comprehensions allow using ``async for`` in list, set, dict comprehensions and generator expressions. You can also use ``await`` expressions within comprehensions. .. code-block:: python >>> import asyncio >>> async def fib(n: int): ... a, b = 0, 1 ... for _ in range(n): ... await asyncio.sleep(1) ... yield a ... b, a = a + b , b ... # async for ... else >>> async def coro(n: int): ... async for f in fib(n): ... print(f, end=" ") ... else: ... print() ... >>> loop = asyncio.get_event_loop() >>> loop.run_until_complete(coro(5)) 0 1 1 2 3 # async for in list >>> async def coro(n: int): ... return [f async for f in fib(n)] ... >>> loop.run_until_complete(coro(5)) [0, 1, 1, 2, 3] # await in list >>> async def slowfmt(n: int) -> str: ... await asyncio.sleep(0.5) ... return f'{n}' ... >>> async def coro(n: int): ... return [await slowfmt(f) async for f in fib(n)] ... >>> loop.run_until_complete(coro(5)) ['0', '1', '1', '2', '3'] New dict implementation ------------------------ **New in Python 3.6** - PEP 468_ - Preserving the order of \*\*kwargs in a function - PEP 520_ - Preserving Class Attribute Definition Order - bpo 27350_ - More compact dictionaries with faster iteration Python 3.6 introduced a new dictionary implementation that uses 20-25% less memory and preserves insertion order. This was an implementation detail in 3.6 but became a language guarantee in Python 3.7. Before Python 3.5 .. code-block:: python >>> import sys >>> sys.getsizeof({str(i):i for i in range(1000)}) 49248 >>> d = {'timmy': 'red', 'barry': 'green', 'guido': 'blue'} >>> d # without order-preserving {'barry': 'green', 'timmy': 'red', 'guido': 'blue'} Python 3.6 - Memory usage is smaller than Python 3.5 - Preserve insertion ordered .. code-block:: python >>> import sys >>> sys.getsizeof({str(i):i for i in range(1000)}) 36968 >>> d = {'timmy': 'red', 'barry': 'green', 'guido': 'blue'} >>> d # preserve insertion ordered {'timmy': 'red', 'barry': 'green', 'guido': 'blue'} ``async`` and ``await`` syntax ------------------------------- **New in Python 3.5** - PEP 492_ - Coroutines with async and await syntax The ``async`` and ``await`` keywords provide native syntax for writing coroutines, making asynchronous code much more readable than the previous generator-based approach. This is the foundation of modern Python async programming. Before Python 3.5 .. code-block:: python >>> import asyncio >>> @asyncio.coroutine ... def fib(n: int): ... a, b = 0, 1 ... for _ in range(n): ... b, a = a + b, b ... return a ... >>> @asyncio.coroutine ... def coro(n: int): ... for x in range(n): ... yield from asyncio.sleep(1) ... f = yield from fib(x) ... print(f) ... >>> loop = asyncio.get_event_loop() >>> loop.run_until_complete(coro(3)) 0 1 1 Python 3.5 or above .. code-block:: python >>> import asyncio >>> async def fib(n: int): ... a, b = 0, 1 ... for _ in range(n): ... b, a = a + b, b ... return a ... >>> async def coro(n: int): ... for x in range(n): ... await asyncio.sleep(1) ... f = await fib(x) ... print(f) ... >>> loop = asyncio.get_event_loop() >>> loop.run_until_complete(coro(3)) 0 1 1 General unpacking ------------------ **New in Python 3.5** - PEP 448_ - Additional Unpacking Generalizations Python 3.5 extended the ``*`` and ``**`` unpacking operators to work in more contexts, including function calls with multiple unpacking operations and in list/dict literals. Python 2 .. code-block:: python >>> def func(*a, **k): ... print(a) ... print(k) ... >>> func(*[1,2,3,4,5], **{"foo": "bar"}) (1, 2, 3, 4, 5) {'foo': 'bar'} Python 3 .. code-block:: python >>> print(*[1, 2, 3], 4, *[5, 6]) 1 2 3 4 5 6 >>> [*range(4), 4] [0, 1, 2, 3, 4] >>> {"foo": "Foo", "bar": "Bar", **{"baz": "baz"}} {'foo': 'Foo', 'bar': 'Bar', 'baz': 'baz'} >>> def func(*a, **k): ... print(a) ... print(k) ... >>> func(*[1], *[4,5], **{"foo": "FOO"}, **{"bar": "BAR"}) (1, 4, 5) {'foo': 'FOO', 'bar': 'BAR'} Matrix multiplication ---------------------- **New in Python 3.5** - PEP 465_ - A dedicated infix operator for matrix multiplication The ``@`` operator was added for matrix multiplication, primarily benefiting scientific computing libraries like NumPy. Classes can implement ``__matmul__`` and ``__imatmul__`` to support this operator. .. code-block:: python >>> # "@" represent matrix multiplication >>> class Arr: ... def __init__(self, *arg): ... self._arr = arg ... def __matmul__(self, other): ... if not isinstance(other, Arr): ... raise TypeError ... if len(self) != len(other): ... raise ValueError ... return sum([x*y for x, y in zip(self._arr, other._arr)]) ... def __imatmul__(self, other): ... if not isinstance(other, Arr): ... raise TypeError ... if len(self) != len(other): ... raise ValueError ... res = sum([x*y for x, y in zip(self._arr, other._arr)]) ... self._arr = [res] ... return self ... def __len__(self): ... return len(self._arr) ... def __str__(self): ... return self.__repr__() ... def __repr__(self): ... return "Arr({})".format(repr(self._arr)) ... >>> a = Arr(9, 5, 2, 7) >>> b = Arr(5, 5, 6, 6) >>> a @ b # __matmul__ 124 >>> a @= b # __imatmul__ >>> a Arr([124]) Format byte string ------------------- **New in Python 3.5** - PEP 461_ - Adding ``%`` formatting to bytes and bytearray The ``%`` formatting operator now works with bytes and bytearray objects, making it easier to work with binary protocols and formats. .. code-block:: python >>> b'abc %b %b' % (b'foo', b'bar') b'abc foo bar' >>> b'%d %f' % (1, 3.14) b'1 3.140000' >>> class Cls(object): ... def __repr__(self): ... return "repr" ... def __str__(self): ... return "str" ... 'repr' >>> b'%a' % Cls() b'repr' Suppressing exception ---------------------- **New in Python 3.3** - PEP 409_ - Suppressing exception context When re-raising exceptions, Python shows the chain of exceptions by default. Using ``raise ... from None`` suppresses the context, showing only the new exception. Without ``raise Exception from None`` .. code-block:: python >>> def func(): ... try: ... 1 / 0 ... except ZeroDivisionError: ... raise ArithmeticError ... >>> func() Traceback (most recent call last): File "", line 3, in func ZeroDivisionError: division by zero During handling of the above exception, another exception occurred: Traceback (most recent call last): File "", line 1, in File "", line 5, in func ArithmeticError With ``raise Exception from None`` .. code-block:: python >>> def func(): ... try: ... 1 / 0 ... except ZeroDivisionError: ... raise ArithmeticError from None ... >>> func() Traceback (most recent call last): File "", line 1, in File "", line 5, in func ArithmeticError # debug >>> try: ... func() ... except ArithmeticError as e: ... print(e.__context__) ... division by zero Generator delegation ---------------------- **New in Python 3.3** - PEP 380_ - Syntax for Delegating to a Subgenerator The ``yield from`` expression allows a generator to delegate part of its operations to another generator. This simplifies writing generators that consume other generators. .. code-block:: python >>> def fib(n: int): ... a, b = 0, 1 ... for _ in range(n): ... yield a ... b, a = a + b, b ... >>> def delegate(n: int): ... yield from fib(n) ... >>> list(delegate(10)) [0, 1, 1, 2, 3, 5, 8, 13, 21, 34] BDFL retirement --------------- **New in Python 3.1** - PEP 401_ - BDFL Retirement An April Fools' joke PEP that added a humorous easter egg. When you import ``barry_as_FLUFL`` from ``__future__``, the ``!=`` operator is replaced with ``<>``. .. code-block:: python >>> from __future__ import barry_as_FLUFL >>> 1 != 2 File "", line 1 1 != 2 ^ SyntaxError: with Barry as BDFL, use '<>' instead of '!=' >>> 1 <> 2 True Function annotations -------------------- **New in Python 3.0** - PEP 3107_ - Function Annotations - PEP 484_ - Type Hints - PEP 483_ - The Theory of Type Hints Function annotations allow attaching metadata to function parameters and return values. While Python doesn't enforce these at runtime, they enable static type checking tools and better IDE support. .. code-block:: python >>> import types >>> generator = types.GeneratorType >>> def fib(n: int) -> generator: ... a, b = 0, 1 ... for _ in range(n): ... yield a ... b, a = a + b, b ... >>> [f for f in fib(10)] [0, 1, 1, 2, 3, 5, 8, 13, 21, 34] Extended iterable unpacking ---------------------------- **New in Python 3.0** - PEP 3132_ - Extended Iterable Unpacking The ``*`` operator in unpacking captures remaining items into a list. This works in assignments and for loops, making it easy to split sequences. .. code-block:: python >>> a, *b, c = range(5) >>> a, b, c (0, [1, 2, 3], 4) >>> for a, *b in [(1, 2, 3), (4, 5, 6, 7)]: ... print(a, b) ... 1 [2, 3] 4 [5, 6, 7] Keyword-Only Arguments ----------------------- **New in Python 3.0** - PEP 3102_ - Keyword-Only Arguments Parameters defined after ``*`` in a function signature must be passed as keyword arguments. This improves API clarity and prevents accidental positional usage. .. code-block:: python >>> def f(a, b, *, kw): ... print(a, b, kw) ... >>> f(1, 2, 3) Traceback (most recent call last): File "", line 1, in TypeError: f() takes 2 positional arguments but 3 were given >>> f(1, 2) Traceback (most recent call last): File "", line 1, in TypeError: f() missing 1 required keyword-only argument: 'kw' >>> f(1, 2, kw=3) 1 2 3 New Super ---------- **New in Python 3.0** - PEP 3135_ - New Super Python 3 simplified the ``super()`` call by making it work without arguments in most cases. The interpreter automatically determines the class and instance. Python 2 .. code-block:: python >>> class ParentCls(object): ... def foo(self): ... print "call parent" ... >>> class ChildCls(ParentCls): ... def foo(self): ... super(ChildCls, self).foo() ... print "call child" ... >>> p = ParentCls() >>> c = ChildCls() >>> p.foo() call parent >>> c.foo() call parent call child Python 3 .. code-block:: python >>> class ParentCls(object): ... def foo(self): ... print("call parent") ... >>> class ChildCls(ParentCls): ... def foo(self): ... super().foo() ... print("call child") ... >>> p = ParentCls() >>> c = ChildCls() >>> p.foo() call parent >>> c.foo() call parent call child Add ``nonlocal`` keyword ------------------------- **New in Python 3.0** - PEP 3104_ - Access to Names in Outer Scopes The ``nonlocal`` keyword allows assigning to variables in an enclosing (but non-global) scope. This is useful for closures that need to modify outer variables. .. code-block:: python >>> def outf(): ... o = "out" ... def inf(): ... nonlocal o ... o = "change out" ... inf() ... print(o) ... >>> outf() change out Not allow ``from module import *`` inside function --------------------------------------------------- **New in Python 3.0** Star imports are now only allowed at module level. This prevents namespace pollution and makes code more predictable. .. code-block:: python >>> def f(): ... from os import * ... File "", line 1 SyntaxError: import * only allowed at module level Remove ``<>`` -------------- **New in Python 3.0** The ``<>`` operator (alternative to ``!=``) was removed in Python 3 to simplify the language. Use ``!=`` for inequality comparisons. Python 2 .. code-block:: python >>> a = "Python2" >>> a <> "Python3" True # equal to != >>> a != "Python3" True Python 3 .. code-block:: python >>> a = "Python3" >>> a != "Python2" True ``print`` is a function ------------------------- **New in Python 3.0** - PEP 3105_ - Make print a function In Python 3, ``print`` became a function instead of a statement. This allows more flexibility with keyword arguments like ``end``, ``sep``, and ``file``. Python 2 .. code-block:: python >>> print "print is a statement" print is a statement >>> for x in range(3): ... print x, ... 0 1 2 Python 3 .. code-block:: python >>> print("print is a function") print is a function >>> print() >>> for x in range(3): ... print(x, end=' ') ... else: ... print() ... 0 1 2 String is unicode ------------------- **New in Python 3.0** - PEP 3138_ - String representation in Python 3000 - PEP 3120_ - Using UTF-8 as the default source encoding - PEP 3131_ - Supporting Non-ASCII Identifiers In Python 3, all strings are Unicode by default. The ``str`` type represents Unicode text, while ``bytes`` represents binary data. This eliminates many encoding-related bugs common in Python 2. Python 2 .. code-block:: python >>> s = 'Café' # byte string >>> s 'Caf\xc3\xa9' >>> type(s) >>> u = u'Café' # unicode string >>> u u'Caf\xe9' >>> type(u) >>> len([_c for _c in 'Café']) 5 Python 3 .. code-block:: python >>> s = 'Café' >>> s 'Café' >>> type(s) >>> s.encode('utf-8') b'Caf\xc3\xa9' >>> s.encode('utf-8').decode('utf-8') 'Café' >>> len([_c for _c in 'Café']) 4 Division Operator ------------------ **New in Python 3.0** - PEP 238_ - Changing the Division Operator In Python 3, the ``/`` operator always performs true division (returning a float), while ``//`` performs floor division. This eliminates a common source of bugs. Python 2 .. code-block:: python >>> 1 / 2 0 >>> 1 // 2 0 >>> 1. / 2 0.5 # back port "true division" to python2 >>> from __future__ import division >>> 1 / 2 0.5 >>> 1 // 2 0 Python 3 .. code-block:: python >>> 1 / 2 0.5 >>> 1 // 2 0 .. _695: https://www.python.org/dev/peps/pep-0695/ .. _701: https://www.python.org/dev/peps/pep-0701/ .. _654: https://www.python.org/dev/peps/pep-0654/ .. _634: https://www.python.org/dev/peps/pep-0634/ .. _635: https://www.python.org/dev/peps/pep-0635/ .. _584: https://www.python.org/dev/peps/pep-0584/ .. _570: https://www.python.org/dev/peps/pep-0570/ .. _572: https://www.python.org/dev/peps/pep-0572/ .. _557: https://www.python.org/dev/peps/pep-0557/ .. _553: https://www.python.org/dev/peps/pep-0553/ .. _560: https://www.python.org/dev/peps/pep-0560/ .. _526: https://www.python.org/dev/peps/pep-0526/ .. _498: https://www.python.org/dev/peps/pep-0498/ .. _525: https://www.python.org/dev/peps/pep-0525/ .. _530: https://www.python.org/dev/peps/pep-0530/ .. _468: https://www.python.org/dev/peps/pep-0468/ .. _520: https://www.python.org/dev/peps/pep-0520/ .. _27350: https://bugs.python.org/issue27350 .. _492: https://www.python.org/dev/peps/pep-0492/ .. _448: https://www.python.org/dev/peps/pep-0448/ .. _465: https://www.python.org/dev/peps/pep-0465/ .. _461: https://www.python.org/dev/peps/pep-0461/ .. _409: https://www.python.org/dev/peps/pep-0409/ .. _380: https://www.python.org/dev/peps/pep-0380/ .. _401: https://www.python.org/dev/peps/pep-0401/ .. _3107: https://www.python.org/dev/peps/pep-3107/ .. _484: https://www.python.org/dev/peps/pep-0484/ .. _483: https://www.python.org/dev/peps/pep-0483/ .. _3132: https://www.python.org/dev/peps/pep-3132/ .. _3102: https://www.python.org/dev/peps/pep-3102/ .. _3135: https://www.python.org/dev/peps/pep-3135/ .. _3104: https://www.python.org/dev/peps/pep-3104/ .. _3105: https://www.python.org/dev/peps/pep-3105/ .. _3138: https://www.python.org/dev/peps/pep-3138/ .. _3120: https://www.python.org/dev/peps/pep-3120/ .. _3131: https://www.python.org/dev/peps/pep-3131/ .. _238: https://www.python.org/dev/peps/pep-0238/ ================================================ FILE: docs/notes/security/index.rst ================================================ .. meta:: :description lang=en: Python security and cryptography guide covering modern encryption, TLS/SSL, common vulnerabilities, and secure coding practices :keywords: Python, Python3, security, cryptography, encryption, AES, RSA, TLS, SSL, vulnerability, padding oracle, injection, secure coding ======== Security ======== Security is essential for protecting data in transit and at rest. This section covers modern cryptographic practices using well-maintained libraries like ``cryptography`` and ``argon2-cffi``, as well as common security vulnerabilities and how to avoid them. We emphasize secure defaults: authenticated encryption (AES-GCM), proper key derivation (PBKDF2, Argon2), secure signatures (Ed25519, RSA-PSS), and correct TLS configuration. Understanding vulnerabilities is equally important—knowing why legacy patterns like AES-CBC without authentication or PKCS#1 v1.5 padding are dangerous helps you recognize and fix insecure code in existing systems. .. toctree:: :maxdepth: 1 python-crypto python-tls python-vulnerability ================================================ FILE: docs/notes/security/python-crypto.rst ================================================ .. meta:: :description lang=en: Modern Python cryptography guide covering symmetric encryption (AES-GCM), asymmetric encryption (RSA-OAEP), digital signatures, key derivation, and secure random generation using the cryptography library :keywords: Python, Python3, Cryptography, AES-GCM, RSA, OAEP, Digital Signature, PBKDF2, Argon2, Key Derivation, Encryption, Decryption, HMAC, Hashing =================== Modern Cryptography =================== .. contents:: Table of Contents :backlinks: none This guide covers modern cryptographic practices in Python using the ``cryptography`` library, which is the recommended choice for new projects. The library provides both high-level recipes (Fernet) for common use cases and low-level primitives for advanced needs. We focus on secure defaults: AES-GCM for symmetric encryption (provides both confidentiality and integrity), RSA-OAEP for asymmetric encryption, Ed25519 for signatures, and proper key derivation functions. Avoid deprecated libraries like PyCrypto—use ``cryptography`` or ``PyCryptodome`` instead. .. warning:: Cryptography is difficult to implement correctly. Prefer high-level APIs like Fernet when possible. Never invent your own cryptographic schemes. Always use authenticated encryption (AES-GCM, ChaCha20-Poly1305) instead of unauthenticated modes (AES-CBC, AES-CTR alone). Algorithm Recommendations ------------------------- Quick reference for choosing secure algorithms. When in doubt, use the recommended options—they represent current best practices as of 2024. :: ┌─────────────────────────────────────────────────────────────────────────┐ │ ALGORITHM RECOMMENDATIONS │ ├─────────────────────────────────────────────────────────────────────────┤ │ USE CASE │ RECOMMENDED │ AVOID │ ├────────────────────────┼──────────────────────┼─────────────────────────┤ │ Symmetric Encryption │ AES-256-GCM │ AES-CBC, AES-ECB, │ │ │ ChaCha20-Poly1305 │ DES, 3DES, Blowfish, │ │ │ │ RC4 │ ├────────────────────────┼──────────────────────┼─────────────────────────┤ │ Asymmetric Encryption │ RSA-OAEP (≥3072-bit) │ RSA PKCS#1 v1.5, │ │ │ ECIES │ RSA < 2048-bit, │ │ │ │ ElGamal │ ├────────────────────────┼──────────────────────┼─────────────────────────┤ │ Digital Signatures │ Ed25519 │ RSA PKCS#1 v1.5, │ │ │ RSA-PSS (≥3072-bit) │ DSA, ECDSA with P-256 │ │ │ Ed448 │ (if possible) │ ├────────────────────────┼──────────────────────┼─────────────────────────┤ │ Key Exchange │ X25519 │ DH < 2048-bit, │ │ │ X448 │ Static DH │ │ │ ECDH (P-384+) │ │ ├────────────────────────┼──────────────────────┼─────────────────────────┤ │ Password Hashing │ Argon2id │ MD5, SHA-1, SHA-256, │ │ │ scrypt │ bcrypt (less preferred),│ │ │ PBKDF2 (≥600k iter) │ plain hashes │ ├────────────────────────┼──────────────────────┼─────────────────────────┤ │ General Hashing │ SHA-256, SHA-3 │ MD5, SHA-1 │ │ │ BLAKE2, BLAKE3 │ │ ├────────────────────────┼──────────────────────┼─────────────────────────┤ │ MAC │ HMAC-SHA256 │ HMAC-MD5, HMAC-SHA1 │ │ │ KMAC, Poly1305 │ │ ├────────────────────────┼──────────────────────┼─────────────────────────┤ │ Random Generation │ secrets module │ random module │ │ │ os.urandom() │ time-based seeds │ └────────────────────────┴──────────────────────┴─────────────────────────┘ Key Size Recommendations ------------------------ Minimum key sizes for security through 2030+. Larger keys provide more security margin but slower performance. :: ┌─────────────────────────────────────────────────────────────────────────┐ │ KEY SIZE GUIDELINES │ ├─────────────────────────────────────────────────────────────────────────┤ │ ALGORITHM │ MINIMUM │ RECOMMENDED │ NOTES │ ├────────────────────────┼────────────┼──────────────┼────────────────────┤ │ AES │ 128-bit │ 256-bit │ 256 for long-term │ │ ChaCha20 │ 256-bit │ 256-bit │ Only size available│ │ RSA (encryption) │ 2048-bit │ 3072-4096 │ 4096 for long-term │ │ RSA (signatures) │ 2048-bit │ 3072-4096 │ 4096 for long-term │ │ ECDSA/ECDH │ P-256 │ P-384 │ Prefer Ed25519 │ │ Ed25519 │ 256-bit │ 256-bit │ Fixed size │ │ X25519 │ 256-bit │ 256-bit │ Fixed size │ │ HMAC key │ 256-bit │ 256-bit │ Match hash output │ │ Salt (password) │ 128-bit │ 128-bit │ 16 bytes minimum │ │ Nonce (AES-GCM) │ 96-bit │ 96-bit │ 12 bytes, unique! │ │ IV (AES-CBC) │ 128-bit │ 128-bit │ 16 bytes, random │ └────────────────────────┴────────────┴──────────────┴────────────────────┘ Common Mistakes (Don't Do This) ------------------------------- Examples of insecure patterns to avoid. Each "BAD" example shows a common mistake; the "GOOD" example shows the secure alternative. **❌ Using random module for security:** .. code-block:: python # BAD: Predictable random - can be reverse-engineered! import random token = ''.join(random.choices('abcdef0123456789', k=32)) key = bytes([random.randint(0, 255) for _ in range(32)]) # GOOD: Cryptographically secure random import secrets token = secrets.token_hex(16) key = secrets.token_bytes(32) **❌ Using ECB mode:** .. code-block:: python # BAD: ECB mode reveals patterns in data! from Crypto.Cipher import AES cipher = AES.new(key, AES.MODE_ECB) # NEVER use ECB # GOOD: Use authenticated encryption from cryptography.hazmat.primitives.ciphers.aead import AESGCM aesgcm = AESGCM(key) ciphertext = aesgcm.encrypt(nonce, plaintext, None) **❌ AES-CBC without authentication:** .. code-block:: python # BAD: No integrity check - vulnerable to padding oracle attacks! from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes cipher = Cipher(algorithms.AES(key), modes.CBC(iv)) encryptor = cipher.encryptor() ciphertext = encryptor.update(padded_data) + encryptor.finalize() # Attacker can modify ciphertext without detection! # GOOD: AES-GCM provides authentication from cryptography.hazmat.primitives.ciphers.aead import AESGCM aesgcm = AESGCM(key) ciphertext = aesgcm.encrypt(nonce, plaintext, None) # Any modification will be detected during decryption **❌ Reusing nonces:** .. code-block:: python # BAD: Reusing nonce completely breaks AES-GCM security! nonce = b'fixed_nonce!' # NEVER do this ct1 = aesgcm.encrypt(nonce, msg1, None) ct2 = aesgcm.encrypt(nonce, msg2, None) # Catastrophic! # GOOD: Generate unique nonce for each encryption import os nonce1 = os.urandom(12) ct1 = aesgcm.encrypt(nonce1, msg1, None) nonce2 = os.urandom(12) ct2 = aesgcm.encrypt(nonce2, msg2, None) **❌ RSA with PKCS#1 v1.5 padding:** .. code-block:: python # BAD: Vulnerable to Bleichenbacher's attack! from Crypto.Cipher import PKCS1_v1_5 cipher = PKCS1_v1_5.new(key) ciphertext = cipher.encrypt(message) # GOOD: Use OAEP padding from cryptography.hazmat.primitives.asymmetric import padding ciphertext = public_key.encrypt( message, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None, ), ) **❌ Hashing passwords with SHA-256:** .. code-block:: python # BAD: Too fast - billions of guesses per second! import hashlib password_hash = hashlib.sha256(password.encode()).hexdigest() # GOOD: Use slow password hashing function from argon2 import PasswordHasher ph = PasswordHasher() password_hash = ph.hash(password) **❌ Comparing MACs with ==:** .. code-block:: python # BAD: Timing attack reveals information about correct MAC! if received_mac == computed_mac: # VULNERABLE process_message() # GOOD: Constant-time comparison import hmac if hmac.compare_digest(received_mac, computed_mac): process_message() **❌ Hardcoding keys/secrets:** .. code-block:: python # BAD: Keys in source code end up in version control! SECRET_KEY = "super_secret_key_12345" API_KEY = "ak_live_xxxxxxxxxxxxx" # GOOD: Use environment variables or secret management import os SECRET_KEY = os.environ.get('SECRET_KEY') # Or use: AWS Secrets Manager, HashiCorp Vault, etc. Security Checklist ------------------ Use this checklist when implementing cryptography in your application. **Before You Start:** .. code-block:: text □ Do I really need custom crypto? (Consider existing solutions first) □ Am I using a well-maintained library? (cryptography, not PyCrypto) □ Have I read the library's security documentation? **Key Management:** .. code-block:: text □ Keys generated with cryptographically secure random (secrets/os.urandom) □ Keys are appropriate size (AES-256, RSA ≥3072-bit) □ Keys stored securely (not in source code, use env vars or secret manager) □ Key rotation plan in place □ Different keys for different purposes (encryption vs signing) □ Keys protected at rest (encrypted with master key or HSM) **Encryption:** .. code-block:: text □ Using authenticated encryption (AES-GCM or ChaCha20-Poly1305) □ Nonces are unique per encryption (random 12 bytes for AES-GCM) □ Not reusing key+nonce combinations □ Associated data (AAD) used where appropriate □ Ciphertext includes nonce for decryption □ Using RSA-OAEP (not PKCS#1 v1.5) for asymmetric encryption **Signatures:** .. code-block:: text □ Using Ed25519 or RSA-PSS (not PKCS#1 v1.5) □ Signing the right data (include all relevant fields) □ Verifying signatures before trusting data □ Handling verification failures securely **Password Storage:** .. code-block:: text □ Using Argon2id, scrypt, or PBKDF2 (not plain hashes) □ Unique salt per password (≥16 bytes) □ Sufficient iterations/memory cost (tune for ~100ms-500ms) □ Rehashing when parameters change □ Not logging passwords or hashes **TLS/Network:** .. code-block:: text □ TLS 1.2 or 1.3 only (no SSL, TLS 1.0, TLS 1.1) □ Certificate validation enabled □ Using trusted CA certificates □ Hostname verification enabled □ Certificate pinning for high-security apps **General:** .. code-block:: text □ No sensitive data in logs □ Secure memory handling (clear keys after use where possible) □ Error messages don't leak sensitive information □ Timing-safe comparisons for secrets □ Dependencies up to date (check for CVEs) Secure Random Generation ------------------------ Cryptographic operations require unpredictable random numbers. Python's ``secrets`` module (Python 3.6+) provides cryptographically secure random generation, suitable for tokens, passwords, and keys. Never use the ``random`` module for security-sensitive applications—it uses a predictable PRNG (Mersenne Twister) that can be reverse-engineered from observed outputs. .. code-block:: python import secrets import os # Generate random bytes (for keys, IVs, salts) key = secrets.token_bytes(32) # 256-bit key iv = secrets.token_bytes(16) # 128-bit IV # Generate URL-safe token (for session IDs, API keys) token = secrets.token_urlsafe(32) # ~43 characters # Generate hex token hex_token = secrets.token_hex(16) # 32 hex characters # Secure random integer n = secrets.randbelow(100) # 0 <= n < 100 # Secure choice from sequence password_char = secrets.choice('abcdefghijklmnopqrstuvwxyz0123456789') # Alternative: os.urandom (works on all Python versions) key = os.urandom(32) Cryptographic Hashing --------------------- Hash functions produce fixed-size digests from arbitrary data. Use SHA-256 or SHA-3 for general hashing. For password storage, use dedicated password hashing functions (see Key Derivation section). Hash functions are one-way: you cannot recover the original data from a hash. They're used for data integrity verification, digital signatures, and as building blocks for other cryptographic operations. .. code-block:: python import hashlib data = b"Hello, World!" # SHA-256 (recommended for general use) digest = hashlib.sha256(data).hexdigest() print(f"SHA-256: {digest}") # SHA-3 (newer, different internal structure) digest = hashlib.sha3_256(data).hexdigest() print(f"SHA3-256: {digest}") # BLAKE2 (fast, secure, supports keying) digest = hashlib.blake2b(data, digest_size=32).hexdigest() print(f"BLAKE2b: {digest}") # Keyed BLAKE2 (MAC without separate HMAC construction) key = b"secret-key-here!" mac = hashlib.blake2b(data, key=key, digest_size=32).hexdigest() # Incremental hashing (for large files) h = hashlib.sha256() with open("largefile.bin", "rb") as f: for chunk in iter(lambda: f.read(8192), b""): h.update(chunk) print(f"File hash: {h.hexdigest()}") HMAC (Hash-based Message Authentication Code) --------------------------------------------- HMAC provides message authentication—verifying both integrity and authenticity. Unlike plain hashes, HMAC requires a secret key, so only parties with the key can create or verify the MAC. Use HMAC when you need to ensure data hasn't been tampered with and came from someone who knows the secret. Always use constant-time comparison to prevent timing attacks when verifying MACs. .. code-block:: python import hmac import hashlib import secrets key = secrets.token_bytes(32) message = b"Important message" # Create HMAC mac = hmac.new(key, message, hashlib.sha256).digest() # Verify HMAC (constant-time comparison) received_mac = mac # In practice, received from sender if hmac.compare_digest(mac, received_mac): print("Message is authentic") else: print("Message was tampered with!") # HMAC with hexdigest mac_hex = hmac.new(key, message, hashlib.sha256).hexdigest() Key Derivation Functions ------------------------ Key derivation functions (KDFs) derive cryptographic keys from passwords or other key material. For passwords, use slow KDFs (PBKDF2, Argon2, scrypt) that resist brute-force attacks. For deriving multiple keys from a master key, use HKDF. Never use plain hashes (SHA-256) for password storage—they're too fast, allowing billions of guesses per second. .. code-block:: python import os from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives.kdf.scrypt import Scrypt # PBKDF2 - widely supported, use >= 600,000 iterations (OWASP 2023) password = b"user-password" salt = os.urandom(16) kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, salt=salt, iterations=600_000, ) key = kdf.derive(password) # To verify a password, derive again and compare kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, salt=salt, # Must use same salt iterations=600_000, ) try: kdf.verify(password, key) print("Password correct") except Exception: print("Password incorrect") # Scrypt - memory-hard, better resistance to GPU/ASIC attacks kdf = Scrypt(salt=salt, length=32, n=2**17, r=8, p=1) key = kdf.derive(password) # HKDF - for deriving multiple keys from master key (not for passwords) master_key = os.urandom(32) hkdf = HKDF( algorithm=hashes.SHA256(), length=32, salt=salt, info=b"encryption-key", ) derived_key = hkdf.derive(master_key) Symmetric Encryption: AES-GCM ----------------------------- AES-GCM (Galois/Counter Mode) is the recommended symmetric encryption mode. It provides authenticated encryption: both confidentiality (data is encrypted) and integrity (tampering is detected). The authentication tag ensures ciphertext hasn't been modified. Always use a unique nonce (number used once) for each encryption with the same key—reusing nonces completely breaks security. :: ┌─────────────────────────────────────────────────────────────────┐ │ AES-GCM ENCRYPTION │ ├─────────────────────────────────────────────────────────────────┤ │ │ │ Plaintext ──┬──► AES-GCM ──► Ciphertext │ │ │ │ │ │ Key ────────┤ ├──► Authentication Tag (16 bytes) │ │ │ │ │ │ Nonce ──────┘ └──► Associated Data (AAD) authenticated │ │ (12 bytes) but not encrypted │ │ │ │ Security: Nonce MUST be unique per key. Random 12-byte │ │ nonce is safe for ~2^32 encryptions per key. │ │ │ └─────────────────────────────────────────────────────────────────┘ .. code-block:: python import os from cryptography.hazmat.primitives.ciphers.aead import AESGCM # Generate a random 256-bit key key = AESGCM.generate_key(bit_length=256) # Create cipher aesgcm = AESGCM(key) # Encrypt nonce = os.urandom(12) # 96-bit nonce, MUST be unique per encryption plaintext = b"Secret message" associated_data = b"header" # Authenticated but not encrypted (optional) ciphertext = aesgcm.encrypt(nonce, plaintext, associated_data) # ciphertext includes the 16-byte authentication tag # Decrypt decrypted = aesgcm.decrypt(nonce, ciphertext, associated_data) assert decrypted == plaintext # Tampering detection - modifying ciphertext raises exception try: tampered = bytearray(ciphertext) tampered[0] ^= 1 # Flip one bit aesgcm.decrypt(nonce, bytes(tampered), associated_data) except Exception as e: print(f"Tampering detected: {e}") Symmetric Encryption: ChaCha20-Poly1305 --------------------------------------- ChaCha20-Poly1305 is an alternative to AES-GCM, offering similar security with better performance on systems without AES hardware acceleration (common on mobile devices and older CPUs). It's used by TLS 1.3, WireGuard, and many modern protocols. Like AES-GCM, it provides authenticated encryption with associated data (AEAD). .. code-block:: python import os from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 # Generate key (256-bit) key = ChaCha20Poly1305.generate_key() # Create cipher chacha = ChaCha20Poly1305(key) # Encrypt nonce = os.urandom(12) # 96-bit nonce plaintext = b"Secret message" aad = b"additional authenticated data" ciphertext = chacha.encrypt(nonce, plaintext, aad) # Decrypt decrypted = chacha.decrypt(nonce, ciphertext, aad) assert decrypted == plaintext High-Level Encryption: Fernet ----------------------------- Fernet provides a high-level, easy-to-use symmetric encryption API. It uses AES-128-CBC with HMAC for authentication, handles IV generation, and includes timestamp for optional TTL-based expiration. Use Fernet when you need simple, secure encryption without worrying about low-level details. The downside is slightly larger ciphertext and no associated data support. .. code-block:: python from cryptography.fernet import Fernet, InvalidToken import time # Generate key (store this securely!) key = Fernet.generate_key() print(f"Key: {key.decode()}") # Base64-encoded # Create Fernet instance f = Fernet(key) # Encrypt plaintext = b"Secret message" token = f.encrypt(plaintext) print(f"Token: {token.decode()}") # Decrypt decrypted = f.decrypt(token) assert decrypted == plaintext # Decrypt with TTL (time-to-live in seconds) try: # Token must have been created within last 60 seconds decrypted = f.decrypt(token, ttl=60) except InvalidToken: print("Token expired or invalid") # Key rotation with MultiFernet from cryptography.fernet import MultiFernet old_key = Fernet.generate_key() new_key = Fernet.generate_key() # MultiFernet tries keys in order for decryption # Always encrypts with first key multi = MultiFernet([Fernet(new_key), Fernet(old_key)]) # Can decrypt tokens from either key old_token = Fernet(old_key).encrypt(b"old data") decrypted = multi.decrypt(old_token) # Works! # Re-encrypt with new key new_token = multi.rotate(old_token) RSA Key Generation ------------------ RSA is an asymmetric algorithm using public/private key pairs. The public key encrypts data or verifies signatures; the private key decrypts or signs. Modern recommendations: use at least 2048-bit keys (3072 or 4096 for long-term security), public exponent 65537, and OAEP padding for encryption or PSS for signatures. For new systems, consider Ed25519 (signatures) or X25519 (key exchange) instead. .. code-block:: python from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives import serialization # Generate RSA key pair private_key = rsa.generate_private_key( public_exponent=65537, key_size=4096, # 2048 minimum, 4096 for long-term ) public_key = private_key.public_key() # Serialize private key (PEM format) private_pem = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.BestAvailableEncryption(b"passphrase"), ) # Serialize private key without encryption private_pem_unencrypted = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ) # Serialize public key public_pem = public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo, ) # Save to files with open("private_key.pem", "wb") as f: f.write(private_pem) with open("public_key.pem", "wb") as f: f.write(public_pem) # Load keys from files with open("private_key.pem", "rb") as f: loaded_private = serialization.load_pem_private_key( f.read(), password=b"passphrase", ) with open("public_key.pem", "rb") as f: loaded_public = serialization.load_pem_public_key(f.read()) RSA Encryption (OAEP) --------------------- RSA encryption should always use OAEP (Optimal Asymmetric Encryption Padding). Never use PKCS#1 v1.5 padding for new applications—it's vulnerable to padding oracle attacks. RSA can only encrypt small amounts of data (key_size/8 - padding overhead), so it's typically used to encrypt a symmetric key, which then encrypts the actual data (hybrid encryption). :: ┌─────────────────────────────────────────────────────────────────┐ │ RSA-OAEP ENCRYPTION │ ├─────────────────────────────────────────────────────────────────┤ │ │ │ Plaintext ──► OAEP Padding ──► RSA ──► Ciphertext │ │ │ │ │ │ │ Public Key │ │ │ │ │ OAEP uses: │ │ - MGF1 (Mask Generation Function) │ │ - Hash algorithm (SHA-256 recommended) │ │ - Optional label │ │ │ │ Max plaintext size: key_bytes - 2*hash_bytes - 2 │ │ For 4096-bit key with SHA-256: 512 - 64 - 2 = 446 bytes │ │ │ └─────────────────────────────────────────────────────────────────┘ .. code-block:: python from cryptography.hazmat.primitives.asymmetric import rsa, padding from cryptography.hazmat.primitives import hashes # Generate keys private_key = rsa.generate_private_key(public_exponent=65537, key_size=4096) public_key = private_key.public_key() # Encrypt with public key (OAEP padding) plaintext = b"Secret message for RSA" ciphertext = public_key.encrypt( plaintext, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None, ), ) # Decrypt with private key decrypted = private_key.decrypt( ciphertext, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None, ), ) assert decrypted == plaintext Hybrid Encryption ----------------- RSA has size limits and is slow. Hybrid encryption combines RSA's key distribution benefits with symmetric encryption's speed: generate a random symmetric key, encrypt the data with AES-GCM, then encrypt the symmetric key with RSA. The recipient decrypts the symmetric key with their RSA private key, then decrypts the data. .. code-block:: python import os from cryptography.hazmat.primitives.asymmetric import rsa, padding from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.ciphers.aead import AESGCM def hybrid_encrypt(public_key, plaintext): """Encrypt data using hybrid RSA + AES-GCM.""" # Generate random AES key and nonce aes_key = AESGCM.generate_key(bit_length=256) nonce = os.urandom(12) # Encrypt data with AES-GCM aesgcm = AESGCM(aes_key) ciphertext = aesgcm.encrypt(nonce, plaintext, None) # Encrypt AES key with RSA-OAEP encrypted_key = public_key.encrypt( aes_key, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None, ), ) return encrypted_key, nonce, ciphertext def hybrid_decrypt(private_key, encrypted_key, nonce, ciphertext): """Decrypt data using hybrid RSA + AES-GCM.""" # Decrypt AES key with RSA aes_key = private_key.decrypt( encrypted_key, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None, ), ) # Decrypt data with AES-GCM aesgcm = AESGCM(aes_key) return aesgcm.decrypt(nonce, ciphertext, None) # Usage private_key = rsa.generate_private_key(public_exponent=65537, key_size=4096) public_key = private_key.public_key() message = b"This message can be arbitrarily long!" * 1000 encrypted_key, nonce, ciphertext = hybrid_encrypt(public_key, message) decrypted = hybrid_decrypt(private_key, encrypted_key, nonce, ciphertext) assert decrypted == message Digital Signatures: RSA-PSS --------------------------- Digital signatures prove authenticity and integrity. The signer uses their private key to create a signature; anyone with the public key can verify it. Use PSS (Probabilistic Signature Scheme) padding for RSA signatures—it's provably secure unlike PKCS#1 v1.5. For new applications, consider Ed25519 instead of RSA. .. code-block:: python from cryptography.hazmat.primitives.asymmetric import rsa, padding from cryptography.hazmat.primitives import hashes # Generate keys private_key = rsa.generate_private_key(public_exponent=65537, key_size=4096) public_key = private_key.public_key() message = b"Message to sign" # Sign with private key (PSS padding) signature = private_key.sign( message, padding.PSS( mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH, ), hashes.SHA256(), ) # Verify with public key try: public_key.verify( signature, message, padding.PSS( mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH, ), hashes.SHA256(), ) print("Signature valid") except Exception: print("Signature invalid!") Digital Signatures: Ed25519 --------------------------- Ed25519 is a modern signature algorithm based on elliptic curves. It offers excellent security with small keys (32 bytes) and signatures (64 bytes), fast operations, and resistance to many implementation pitfalls. Prefer Ed25519 over RSA for new applications unless you need compatibility with legacy systems. .. code-block:: python from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey from cryptography.hazmat.primitives import serialization # Generate key pair private_key = Ed25519PrivateKey.generate() public_key = private_key.public_key() # Sign message message = b"Message to sign" signature = private_key.sign(message) # Verify signature try: public_key.verify(signature, message) print("Signature valid") except Exception: print("Signature invalid!") # Serialize keys private_bytes = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ) public_bytes = public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo, ) # OpenSSH format for public key public_ssh = public_key.public_bytes( encoding=serialization.Encoding.OpenSSH, format=serialization.PublicFormat.OpenSSH, ) print(public_ssh.decode()) # ssh-ed25519 AAAA... Elliptic Curve Diffie-Hellman (ECDH) ------------------------------------ ECDH allows two parties to establish a shared secret over an insecure channel. Each party generates a key pair, exchanges public keys, and derives the same shared secret. Use X25519 (Curve25519) for modern applications—it's fast, secure, and resistant to timing attacks. The shared secret should be passed through a KDF before use as an encryption key. .. code-block:: python from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives import hashes # Alice generates her key pair alice_private = X25519PrivateKey.generate() alice_public = alice_private.public_key() # Bob generates his key pair bob_private = X25519PrivateKey.generate() bob_public = bob_private.public_key() # Exchange public keys (over insecure channel) # Alice computes shared secret alice_shared = alice_private.exchange(bob_public) # Bob computes shared secret bob_shared = bob_private.exchange(alice_public) # Both arrive at the same shared secret assert alice_shared == bob_shared # Derive encryption key from shared secret using HKDF def derive_key(shared_secret, info): hkdf = HKDF( algorithm=hashes.SHA256(), length=32, salt=None, info=info, ) return hkdf.derive(shared_secret) encryption_key = derive_key(alice_shared, b"encryption") mac_key = derive_key(alice_shared, b"authentication") Password Hashing with Argon2 ---------------------------- Argon2 is the winner of the Password Hashing Competition (2015) and the recommended algorithm for password storage. It's memory-hard, making GPU/ASIC attacks expensive. Use the ``argon2-cffi`` library for Python. Store the full hash string (includes salt and parameters) in your database. .. code-block:: python # pip install argon2-cffi from argon2 import PasswordHasher from argon2.exceptions import VerifyMismatchError ph = PasswordHasher( time_cost=3, # Number of iterations memory_cost=65536, # Memory usage in KiB (64 MB) parallelism=4, # Number of parallel threads ) # Hash a password (for storage) password = "user-password" hash_str = ph.hash(password) print(f"Hash: {hash_str}") # $argon2id$v=19$m=65536,t=3,p=4$... # Verify a password (during login) try: ph.verify(hash_str, password) print("Password correct") # Check if rehash needed (parameters changed) if ph.check_needs_rehash(hash_str): new_hash = ph.hash(password) # Update stored hash except VerifyMismatchError: print("Password incorrect") ================================================ FILE: docs/notes/security/python-tls.rst ================================================ .. meta:: :description lang=en: Python TLS/SSL and X.509 certificate guide covering secure HTTPS servers, certificate generation, CSR creation, and certificate verification using the cryptography library :keywords: Python, Python3, TLS, SSL, HTTPS, X.509, Certificate, CSR, Certificate Authority, Self-Signed, SSLContext, cryptography ======================== TLS/SSL and Certificates ======================== .. contents:: Table of Contents :backlinks: none Transport Layer Security (TLS) provides encrypted, authenticated communication over networks. This guide covers creating secure HTTPS servers, generating certificates, and proper TLS configuration in Python. We use the ``ssl`` module's ``SSLContext`` API (not the deprecated ``wrap_socket``) and the ``cryptography`` library for certificate operations. Always use TLS 1.2 or 1.3—older versions have known vulnerabilities. .. warning:: For production, always use certificates from a trusted Certificate Authority (CA) like Let's Encrypt. Self-signed certificates are only for development and testing. Never disable certificate verification in production code. Secure HTTPS Server ------------------- Create an HTTPS server using ``SSLContext`` with secure defaults. The context configures TLS version, cipher suites, and certificate verification. Always load both the certificate chain and private key. For production, use certificates from a real CA. .. code-block:: python import ssl from http.server import HTTPServer, SimpleHTTPRequestHandler def create_secure_context(certfile, keyfile): """Create SSLContext with secure defaults.""" # TLS 1.2+ only, secure ciphers context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) context.minimum_version = ssl.TLSVersion.TLSv1_2 # Load certificate and private key context.load_cert_chain(certfile=certfile, keyfile=keyfile) # Disable insecure options context.options |= ssl.OP_NO_SSLv2 context.options |= ssl.OP_NO_SSLv3 context.options |= ssl.OP_NO_TLSv1 context.options |= ssl.OP_NO_TLSv1_1 return context # Create server host, port = "localhost", 8443 context = create_secure_context("cert.pem", "key.pem") httpd = HTTPServer((host, port), SimpleHTTPRequestHandler) httpd.socket = context.wrap_socket(httpd.socket, server_side=True) print(f"Serving HTTPS on https://{host}:{port}") httpd.serve_forever() Secure HTTPS Client ------------------- When making HTTPS requests, Python verifies certificates by default. For custom CA certificates or client authentication, configure an ``SSLContext``. Never set ``verify=False`` or disable hostname checking in production. .. code-block:: python import ssl import urllib.request # Default secure context (verifies certificates) context = ssl.create_default_context() # Make HTTPS request url = "https://example.com" with urllib.request.urlopen(url, context=context) as response: print(response.read().decode()) # Custom CA certificate (e.g., internal CA) context = ssl.create_default_context() context.load_verify_locations("internal-ca.pem") # Client certificate authentication (mTLS) context = ssl.create_default_context() context.load_cert_chain(certfile="client.pem", keyfile="client-key.pem") # Using requests library (recommended for HTTP) import requests # Default (secure) response = requests.get("https://example.com") # Custom CA response = requests.get("https://internal.example.com", verify="internal-ca.pem") # Client certificate response = requests.get( "https://secure.example.com", cert=("client.pem", "client-key.pem"), ) Generate Self-Signed Certificate -------------------------------- Self-signed certificates are useful for development and testing. The certificate is signed by its own private key rather than a CA. Browsers will show warnings for self-signed certificates. Use the ``cryptography`` library for certificate generation—it's more Pythonic than calling OpenSSL. .. code-block:: python import ipaddress from datetime import datetime, timedelta from cryptography import x509 from cryptography.x509.oid import NameOID from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives import serialization # Generate private key private_key = rsa.generate_private_key( public_exponent=65537, key_size=4096, ) # Certificate subject and issuer (same for self-signed) subject = issuer = x509.Name([ x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"), x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"), x509.NameAttribute(NameOID.ORGANIZATION_NAME, "My Organization"), x509.NameAttribute(NameOID.COMMON_NAME, "localhost"), ]) # Build certificate cert = ( x509.CertificateBuilder() .subject_name(subject) .issuer_name(issuer) .public_key(private_key.public_key()) .serial_number(x509.random_serial_number()) .not_valid_before(datetime.utcnow()) .not_valid_after(datetime.utcnow() + timedelta(days=365)) .add_extension( x509.SubjectAlternativeName([ x509.DNSName("localhost"), x509.DNSName("*.localhost"), x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), ]), critical=False, ) .add_extension( x509.BasicConstraints(ca=False, path_length=None), critical=True, ) .sign(private_key, hashes.SHA256()) ) # Save private key with open("key.pem", "wb") as f: f.write(private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption(), )) # Save certificate with open("cert.pem", "wb") as f: f.write(cert.public_bytes(serialization.Encoding.PEM)) print("Generated key.pem and cert.pem") Generate Certificate Signing Request (CSR) ------------------------------------------ A CSR is sent to a Certificate Authority to obtain a signed certificate. It contains your public key and identity information. The CA verifies your identity and returns a signed certificate. Keep your private key secret—never send it to the CA. .. code-block:: python from cryptography import x509 from cryptography.x509.oid import NameOID from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives import serialization # Generate private key (keep this secret!) private_key = rsa.generate_private_key( public_exponent=65537, key_size=4096, ) # Build CSR csr = ( x509.CertificateSigningRequestBuilder() .subject_name(x509.Name([ x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"), x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"), x509.NameAttribute(NameOID.ORGANIZATION_NAME, "My Company"), x509.NameAttribute(NameOID.COMMON_NAME, "www.example.com"), ])) .add_extension( x509.SubjectAlternativeName([ x509.DNSName("www.example.com"), x509.DNSName("example.com"), x509.DNSName("api.example.com"), ]), critical=False, ) .sign(private_key, hashes.SHA256()) ) # Save private key with open("private.key", "wb") as f: f.write(private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption(), )) # Save CSR (send this to CA) with open("request.csr", "wb") as f: f.write(csr.public_bytes(serialization.Encoding.PEM)) print("Generated private.key and request.csr") print("Send request.csr to your CA, keep private.key secret!") Read Certificate Information ---------------------------- Parse and inspect X.509 certificates to view subject, issuer, validity period, extensions, and other attributes. Useful for debugging certificate issues. .. code-block:: python from cryptography import x509 from cryptography.hazmat.primitives import serialization # Load certificate from file with open("cert.pem", "rb") as f: cert = x509.load_pem_x509_certificate(f.read()) # Basic information print(f"Subject: {cert.subject}") print(f"Issuer: {cert.issuer}") print(f"Serial: {cert.serial_number}") print(f"Not Before: {cert.not_valid_before}") print(f"Not After: {cert.not_valid_after}") # Get specific subject attributes cn = cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME) if cn: print(f"Common Name: {cn[0].value}") # Check extensions try: san = cert.extensions.get_extension_for_oid( x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME ) print(f"SANs: {san.value.get_values_for_type(x509.DNSName)}") except x509.ExtensionNotFound: print("No SAN extension") # Check if self-signed is_self_signed = cert.subject == cert.issuer print(f"Self-signed: {is_self_signed}") # Verify certificate signature (self-signed only) if is_self_signed: public_key = cert.public_key() try: # This verifies the certificate was signed by its own key public_key.verify( cert.signature, cert.tbs_certificate_bytes, cert.signature_algorithm_parameters, ) print("Signature valid") except Exception as e: print(f"Signature invalid: {e}") Create a Certificate Authority ------------------------------ For internal use, you can create your own CA to sign certificates. The CA certificate is distributed to clients, which then trust any certificate signed by the CA. This is useful for development environments or internal services. .. code-block:: python import ipaddress from datetime import datetime, timedelta from cryptography import x509 from cryptography.x509.oid import NameOID from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives import serialization def create_ca(): """Create a Certificate Authority.""" # Generate CA private key ca_key = rsa.generate_private_key( public_exponent=65537, key_size=4096, ) # CA certificate (self-signed) ca_name = x509.Name([ x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), x509.NameAttribute(NameOID.ORGANIZATION_NAME, "My Internal CA"), x509.NameAttribute(NameOID.COMMON_NAME, "My Internal Root CA"), ]) ca_cert = ( x509.CertificateBuilder() .subject_name(ca_name) .issuer_name(ca_name) .public_key(ca_key.public_key()) .serial_number(x509.random_serial_number()) .not_valid_before(datetime.utcnow()) .not_valid_after(datetime.utcnow() + timedelta(days=3650)) # 10 years .add_extension( x509.BasicConstraints(ca=True, path_length=0), critical=True, ) .add_extension( x509.KeyUsage( digital_signature=True, key_cert_sign=True, crl_sign=True, key_encipherment=False, content_commitment=False, data_encipherment=False, key_agreement=False, encipher_only=False, decipher_only=False, ), critical=True, ) .sign(ca_key, hashes.SHA256()) ) return ca_key, ca_cert def sign_csr(ca_key, ca_cert, csr_path, days=365): """Sign a CSR with the CA.""" with open(csr_path, "rb") as f: csr = x509.load_pem_x509_csr(f.read()) cert = ( x509.CertificateBuilder() .subject_name(csr.subject) .issuer_name(ca_cert.subject) .public_key(csr.public_key()) .serial_number(x509.random_serial_number()) .not_valid_before(datetime.utcnow()) .not_valid_after(datetime.utcnow() + timedelta(days=days)) .add_extension( x509.BasicConstraints(ca=False, path_length=None), critical=True, ) ) # Copy extensions from CSR for ext in csr.extensions: cert = cert.add_extension(ext.value, ext.critical) return cert.sign(ca_key, hashes.SHA256()) # Create CA ca_key, ca_cert = create_ca() # Save CA files with open("ca-key.pem", "wb") as f: f.write(ca_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.BestAvailableEncryption(b"ca-password"), )) with open("ca-cert.pem", "wb") as f: f.write(ca_cert.public_bytes(serialization.Encoding.PEM)) print("Created ca-key.pem (keep secret!) and ca-cert.pem (distribute to clients)") TLS Version and Cipher Information ---------------------------------- Inspect TLS connection details including protocol version, cipher suite, and peer certificate. Useful for debugging and security auditing. .. code-block:: python import ssl import socket def get_tls_info(hostname, port=443): """Get TLS connection information for a host.""" context = ssl.create_default_context() with socket.create_connection((hostname, port)) as sock: with context.wrap_socket(sock, server_hostname=hostname) as ssock: print(f"TLS Version: {ssock.version()}") print(f"Cipher: {ssock.cipher()}") # Peer certificate cert = ssock.getpeercert() print(f"Subject: {dict(x[0] for x in cert['subject'])}") print(f"Issuer: {dict(x[0] for x in cert['issuer'])}") print(f"Not Before: {cert['notBefore']}") print(f"Not After: {cert['notAfter']}") # Subject Alternative Names if 'subjectAltName' in cert: sans = [x[1] for x in cert['subjectAltName']] print(f"SANs: {sans}") get_tls_info("www.google.com") Certificate Pinning ------------------- Certificate pinning adds an extra layer of security by verifying the server's certificate matches an expected value. This prevents attacks using fraudulently issued certificates. Pin the public key (SPKI) rather than the certificate to survive certificate renewals. .. code-block:: python import ssl import socket import hashlib from cryptography import x509 from cryptography.hazmat.primitives import serialization def get_certificate_pin(hostname, port=443): """Get the SPKI pin for a certificate.""" context = ssl.create_default_context() with socket.create_connection((hostname, port)) as sock: with context.wrap_socket(sock, server_hostname=hostname) as ssock: # Get certificate in DER format der_cert = ssock.getpeercert(binary_form=True) # Parse certificate cert = x509.load_der_x509_certificate(der_cert) # Get public key in DER format (SPKI) spki = cert.public_key().public_bytes( encoding=serialization.Encoding.DER, format=serialization.PublicFormat.SubjectPublicKeyInfo, ) # SHA-256 hash of SPKI pin = hashlib.sha256(spki).digest() return pin def verify_pin(hostname, expected_pin, port=443): """Verify certificate matches expected pin.""" actual_pin = get_certificate_pin(hostname, port) if actual_pin != expected_pin: raise ssl.SSLError(f"Certificate pin mismatch for {hostname}") print(f"Pin verified for {hostname}") # Get pin (do this once, store the result) pin = get_certificate_pin("www.google.com") print(f"Pin (base64): {__import__('base64').b64encode(pin).decode()}") # Verify pin on subsequent connections verify_pin("www.google.com", pin) ================================================ FILE: docs/notes/security/python-vulnerability.rst ================================================ .. meta:: :description lang=en: Common Python security vulnerabilities and why legacy cryptographic patterns are insecure, with attack demonstrations :keywords: Python, Python3, security, vulnerability, PyCrypto, padding oracle, PKCS1 v1.5, AES-CBC, timing attack, insecure =============================== Common Security Vulnerabilities =============================== :Source: `src/security/vulnerability_.py `_ .. contents:: Table of Contents :backlinks: none Introduction ------------ This page explains why certain cryptographic patterns are insecure and how attackers can exploit them. Understanding these vulnerabilities helps you recognize dangerous code in legacy systems and avoid introducing similar weaknesses in new projects. For secure implementations, see :doc:`python-crypto` and :doc:`python-tls`. AES-CBC Without Authentication (Padding Oracle) ----------------------------------------------- AES-CBC mode encrypts data but provides no integrity protection. An attacker who can modify ciphertext and observe whether decryption succeeds can recover the plaintext byte-by-byte through a **padding oracle attack**. This attack exploits the PKCS#7 padding validation to leak information. **Vulnerable Code:** .. code-block:: python # INSECURE: AES-CBC without authentication from Crypto.Cipher import AES def encrypt_cbc(key, iv, plaintext): cipher = AES.new(key, AES.MODE_CBC, iv) # Manual PKCS#7 padding pad_len = 16 - (len(plaintext) % 16) padded = plaintext + bytes([pad_len] * pad_len) return cipher.encrypt(padded) def decrypt_cbc(key, iv, ciphertext): cipher = AES.new(key, AES.MODE_CBC, iv) padded = cipher.decrypt(ciphertext) # VULNERABLE: Padding validation leaks information pad_len = padded[-1] if not all(b == pad_len for b in padded[-pad_len:]): raise ValueError("Invalid padding") # Oracle! return padded[:-pad_len] **Why It's Vulnerable:** The padding validation error reveals whether the decrypted padding is valid. An attacker can: 1. Intercept a ciphertext block 2. Modify the previous block's last byte 3. Submit to the server and observe if padding error occurs 4. Repeat 256 times to determine one plaintext byte 5. Continue for all bytes .. code-block:: python # Simplified padding oracle attack concept def padding_oracle_attack(ciphertext, oracle_func): """ oracle_func returns True if padding is valid, False otherwise. This leaks enough information to decrypt without the key. """ # For each block, XOR previous block to control decrypted value # Try all 256 values until padding is valid # Valid padding reveals the intermediate state # XOR with known value gives plaintext pass # Full implementation is complex but well-documented **Secure Alternative:** Use AES-GCM which provides authenticated encryption: .. code-block:: python from cryptography.hazmat.primitives.ciphers.aead import AESGCM key = AESGCM.generate_key(bit_length=256) aesgcm = AESGCM(key) nonce = os.urandom(12) # Encryption includes authentication tag - tampering is detected ciphertext = aesgcm.encrypt(nonce, plaintext, associated_data) RSA PKCS#1 v1.5 Padding (Bleichenbacher Attack) ----------------------------------------------- RSA with PKCS#1 v1.5 padding is vulnerable to the **Bleichenbacher attack** (also called the "million message attack"). If a server reveals whether decryption produced valid PKCS#1 v1.5 padding, an attacker can decrypt messages or forge signatures. **Vulnerable Code:** .. code-block:: python # INSECURE: PKCS#1 v1.5 padding from Crypto.Cipher import PKCS1_v1_5 from Crypto.PublicKey import RSA def decrypt_rsa_v15(private_key_pem, ciphertext): key = RSA.import_key(private_key_pem) cipher = PKCS1_v1_5.new(key) # VULNERABLE: Different errors for padding vs other failures plaintext = cipher.decrypt(ciphertext, sentinel=None) if plaintext is None: raise ValueError("Decryption failed") # Oracle! return plaintext **Why It's Vulnerable:** PKCS#1 v1.5 padding has a specific structure: ``0x00 0x02 [random] 0x00 [message]``. When decryption fails due to invalid padding vs. other reasons, the different error responses create an oracle. An attacker can: 1. Choose a ciphertext ``c`` 2. Compute ``c' = c * s^e mod n`` for various ``s`` values 3. Submit ``c'`` and check if padding is valid 4. Use valid/invalid responses to narrow down the plaintext **Secure Alternative:** Use RSA-OAEP padding: .. code-block:: python from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives import hashes ciphertext = public_key.encrypt( plaintext, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None ) ) Timing Attacks on String Comparison ----------------------------------- Comparing secrets using ``==`` is vulnerable to **timing attacks**. The comparison stops at the first different byte, so the time taken reveals how many bytes match. An attacker can guess secrets byte-by-byte. **Vulnerable Code:** .. code-block:: python # INSECURE: Regular string comparison def verify_token(user_token, stored_token): return user_token == stored_token # Timing leak! def verify_signature(computed_sig, provided_sig): return computed_sig == provided_sig # Timing leak! **Why It's Vulnerable:** .. code-block:: python # Demonstration of timing difference import time secret = b"correct_secret_token_here" def insecure_compare(a, b): if len(a) != len(b): return False for x, y in zip(a, b): if x != y: return False # Returns early - timing leak return True # Attacker measures time for different guesses: # "a..." - fails fast (wrong first byte) # "c..." - takes slightly longer (first byte correct) # "co..." - even longer (two bytes correct) # Eventually recovers entire secret **Secure Alternative:** Use constant-time comparison: .. code-block:: python import hmac def verify_token(user_token, stored_token): # hmac.compare_digest runs in constant time return hmac.compare_digest(user_token, stored_token) Weak Random Number Generation ----------------------------- Using ``random`` module for security purposes is dangerous. It uses a deterministic PRNG (Mersenne Twister) that can be predicted if an attacker observes enough outputs. **Vulnerable Code:** .. code-block:: python # INSECURE: Using random for security import random import string def generate_token(): # VULNERABLE: Predictable after ~624 outputs observed chars = string.ascii_letters + string.digits return ''.join(random.choice(chars) for _ in range(32)) def generate_session_id(): # VULNERABLE: Can be predicted return random.randint(0, 2**64) **Why It's Vulnerable:** Mersenne Twister has 624 32-bit state values. After observing 624 outputs, an attacker can reconstruct the internal state and predict all future outputs. .. code-block:: python # Mersenne Twister state recovery (conceptual) # After collecting 624 consecutive 32-bit outputs, # attacker can "untemper" them to recover internal state # Then predict all future random() calls **Secure Alternative:** Use ``secrets`` module: .. code-block:: python import secrets def generate_token(): return secrets.token_urlsafe(32) def generate_session_id(): return secrets.token_hex(16) Hardcoded Secrets and Keys -------------------------- Embedding secrets in source code exposes them through version control, logs, error messages, and decompilation. **Vulnerable Code:** .. code-block:: python # INSECURE: Hardcoded secrets API_KEY = "sk_live_abc123xyz789" # Exposed in git history! DB_PASSWORD = "super_secret_password" ENCRYPTION_KEY = b"0123456789abcdef" def connect_to_api(): return requests.get(url, headers={"Authorization": API_KEY}) **Why It's Vulnerable:** - Secrets in git history persist even after deletion - Error messages may include variable values - Compiled Python (.pyc) can be decompiled - Logs may capture the values **Secure Alternative:** Use environment variables or secret managers: .. code-block:: python import os API_KEY = os.environ.get("API_KEY") if not API_KEY: raise RuntimeError("API_KEY environment variable required") # Or use a secrets manager from aws_secretsmanager import get_secret secrets = get_secret("my-app/production") SQL Injection ------------- Building SQL queries with string concatenation allows attackers to inject malicious SQL commands. **Vulnerable Code:** .. code-block:: python # INSECURE: String concatenation in SQL def get_user(username): query = f"SELECT * FROM users WHERE username = '{username}'" cursor.execute(query) # SQL injection! return cursor.fetchone() # Attacker input: "admin' OR '1'='1" # Results in: SELECT * FROM users WHERE username = 'admin' OR '1'='1' # Returns all users! # Worse: "admin'; DROP TABLE users; --" # Deletes the entire table! **Secure Alternative:** Use parameterized queries: .. code-block:: python def get_user(username): query = "SELECT * FROM users WHERE username = ?" cursor.execute(query, (username,)) # Safe - parameterized return cursor.fetchone() # Or with SQLAlchemy from sqlalchemy import select stmt = select(User).where(User.username == username) Command Injection ----------------- Passing user input to shell commands allows arbitrary command execution. **Vulnerable Code:** .. code-block:: python # INSECURE: Shell injection import os import subprocess def ping_host(hostname): os.system(f"ping -c 1 {hostname}") # Command injection! # Attacker input: "google.com; rm -rf /" # Executes: ping -c 1 google.com; rm -rf / def get_file_info(filename): # Also vulnerable with subprocess and shell=True result = subprocess.run( f"file {filename}", shell=True, # DANGEROUS capture_output=True ) **Secure Alternative:** Avoid shell, use argument lists: .. code-block:: python import subprocess import shlex def ping_host(hostname): # Validate input first if not hostname.replace('.', '').replace('-', '').isalnum(): raise ValueError("Invalid hostname") # Use list of arguments, not shell string subprocess.run(["ping", "-c", "1", hostname], check=True) def get_file_info(filename): # shell=False (default) prevents injection result = subprocess.run( ["file", filename], capture_output=True, check=True ) Insecure Deserialization (Pickle) --------------------------------- Python's ``pickle`` module can execute arbitrary code during deserialization. Never unpickle data from untrusted sources. **Vulnerable Code:** .. code-block:: python # INSECURE: Unpickling untrusted data import pickle def load_user_data(data): return pickle.loads(data) # Remote code execution! # Attacker can craft malicious pickle: import os class Exploit: def __reduce__(self): return (os.system, ("rm -rf /",)) malicious = pickle.dumps(Exploit()) # When unpickled, executes: os.system("rm -rf /") **Secure Alternative:** Use safe formats like JSON: .. code-block:: python import json def load_user_data(data): return json.loads(data) # Safe - no code execution # If you must use pickle, restrict classes import pickle import io class RestrictedUnpickler(pickle.Unpickler): ALLOWED_CLASSES = {('mymodule', 'SafeClass')} def find_class(self, module, name): if (module, name) not in self.ALLOWED_CLASSES: raise pickle.UnpicklingError(f"Forbidden: {module}.{name}") return super().find_class(module, name) Summary: Legacy vs Modern ------------------------- +------------------------+---------------------------+---------------------------+ | Vulnerability | Legacy (Insecure) | Modern (Secure) | +========================+===========================+===========================+ | Symmetric Encryption | AES-CBC without auth | AES-GCM | +------------------------+---------------------------+---------------------------+ | RSA Padding | PKCS#1 v1.5 | OAEP | +------------------------+---------------------------+---------------------------+ | Secret Comparison | ``==`` | ``hmac.compare_digest`` | +------------------------+---------------------------+---------------------------+ | Random Numbers | ``random`` | ``secrets`` | +------------------------+---------------------------+---------------------------+ | Password Hashing | MD5, SHA1 | Argon2, bcrypt | +------------------------+---------------------------+---------------------------+ | Crypto Library | PyCrypto | ``cryptography`` | +------------------------+---------------------------+---------------------------+ | SSL/TLS | ``ssl.wrap_socket`` | ``SSLContext`` | +------------------------+---------------------------+---------------------------+ ================================================ FILE: requirements.txt ================================================ coverage==7.13.4 cryptography==46.0.5 argon2-cffi==25.1.0 cffi==2.0.0 SQLAlchemy==2.0.48 bandit==1.9.4 coveralls==4.1.0 Flask==3.1.3 Flask-SSLify==0.1.5 Flask-Testing==0.8.1 Flask-SeaSurf==2.0.0 flask-talisman==1.1.0 gunicorn==25.1.0 pycodestyle==2.14.0 pydocstyle==6.3.0 pytest==9.0.2 requests==2.32.5 Sphinx==8.1.3 Werkzeug==3.1.6 setuptools>=70.0.0 # not directly required, pinned by Snyk to avoid a vulnerability jinja2>=3.1.3 # not directly required, pinned by Snyk to avoid a vulnerability zipp>=3.19.1 # not directly required, pinned by Snyk to avoid a vulnerability myst-parser==4.0.1 sphinx-copybutton==0.5.2 sphinx_design==0.6.1 sphinx-book-theme==1.1.4 urllib3>=2.6.3 # not directly required, pinned by Snyk to avoid a vulnerability ================================================ FILE: runtime.txt ================================================ python-3.12 ================================================ FILE: skills/py/SKILL.md ================================================ --- name: py description: Comprehensive Python programming reference covering syntax, concurrency, networking, databases, ML/LLM development, and HPC. Use for: Python questions, debugging, performance optimization, async patterns, library examples, code review, best practices, MLOps workflows, distributed computing, security implementations, and any Python development tasks. --- # Python Cheat Sheets (/py) Help users write functional, correct Python code and answer Python questions by fetching proven patterns and examples from pythonsheets.com. ## How It Works When a user asks a Python question or wants to write a Python script: 1. Look up the relevant topic(s) in [Structure](references/structure.md) to find the matching URL(s) 2. **Always fetch** the URL(s) using WebFetch to get real examples and patterns from the site 3. Use the fetched content to: - **Write code**: Apply the patterns to produce functional, correct code that solves the user's task - **Answer questions**: Provide thorough explanations backed by the examples and information from the site 4. Follow the [Guidelines](references/guidelines.md) for code quality ## Key Principle **Functionality first, cleanliness second.** The code must work correctly and handle the task properly. Fetching from pythonsheets.com ensures solutions use battle-tested patterns rather than guessing. The site contains rich examples covering edge cases, common pitfalls, and practical usage that go beyond basic documentation. ## Coverage Areas **Core:** Syntax, typing, OOP, functions, data structures, sets, heap, regex, unicode **System:** File I/O, datetime, OS interfaces **Concurrency:** Threading, multiprocessing, asyncio **Network:** Sockets, SSL/TLS, SSH, async I/O, packet sniffing **Database:** SQLAlchemy ORM, queries, transactions **Security:** Cryptography, TLS, vulnerabilities **Extensions:** C/C++ integration, pybind11, Cython **ML/LLM:** PyTorch, Megatron, distributed training, inference, serving, benchmarking **HPC:** Slurm, cluster computing, job scheduling, EFA monitoring, NCCL **Appendix:** Walrus operator, GDB debugging, disaggregated prefill/decode ## References - **[Structure](references/structure.md)** - Topic-to-URL map for fetching examples - **[Guidelines](references/guidelines.md)** - Code quality standards to apply after ensuring correctness ## Examples - "How does asyncio work?" → Fetch https://www.pythonsheets.com/notes/asyncio/python-asyncio-guide.html and explain with the site's examples - "Write a socket server" → Fetch https://www.pythonsheets.com/notes/network/python-socket-server.html, use the patterns to write a working server - "What's the walrus operator?" → Fetch https://www.pythonsheets.com/notes/appendix/python-walrus.html and explain with practical examples - "Set up Megatron distributed training" → Fetch https://www.pythonsheets.com/notes/llm/megatron.html, use the patterns to write a correct training script ================================================ FILE: skills/py/references/guidelines.md ================================================ # Python Development Guidelines Always fetch relevant examples from pythonsheets.com first to ensure correctness, then apply these guidelines when writing code. ## Choosing the Right Approach - **Concurrency model**: Use `asyncio` for I/O-bound tasks with many connections, `threading` for simpler I/O-bound work, `multiprocessing` for CPU-bound work - **Data structures**: Use `set` for membership tests, `heapq` for priority queues, `deque` for FIFO queues, `defaultdict` to avoid key existence checks - **Database access**: Use SQLAlchemy ORM for complex queries, parameterized queries for raw SQL — never string-interpolate SQL - **File operations**: Use `pathlib.Path` instead of `os.path` for cleaner, cross-platform file handling - **Network programming**: Use `asyncio` streams for high-concurrency servers, `socket` for low-level control, `ssl` context for secure connections ## Writing Correct Code - Handle errors at the right level — catch specific exceptions where you can recover, let others propagate - Use context managers (`with`) for files, connections, locks, and any resource that needs cleanup - Avoid mutable default arguments (`def f(x=[])`) — use `None` and initialize inside the function - Use `logging` instead of `print` for diagnostics — it supports levels, formatting, and output routing - Validate inputs at system boundaries (user input, external APIs), trust internal code ## Writing Clean Code - Use type hints for function signatures to clarify intent and enable static analysis - Use f-strings for string formatting - Use `dataclasses` or `NamedTuple` for structured data instead of raw dicts or tuples - Use `enum.Enum` for fixed sets of values - Prefer early returns over deep nesting - Keep functions short and focused on a single responsibility ## Performance Considerations - Profile before optimizing — use `cProfile`, `timeit`, or `line_profiler` to identify actual bottlenecks - Consider memory usage: generators for large sequences, `__slots__` for memory-heavy classes - Use connection pooling for database and network connections - Use appropriate caching (`functools.lru_cache`, `functools.cache`) for expensive pure functions - For CPU-bound hot paths, consider C extensions via `ctypes`, `pybind11`, or Cython ## Security Practices - Use `secrets` module for tokens and passwords, not `random` - Always use parameterized queries to prevent SQL injection - Use `ssl.create_default_context()` for TLS — don't disable certificate verification - Avoid `pickle` for untrusted data — use `json` or other safe serialization - Store secrets in environment variables or secret managers, never in code ## Related Documentation This skill is based on the comprehensive Python reference available at https://www.pythonsheets.com/ which includes working code snippets, performance benchmarks, real-world patterns, and integration guides. The reference is continuously updated with the latest Python features and best practices. ================================================ FILE: skills/py/references/structure.md ================================================ # Python Topics Reference Map Complete reference guide organized by topic, with direct links to live documentation. ## Core Python - **Basics** → https://www.pythonsheets.com/notes/basic/python-basic.html - **Type Hints** → https://www.pythonsheets.com/notes/basic/python-typing.html - **Classes & OOP** → https://www.pythonsheets.com/notes/basic/python-object.html - **Functions** → https://www.pythonsheets.com/notes/basic/python-func.html - **Lists** → https://www.pythonsheets.com/notes/basic/python-list.html - **Dictionaries** → https://www.pythonsheets.com/notes/basic/python-dict.html - **Sets** → https://www.pythonsheets.com/notes/basic/python-set.html - **Generators** → https://www.pythonsheets.com/notes/basic/python-generator.html - **Heap** → https://www.pythonsheets.com/notes/basic/python-heap.html - **Regular Expressions** → https://www.pythonsheets.com/notes/basic/python-rexp.html - **Unicode** → https://www.pythonsheets.com/notes/basic/python-unicode.html - **__future__** → https://www.pythonsheets.com/notes/basic/python-future.html ## What's New - **What's New in Python 3** → https://www.pythonsheets.com/notes/python-new-py3.html ## System Programming - **Date/Time** → https://www.pythonsheets.com/notes/os/python-date.html - **File I/O** → https://www.pythonsheets.com/notes/os/python-io.html - **OS Interfaces** → https://www.pythonsheets.com/notes/os/python-os.html ## Concurrency - **Threading** → https://www.pythonsheets.com/notes/concurrency/python-threading.html - **Multiprocessing** → https://www.pythonsheets.com/notes/concurrency/python-multiprocessing.html - **Futures** → https://www.pythonsheets.com/notes/concurrency/python-futures.html ## Asyncio - **Async Guide** → https://www.pythonsheets.com/notes/asyncio/python-asyncio-guide.html - **Async Basics** → https://www.pythonsheets.com/notes/asyncio/python-asyncio-basic.html - **Async Servers** → https://www.pythonsheets.com/notes/asyncio/python-asyncio-server.html - **Async Advanced** → https://www.pythonsheets.com/notes/asyncio/python-asyncio-advanced.html ## Network Programming - **Socket Basics** → https://www.pythonsheets.com/notes/network/python-socket.html - **Socket Servers** → https://www.pythonsheets.com/notes/network/python-socket-server.html - **Async Sockets** → https://www.pythonsheets.com/notes/network/python-socket-async.html - **Packet Sniffer** → https://www.pythonsheets.com/notes/network/python-socket-sniffer.html - **SSL/TLS** → https://www.pythonsheets.com/notes/network/python-socket-ssl.html - **SSH** → https://www.pythonsheets.com/notes/network/python-ssh.html ## Database - **SQLAlchemy Basics** → https://www.pythonsheets.com/notes/database/python-sqlalchemy.html - **SQLAlchemy ORM** → https://www.pythonsheets.com/notes/database/python-sqlalchemy-orm.html - **Query Patterns** → https://www.pythonsheets.com/notes/database/python-sqlalchemy-query.html ## Security - **Cryptography** → https://www.pythonsheets.com/notes/security/python-crypto.html - **TLS/SSL** → https://www.pythonsheets.com/notes/security/python-tls.html - **Vulnerabilities** → https://www.pythonsheets.com/notes/security/python-vulnerability.html ## C/C++ Extensions - **ctypes** → https://www.pythonsheets.com/notes/extension/python-ctypes.html - **C API** → https://www.pythonsheets.com/notes/extension/python-capi.html - **Modern Extensions** → https://www.pythonsheets.com/notes/extension/python-cext-modern.html - **C++ from Python** → https://www.pythonsheets.com/notes/extension/cpp-from-python.html ## LLM & Machine Learning - **PyTorch** → https://www.pythonsheets.com/notes/llm/pytorch.html - **Megatron / Distributed Training** → https://www.pythonsheets.com/notes/llm/megatron.html - **LLM Serving** → https://www.pythonsheets.com/notes/llm/llm-serving.html - **LLM Benchmarking** → https://www.pythonsheets.com/notes/llm/llm-bench.html ## High-Performance Computing - **Slurm HPC** → https://www.pythonsheets.com/notes/hpc/slurm.html ## Appendix - **Disaggregated Prefill/Decode** → https://www.pythonsheets.com/notes/appendix/disaggregated-prefill-decode.html - **Megatron EFA Monitoring** → https://www.pythonsheets.com/notes/appendix/megatron-efa-monitoring.html - **NCCL GIN** → https://www.pythonsheets.com/notes/appendix/nccl-gin.html - **Walrus Operator** → https://www.pythonsheets.com/notes/appendix/python-walrus.html - **Python GDB Debugging** → https://www.pythonsheets.com/notes/appendix/python-gdb.html ================================================ FILE: src/basic/asyncio_.py ================================================ """Tests for asyncio examples.""" import asyncio import pytest class TestAsyncioBasics: """Test basic asyncio operations.""" def test_asyncio_run(self): """Test basic coroutine execution.""" async def hello(): return "hello" result = asyncio.run(hello()) assert result == "hello" def test_create_task(self): """Test task creation and execution.""" async def compute(x): await asyncio.sleep(0.01) return x * 2 async def main(): task = asyncio.create_task(compute(5)) return await task result = asyncio.run(main()) assert result == 10 def test_gather(self): """Test gathering multiple coroutines.""" async def fetch(n): await asyncio.sleep(0.01) return n async def main(): return await asyncio.gather(fetch(1), fetch(2), fetch(3)) results = asyncio.run(main()) assert results == [1, 2, 3] def test_wait_for_timeout(self): """Test timeout handling.""" async def slow(): await asyncio.sleep(10) async def main(): await asyncio.wait_for(slow(), timeout=0.01) with pytest.raises(asyncio.TimeoutError): asyncio.run(main()) def test_wait_first_completed(self): """Test waiting for first completed task.""" async def fast(): await asyncio.sleep(0.01) return "fast" async def slow(): await asyncio.sleep(1) return "slow" async def main(): tasks = [asyncio.create_task(fast()), asyncio.create_task(slow())] done, pending = await asyncio.wait( tasks, return_when=asyncio.FIRST_COMPLETED ) for task in pending: task.cancel() return len(done), len(pending) done_count, pending_count = asyncio.run(main()) assert done_count == 1 assert pending_count == 1 class TestAsyncIterator: """Test async iterators.""" def test_async_iterator(self): """Test custom async iterator.""" class AsyncRange: def __init__(self, stop): self.current = 0 self.stop = stop def __aiter__(self): return self async def __anext__(self): if self.current >= self.stop: raise StopAsyncIteration await asyncio.sleep(0.001) value = self.current self.current += 1 return value async def main(): results = [] async for num in AsyncRange(3): results.append(num) return results results = asyncio.run(main()) assert results == [0, 1, 2] def test_async_generator(self): """Test async generator.""" async def async_range(stop): for i in range(stop): await asyncio.sleep(0.001) yield i async def main(): return [x async for x in async_range(3)] results = asyncio.run(main()) assert results == [0, 1, 2] class TestAsyncContextManager: """Test async context managers.""" def test_async_context_manager(self): """Test custom async context manager.""" state = {"entered": False, "exited": False} class AsyncCtx: async def __aenter__(self): await asyncio.sleep(0.001) state["entered"] = True return self async def __aexit__(self, *args): await asyncio.sleep(0.001) state["exited"] = True async def main(): async with AsyncCtx(): assert state["entered"] assert not state["exited"] asyncio.run(main()) assert state["exited"] def test_asynccontextmanager_decorator(self): """Test @asynccontextmanager decorator.""" from contextlib import asynccontextmanager @asynccontextmanager async def managed(): await asyncio.sleep(0.001) yield "resource" await asyncio.sleep(0.001) async def main(): async with managed() as r: return r result = asyncio.run(main()) assert result == "resource" class TestSynchronization: """Test asyncio synchronization primitives.""" def test_lock(self): """Test asyncio.Lock.""" async def main(): lock = asyncio.Lock() counter = [0] async def increment(): async with lock: current = counter[0] await asyncio.sleep(0.001) counter[0] = current + 1 await asyncio.gather(*[increment() for _ in range(10)]) return counter[0] result = asyncio.run(main()) assert result == 10 def test_semaphore(self): """Test asyncio.Semaphore for rate limiting.""" async def main(): semaphore = asyncio.Semaphore(2) concurrent = [0] max_concurrent = [0] async def task(): async with semaphore: concurrent[0] += 1 max_concurrent[0] = max(max_concurrent[0], concurrent[0]) await asyncio.sleep(0.01) concurrent[0] -= 1 await asyncio.gather(*[task() for _ in range(5)]) return max_concurrent[0] max_conc = asyncio.run(main()) assert max_conc <= 2 def test_event(self): """Test asyncio.Event for signaling.""" async def main(): event = asyncio.Event() results = [] async def waiter(name): await event.wait() results.append(name) async def setter(): await asyncio.sleep(0.01) event.set() await asyncio.gather(waiter("A"), waiter("B"), setter()) return results results = asyncio.run(main()) assert set(results) == {"A", "B"} class TestQueue: """Test asyncio queues.""" def test_queue(self): """Test asyncio.Queue.""" async def main(): queue = asyncio.Queue() results = [] async def producer(): for i in range(3): await queue.put(i) async def consumer(): for _ in range(3): item = await queue.get() results.append(item) queue.task_done() await asyncio.gather(producer(), consumer()) return results results = asyncio.run(main()) assert results == [0, 1, 2] def test_priority_queue(self): """Test asyncio.PriorityQueue.""" async def main(): queue = asyncio.PriorityQueue() await queue.put((3, "low")) await queue.put((1, "high")) await queue.put((2, "medium")) results = [] while not queue.empty(): _, item = await queue.get() results.append(item) return results results = asyncio.run(main()) assert results == ["high", "medium", "low"] class TestExceptionHandling: """Test exception handling in asyncio.""" def test_task_exception(self): """Test exception propagation from tasks.""" async def failing(): raise ValueError("test error") async def main(): task = asyncio.create_task(failing()) await task with pytest.raises(ValueError, match="test error"): asyncio.run(main()) def test_gather_return_exceptions(self): """Test gather with return_exceptions.""" async def ok(): return "ok" async def fail(): raise ValueError("error") async def main(): return await asyncio.gather(ok(), fail(), return_exceptions=True) results = asyncio.run(main()) assert results[0] == "ok" assert isinstance(results[1], ValueError) class TestCancellation: """Test task cancellation.""" def test_cancel_task(self): """Test cancelling a task.""" async def main(): cleanup_done = [False] async def long_running(): try: await asyncio.sleep(10) except asyncio.CancelledError: cleanup_done[0] = True raise task = asyncio.create_task(long_running()) await asyncio.sleep(0.01) task.cancel() try: await task except asyncio.CancelledError: pass return cleanup_done[0] result = asyncio.run(main()) assert result class TestExecutor: """Test running blocking code in executor.""" def test_run_in_executor(self): """Test run_in_executor for blocking code.""" import time def blocking(): time.sleep(0.01) return "done" async def main(): loop = asyncio.get_event_loop() return await loop.run_in_executor(None, blocking) result = asyncio.run(main()) assert result == "done" class TestSubprocess: """Test asyncio subprocess.""" def test_subprocess(self): """Test running subprocess.""" async def main(): proc = await asyncio.create_subprocess_shell( "echo hello", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) stdout, _ = await proc.communicate() return stdout.decode().strip(), proc.returncode output, code = asyncio.run(main()) assert output == "hello" assert code == 0 class TestTimeout: """Test timeout patterns.""" def test_wait_for_timeout(self): """Test asyncio.wait_for timeout.""" async def slow(): await asyncio.sleep(10) async def main(): try: await asyncio.wait_for(slow(), timeout=0.01) return False except asyncio.TimeoutError: return True result = asyncio.run(main()) assert result ================================================ FILE: src/basic/basic.py ================================================ """Python Basics Examples Source code for docs/notes/basic/python-basic.rst """ import sys import platform import pytest # Python Version def get_version_info() -> tuple: """Get Python version info.""" return sys.version_info[:3] def get_version_string() -> str: """Get Python version as string.""" return platform.python_version() def check_version(major: int, minor: int) -> bool: """Check if Python version is at least major.minor.""" return sys.version_info >= (major, minor) # Control Flow def classify_number(x: int) -> str: """Classify number as negative, zero, or positive.""" if x < 0: return "negative" elif x == 0: return "zero" else: return "positive" def is_even(x: int) -> bool: """Check if number is even using ternary.""" return True if x % 2 == 0 else False # Loops def sum_range(n: int) -> int: """Sum numbers from 0 to n-1.""" total = 0 for i in range(n): total += i return total def find_first_even(numbers: list) -> int | None: """Find first even number, demonstrating break.""" for n in numbers: if n % 2 == 0: return n return None def sum_odd_only(numbers: list) -> int: """Sum only odd numbers, demonstrating continue.""" total = 0 for n in numbers: if n % 2 == 0: continue total += n return total def loop_completed(items: list, target) -> bool: """Check if loop completed without finding target (for-else).""" for item in items: if item == target: return False return True # Exception Handling def safe_divide(a: float, b: float) -> float | None: """Divide with exception handling.""" try: return a / b except ZeroDivisionError: return None def parse_int(s: str) -> int | None: """Parse string to int with error handling.""" try: return int(s) except ValueError: return None def divide_or_raise(a: float, b: float) -> float: """Divide or raise ValueError.""" if b == 0: raise ValueError("divisor cannot be zero") return a / b # Comprehensions def squares(n: int) -> list: """List of squares using comprehension.""" return [x**2 for x in range(n)] def even_numbers(n: int) -> list: """Even numbers using comprehension with filter.""" return [x for x in range(n) if x % 2 == 0] def square_dict(n: int) -> dict: """Dict comprehension.""" return {x: x**2 for x in range(n)} # Truthiness def is_truthy(value) -> bool: """Check if value is truthy.""" return bool(value) # Multiple Assignment def swap(a, b) -> tuple: """Swap two values.""" return b, a def first_and_rest(items: list) -> tuple: """Split into first and rest.""" first, *rest = items return first, rest # Tests class TestVersion: def test_get_version_info(self): info = get_version_info() assert len(info) == 3 assert info[0] >= 3 def test_check_version(self): assert check_version(3, 0) assert not check_version(99, 0) class TestControlFlow: def test_classify_number(self): assert classify_number(-5) == "negative" assert classify_number(0) == "zero" assert classify_number(5) == "positive" def test_is_even(self): assert is_even(4) assert not is_even(3) class TestLoops: def test_sum_range(self): assert sum_range(5) == 10 # 0+1+2+3+4 def test_find_first_even(self): assert find_first_even([1, 3, 4, 6]) == 4 assert find_first_even([1, 3, 5]) is None def test_sum_odd_only(self): assert sum_odd_only([1, 2, 3, 4, 5]) == 9 # 1+3+5 def test_loop_completed(self): assert loop_completed([1, 2, 3], 5) assert not loop_completed([1, 2, 3], 2) class TestExceptions: def test_safe_divide(self): assert safe_divide(10, 2) == 5.0 assert safe_divide(10, 0) is None def test_parse_int(self): assert parse_int("42") == 42 assert parse_int("abc") is None def test_divide_or_raise(self): assert divide_or_raise(10, 2) == 5.0 with pytest.raises(ValueError): divide_or_raise(10, 0) class TestComprehensions: def test_squares(self): assert squares(5) == [0, 1, 4, 9, 16] def test_even_numbers(self): assert even_numbers(10) == [0, 2, 4, 6, 8] def test_square_dict(self): assert square_dict(3) == {0: 0, 1: 1, 2: 4} class TestTruthiness: def test_falsy(self): assert not is_truthy(None) assert not is_truthy(0) assert not is_truthy("") assert not is_truthy([]) def test_truthy(self): assert is_truthy(1) assert is_truthy("text") assert is_truthy([1]) class TestAssignment: def test_swap(self): assert swap(1, 2) == (2, 1) def test_first_and_rest(self): first, rest = first_and_rest([1, 2, 3, 4]) assert first == 1 assert rest == [2, 3, 4] ================================================ FILE: src/basic/cext_.py ================================================ """ Tests for C extension examples (ctypes and cffi). These tests demonstrate calling C code from Python without requiring compilation of pybind11/Cython modules. """ import ctypes import math import os import platform import subprocess import tempfile import pytest # Skip all tests if no C compiler available def has_c_compiler(): """Check if gcc or clang is available.""" for compiler in ["gcc", "clang", "cc"]: try: result = subprocess.run( [compiler, "--version"], capture_output=True, timeout=5 ) if result.returncode == 0: return True except (FileNotFoundError, subprocess.TimeoutExpired): continue return False requires_compiler = pytest.mark.skipif( not has_c_compiler(), reason="No C compiler available" ) class TestCtypesBasic: """Test ctypes with standard library functions.""" def test_libc_strlen(self): """Test calling strlen from libc.""" if platform.system() == "Darwin": libc = ctypes.CDLL("libc.dylib") elif platform.system() == "Linux": libc = ctypes.CDLL("libc.so.6") else: pytest.skip("Unsupported platform") libc.strlen.argtypes = [ctypes.c_char_p] libc.strlen.restype = ctypes.c_size_t result = libc.strlen(b"hello") assert result == 5 def test_libc_abs(self): """Test calling abs from libc.""" if platform.system() == "Darwin": libc = ctypes.CDLL("libc.dylib") elif platform.system() == "Linux": libc = ctypes.CDLL("libc.so.6") else: pytest.skip("Unsupported platform") assert libc.abs(-42) == 42 assert libc.abs(42) == 42 def test_math_sqrt(self): """Test calling sqrt from libm.""" if platform.system() == "Darwin": libm = ctypes.CDLL("libm.dylib") elif platform.system() == "Linux": libm = ctypes.CDLL("libm.so.6") else: pytest.skip("Unsupported platform") libm.sqrt.argtypes = [ctypes.c_double] libm.sqrt.restype = ctypes.c_double result = libm.sqrt(16.0) assert abs(result - 4.0) < 1e-10 class TestCtypesStructures: """Test ctypes with structures.""" def test_simple_structure(self): """Test creating and using a ctypes Structure.""" class Point(ctypes.Structure): _fields_ = [("x", ctypes.c_double), ("y", ctypes.c_double)] p = Point(3.0, 4.0) assert p.x == 3.0 assert p.y == 4.0 # Calculate distance manually distance = math.sqrt(p.x**2 + p.y**2) assert abs(distance - 5.0) < 1e-10 def test_nested_structure(self): """Test nested structures.""" class Point(ctypes.Structure): _fields_ = [("x", ctypes.c_double), ("y", ctypes.c_double)] class Rectangle(ctypes.Structure): _fields_ = [("top_left", Point), ("bottom_right", Point)] rect = Rectangle(Point(0, 10), Point(10, 0)) assert rect.top_left.x == 0 assert rect.top_left.y == 10 assert rect.bottom_right.x == 10 assert rect.bottom_right.y == 0 def test_array_in_structure(self): """Test arrays within structures.""" class Data(ctypes.Structure): _fields_ = [("values", ctypes.c_int * 5), ("count", ctypes.c_int)] d = Data() d.count = 5 for i in range(5): d.values[i] = i * 10 assert list(d.values) == [0, 10, 20, 30, 40] assert d.count == 5 class TestCtypesPointers: """Test ctypes pointer operations.""" def test_pointer_to_int(self): """Test pointer to integer.""" value = ctypes.c_int(42) ptr = ctypes.pointer(value) assert ptr.contents.value == 42 # Modify through pointer ptr.contents.value = 100 assert value.value == 100 def test_byref(self): """Test byref for passing by reference.""" value = ctypes.c_int(42) # byref creates a lightweight pointer for passing to C functions ref = ctypes.byref(value) # byref returns a CArgObject, not a full pointer assert ref is not None @requires_compiler class TestCtypesCustomLibrary: """Test ctypes with custom compiled C code.""" @pytest.fixture def fib_library(self, tmp_path): """Compile a simple Fibonacci library.""" c_code = """ unsigned long fib(unsigned long n) { if (n < 2) return n; return fib(n - 1) + fib(n - 2); } int add(int a, int b) { return a + b; } double multiply(double a, double b) { return a * b; } """ c_file = tmp_path / "fib.c" c_file.write_text(c_code) if platform.system() == "Darwin": lib_file = tmp_path / "libfib.dylib" cmd = [ "clang", "-shared", "-fPIC", "-o", str(lib_file), str(c_file), ] else: lib_file = tmp_path / "libfib.so" cmd = ["gcc", "-shared", "-fPIC", "-o", str(lib_file), str(c_file)] result = subprocess.run(cmd, capture_output=True) if result.returncode != 0: pytest.skip(f"Compilation failed: {result.stderr.decode()}") lib = ctypes.CDLL(str(lib_file)) # Set up function signatures lib.fib.argtypes = [ctypes.c_ulong] lib.fib.restype = ctypes.c_ulong lib.add.argtypes = [ctypes.c_int, ctypes.c_int] lib.add.restype = ctypes.c_int lib.multiply.argtypes = [ctypes.c_double, ctypes.c_double] lib.multiply.restype = ctypes.c_double return lib def test_fib(self, fib_library): """Test Fibonacci function.""" assert fib_library.fib(0) == 0 assert fib_library.fib(1) == 1 assert fib_library.fib(10) == 55 assert fib_library.fib(20) == 6765 def test_add(self, fib_library): """Test add function.""" assert fib_library.add(1, 2) == 3 assert fib_library.add(-5, 10) == 5 assert fib_library.add(0, 0) == 0 def test_multiply(self, fib_library): """Test multiply function with doubles.""" assert abs(fib_library.multiply(2.5, 4.0) - 10.0) < 1e-10 assert abs(fib_library.multiply(-1.5, 2.0) - (-3.0)) < 1e-10 class TestCffi: """Test cffi if available.""" @pytest.fixture def cffi_available(self): """Check if cffi is installed.""" try: import cffi return cffi.FFI() except ImportError: pytest.skip("cffi not installed") def test_cffi_libc(self, cffi_available): """Test cffi with libc.""" ffi = cffi_available ffi.cdef( """ int abs(int x); size_t strlen(const char *s); """ ) if platform.system() == "Darwin": libc = ffi.dlopen("libc.dylib") elif platform.system() == "Linux": libc = ffi.dlopen("libc.so.6") else: pytest.skip("Unsupported platform") assert libc.abs(-42) == 42 assert libc.strlen(b"hello") == 5 def test_cffi_math(self, cffi_available): """Test cffi with libm.""" ffi = cffi_available ffi.cdef( """ double sqrt(double x); double pow(double x, double y); """ ) if platform.system() == "Darwin": libm = ffi.dlopen("libm.dylib") elif platform.system() == "Linux": libm = ffi.dlopen("libm.so.6") else: pytest.skip("Unsupported platform") assert abs(libm.sqrt(16.0) - 4.0) < 1e-10 assert abs(libm.pow(2.0, 10.0) - 1024.0) < 1e-10 class TestPythonPerformance: """Test pure Python implementations for comparison.""" def test_python_fib(self): """Test pure Python Fibonacci.""" def fib(n): if n < 2: return n return fib(n - 1) + fib(n - 2) assert fib(0) == 0 assert fib(1) == 1 assert fib(10) == 55 assert fib(20) == 6765 def test_python_fib_iterative(self): """Test iterative Fibonacci (faster).""" def fib_iter(n): if n < 2: return n a, b = 0, 1 for _ in range(n - 1): a, b = b, a + b return b assert fib_iter(0) == 0 assert fib_iter(1) == 1 assert fib_iter(10) == 55 assert fib_iter(50) == 12586269025 def test_python_fib_memoized(self): """Test memoized Fibonacci.""" from functools import lru_cache @lru_cache(maxsize=None) def fib_memo(n): if n < 2: return n return fib_memo(n - 1) + fib_memo(n - 2) assert fib_memo(0) == 0 assert fib_memo(1) == 1 assert fib_memo(10) == 55 assert fib_memo(50) == 12586269025 ================================================ FILE: src/basic/concurrency_.py ================================================ """Tests for concurrency examples.""" import pytest import time from threading import Thread, Lock, RLock, Semaphore, Event, Condition, Barrier from queue import Queue from concurrent.futures import ThreadPoolExecutor, as_completed # Module-level functions for multiprocessing (must be picklable) def _mp_square(x): return x * x def _mp_add(a, b): return a + b def _mp_worker(q, n): q.put(n * n) def _mp_increment(counter): for _ in range(100): with counter.get_lock(): counter.value += 1 def _mp_double(arr): for i in range(len(arr)): arr[i] *= 2 class TestThreading: """Test threading operations.""" def test_thread_creation(self): """Test basic thread creation.""" results = [] def task(n): results.append(n) threads = [Thread(target=task, args=(i,)) for i in range(5)] for t in threads: t.start() for t in threads: t.join() assert sorted(results) == [0, 1, 2, 3, 4] def test_thread_with_return_value(self): """Test getting return values from threads.""" results = {} def compute(n, res): res[n] = n * n threads = [Thread(target=compute, args=(i, results)) for i in range(5)] for t in threads: t.start() for t in threads: t.join() assert results == {0: 0, 1: 1, 2: 4, 3: 9, 4: 16} def test_lock(self): """Test Lock for mutual exclusion.""" counter = [0] lock = Lock() def increment(): for _ in range(1000): with lock: counter[0] += 1 threads = [Thread(target=increment) for _ in range(10)] for t in threads: t.start() for t in threads: t.join() assert counter[0] == 10000 def test_rlock(self): """Test RLock for reentrant locking.""" lock = RLock() results = [] def outer(): with lock: results.append("outer") inner() def inner(): with lock: # Same thread can acquire again results.append("inner") t = Thread(target=outer) t.start() t.join() assert results == ["outer", "inner"] def test_semaphore(self): """Test Semaphore for resource limiting.""" max_concurrent = [0] current = [0] sem = Semaphore(3) def task(): with sem: current[0] += 1 max_concurrent[0] = max(max_concurrent[0], current[0]) time.sleep(0.01) current[0] -= 1 threads = [Thread(target=task) for _ in range(10)] for t in threads: t.start() for t in threads: t.join() assert max_concurrent[0] <= 3 def test_event(self): """Test Event for thread signaling.""" event = Event() results = [] def waiter(n): event.wait() results.append(n) threads = [Thread(target=waiter, args=(i,)) for i in range(3)] for t in threads: t.start() time.sleep(0.1) assert len(results) == 0 # All waiting event.set() for t in threads: t.join() assert sorted(results) == [0, 1, 2] def test_condition(self): """Test Condition for complex synchronization.""" items = [] condition = Condition() consumed = [] def producer(): for i in range(3): with condition: items.append(i) condition.notify() def consumer(): for _ in range(3): with condition: while not items: condition.wait() consumed.append(items.pop(0)) t1 = Thread(target=producer) t2 = Thread(target=consumer) t2.start() time.sleep(0.01) t1.start() t1.join() t2.join() assert consumed == [0, 1, 2] def test_barrier(self): """Test Barrier for synchronization point.""" barrier = Barrier(3) order = [] def worker(n): order.append(f"before_{n}") barrier.wait() order.append(f"after_{n}") threads = [Thread(target=worker, args=(i,)) for i in range(3)] for t in threads: t.start() for t in threads: t.join() # All "before" should come before all "after" before_count = sum(1 for x in order[:3] if x.startswith("before")) assert before_count == 3 def test_queue(self): """Test thread-safe Queue.""" q = Queue() results = [] def producer(): for i in range(5): q.put(i) def consumer(): for _ in range(5): results.append(q.get()) q.task_done() t1 = Thread(target=producer) t2 = Thread(target=consumer) t1.start() t2.start() t1.join() t2.join() assert sorted(results) == [0, 1, 2, 3, 4] class TestMultiprocessing: """Test multiprocessing operations.""" def test_process_creation(self): """Test basic process creation.""" from multiprocessing import Process, Queue as MPQueue q = MPQueue() processes = [Process(target=_mp_worker, args=(q, i)) for i in range(4)] for p in processes: p.start() for p in processes: p.join() results = [q.get() for _ in range(4)] assert sorted(results) == [0, 1, 4, 9] def test_pool_map(self): """Test Pool.map for parallel execution.""" from multiprocessing import Pool with Pool(2) as pool: results = pool.map(_mp_square, range(5)) assert results == [0, 1, 4, 9, 16] def test_pool_starmap(self): """Test Pool.starmap for multiple arguments.""" from multiprocessing import Pool with Pool(2) as pool: results = pool.starmap(_mp_add, [(1, 2), (3, 4), (5, 6)]) assert results == [3, 7, 11] def test_shared_value(self): """Test shared Value between processes.""" from multiprocessing import Process, Value counter = Value("i", 0) processes = [ Process(target=_mp_increment, args=(counter,)) for _ in range(4) ] for p in processes: p.start() for p in processes: p.join() assert counter.value == 400 def test_shared_array(self): """Test shared Array between processes.""" from multiprocessing import Process, Array arr = Array("i", [1, 2, 3, 4]) p = Process(target=_mp_double, args=(arr,)) p.start() p.join() assert list(arr) == [2, 4, 6, 8] class TestConcurrentFutures: """Test concurrent.futures operations.""" def test_thread_pool_map(self): """Test ThreadPoolExecutor.map.""" def square(x): return x * x with ThreadPoolExecutor(max_workers=3) as executor: results = list(executor.map(square, range(5))) assert results == [0, 1, 4, 9, 16] def test_thread_pool_submit(self): """Test ThreadPoolExecutor.submit.""" def compute(x): return x * 2 with ThreadPoolExecutor(max_workers=2) as executor: futures = [executor.submit(compute, i) for i in range(5)] results = [f.result() for f in futures] assert results == [0, 2, 4, 6, 8] def test_as_completed(self): """Test as_completed for processing results.""" def task(n): time.sleep(0.1 - n * 0.02) # Varying delays return n with ThreadPoolExecutor(max_workers=3) as executor: futures = [executor.submit(task, i) for i in range(3)] results = [f.result() for f in as_completed(futures)] # Results may be in any order assert sorted(results) == [0, 1, 2] def test_future_callback(self): """Test Future.add_done_callback.""" results = [] def on_complete(future): results.append(future.result()) def compute(x): return x * x with ThreadPoolExecutor(max_workers=2) as executor: for i in range(3): future = executor.submit(compute, i) future.add_done_callback(on_complete) time.sleep(0.1) # Wait for callbacks assert sorted(results) == [0, 1, 4] def test_future_exception(self): """Test exception handling in futures.""" def failing_task(): raise ValueError("test error") with ThreadPoolExecutor(max_workers=1) as executor: future = executor.submit(failing_task) with pytest.raises(ValueError, match="test error"): future.result() def test_future_timeout(self): """Test timeout on future.result().""" def slow_task(): time.sleep(10) with ThreadPoolExecutor(max_workers=1) as executor: future = executor.submit(slow_task) from concurrent.futures import TimeoutError with pytest.raises(TimeoutError): future.result(timeout=0.1) def test_future_cancel(self): """Test cancelling a future.""" def slow_task(): time.sleep(10) with ThreadPoolExecutor(max_workers=1) as executor: # First task blocks the worker future1 = executor.submit(slow_task) # Second task is queued future2 = executor.submit(slow_task) time.sleep(0.01) # Let first task start # Can cancel queued task assert future2.cancel() == True assert future2.cancelled() == True def test_process_pool_map(self): """Test ProcessPoolExecutor.map.""" from concurrent.futures import ProcessPoolExecutor with ProcessPoolExecutor(max_workers=2) as executor: results = list(executor.map(_mp_square, range(5))) assert results == [0, 1, 4, 9, 16] class TestProducerConsumer: """Test producer-consumer patterns.""" def test_basic_producer_consumer(self): """Test basic producer-consumer with Queue.""" q = Queue() produced = [] consumed = [] def producer(): for i in range(5): produced.append(i) q.put(i) q.put(None) # Sentinel def consumer(): while True: item = q.get() if item is None: break consumed.append(item) t1 = Thread(target=producer) t2 = Thread(target=consumer) t1.start() t2.start() t1.join() t2.join() assert produced == consumed def test_multiple_consumers(self): """Test multiple consumers.""" q = Queue() consumed = [] lock = Lock() def producer(): for i in range(10): q.put(i) for _ in range(3): # Sentinels for each consumer q.put(None) def consumer(): while True: item = q.get() if item is None: break with lock: consumed.append(item) producer_thread = Thread(target=producer) consumer_threads = [Thread(target=consumer) for _ in range(3)] producer_thread.start() for t in consumer_threads: t.start() producer_thread.join() for t in consumer_threads: t.join() assert sorted(consumed) == list(range(10)) ================================================ FILE: src/basic/crypto_.py ================================================ """ Tests for modern cryptography examples. """ import hashlib import hmac import os import secrets import pytest from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa, padding, ed25519 from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey from cryptography.hazmat.primitives.ciphers.aead import ( AESGCM, ChaCha20Poly1305, ) from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC class TestSecureRandom: """Test secure random generation.""" def test_token_bytes(self): key = secrets.token_bytes(32) assert len(key) == 32 # Should be different each time assert key != secrets.token_bytes(32) def test_token_urlsafe(self): token = secrets.token_urlsafe(32) assert len(token) >= 32 # Base64 encoding makes it longer def test_token_hex(self): token = secrets.token_hex(16) assert len(token) == 32 # 16 bytes = 32 hex chars def test_randbelow(self): for _ in range(100): n = secrets.randbelow(100) assert 0 <= n < 100 class TestHashing: """Test cryptographic hashing.""" def test_sha256(self): data = b"Hello, World!" digest = hashlib.sha256(data).hexdigest() assert len(digest) == 64 # 256 bits = 64 hex chars # Same input = same output assert digest == hashlib.sha256(data).hexdigest() def test_sha3_256(self): data = b"Hello, World!" digest = hashlib.sha3_256(data).hexdigest() assert len(digest) == 64 def test_blake2b(self): data = b"Hello, World!" digest = hashlib.blake2b(data, digest_size=32).hexdigest() assert len(digest) == 64 def test_blake2b_keyed(self): data = b"Hello, World!" key = b"secret-key-here!" mac = hashlib.blake2b(data, key=key, digest_size=32).hexdigest() assert len(mac) == 64 # Different key = different MAC other_mac = hashlib.blake2b( data, key=b"other-key-here!!", digest_size=32 ).hexdigest() assert mac != other_mac class TestHMAC: """Test HMAC operations.""" def test_hmac_create_verify(self): key = secrets.token_bytes(32) message = b"Important message" mac = hmac.new(key, message, hashlib.sha256).digest() assert hmac.compare_digest( mac, hmac.new(key, message, hashlib.sha256).digest() ) def test_hmac_tamper_detection(self): key = secrets.token_bytes(32) message = b"Important message" mac = hmac.new(key, message, hashlib.sha256).digest() tampered = b"Tampered message" tampered_mac = hmac.new(key, tampered, hashlib.sha256).digest() assert not hmac.compare_digest(mac, tampered_mac) class TestKeyDerivation: """Test key derivation functions.""" def test_pbkdf2(self): password = b"user-password" salt = os.urandom(16) kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, salt=salt, iterations=100_000, # Lower for tests ) key = kdf.derive(password) assert len(key) == 32 def test_pbkdf2_verify(self): password = b"user-password" salt = os.urandom(16) kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, salt=salt, iterations=100_000, ) key = kdf.derive(password) # Verify with new KDF instance kdf2 = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, salt=salt, iterations=100_000, ) kdf2.verify(password, key) # Should not raise def test_hkdf(self): master_key = os.urandom(32) salt = os.urandom(16) hkdf = HKDF( algorithm=hashes.SHA256(), length=32, salt=salt, info=b"encryption-key", ) derived = hkdf.derive(master_key) assert len(derived) == 32 class TestAESGCM: """Test AES-GCM authenticated encryption.""" def test_encrypt_decrypt(self): key = AESGCM.generate_key(bit_length=256) aesgcm = AESGCM(key) nonce = os.urandom(12) plaintext = b"Secret message" ciphertext = aesgcm.encrypt(nonce, plaintext, None) decrypted = aesgcm.decrypt(nonce, ciphertext, None) assert decrypted == plaintext def test_with_associated_data(self): key = AESGCM.generate_key(bit_length=256) aesgcm = AESGCM(key) nonce = os.urandom(12) plaintext = b"Secret message" aad = b"header" ciphertext = aesgcm.encrypt(nonce, plaintext, aad) decrypted = aesgcm.decrypt(nonce, ciphertext, aad) assert decrypted == plaintext def test_tamper_detection(self): key = AESGCM.generate_key(bit_length=256) aesgcm = AESGCM(key) nonce = os.urandom(12) plaintext = b"Secret message" ciphertext = aesgcm.encrypt(nonce, plaintext, None) tampered = bytearray(ciphertext) tampered[0] ^= 1 with pytest.raises(Exception): aesgcm.decrypt(nonce, bytes(tampered), None) class TestChaCha20Poly1305: """Test ChaCha20-Poly1305 authenticated encryption.""" def test_encrypt_decrypt(self): key = ChaCha20Poly1305.generate_key() chacha = ChaCha20Poly1305(key) nonce = os.urandom(12) plaintext = b"Secret message" ciphertext = chacha.encrypt(nonce, plaintext, None) decrypted = chacha.decrypt(nonce, ciphertext, None) assert decrypted == plaintext class TestFernet: """Test Fernet high-level encryption.""" def test_encrypt_decrypt(self): key = Fernet.generate_key() f = Fernet(key) plaintext = b"Secret message" token = f.encrypt(plaintext) decrypted = f.decrypt(token) assert decrypted == plaintext def test_different_tokens(self): key = Fernet.generate_key() f = Fernet(key) plaintext = b"Secret message" token1 = f.encrypt(plaintext) token2 = f.encrypt(plaintext) # Same plaintext produces different tokens (random IV) assert token1 != token2 class TestRSA: """Test RSA operations.""" def test_key_generation(self): private_key = rsa.generate_private_key( public_exponent=65537, key_size=2048, ) public_key = private_key.public_key() assert private_key is not None assert public_key is not None def test_key_serialization(self): private_key = rsa.generate_private_key( public_exponent=65537, key_size=2048, ) pem = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ) assert pem.startswith(b"-----BEGIN PRIVATE KEY-----") def test_oaep_encrypt_decrypt(self): private_key = rsa.generate_private_key( public_exponent=65537, key_size=2048, ) public_key = private_key.public_key() plaintext = b"Secret message" ciphertext = public_key.encrypt( plaintext, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None, ), ) decrypted = private_key.decrypt( ciphertext, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None, ), ) assert decrypted == plaintext def test_pss_sign_verify(self): private_key = rsa.generate_private_key( public_exponent=65537, key_size=2048, ) public_key = private_key.public_key() message = b"Message to sign" signature = private_key.sign( message, padding.PSS( mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH, ), hashes.SHA256(), ) # Should not raise public_key.verify( signature, message, padding.PSS( mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH, ), hashes.SHA256(), ) class TestEd25519: """Test Ed25519 signatures.""" def test_sign_verify(self): private_key = ed25519.Ed25519PrivateKey.generate() public_key = private_key.public_key() message = b"Message to sign" signature = private_key.sign(message) # Should not raise public_key.verify(signature, message) def test_invalid_signature(self): private_key = ed25519.Ed25519PrivateKey.generate() public_key = private_key.public_key() message = b"Message to sign" signature = private_key.sign(message) with pytest.raises(Exception): public_key.verify(signature, b"Different message") class TestX25519: """Test X25519 key exchange.""" def test_key_exchange(self): alice_private = X25519PrivateKey.generate() alice_public = alice_private.public_key() bob_private = X25519PrivateKey.generate() bob_public = bob_private.public_key() alice_shared = alice_private.exchange(bob_public) bob_shared = bob_private.exchange(alice_public) assert alice_shared == bob_shared assert len(alice_shared) == 32 class TestHybridEncryption: """Test hybrid RSA + AES-GCM encryption.""" def test_hybrid_encrypt_decrypt(self): def hybrid_encrypt(public_key, plaintext): aes_key = AESGCM.generate_key(bit_length=256) nonce = os.urandom(12) aesgcm = AESGCM(aes_key) ciphertext = aesgcm.encrypt(nonce, plaintext, None) encrypted_key = public_key.encrypt( aes_key, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None, ), ) return encrypted_key, nonce, ciphertext def hybrid_decrypt(private_key, encrypted_key, nonce, ciphertext): aes_key = private_key.decrypt( encrypted_key, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None, ), ) aesgcm = AESGCM(aes_key) return aesgcm.decrypt(nonce, ciphertext, None) private_key = rsa.generate_private_key( public_exponent=65537, key_size=2048, ) public_key = private_key.public_key() message = b"This is a longer message that exceeds RSA limits" * 100 encrypted_key, nonce, ciphertext = hybrid_encrypt(public_key, message) decrypted = hybrid_decrypt( private_key, encrypted_key, nonce, ciphertext ) assert decrypted == message ================================================ FILE: src/basic/datetime_.py ================================================ """Tests for datetime operations.""" import calendar import time from datetime import date, datetime, time as dt_time, timedelta, timezone def test_current_datetime(): """Get current date and time.""" now = datetime.now() assert isinstance(now, datetime) utc_now = datetime.now(timezone.utc) assert utc_now.tzinfo == timezone.utc today = date.today() assert isinstance(today, date) def test_create_datetime(): """Create datetime objects.""" dt = datetime(2024, 1, 15, 10, 30, 45) assert dt.year == 2024 assert dt.month == 1 assert dt.day == 15 assert dt.hour == 10 d = date(2024, 1, 15) t = dt_time(10, 30, 45) combined = datetime.combine(d, t) assert combined == dt def test_timestamp_conversion(): """Convert between timestamps and datetime.""" ts = time.time() dt = datetime.fromtimestamp(ts) ts_back = dt.timestamp() assert abs(ts - ts_back) < 0.001 dt_utc = datetime.fromtimestamp(ts, tz=timezone.utc) assert dt_utc.tzinfo == timezone.utc def test_strftime_formatting(): """Format datetime as string.""" dt = datetime(2024, 1, 15, 14, 30, 45) assert dt.strftime("%Y-%m-%d") == "2024-01-15" assert dt.strftime("%d/%m/%Y") == "15/01/2024" assert dt.strftime("%Y-%m-%d %H:%M:%S") == "2024-01-15 14:30:45" assert dt.strftime("%Y%m%d_%H%M%S") == "20240115_143045" assert dt.isoformat() == "2024-01-15T14:30:45" def test_strptime_parsing(): """Parse string to datetime.""" dt1 = datetime.strptime("2024-01-15", "%Y-%m-%d") assert dt1.year == 2024 and dt1.month == 1 and dt1.day == 15 dt2 = datetime.strptime("15/01/2024", "%d/%m/%Y") assert dt2.day == 15 dt3 = datetime.fromisoformat("2024-01-15T14:30:45") assert dt3.hour == 14 and dt3.minute == 30 def test_timedelta_arithmetic(): """Date arithmetic with timedelta.""" now = datetime(2024, 1, 15, 12, 0, 0) tomorrow = now + timedelta(days=1) assert tomorrow.day == 16 yesterday = now - timedelta(days=1) assert yesterday.day == 14 in_2_hours = now + timedelta(hours=2) assert in_2_hours.hour == 14 date1 = datetime(2024, 1, 1) date2 = datetime(2024, 12, 31) diff = date2 - date1 assert diff.days == 365 def test_timezone_operations(): """Work with timezones.""" utc = timezone.utc dt_utc = datetime(2024, 1, 15, 12, 0, 0, tzinfo=utc) assert dt_utc.tzinfo == utc pst = timezone(timedelta(hours=-8)) dt_pst = dt_utc.astimezone(pst) assert dt_pst.hour == 4 # 12:00 UTC = 04:00 PST naive = datetime(2024, 1, 15, 10, 30) aware = naive.replace(tzinfo=utc) assert aware.tzinfo == utc def test_date_comparison(): """Compare datetime objects.""" dt1 = datetime(2024, 1, 15, 10, 0) dt2 = datetime(2024, 1, 15, 14, 0) dt3 = datetime(2024, 1, 16, 10, 0) assert dt1 < dt2 assert dt3 > dt2 assert dt1 != dt2 start = datetime(2024, 1, 1) end = datetime(2024, 12, 31) check = datetime(2024, 6, 15) assert start <= check <= end def test_weekday_operations(): """Work with weekdays and weeks.""" dt = datetime(2024, 1, 15) # Monday assert dt.weekday() == 0 # Monday = 0 assert dt.isoweekday() == 1 # Monday = 1 (ISO) year, week, weekday = dt.isocalendar() assert year == 2024 and week == 3 and weekday == 1 start_of_week = dt - timedelta(days=dt.weekday()) assert start_of_week.weekday() == 0 def test_start_end_of_day(): """Get start and end of day.""" dt = datetime(2024, 1, 15, 14, 30, 45) start_of_day = datetime.combine(dt.date(), dt_time.min) assert start_of_day.hour == 0 and start_of_day.minute == 0 end_of_day = datetime.combine(dt.date(), dt_time.max) assert end_of_day.hour == 23 and end_of_day.minute == 59 def test_start_end_of_month(): """Get start and end of month.""" dt = datetime(2024, 1, 15, 14, 30, 45) start_of_month = dt.replace( day=1, hour=0, minute=0, second=0, microsecond=0 ) assert start_of_month.day == 1 last_day = calendar.monthrange(dt.year, dt.month)[1] end_of_month = dt.replace(day=last_day) assert end_of_month.day == 31 def test_calendar_operations(): """Calendar module operations.""" assert calendar.isleap(2024) is True assert calendar.isleap(2023) is False weekday, days = calendar.monthrange(2024, 2) assert days == 29 # Leap year def test_date_range(): """Generate date ranges.""" def date_range(start, end, step=timedelta(days=1)): current = start while current <= end: yield current current += step start = date(2024, 1, 1) end = date(2024, 1, 7) dates = list(date_range(start, end)) assert len(dates) == 7 assert dates[0] == start assert dates[-1] == end def test_age_calculation(): """Calculate age from birthdate.""" def calculate_age(birthdate, reference_date=None): if reference_date is None: reference_date = date.today() age = reference_date.year - birthdate.year if (reference_date.month, reference_date.day) < ( birthdate.month, birthdate.day, ): age -= 1 return age birthdate = date(1990, 6, 15) reference = date(2024, 1, 15) assert calculate_age(birthdate, reference) == 33 reference_after = date(2024, 7, 1) assert calculate_age(birthdate, reference_after) == 34 def test_time_ago(): """Human readable time differences.""" def time_ago(dt, now=None): if now is None: now = datetime.now() diff = now - dt seconds = diff.total_seconds() if seconds < 60: return "just now" elif seconds < 3600: minutes = int(seconds // 60) return f"{minutes} minute{'s' if minutes != 1 else ''} ago" elif seconds < 86400: hours = int(seconds // 3600) return f"{hours} hour{'s' if hours != 1 else ''} ago" else: days = int(seconds // 86400) return f"{days} day{'s' if days != 1 else ''} ago" now = datetime(2024, 1, 15, 12, 0, 0) assert time_ago(now - timedelta(seconds=30), now) == "just now" assert time_ago(now - timedelta(minutes=5), now) == "5 minutes ago" assert time_ago(now - timedelta(hours=2), now) == "2 hours ago" assert time_ago(now - timedelta(days=3), now) == "3 days ago" def test_business_days(): """Generate business days (skip weekends).""" def business_days(start, end): current = start while current <= end: if current.weekday() < 5: yield current current += timedelta(days=1) start = date(2024, 1, 15) # Monday end = date(2024, 1, 21) # Sunday bdays = list(business_days(start, end)) assert len(bdays) == 5 # Mon-Fri assert all(d.weekday() < 5 for d in bdays) ================================================ FILE: src/basic/dict.py ================================================ """Python Dictionary Examples Source code for docs/notes/basic/python-dict.rst """ import pytest from collections import defaultdict, OrderedDict from functools import lru_cache # Create a Dictionary def create_dict_literal(): """Create dict using literal syntax.""" return {"key": "value", "num": 42} def create_dict_constructor(): """Create dict using constructor.""" return dict(key="value", num=42) def create_dict_comprehension(n: int) -> dict: """Create dict using comprehension.""" return {x: x**2 for x in range(n)} # Get Keys, Values, Items def get_keys(d: dict) -> list: """Get all keys from dictionary.""" return list(d.keys()) def get_values(d: dict) -> list: """Get all values from dictionary.""" return list(d.values()) def get_items(d: dict) -> list: """Get all key-value pairs.""" return list(d.items()) # Find Common Keys def find_common_keys(a: dict, b: dict) -> set: """Find keys that exist in both dictionaries.""" return a.keys() & b.keys() # Set Default Value def setdefault_example(): """Use setdefault to initialize missing keys.""" d = {} d.setdefault("key", []).append("value") return d def defaultdict_example(): """Use defaultdict for automatic default values.""" d = defaultdict(list) d["key"].append("value") return dict(d) # Merge Dictionaries def merge_dicts_operator(a: dict, b: dict) -> dict: """Merge dicts using | operator (Python 3.9+).""" return a | b def merge_dicts_unpack(a: dict, b: dict) -> dict: """Merge dicts using unpacking (Python 3.5+).""" return {**a, **b} # Dictionary Comprehension def dict_comprehension_filter(n: int) -> dict: """Dict comprehension with filter.""" return {x: x**2 for x in range(n) if x % 2 == 0} def swap_keys_values(d: dict) -> dict: """Swap dictionary keys and values.""" return {v: k for k, v in d.items()} # Emulating a Dictionary class EmuDict: """Custom dictionary-like class.""" def __init__(self, data=None): self._dict = data or {} def __repr__(self): return f"EmuDict({self._dict})" def __getitem__(self, key): return self._dict[key] def __setitem__(self, key, val): self._dict[key] = val def __delitem__(self, key): del self._dict[key] def __contains__(self, key): return key in self._dict def __iter__(self): return iter(self._dict) def __len__(self): return len(self._dict) # LRU Cache class LRUCache: """LRU Cache implementation using OrderedDict.""" def __init__(self, maxsize=128): self._maxsize = maxsize self._cache = OrderedDict() def get(self, key): if key not in self._cache: return None self._cache.move_to_end(key) return self._cache[key] def put(self, key, value): if key in self._cache: self._cache.move_to_end(key) self._cache[key] = value if len(self._cache) > self._maxsize: self._cache.popitem(last=False) @lru_cache(maxsize=128) def fibonacci(n: int) -> int: """Fibonacci with LRU cache decorator.""" if n < 2: return n return fibonacci(n - 1) + fibonacci(n - 2) # Tests class TestDictCreation: def test_literal(self): assert create_dict_literal() == {"key": "value", "num": 42} def test_constructor(self): assert create_dict_constructor() == {"key": "value", "num": 42} def test_comprehension(self): assert create_dict_comprehension(3) == {0: 0, 1: 1, 2: 4} class TestDictAccess: def test_get_keys(self): assert get_keys({"a": 1, "b": 2}) == ["a", "b"] def test_get_values(self): assert get_values({"a": 1, "b": 2}) == [1, 2] def test_get_items(self): assert get_items({"a": 1}) == [("a", 1)] class TestDictOperations: def test_find_common_keys(self): assert find_common_keys({"a": 1, "b": 2}, {"b": 3, "c": 4}) == {"b"} def test_setdefault(self): assert setdefault_example() == {"key": ["value"]} def test_defaultdict(self): assert defaultdict_example() == {"key": ["value"]} def test_merge_operator(self): assert merge_dicts_operator({"a": 1}, {"b": 2}) == {"a": 1, "b": 2} def test_merge_unpack(self): assert merge_dicts_unpack({"a": 1}, {"b": 2}) == {"a": 1, "b": 2} class TestDictComprehension: def test_filter(self): assert dict_comprehension_filter(6) == {0: 0, 2: 4, 4: 16} def test_swap(self): assert swap_keys_values({"a": 1, "b": 2}) == {1: "a", 2: "b"} class TestEmuDict: def test_getitem(self): d = EmuDict({"a": 1}) assert d["a"] == 1 def test_setitem(self): d = EmuDict() d["a"] = 1 assert d["a"] == 1 def test_contains(self): d = EmuDict({"a": 1}) assert "a" in d assert "b" not in d def test_len(self): assert len(EmuDict({"a": 1, "b": 2})) == 2 class TestLRUCache: def test_get_put(self): cache = LRUCache(maxsize=2) cache.put("a", 1) cache.put("b", 2) assert cache.get("a") == 1 def test_eviction(self): cache = LRUCache(maxsize=2) cache.put("a", 1) cache.put("b", 2) cache.put("c", 3) # evicts "a" assert cache.get("a") is None def test_fibonacci(self): assert fibonacci(10) == 55 assert fibonacci(20) == 6765 ================================================ FILE: src/basic/fileio_.py ================================================ """Tests for file I/O operations.""" import csv import gzip import json import tempfile import zipfile from pathlib import Path def test_read_write_text(tmp_path): """Read and write text files.""" p = tmp_path / "test.txt" content = "Hello, World!\nLine 2" with open(p, "w", encoding="utf-8") as f: f.write(content) with open(p, encoding="utf-8") as f: result = f.read() assert result == content def test_read_lines(tmp_path): """Read file line by line.""" p = tmp_path / "lines.txt" p.write_text("line1\nline2\nline3\n") lines = [] with open(p, encoding="utf-8") as f: for line in f: lines.append(line.rstrip()) assert lines == ["line1", "line2", "line3"] def test_write_modes(tmp_path): """Test different write modes.""" p = tmp_path / "modes.txt" # Write mode with open(p, "w", encoding="utf-8") as f: f.write("first") # Append mode with open(p, "a", encoding="utf-8") as f: f.write(" second") assert p.read_text() == "first second" def test_binary_files(tmp_path): """Read and write binary files.""" p = tmp_path / "binary.bin" data = b"\x00\x01\x02\xff" with open(p, "wb") as f: f.write(data) with open(p, "rb") as f: result = f.read() assert result == data def test_pathlib_properties(tmp_path): """Test pathlib path properties.""" p = tmp_path / "folder" / "report.pdf" assert p.name == "report.pdf" assert p.stem == "report" assert p.suffix == ".pdf" assert p.parent == tmp_path / "folder" def test_pathlib_with_suffix(): """Change path suffix.""" p = Path("/home/user/doc.txt") new_p = p.with_suffix(".md") assert new_p.suffix == ".md" assert new_p.stem == "doc" def test_pathlib_read_write(tmp_path): """Read and write with pathlib.""" p = tmp_path / "pathlib.txt" p.write_text("pathlib content", encoding="utf-8") content = p.read_text(encoding="utf-8") assert content == "pathlib content" def test_pathlib_bytes(tmp_path): """Read and write bytes with pathlib.""" p = tmp_path / "bytes.bin" data = b"binary data" p.write_bytes(data) result = p.read_bytes() assert result == data def test_list_directory(tmp_path): """List directory contents.""" (tmp_path / "file1.txt").touch() (tmp_path / "file2.txt").touch() (tmp_path / "subdir").mkdir() items = list(tmp_path.iterdir()) assert len(items) == 3 files = [i for i in items if i.is_file()] dirs = [i for i in items if i.is_dir()] assert len(files) == 2 assert len(dirs) == 1 def test_glob_pattern(tmp_path): """Find files with glob patterns.""" (tmp_path / "a.py").touch() (tmp_path / "b.py").touch() (tmp_path / "c.txt").touch() py_files = list(tmp_path.glob("*.py")) assert len(py_files) == 2 def test_recursive_glob(tmp_path): """Recursive glob pattern.""" (tmp_path / "a.py").touch() subdir = tmp_path / "sub" subdir.mkdir() (subdir / "b.py").touch() py_files = list(tmp_path.rglob("*.py")) assert len(py_files) == 2 def test_mkdir_parents(tmp_path): """Create nested directories.""" nested = tmp_path / "a" / "b" / "c" nested.mkdir(parents=True, exist_ok=True) assert nested.exists() assert nested.is_dir() def test_path_exists(tmp_path): """Check path existence and type.""" file_path = tmp_path / "file.txt" file_path.touch() dir_path = tmp_path / "dir" dir_path.mkdir() assert file_path.exists() assert file_path.is_file() assert not file_path.is_dir() assert dir_path.exists() assert dir_path.is_dir() assert not dir_path.is_file() def test_temporary_file(): """Create temporary file.""" with tempfile.NamedTemporaryFile( mode="w", suffix=".txt", delete=True ) as f: f.write("temp content") f.flush() assert Path(f.name).exists() assert f.name.endswith(".txt") def test_temporary_directory(): """Create temporary directory.""" with tempfile.TemporaryDirectory() as tmpdir: p = Path(tmpdir) assert p.exists() (p / "file.txt").write_text("content") assert (p / "file.txt").exists() # Directory should be deleted after context assert not p.exists() def test_csv_read_write(tmp_path): """Read and write CSV files.""" csv_path = tmp_path / "data.csv" # Write with open(csv_path, "w", newline="", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow(["name", "age"]) writer.writerow(["Alice", 30]) writer.writerow(["Bob", 25]) # Read rows = [] with open(csv_path, newline="", encoding="utf-8") as f: reader = csv.reader(f) for row in reader: rows.append(row) assert rows[0] == ["name", "age"] assert rows[1] == ["Alice", "30"] def test_csv_dictreader(tmp_path): """Read CSV with DictReader.""" csv_path = tmp_path / "data.csv" csv_path.write_text("name,age\nAlice,30\nBob,25") with open(csv_path, newline="", encoding="utf-8") as f: reader = csv.DictReader(f) rows = list(reader) assert rows[0]["name"] == "Alice" assert rows[0]["age"] == "30" def test_json_read_write(tmp_path): """Read and write JSON files.""" json_path = tmp_path / "data.json" data = {"name": "Alice", "scores": [95, 87, 92]} with open(json_path, "w", encoding="utf-8") as f: json.dump(data, f) with open(json_path, encoding="utf-8") as f: loaded = json.load(f) assert loaded == data def test_gzip_read_write(tmp_path): """Read and write gzip files.""" gz_path = tmp_path / "file.txt.gz" content = "compressed content" with gzip.open(gz_path, "wt", encoding="utf-8") as f: f.write(content) with gzip.open(gz_path, "rt", encoding="utf-8") as f: result = f.read() assert result == content def test_zipfile_create_extract(tmp_path): """Create and extract zip archives.""" zip_path = tmp_path / "archive.zip" file1 = tmp_path / "file1.txt" file1.write_text("content 1") # Create zip with zipfile.ZipFile(zip_path, "w") as zf: zf.write(file1, "file1.txt") zf.writestr("file2.txt", "content 2") # List contents with zipfile.ZipFile(zip_path, "r") as zf: names = zf.namelist() assert "file1.txt" in names assert "file2.txt" in names # Extract extract_dir = tmp_path / "extracted" extract_dir.mkdir() with zipfile.ZipFile(zip_path, "r") as zf: zf.extractall(extract_dir) assert (extract_dir / "file1.txt").read_text() == "content 1" assert (extract_dir / "file2.txt").read_text() == "content 2" def test_symlink(tmp_path): """Create and read symbolic links.""" target = tmp_path / "target.txt" target.write_text("target content") link = tmp_path / "link.txt" link.symlink_to(target) assert link.is_symlink() assert link.read_text() == "target content" assert link.resolve() == target.resolve() def test_file_stat(tmp_path): """Get file statistics.""" p = tmp_path / "file.txt" p.write_text("some content") stat = p.stat() assert stat.st_size == len("some content") assert stat.st_mtime > 0 def test_shutil_copy(tmp_path): """Copy files with shutil.""" import shutil src = tmp_path / "source.txt" src.write_text("content") # copy - content only dst1 = tmp_path / "dest1.txt" shutil.copy(src, dst1) assert dst1.read_text() == "content" # copy2 - preserves metadata dst2 = tmp_path / "dest2.txt" shutil.copy2(src, dst2) assert dst2.read_text() == "content" # copy to directory subdir = tmp_path / "subdir" subdir.mkdir() shutil.copy(src, subdir) assert (subdir / "source.txt").exists() def test_shutil_copytree(tmp_path): """Copy directory tree with shutil.""" import shutil # Create source structure src = tmp_path / "source" src.mkdir() (src / "file1.txt").write_text("content1") (src / "sub").mkdir() (src / "sub" / "file2.txt").write_text("content2") # Copy tree dst = tmp_path / "dest" shutil.copytree(src, dst) assert (dst / "file1.txt").read_text() == "content1" assert (dst / "sub" / "file2.txt").read_text() == "content2" def test_shutil_copytree_ignore(tmp_path): """Copy directory tree with ignore patterns.""" import shutil src = tmp_path / "source" src.mkdir() (src / "keep.txt").write_text("keep") (src / "ignore.pyc").write_text("ignore") dst = tmp_path / "dest" shutil.copytree(src, dst, ignore=shutil.ignore_patterns("*.pyc")) assert (dst / "keep.txt").exists() assert not (dst / "ignore.pyc").exists() def test_shutil_copytree_dirs_exist_ok(tmp_path): """Copy into existing directory.""" import shutil src = tmp_path / "source" src.mkdir() (src / "new.txt").write_text("new") dst = tmp_path / "dest" dst.mkdir() (dst / "existing.txt").write_text("existing") shutil.copytree(src, dst, dirs_exist_ok=True) assert (dst / "new.txt").exists() assert (dst / "existing.txt").exists() def test_shutil_move(tmp_path): """Move files and directories with shutil.""" import shutil # Move file src = tmp_path / "source.txt" src.write_text("content") dst = tmp_path / "dest.txt" shutil.move(src, dst) assert not src.exists() assert dst.read_text() == "content" # Move to directory subdir = tmp_path / "subdir" subdir.mkdir() shutil.move(dst, subdir) assert (subdir / "dest.txt").exists() def test_shutil_rmtree(tmp_path): """Remove directory tree with shutil.""" import shutil d = tmp_path / "to_delete" d.mkdir() (d / "file.txt").write_text("content") (d / "sub").mkdir() (d / "sub" / "nested.txt").write_text("nested") shutil.rmtree(d) assert not d.exists() def test_shutil_disk_usage(tmp_path): """Get disk usage statistics.""" import shutil usage = shutil.disk_usage(tmp_path) assert usage.total > 0 assert usage.free > 0 assert usage.used >= 0 def test_shutil_which(): """Find executable in PATH.""" import shutil python = shutil.which("python") assert python is not None nonexistent = shutil.which("nonexistent_command_xyz") assert nonexistent is None def test_shutil_make_archive(tmp_path): """Create archives with shutil.""" import shutil # Create source src = tmp_path / "source" src.mkdir() (src / "file.txt").write_text("content") # Create zip archive archive = shutil.make_archive(str(tmp_path / "backup"), "zip", src) assert Path(archive).exists() assert archive.endswith(".zip") def test_shutil_unpack_archive(tmp_path): """Extract archives with shutil.""" import shutil # Create and pack src = tmp_path / "source" src.mkdir() (src / "file.txt").write_text("content") archive = shutil.make_archive(str(tmp_path / "backup"), "zip", src) # Unpack extract = tmp_path / "extracted" shutil.unpack_archive(archive, extract) assert (extract / "file.txt").read_text() == "content" ================================================ FILE: src/basic/func.py ================================================ """Python Function Examples Source code for docs/notes/basic/python-func.rst """ from functools import lru_cache, partial, reduce, singledispatch, wraps import pytest # Default Arguments def greet(name: str, greeting: str = "Hello") -> str: """Greet with optional greeting.""" return f"{greeting}, {name}!" def good_default(items=None): """Correct way to use mutable default.""" if items is None: items = [] items.append(1) return items # Variable Arguments def sum_all(*args) -> int: """Sum all positional arguments.""" return sum(args) def format_info(**kwargs) -> str: """Format keyword arguments.""" return ", ".join(f"{k}={v}" for k, v in kwargs.items()) def mixed_args(a, b=None, *args, **kwargs): """Function with mixed argument types.""" return {"a": a, "b": b, "args": args, "kwargs": kwargs} # Keyword-Only Arguments def keyword_only(a, b, *, kw): """Function with keyword-only argument.""" return a + b + kw def keyword_only_default(a, *, kw=10): """Keyword-only with default value.""" return a + kw # Positional-Only Arguments (Python 3.8+) def positional_only(a, b, /, c): """Function with positional-only arguments.""" return a + b + c def combined_args(a, /, b, *, c): """Positional-only and keyword-only combined.""" return a + b + c # Lambda square = lambda x: x**2 add = lambda a, b: a + b max_val = lambda a, b: a if a > b else b # Closure def make_multiplier(n: int): """Create a multiplier function.""" def multiplier(x: int) -> int: return x * n return multiplier def make_counter(): """Create a counter with mutable state.""" count = 0 def counter(): nonlocal count count += 1 return count return counter # Generator def fibonacci(n: int): """Generate fibonacci sequence.""" a, b = 0, 1 for _ in range(n): yield a a, b = b, a + b # Decorator def log_calls(func): """Decorator that logs function calls.""" @wraps(func) def wrapper(*args, **kwargs): wrapper.call_count += 1 return func(*args, **kwargs) wrapper.call_count = 0 return wrapper # Decorator with Arguments def repeat(times: int): """Decorator that repeats function calls.""" def decorator(func): @wraps(func) def wrapper(*args, **kwargs): result = None for _ in range(times): result = func(*args, **kwargs) return result return wrapper return decorator # Class Decorator class CountCalls: """Decorator class that counts calls.""" def __init__(self, func): self.func = func self.count = 0 wraps(func)(self) def __call__(self, *args, **kwargs): self.count += 1 return self.func(*args, **kwargs) # Cache @lru_cache(maxsize=None) def fib_cached(n: int) -> int: """Fibonacci with caching.""" if n < 2: return n return fib_cached(n - 1) + fib_cached(n - 2) # Partial def power(base: int, exponent: int) -> int: """Raise base to exponent.""" return base**exponent square_partial = partial(power, exponent=2) cube_partial = partial(power, exponent=3) # Singledispatch @singledispatch def process(arg): """Generic function with type dispatch.""" return f"Default: {arg}" @process.register(int) def _(arg): return f"Integer: {arg * 2}" @process.register(list) def _(arg): return f"List with {len(arg)} items" # Callable Class class Adder: """Callable class that adds a fixed value.""" def __init__(self, n: int): self.n = n def __call__(self, x: int) -> int: return self.n + x # Higher-order functions def apply_twice(func, x): """Apply function twice.""" return func(func(x)) # Tests class TestDefaultArguments: def test_default(self): assert greet("Alice") == "Hello, Alice!" def test_custom(self): assert greet("Bob", "Hi") == "Hi, Bob!" def test_mutable_default(self): assert good_default() == [1] assert good_default() == [1] # not [1, 1] class TestVariableArguments: def test_args(self): assert sum_all(1, 2, 3, 4, 5) == 15 def test_kwargs(self): result = format_info(name="Alice", age=30) assert "name=Alice" in result assert "age=30" in result def test_mixed(self): result = mixed_args(1, 2, 3, 4, x=5) assert result["a"] == 1 assert result["b"] == 2 assert result["args"] == (3, 4) assert result["kwargs"] == {"x": 5} class TestKeywordOnly: def test_keyword_only(self): assert keyword_only(1, 2, kw=3) == 6 def test_keyword_only_default(self): assert keyword_only_default(5) == 15 assert keyword_only_default(5, kw=20) == 25 class TestPositionalOnly: def test_positional_only(self): assert positional_only(1, 2, 3) == 6 assert positional_only(1, 2, c=3) == 6 def test_combined(self): assert combined_args(1, 2, c=3) == 6 assert combined_args(1, b=2, c=3) == 6 class TestLambda: def test_square(self): assert square(5) == 25 def test_add(self): assert add(2, 3) == 5 def test_conditional(self): assert max_val(3, 5) == 5 assert max_val(7, 2) == 7 class TestClosure: def test_multiplier(self): double = make_multiplier(2) triple = make_multiplier(3) assert double(5) == 10 assert triple(5) == 15 def test_counter(self): counter = make_counter() assert counter() == 1 assert counter() == 2 assert counter() == 3 class TestGenerator: def test_fibonacci(self): assert list(fibonacci(10)) == [0, 1, 1, 2, 3, 5, 8, 13, 21, 34] class TestDecorator: def test_log_calls(self): @log_calls def example(): return "result" example() example() assert example.call_count == 2 assert example.__name__ == "example" def test_repeat(self): counter = {"count": 0} @repeat(3) def increment(): counter["count"] += 1 increment() assert counter["count"] == 3 class TestClassDecorator: def test_count_calls(self): @CountCalls def example(): return "result" example() example() assert example.count == 2 class TestCache: def test_fib_cached(self): fib_cached.cache_clear() assert fib_cached(10) == 55 assert fib_cached(20) == 6765 info = fib_cached.cache_info() assert info.hits > 0 class TestPartial: def test_square(self): assert square_partial(5) == 25 def test_cube(self): assert cube_partial(5) == 125 class TestSingledispatch: def test_default(self): assert process("hello") == "Default: hello" def test_int(self): assert process(5) == "Integer: 10" def test_list(self): assert process([1, 2, 3]) == "List with 3 items" class TestCallable: def test_adder(self): add_five = Adder(5) assert add_five(10) == 15 assert callable(add_five) class TestHigherOrder: def test_apply_twice(self): assert apply_twice(lambda x: x * 2, 3) == 12 def test_map(self): assert list(map(square, [1, 2, 3])) == [1, 4, 9] def test_filter(self): assert list(filter(lambda x: x > 2, [1, 2, 3, 4])) == [3, 4] def test_reduce(self): assert reduce(lambda x, y: x + y, [1, 2, 3, 4, 5]) == 15 assert reduce(lambda x, y: x * y, [1, 2, 3, 4, 5]) == 120 ================================================ FILE: src/basic/future_.py ================================================ """Python Future Examples Source code for docs/notes/basic/python-future.rst """ from __future__ import annotations import __future__ import sys import pytest # List Future Features def get_all_features() -> list[str]: """Get all available future features.""" return __future__.all_feature_names def get_feature_info(name: str) -> tuple: """Get feature info (optional, mandatory versions).""" feature = getattr(__future__, name, None) if feature: return (feature.optional, feature.mandatory) return None # Annotations Example class Node: """Example using forward reference with annotations.""" def __init__(self, value: int, next: Node | None = None): self.value = value self.next = next def append(self, value: int) -> Node: """Append value and return new node.""" new_node = Node(value) self.next = new_node return new_node def get_annotations(func) -> dict: """Get function annotations as strings.""" return func.__annotations__ # Division def true_division(a: int, b: int) -> float: """True division (/).""" return a / b def floor_division(a: int, b: int) -> int: """Floor division (//).""" return a // b # Version Check def check_version(major: int, minor: int) -> bool: """Check if Python version is at least major.minor.""" return sys.version_info >= (major, minor) def get_version_string() -> str: """Get Python version as string.""" return f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" # Tests class TestFutureFeatures: def test_get_all_features(self): features = get_all_features() assert "annotations" in features assert "division" in features assert "print_function" in features def test_get_feature_info(self): info = get_feature_info("annotations") assert info is not None assert len(info) == 2 class TestAnnotations: def test_node_creation(self): node = Node(1) assert node.value == 1 assert node.next is None def test_node_append(self): node = Node(1) node2 = node.append(2) assert node.next is node2 assert node2.value == 2 def test_annotations_are_strings(self): # With from __future__ import annotations, # annotations are stored as strings annotations = get_annotations(Node.__init__) assert "value" in annotations assert "next" in annotations class TestDivision: def test_true_division(self): assert true_division(5, 2) == 2.5 assert true_division(4, 2) == 2.0 def test_floor_division(self): assert floor_division(5, 2) == 2 assert floor_division(7, 3) == 2 class TestVersionCheck: def test_check_version(self): # We're running Python 3.x assert check_version(3, 0) # Future version should be False assert not check_version(99, 0) def test_get_version_string(self): version = get_version_string() assert version.startswith("3.") parts = version.split(".") assert len(parts) == 3 ================================================ FILE: src/basic/generator.py ================================================ """Python Generator Examples Source code for docs/notes/basic/python-generator.rst """ import inspect from contextlib import contextmanager from types import GeneratorType import pytest # Generator Function def simple_gen(): """Simple generator yielding values.""" yield 1 yield 2 yield 3 def countdown(n: int): """Countdown generator.""" while n > 0: yield n n -= 1 def fibonacci(n: int): """Generate fibonacci sequence.""" a, b = 0, 1 for _ in range(n): yield a a, b = b, a + b def infinite_counter(start: int = 0): """Infinite counter generator.""" n = start while True: yield n n += 1 # Generator Expression def gen_expr_sum(n: int) -> int: """Sum using generator expression.""" return sum(x**2 for x in range(n)) # Send Values def accumulator(): """Accumulator that receives values via send.""" total = 0 while True: value = yield total if value is not None: total += value # Generator with Return def average(): """Calculate average, return via StopIteration.""" total = 0.0 count = 0 while True: value = yield if value is None: break total += value count += 1 return total / count if count else 0 # yield from def chain(*iterables): """Chain multiple iterables.""" for it in iterables: yield from it def flatten(nested): """Flatten nested lists.""" for item in nested: if isinstance(item, list): yield from flatten(item) else: yield item # Iterable Class class Range: """Custom range class using generator.""" def __init__(self, start: int, end: int): self.start = start self.end = end def __iter__(self): n = self.start while n < self.end: yield n n += 1 def __reversed__(self): n = self.end - 1 while n >= self.start: yield n n -= 1 # Pipeline def filter_positive(nums): """Filter positive numbers.""" for n in nums: if n > 0: yield n def double(nums): """Double each number.""" for n in nums: yield n * 2 def read_lines(lines): """Strip whitespace from lines.""" for line in lines: yield line.strip() def filter_comments(lines): """Filter out comment lines.""" for line in lines: if not line.startswith("#"): yield line # Throw and Close def gen_with_exception(): """Generator that handles exceptions.""" try: yield 1 yield 2 except ValueError: yield "caught" def gen_with_cleanup(): """Generator with cleanup in finally.""" try: yield 1 yield 2 finally: pass # cleanup would go here # Context Manager @contextmanager def capture_output(): """Context manager using generator.""" output = [] yield output # Tests class TestGeneratorBasics: def test_simple_gen(self): assert list(simple_gen()) == [1, 2, 3] def test_countdown(self): assert list(countdown(3)) == [3, 2, 1] def test_fibonacci(self): assert list(fibonacci(10)) == [0, 1, 1, 2, 3, 5, 8, 13, 21, 34] def test_infinite_counter(self): from itertools import islice assert list(islice(infinite_counter(), 5)) == [0, 1, 2, 3, 4] assert list(islice(infinite_counter(10), 3)) == [10, 11, 12] class TestGeneratorExpression: def test_gen_expr_sum(self): assert gen_expr_sum(5) == 0 + 1 + 4 + 9 + 16 def test_unpack(self): g = (x for x in range(3)) assert [*g] == [0, 1, 2] def test_unpack_multiple(self): g1 = (x for x in range(2)) g2 = (x**2 for x in range(2)) assert [*g1, *g2] == [0, 1, 0, 1] class TestSend: def test_accumulator(self): acc = accumulator() assert next(acc) == 0 assert acc.send(10) == 10 assert acc.send(20) == 30 assert acc.send(5) == 35 class TestGeneratorReturn: def test_average(self): g = average() next(g) g.send(10) g.send(20) g.send(30) try: g.send(None) except StopIteration as e: assert e.value == 20.0 class TestYieldFrom: def test_chain(self): assert list(chain([1, 2], [3, 4])) == [1, 2, 3, 4] def test_flatten(self): assert list(flatten([1, [2, [3, 4], 5], 6])) == [1, 2, 3, 4, 5, 6] class TestIterableClass: def test_range(self): assert list(Range(1, 5)) == [1, 2, 3, 4] def test_reversed_range(self): assert list(reversed(Range(1, 5))) == [4, 3, 2, 1] class TestPipeline: def test_pipeline(self): nums = [-1, 2, -3, 4, 5] result = list(double(filter_positive(nums))) assert result == [4, 8, 10] def test_text_pipeline(self): data = [" hello ", "# comment", " world "] result = list(filter_comments(read_lines(data))) assert result == ["hello", "world"] class TestThrowClose: def test_throw(self): g = gen_with_exception() assert next(g) == 1 assert g.throw(ValueError) == "caught" def test_close(self): g = gen_with_cleanup() assert next(g) == 1 g.close() # should not raise class TestGeneratorState: def test_states(self): def gen(): yield 1 g = gen() assert inspect.getgeneratorstate(g) == "GEN_CREATED" next(g) assert inspect.getgeneratorstate(g) == "GEN_SUSPENDED" try: next(g) except StopIteration: pass assert inspect.getgeneratorstate(g) == "GEN_CLOSED" class TestGeneratorType: def test_isinstance(self): def gen(): yield 1 assert isinstance(gen(), GeneratorType) assert not isinstance([1, 2, 3], GeneratorType) class TestContextManager: def test_capture(self): with capture_output() as out: out.append("test") assert out == ["test"] # Prime Generator def prime(n: int): """Generate n prime numbers.""" p = 2 while n > 0: for x in range(2, p): if p % x == 0: break else: yield p n -= 1 p += 1 # Closure using Generator def closure_gen(start: int = 0): """Closure implemented as generator.""" x = start while True: x += 1 yield x # Simple Scheduler def fib(n: int) -> int: """Fibonacci for scheduler example.""" if n <= 2: return 1 return fib(n - 1) + fib(n - 2) def g_fib(n: int): """Generator yielding fibonacci numbers.""" for x in range(1, n + 1): yield fib(x) def run_scheduler(tasks: list) -> list: """Simple round-robin scheduler.""" from collections import deque q = deque(tasks) results = [] while q: try: t = q.popleft() results.append(next(t)) q.append(t) except StopIteration: results.append("done") return results # Compiler Components import re from collections import namedtuple Token = namedtuple("Token", ["type", "value"]) def tokenize(text: str): """Tokenize arithmetic expression.""" tokens = [ r"(?P\d+)", r"(?P\+)", r"(?P-)", r"(?P\*)", r"(?P/)", r"(?P\s+)", ] lex = re.compile("|".join(tokens)) scan = lex.scanner(text) return ( Token(m.lastgroup, m.group()) for m in iter(scan.match, None) if m.lastgroup != "WS" ) class Node: _fields = [] def __init__(self, *args): for attr, value in zip(self._fields, args): setattr(self, attr, value) class Number(Node): _fields = ["value"] class BinOp(Node): _fields = ["op", "left", "right"] def parse(toks): """Parse tokens into AST.""" lookahead, current = next(toks, None), None def accept(*toktypes): nonlocal lookahead, current if lookahead and lookahead.type in toktypes: current, lookahead = lookahead, next(toks, None) return True def expr(): left = term() while accept("PLUS", "MINUS"): left = BinOp(current.value, left) left.right = term() return left def term(): left = factor() while accept("TIMES", "DIVIDE"): left = BinOp(current.value, left) left.right = factor() return left def factor(): if accept("NUMBER"): return Number(int(current.value)) raise SyntaxError() return expr() import types class NodeVisitor: """Visitor using generators for stack-based evaluation.""" def visit(self, node): stack = [self.genvisit(node)] ret = None while stack: try: node = stack[-1].send(ret) stack.append(self.genvisit(node)) ret = None except StopIteration as e: stack.pop() ret = e.value return ret def genvisit(self, node): ret = getattr(self, "visit_" + type(node).__name__)(node) if isinstance(ret, types.GeneratorType): ret = yield from ret return ret class Evaluator(NodeVisitor): """Evaluate AST using generator-based visitor.""" def visit_Number(self, node): return node.value def visit_BinOp(self, node): leftval = yield node.left rightval = yield node.right ops = { "+": lambda a, b: a + b, "-": lambda a, b: a - b, "*": lambda a, b: a * b, "/": lambda a, b: a / b, } return ops[node.op](leftval, rightval) def evaluate(exp: str): """Evaluate arithmetic expression.""" toks = tokenize(exp) tree = parse(toks) return Evaluator().visit(tree) # Async Iterator for comparison class AsyncIter: """Async iterator for performance comparison.""" def __init__(self, n): self._n = n def __aiter__(self): return self async def __anext__(self): if self._n == 0: raise StopAsyncIteration self._n -= 1 return self._n async def agen(n: int): """Async generator for performance comparison.""" for i in range(n): yield i # Additional Tests class TestPrime: def test_prime(self): assert list(prime(5)) == [2, 3, 5, 7, 11] class TestClosure: def test_closure_gen(self): g = closure_gen(5566) assert next(g) == 5567 assert next(g) == 5568 assert next(g) == 5569 class TestScheduler: def test_round_robin(self): results = run_scheduler([g_fib(3), g_fib(3)]) assert results == [1, 1, 1, 1, 2, 2, "done", "done"] class TestCompiler: def test_tokenize(self): tokens = list(tokenize("2 + 3")) assert tokens[0] == Token("NUMBER", "2") assert tokens[1] == Token("PLUS", "+") assert tokens[2] == Token("NUMBER", "3") def test_evaluate_simple(self): assert evaluate("2 + 3") == 5 assert evaluate("2 * 3") == 6 assert evaluate("10 - 4") == 6 assert evaluate("8 / 2") == 4.0 def test_evaluate_complex(self): assert evaluate("2 * 3 + 5") == 11 assert evaluate("2 + 3 * 5") == 17 assert evaluate("2 * 3 + 5 / 2") == 8.5 class TestAsyncGen: def test_async_iter(self): import asyncio async def collect(): return [x async for x in agen(5)] assert asyncio.run(collect()) == [0, 1, 2, 3, 4] def test_async_iter_class(self): import asyncio async def collect(): return [x async for x in AsyncIter(5)] assert asyncio.run(collect()) == [4, 3, 2, 1, 0] ================================================ FILE: src/basic/heap.py ================================================ """Python Heap Examples Source code for docs/notes/basic/python-heap.rst """ import heapq import pytest # Basic Heap Operations def heapify_list(items: list) -> list: """Convert list to heap in-place.""" h = items.copy() heapq.heapify(h) return h def heap_push(h: list, item) -> list: """Push item onto heap.""" heapq.heappush(h, item) return h def heap_pop(h: list): """Pop smallest item from heap.""" return heapq.heappop(h) def heap_pushpop(h: list, item): """Push item then pop smallest.""" return heapq.heappushpop(h, item) def heap_replace(h: list, item): """Pop smallest then push item.""" return heapq.heapreplace(h, item) # Heap Sort def heap_sort(items: list) -> list: """Sort using heap.""" h = items.copy() heapq.heapify(h) return [heapq.heappop(h) for _ in range(len(h))] # Max Heap def max_heap_sort(items: list) -> list: """Sort descending using negated values.""" h = [-x for x in items] heapq.heapify(h) return [-heapq.heappop(h) for _ in range(len(h))] class MaxHeapItem: """Wrapper for max heap behavior.""" def __init__(self, val): self.val = val def __lt__(self, other): return self.val > other.val # Priority Queue def priority_queue_example(): """Priority queue using tuples.""" pq = [] heapq.heappush(pq, (2, "medium")) heapq.heappush(pq, (1, "high")) heapq.heappush(pq, (3, "low")) return [heapq.heappop(pq) for _ in range(len(pq))] # Custom Objects class Task: """Task with priority for heap.""" def __init__(self, priority: int, name: str): self.priority = priority self.name = name def __lt__(self, other): return self.priority < other.priority def __repr__(self): return f"Task({self.priority}, {self.name!r})" def task_queue(): """Process tasks by priority.""" h = [] heapq.heappush(h, Task(3, "low")) heapq.heappush(h, Task(1, "high")) heapq.heappush(h, Task(2, "medium")) return [heapq.heappop(h) for _ in range(len(h))] # K Smallest/Largest def k_smallest(items: list, k: int) -> list: """Find k smallest elements.""" return heapq.nsmallest(k, items) def k_largest(items: list, k: int) -> list: """Find k largest elements.""" return heapq.nlargest(k, items) def k_largest_by_key(items: list, k: int, key) -> list: """Find k largest by key function.""" return heapq.nlargest(k, items, key=key) # Merge Sorted def merge_sorted(*iterables) -> list: """Merge sorted iterables.""" return list(heapq.merge(*iterables)) def merge_sorted_reverse(*iterables) -> list: """Merge sorted iterables in reverse.""" return list(heapq.merge(*iterables, reverse=True)) # Fixed-Size Heap def top_k(items: list, k: int) -> list: """Keep track of k largest elements.""" h = [] for x in items: if len(h) < k: heapq.heappush(h, x) elif x > h[0]: heapq.heapreplace(h, x) return sorted(h, reverse=True) # Indexed Heap class IndexedHeap: """Heap with priority updates.""" REMOVED = "" def __init__(self): self.heap = [] self.entry_finder = {} def push(self, item, priority): if item in self.entry_finder: self.remove(item) entry = [priority, item] self.entry_finder[item] = entry heapq.heappush(self.heap, entry) def remove(self, item): entry = self.entry_finder.pop(item) entry[-1] = self.REMOVED def pop(self): while self.heap: priority, item = heapq.heappop(self.heap) if item is not self.REMOVED: del self.entry_finder[item] return item raise KeyError("pop from empty heap") # Tests class TestBasicHeap: def test_heapify(self): h = heapify_list([5, 1, 3, 2, 6]) assert h[0] == 1 def test_push_pop(self): h = [] heap_push(h, 3) heap_push(h, 1) heap_push(h, 2) assert heap_pop(h) == 1 assert heap_pop(h) == 2 def test_pushpop(self): h = [1, 3, 5] heapq.heapify(h) assert heap_pushpop(h, 2) == 1 def test_replace(self): h = [1, 3, 5] heapq.heapify(h) assert heap_replace(h, 2) == 1 assert h[0] == 2 class TestHeapSort: def test_sort(self): assert heap_sort([5, 1, 3, 2, 6]) == [1, 2, 3, 5, 6] def test_max_sort(self): assert max_heap_sort([5, 1, 3, 2, 6]) == [6, 5, 3, 2, 1] class TestPriorityQueue: def test_priority_order(self): result = priority_queue_example() assert result[0] == (1, "high") assert result[1] == (2, "medium") assert result[2] == (3, "low") class TestCustomObjects: def test_max_heap_item(self): h = [] for x in [5, 1, 3]: heapq.heappush(h, MaxHeapItem(x)) assert heapq.heappop(h).val == 5 def test_task_queue(self): tasks = task_queue() assert tasks[0].name == "high" assert tasks[1].name == "medium" assert tasks[2].name == "low" class TestKElements: def test_k_smallest(self): assert k_smallest([5, 1, 8, 3, 9], 3) == [1, 3, 5] def test_k_largest(self): assert k_largest([5, 1, 8, 3, 9], 3) == [9, 8, 5] def test_k_largest_by_key(self): data = [{"score": 85}, {"score": 92}, {"score": 78}] result = k_largest_by_key(data, 2, key=lambda x: x["score"]) assert result[0]["score"] == 92 assert result[1]["score"] == 85 class TestMerge: def test_merge_sorted(self): assert merge_sorted([1, 3, 5], [2, 4, 6]) == [1, 2, 3, 4, 5, 6] def test_merge_three(self): assert merge_sorted([1, 3], [2, 4], [0, 5]) == [0, 1, 2, 3, 4, 5] def test_merge_reverse(self): assert merge_sorted_reverse([5, 3, 1], [6, 4, 2]) == [6, 5, 4, 3, 2, 1] class TestTopK: def test_top_k(self): assert top_k([5, 1, 8, 3, 9, 2, 7, 4, 6], 3) == [9, 8, 7] class TestIndexedHeap: def test_push_pop(self): h = IndexedHeap() h.push("a", 3) h.push("b", 1) h.push("c", 2) assert h.pop() == "b" def test_update_priority(self): h = IndexedHeap() h.push("task1", 3) h.push("task2", 1) h.push("task1", 0) # update priority assert h.pop() == "task1" ================================================ FILE: src/basic/list.py ================================================ """Python List Examples Source code for docs/notes/basic/python-list.rst """ import bisect import copy import itertools from collections import defaultdict, deque from functools import reduce import pytest # Initialize def init_immutable(n: int) -> list: """Initialize list with immutable objects.""" return [0] * n def init_mutable(n: int) -> list: """Initialize list with mutable objects (correct way).""" return [[] for _ in range(n)] # Copy def shallow_copy(lst: list) -> list: """Create shallow copy of list.""" return lst.copy() def deep_copy(lst: list) -> list: """Create deep copy of nested list.""" return copy.deepcopy(lst) # List Comprehensions def squares(n: int) -> list: """List of squares.""" return [x**2 for x in range(n)] def filter_even(nums: list) -> list: """Filter even numbers.""" return [x for x in nums if x % 2 == 0] def flatten(nested: list) -> list: """Flatten nested list.""" return [x for sublist in nested for x in sublist] # Unpacking def extended_unpack(lst: list) -> tuple: """Extended unpacking with *.""" first, *middle, last = lst return first, middle, last # Enumerate and Zip def enumerate_example(items: list) -> list: """Enumerate with index.""" return [(i, v) for i, v in enumerate(items)] def zip_to_dict(keys: list, values: list) -> dict: """Create dict from two lists.""" return dict(zip(keys, values)) def unzip(pairs: list) -> tuple: """Unzip list of pairs.""" return tuple(zip(*pairs)) # Sorting def sort_by_key(items: list, key_func) -> list: """Sort by custom key.""" return sorted(items, key=key_func) def sort_dicts(dicts: list, key: str) -> list: """Sort list of dicts by key.""" return sorted(dicts, key=lambda x: x[key]) # Stack class Stack: """Stack implementation using list.""" def __init__(self): self._items = [] def push(self, item): self._items.append(item) def pop(self): return self._items.pop() def peek(self): return self._items[-1] if self._items else None def is_empty(self): return len(self._items) == 0 def __len__(self): return len(self._items) # Tests class TestInitialize: def test_immutable(self): a = init_immutable(3) a[0] = 1 assert a == [1, 0, 0] def test_mutable(self): a = init_mutable(3) a[0].append(1) assert a == [[1], [], []] class TestCopy: def test_shallow(self): a = [1, 2, 3] b = shallow_copy(a) b[0] = 99 assert a == [1, 2, 3] def test_deep(self): a = [[1, 2], [3, 4]] b = deep_copy(a) b[0][0] = 99 assert a == [[1, 2], [3, 4]] class TestComprehensions: def test_squares(self): assert squares(5) == [0, 1, 4, 9, 16] def test_filter_even(self): assert filter_even([1, 2, 3, 4, 5, 6]) == [2, 4, 6] def test_flatten(self): assert flatten([[1, 2], [3, 4]]) == [1, 2, 3, 4] class TestUnpacking: def test_extended(self): first, middle, last = extended_unpack([1, 2, 3, 4, 5]) assert first == 1 assert middle == [2, 3, 4] assert last == 5 class TestEnumerateZip: def test_enumerate(self): assert enumerate_example(["a", "b"]) == [(0, "a"), (1, "b")] def test_zip_to_dict(self): assert zip_to_dict(["a", "b"], [1, 2]) == {"a": 1, "b": 2} def test_unzip(self): nums, chars = unzip([(1, "a"), (2, "b")]) assert nums == (1, 2) assert chars == ("a", "b") class TestSorting: def test_sort_by_key(self): assert sort_by_key(["bb", "a", "ccc"], len) == ["a", "bb", "ccc"] def test_sort_dicts(self): data = [{"n": 2}, {"n": 1}] assert sort_dicts(data, "n") == [{"n": 1}, {"n": 2}] class TestStack: def test_push_pop(self): s = Stack() s.push(1) s.push(2) assert s.pop() == 2 assert s.pop() == 1 def test_peek(self): s = Stack() s.push(1) assert s.peek() == 1 assert len(s) == 1 def test_is_empty(self): s = Stack() assert s.is_empty() s.push(1) assert not s.is_empty() # Bisect - Maintain Sorted List def bisect_insort(items: list) -> list: """Insert items maintaining sorted order.""" result = [] for x in items: bisect.insort(result, x) return result def bisect_left_example(lst: list, x) -> int: """Find lower bound position.""" return bisect.bisect_left(lst, x) def bisect_right_example(lst: list, x) -> int: """Find upper bound position.""" return bisect.bisect_right(lst, x) def binary_search(arr: list, x, lo: int = 0, hi: int = None) -> int: """Binary search in sorted list.""" if hi is None: hi = len(arr) pos = bisect.bisect_left(arr, x, lo, hi) return pos if pos != hi and arr[pos] == x else -1 # Nested Lists def create_2d_list(rows: int, cols: int) -> list: """Create 2D list correctly.""" return [[0] * cols for _ in range(rows)] # Deque - Circular Buffer def circular_buffer(items: list, maxlen: int) -> deque: """Create circular buffer with deque.""" d = deque(maxlen=maxlen) for x in items: d.append(x) return d # Chunk List def chunk(lst: list, n: int) -> list: """Split list into chunks of size n.""" return [lst[i : i + n] for i in range(0, len(lst), n)] # Groupby def groupby_example(s: str) -> list: """Group consecutive elements.""" return [(k, list(v)) for k, v in itertools.groupby(s)] # Trie def create_trie(words: list) -> dict: """Create trie from list of words.""" Trie = lambda: defaultdict(Trie) trie = Trie() end = True for word in words: reduce(dict.__getitem__, word, trie)[end] = word return trie def trie_has_prefix(trie: dict, prefix: str) -> bool: """Check if trie has prefix.""" curr = trie for c in prefix: if c not in curr: return False curr = curr[c] return True class TestBisect: def test_insort(self): assert bisect_insort([3, 1, 2, 0]) == [0, 1, 2, 3] def test_bisect_left(self): a = [1, 2, 3, 3, 4, 5] assert bisect_left_example(a, 3) == 2 def test_bisect_right(self): a = [1, 2, 3, 3, 4, 5] assert bisect_right_example(a, 3) == 4 def test_binary_search(self): a = [1, 1, 1, 2, 3] assert binary_search(a, 1) == 0 assert binary_search(a, 2) == 3 assert binary_search(a, 99) == -1 class TestNestedLists: def test_create_2d(self): grid = create_2d_list(2, 3) grid[0][0] = 1 assert grid == [[1, 0, 0], [0, 0, 0]] class TestDeque: def test_circular_buffer(self): d = circular_buffer(range(5), 3) assert list(d) == [2, 3, 4] class TestChunk: def test_chunk(self): assert chunk([1, 2, 3, 4, 5, 6, 7, 8], 3) == [ [1, 2, 3], [4, 5, 6], [7, 8], ] class TestGroupby: def test_groupby(self): result = groupby_example("AAABBC") assert result == [ ("A", ["A", "A", "A"]), ("B", ["B", "B"]), ("C", ["C"]), ] class TestTrie: def test_create_and_search(self): trie = create_trie(["abc", "de", "g"]) assert trie_has_prefix(trie, "ab") assert trie_has_prefix(trie, "abc") assert not trie_has_prefix(trie, "xyz") ================================================ FILE: src/basic/object.py ================================================ """Python OOP Examples Source code for docs/notes/basic/python-object.rst """ from abc import ABC, abstractmethod from functools import total_ordering import pytest # Basic Class class Person: """Basic class with __init__.""" def __init__(self, name: str, age: int): self.name = name self.age = age def greet(self) -> str: return f"Hello, I'm {self.name}" # Class and Instance Attributes class Counter: """Class with class and instance attributes.""" count = 0 def __init__(self): Counter.count += 1 self.id = Counter.count # Inheritance class Animal: """Base class for inheritance.""" def __init__(self, name: str): self.name = name def speak(self) -> str: raise NotImplementedError class Dog(Animal): def speak(self) -> str: return f"{self.name} says Woof!" class Cat(Animal): def speak(self) -> str: return f"{self.name} says Meow!" # Magic Methods - repr and str class Vector: """Class with magic methods.""" def __init__(self, x: int, y: int): self.x, self.y = x, y def __repr__(self): return f"Vector({self.x}, {self.y})" def __str__(self): return f"({self.x}, {self.y})" def __add__(self, other): return Vector(self.x + other.x, self.y + other.y) def __mul__(self, scalar): return Vector(self.x * scalar, self.y * scalar) def __eq__(self, other): return self.x == other.x and self.y == other.y # Comparison with total_ordering @total_ordering class Number: """Class with comparison methods.""" def __init__(self, val): self.val = val def __eq__(self, other): return self.val == other.val def __lt__(self, other): return self.val < other.val # Callable class Multiplier: """Callable class.""" def __init__(self, factor: int): self.factor = factor def __call__(self, x: int) -> int: return x * self.factor # Property class Circle: """Class with property.""" def __init__(self, radius: float): self._radius = radius @property def radius(self) -> float: return self._radius @radius.setter def radius(self, value: float): if value < 0: raise ValueError("Radius must be positive") self._radius = value @property def area(self) -> float: return 3.14159 * self._radius**2 # Descriptor class Positive: """Descriptor that enforces positive values.""" def __init__(self, name): self.name = name def __get__(self, obj, objtype=None): if obj is None: return self return obj.__dict__[self.name] def __set__(self, obj, value): if value < 0: raise ValueError("Must be positive") obj.__dict__[self.name] = value class DescriptorExample: x = Positive("x") def __init__(self, x): self.x = x # Context Manager class ManagedResource: """Context manager class.""" def __init__(self): self.entered = False self.exited = False def __enter__(self): self.entered = True return self def __exit__(self, exc_type, exc_val, exc_tb): self.exited = True return False # Singleton class Singleton: """Singleton pattern.""" _instance = None def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance # Class Methods and Static Methods class Date: """Class with classmethod and staticmethod.""" def __init__(self, year: int, month: int, day: int): self.year, self.month, self.day = year, month, day @classmethod def from_string(cls, date_string: str): year, month, day = map(int, date_string.split("-")) return cls(year, month, day) @staticmethod def is_valid(date_string: str) -> bool: try: year, month, day = map(int, date_string.split("-")) return 1 <= month <= 12 and 1 <= day <= 31 except Exception: return False # Abstract Base Class class Shape(ABC): """Abstract base class.""" @abstractmethod def area(self) -> float: pass class Rectangle(Shape): def __init__(self, width: float, height: float): self.width, self.height = width, height def area(self) -> float: return self.width * self.height # MRO - Diamond Problem class A: def method(self): return "A" class B(A): def method(self): return "B" class C(A): def method(self): return "C" class D(B, C): pass # Slots class PointWithSlots: """Class with __slots__ for memory efficiency.""" __slots__ = ["x", "y"] def __init__(self, x, y): self.x, self.y = x, y # Tests class TestBasicClass: def test_person(self): p = Person("Alice", 30) assert p.name == "Alice" assert p.greet() == "Hello, I'm Alice" class TestClassAttributes: def test_counter(self): Counter.count = 0 a, b = Counter(), Counter() assert a.id == 1 assert b.id == 2 assert Counter.count == 2 class TestInheritance: def test_dog(self): assert Dog("Buddy").speak() == "Buddy says Woof!" def test_cat(self): assert Cat("Whiskers").speak() == "Whiskers says Meow!" class TestMagicMethods: def test_vector_add(self): v1 = Vector(1, 2) v2 = Vector(3, 4) assert v1 + v2 == Vector(4, 6) def test_vector_mul(self): v = Vector(1, 2) assert v * 3 == Vector(3, 6) def test_vector_repr(self): assert repr(Vector(1, 2)) == "Vector(1, 2)" def test_vector_str(self): assert str(Vector(1, 2)) == "(1, 2)" class TestComparison: def test_total_ordering(self): assert Number(1) < Number(2) assert Number(2) > Number(1) assert Number(2) >= Number(1) assert Number(1) <= Number(2) assert Number(1) == Number(1) class TestCallable: def test_multiplier(self): double = Multiplier(2) assert double(5) == 10 assert callable(double) class TestProperty: def test_circle_area(self): c = Circle(5) assert abs(c.area - 78.53975) < 0.001 def test_circle_setter(self): c = Circle(5) c.radius = 10 assert c.radius == 10 def test_circle_invalid(self): c = Circle(5) with pytest.raises(ValueError): c.radius = -1 class TestDescriptor: def test_positive(self): ex = DescriptorExample(10) assert ex.x == 10 def test_positive_invalid(self): with pytest.raises(ValueError): DescriptorExample(-1) class TestContextManager: def test_managed_resource(self): with ManagedResource() as r: assert r.entered assert r.exited class TestSingleton: def test_singleton(self): Singleton._instance = None # reset a = Singleton() b = Singleton() assert a is b class TestClassMethods: def test_from_string(self): d = Date.from_string("2024-01-15") assert d.year == 2024 assert d.month == 1 def test_is_valid(self): assert Date.is_valid("2024-01-15") assert not Date.is_valid("2024-13-01") class TestABC: def test_rectangle(self): r = Rectangle(3, 4) assert r.area() == 12 def test_abstract_instantiation(self): with pytest.raises(TypeError): Shape() class TestMRO: def test_diamond(self): assert D().method() == "B" def test_mro(self): assert D.mro() == [D, B, C, A, object] class TestSlots: def test_slots(self): p = PointWithSlots(1, 2) assert p.x == 1 assert p.y == 2 def test_slots_no_dict(self): p = PointWithSlots(1, 2) assert not hasattr(p, "__dict__") ================================================ FILE: src/basic/os_.py ================================================ """ Tests for operating system operations. These tests demonstrate Python's os module for file system operations, process management, environment variables, and path manipulation. """ import os import platform import subprocess import tempfile from pathlib import Path import pytest class TestSystemInfo: """Test system information functions.""" def test_os_name(self): """Test os.name returns valid value.""" assert os.name in ("posix", "nt") def test_platform_system(self): """Test platform.system returns valid value.""" assert platform.system() in ("Linux", "Darwin", "Windows") def test_cpu_count(self): """Test cpu_count returns positive integer.""" count = os.cpu_count() assert count is not None assert count > 0 def test_getpid(self): """Test getpid returns positive integer.""" pid = os.getpid() assert pid > 0 def test_getcwd(self): """Test getcwd returns valid path.""" cwd = os.getcwd() assert os.path.isdir(cwd) class TestEnvironmentVariables: """Test environment variable operations.""" def test_get_env(self): """Test getting environment variable.""" # PATH should exist on all systems path = os.environ.get("PATH") assert path is not None def test_getenv_default(self): """Test getenv with default value.""" value = os.getenv("NONEXISTENT_VAR_12345", "default") assert value == "default" def test_set_env(self): """Test setting environment variable.""" os.environ["TEST_VAR_PYSHEEET"] = "test_value" assert os.environ["TEST_VAR_PYSHEEET"] == "test_value" del os.environ["TEST_VAR_PYSHEEET"] def test_env_not_found(self): """Test KeyError for missing env var.""" with pytest.raises(KeyError): _ = os.environ["NONEXISTENT_VAR_12345"] class TestPathOperations: """Test path manipulation functions.""" def test_join(self): """Test os.path.join.""" path = os.path.join("dir1", "dir2", "file.txt") assert "dir1" in path assert "dir2" in path assert "file.txt" in path def test_dirname_basename(self): """Test dirname and basename.""" path = os.path.join("home", "user", "file.txt") assert os.path.basename(path) == "file.txt" assert "user" in os.path.dirname(path) def test_splitext(self): """Test splitext.""" name, ext = os.path.splitext("file.txt") assert name == "file" assert ext == ".txt" def test_abspath(self): """Test abspath returns absolute path.""" abs_path = os.path.abspath(".") assert os.path.isabs(abs_path) def test_exists(self): """Test path existence checks.""" assert os.path.exists(".") assert not os.path.exists("/nonexistent/path/12345") def test_isfile_isdir(self): """Test isfile and isdir.""" assert os.path.isdir(".") # Create temp file to test isfile with tempfile.NamedTemporaryFile() as f: assert os.path.isfile(f.name) class TestDirectoryOperations: """Test directory operations.""" def test_mkdir_rmdir(self): """Test creating and removing directory.""" with tempfile.TemporaryDirectory() as tmpdir: new_dir = os.path.join(tmpdir, "test_dir") os.mkdir(new_dir) assert os.path.isdir(new_dir) os.rmdir(new_dir) assert not os.path.exists(new_dir) def test_makedirs(self): """Test creating nested directories.""" with tempfile.TemporaryDirectory() as tmpdir: nested = os.path.join(tmpdir, "a", "b", "c") os.makedirs(nested) assert os.path.isdir(nested) def test_makedirs_exist_ok(self): """Test makedirs with exist_ok.""" with tempfile.TemporaryDirectory() as tmpdir: os.makedirs(tmpdir, exist_ok=True) # Should not raise def test_listdir(self): """Test listing directory contents.""" with tempfile.TemporaryDirectory() as tmpdir: # Create some files Path(tmpdir, "file1.txt").touch() Path(tmpdir, "file2.txt").touch() entries = os.listdir(tmpdir) assert "file1.txt" in entries assert "file2.txt" in entries def test_walk(self): """Test walking directory tree.""" with tempfile.TemporaryDirectory() as tmpdir: # Create nested structure os.makedirs(os.path.join(tmpdir, "subdir")) Path(tmpdir, "file1.txt").touch() Path(tmpdir, "subdir", "file2.txt").touch() files_found = [] for root, dirs, files in os.walk(tmpdir): for f in files: files_found.append(f) assert "file1.txt" in files_found assert "file2.txt" in files_found class TestFileOperations: """Test file operations.""" def test_rename(self): """Test renaming file.""" with tempfile.TemporaryDirectory() as tmpdir: old = os.path.join(tmpdir, "old.txt") new = os.path.join(tmpdir, "new.txt") Path(old).write_text("content") os.rename(old, new) assert not os.path.exists(old) assert os.path.exists(new) def test_remove(self): """Test removing file.""" with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, "file.txt") Path(path).write_text("content") os.remove(path) assert not os.path.exists(path) def test_getsize(self): """Test getting file size.""" with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: f.write("hello") path = f.name try: size = os.path.getsize(path) assert size == 5 finally: os.unlink(path) class TestSubprocess: """Test subprocess operations.""" def test_run_simple(self): """Test simple subprocess.run.""" result = subprocess.run( ["echo", "hello"], capture_output=True, text=True ) assert result.returncode == 0 assert "hello" in result.stdout def test_check_output(self): """Test subprocess.check_output.""" output = subprocess.check_output(["echo", "test"], text=True) assert "test" in output def test_run_with_input(self): """Test subprocess with input.""" result = subprocess.run( ["cat"], input="hello world", capture_output=True, text=True ) assert result.stdout == "hello world" class TestTempFiles: """Test temporary file operations.""" def test_named_temp_file(self): """Test NamedTemporaryFile.""" with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as f: f.write("test data") f.flush() assert os.path.exists(f.name) assert f.name.endswith(".txt") def test_temp_directory(self): """Test TemporaryDirectory.""" with tempfile.TemporaryDirectory() as tmpdir: assert os.path.isdir(tmpdir) # Create file inside path = os.path.join(tmpdir, "test.txt") Path(path).write_text("data") assert os.path.exists(path) # Directory should be deleted assert not os.path.exists(tmpdir) def test_gettempdir(self): """Test getting temp directory path.""" tmpdir = tempfile.gettempdir() assert os.path.isdir(tmpdir) class TestPathlib: """Test pathlib operations.""" def test_path_creation(self): """Test creating Path objects.""" p = Path("/home/user/file.txt") assert p.name == "file.txt" assert p.stem == "file" assert p.suffix == ".txt" def test_path_join(self): """Test joining paths with /.""" p = Path("/home") / "user" / "file.txt" assert str(p) == "/home/user/file.txt" def test_path_parent(self): """Test getting parent.""" p = Path("/home/user/file.txt") assert p.parent == Path("/home/user") def test_path_exists(self): """Test path existence.""" assert Path(".").exists() assert Path(".").is_dir() def test_read_write_text(self): """Test reading and writing text.""" with tempfile.TemporaryDirectory() as tmpdir: p = Path(tmpdir) / "test.txt" p.write_text("hello world") assert p.read_text() == "hello world" def test_mkdir(self): """Test creating directory.""" with tempfile.TemporaryDirectory() as tmpdir: p = Path(tmpdir) / "a" / "b" / "c" p.mkdir(parents=True) assert p.is_dir() def test_glob(self): """Test glob pattern matching.""" with tempfile.TemporaryDirectory() as tmpdir: (Path(tmpdir) / "file1.py").touch() (Path(tmpdir) / "file2.py").touch() (Path(tmpdir) / "file3.txt").touch() py_files = list(Path(tmpdir).glob("*.py")) assert len(py_files) == 2 def test_iterdir(self): """Test iterating directory.""" with tempfile.TemporaryDirectory() as tmpdir: (Path(tmpdir) / "file1.txt").touch() (Path(tmpdir) / "file2.txt").touch() entries = list(Path(tmpdir).iterdir()) assert len(entries) == 2 class TestPsutil: """Test psutil operations.""" @pytest.fixture def psutil_available(self): """Check if psutil is installed.""" try: import psutil return psutil except ImportError: pytest.skip("psutil not installed") def test_cpu_count(self, psutil_available): """Test CPU count.""" psutil = psutil_available logical = psutil.cpu_count() physical = psutil.cpu_count(logical=False) assert logical > 0 assert physical > 0 assert logical >= physical def test_cpu_percent(self, psutil_available): """Test CPU percentage.""" psutil = psutil_available percent = psutil.cpu_percent(interval=0.1) assert 0 <= percent <= 100 def test_virtual_memory(self, psutil_available): """Test virtual memory info.""" psutil = psutil_available mem = psutil.virtual_memory() assert mem.total > 0 assert mem.available > 0 assert 0 <= mem.percent <= 100 def test_disk_usage(self, psutil_available): """Test disk usage.""" psutil = psutil_available usage = psutil.disk_usage("/") assert usage.total > 0 assert usage.used >= 0 assert 0 <= usage.percent <= 100 def test_process(self, psutil_available): """Test current process info.""" psutil = psutil_available p = psutil.Process() assert p.pid == os.getpid() assert p.name() assert p.num_threads() > 0 def test_boot_time(self, psutil_available): """Test boot time.""" psutil = psutil_available boot = psutil.boot_time() assert boot > 0 ================================================ FILE: src/basic/rexp.py ================================================ """Python Regular Expression Examples Source code for docs/notes/basic/python-rexp.rst """ import re from collections import namedtuple import pytest # Basic Matching def search_pattern(pattern: str, text: str) -> str | None: """Find first match of pattern in text.""" m = re.search(pattern, text) return m.group() if m else None def match_start(pattern: str, text: str) -> bool: """Check if text starts with pattern.""" return re.match(pattern, text) is not None def fullmatch(pattern: str, text: str) -> bool: """Check if entire text matches pattern.""" return re.fullmatch(pattern, text) is not None # Find All def find_all(pattern: str, text: str) -> list: """Find all matches of pattern.""" return re.findall(pattern, text) def find_all_groups(pattern: str, text: str) -> list: """Find all matches with groups.""" return re.findall(pattern, text) # Split def split_pattern(pattern: str, text: str) -> list: """Split text by pattern.""" return re.split(pattern, text) # Groups def parse_date(text: str) -> dict | None: """Parse date string into components.""" m = re.search(r"(\d{4})-(\d{2})-(\d{2})", text) if m: return {"year": m.group(1), "month": m.group(2), "day": m.group(3)} return None def parse_date_named(text: str) -> dict | None: """Parse date using named groups.""" pattern = r"(?P\d{4})-(?P\d{2})-(?P\d{2})" m = re.search(pattern, text) return m.groupdict() if m else None # Non-capturing group def parse_url(url: str) -> tuple | None: """Parse URL with non-capturing group for protocol.""" m = re.search(r"(?:https?|ftp)://([^/\r\n]+)(/[^\r\n]*)?", url) return m.groups() if m else None # Back Reference def has_repeated_char(text: str) -> bool: """Check if text has repeated adjacent characters.""" return re.search(r"(\w)\1", text) is not None def match_html_tag(text: str) -> str | None: """Match HTML tag with matching close tag.""" m = re.search(r"<(\w+)>[^<]*", text) return m.group() if m else None # Lookahead/Lookbehind def find_before_at(text: str) -> list: """Find words before @ symbol (positive lookahead).""" return re.findall(r"\w+(?=@)", text) def find_after_dollar(text: str) -> list: """Find numbers after $ symbol (positive lookbehind).""" return re.findall(r"(?<=\$)\d+", text) def find_not_followed_by(text: str, suffix: str) -> list: """Find digits not followed by suffix (negative lookahead).""" return re.findall(rf"\d+(?!{suffix})", text) # Substitution def replace_pattern(pattern: str, repl: str, text: str) -> str: """Replace pattern with replacement.""" return re.sub(pattern, repl, text) def double_numbers(text: str) -> str: """Double all numbers in text using function replacement.""" return re.sub(r"\d+", lambda m: str(int(m.group()) * 2), text) def camel_to_snake(s: str) -> str: """Convert CamelCase to snake_case.""" s = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", s) return re.sub(r"([a-z])([A-Z])", r"\1_\2", s).lower() # Compiled Patterns EMAIL_PATTERN = re.compile(r"^[\w.+-]+@[\w-]+\.[\w.-]+$") IP_PATTERN = re.compile( r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}" r"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$" ) MAC_PATTERN = re.compile(r"^([0-9a-f]{2}:){5}[0-9a-f]{2}$", re.I) URL_PATTERN = re.compile( r"^(https?://)?([\da-z.-]+)\.([a-z.]{2,6})([/\w.-]*)*/?$", re.I ) HEX_COLOR_PATTERN = re.compile(r"^#?([a-fA-F0-9]{6}|[a-fA-F0-9]{3})$") PHONE_PATTERN = re.compile( r"^(\+1)?[-.\s]?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}$" ) PASSWORD_PATTERN = re.compile( r"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[@$!%*?&])[A-Za-z\d@$!%*?&]{8,}$" ) def is_valid_email(text: str) -> bool: """Validate email address.""" return EMAIL_PATTERN.match(text) is not None def is_valid_ip(text: str) -> bool: """Validate IPv4 address.""" return IP_PATTERN.match(text) is not None def is_valid_mac(text: str) -> bool: """Validate MAC address.""" return MAC_PATTERN.match(text) is not None def is_valid_url(text: str) -> bool: """Validate URL.""" return URL_PATTERN.match(text) is not None def is_valid_hex_color(text: str) -> bool: """Validate hex color code.""" return HEX_COLOR_PATTERN.match(text) is not None def is_valid_phone(text: str) -> bool: """Validate US phone number.""" return PHONE_PATTERN.match(text) is not None def is_strong_password(text: str) -> bool: """Validate password strength.""" return PASSWORD_PATTERN.match(text) is not None # HTML Tags def find_open_tags(html: str) -> list: """Find all open tags.""" return re.findall(r"<[^/>][^>]*>", html) def find_close_tags(html: str) -> list: """Find all close tags.""" return re.findall(r"]+>", html) def strip_html_tags(html: str) -> str: """Remove all HTML tags.""" return re.sub(r"<[^>]+>", "", html) # Lexer Token = namedtuple("Token", ["type", "value"]) def tokenize(text: str) -> list: """Tokenize arithmetic expression.""" tokens = [ r"(?P\d+)", r"(?P\+)", r"(?P-)", r"(?P\*)", r"(?P/)", r"(?P\s+)", ] lex = re.compile("|".join(tokens)) scan = lex.scanner(text) return [ Token(m.lastgroup, m.group()) for m in iter(scan.match, None) if m.lastgroup != "WS" ] # Utility functions def find_hashtags(text: str) -> list: """Find all hashtags in text.""" return re.findall(r"#\w+", text) def find_mentions(text: str) -> list: """Find all @mentions in text.""" return re.findall(r"@\w+", text) def extract_domain(url: str) -> str | None: """Extract domain from URL.""" m = re.search(r"https?://([^/]+)", url) return m.group(1) if m else None # Tests class TestBasicMatching: def test_search(self): assert search_pattern(r"\d+", "abc123def") == "123" assert search_pattern(r"\d+", "no digits") is None def test_match_start(self): assert match_start(r"\d+", "123abc") assert not match_start(r"\d+", "abc123") def test_fullmatch(self): assert fullmatch(r"\d+", "123") assert not fullmatch(r"\d+", "123abc") class TestFindAll: def test_find_all(self): assert find_all(r"\d+", "a1b22c333") == ["1", "22", "333"] def test_find_all_groups(self): assert find_all_groups(r"(\w+)=(\d+)", "a=1 b=2") == [ ("a", "1"), ("b", "2"), ] class TestSplit: def test_split(self): assert split_pattern(r"\s+", "a b c") == ["a", "b", "c"] assert split_pattern(r"[,;]", "a,b;c") == ["a", "b", "c"] class TestGroups: def test_parse_date(self): result = parse_date("2024-01-15") assert result == {"year": "2024", "month": "01", "day": "15"} def test_parse_date_named(self): result = parse_date_named("2024-01-15") assert result == {"year": "2024", "month": "01", "day": "15"} def test_parse_url(self): assert parse_url("http://example.com/path") == ("example.com", "/path") class TestBackReference: def test_repeated_char(self): assert has_repeated_char("hello") # ll assert not has_repeated_char("world") def test_html_tag(self): assert match_html_tag("bold") == "bold" assert match_html_tag("text") is None class TestLookaround: def test_lookahead(self): assert find_before_at("user@example.com") == ["user"] def test_lookbehind(self): assert find_after_dollar("$100 $200") == ["100", "200"] def test_negative_lookahead(self): assert find_not_followed_by("12px 34em 56", "px") == ["1", "34", "56"] class TestSubstitution: def test_replace(self): assert replace_pattern(r"\d+", "X", "a1b2c3") == "aXbXcX" def test_double_numbers(self): assert double_numbers("a1b2c3") == "a2b4c6" def test_camel_to_snake(self): assert camel_to_snake("CamelCase") == "camel_case" assert camel_to_snake("SimpleHTTPServer") == "simple_http_server" class TestValidation: def test_email(self): assert is_valid_email("user@example.com") assert is_valid_email("user+tag@sub.domain.org") assert not is_valid_email("invalid@") def test_ip(self): assert is_valid_ip("192.168.1.1") assert is_valid_ip("255.255.255.0") assert not is_valid_ip("256.0.0.0") def test_mac(self): assert is_valid_mac("3c:38:51:05:03:1e") assert is_valid_mac("AA:BB:CC:DD:EE:FF") def test_url(self): assert is_valid_url("https://www.example.com/path") assert is_valid_url("example.com") def test_hex_color(self): assert is_valid_hex_color("#ffffff") assert is_valid_hex_color("fff") assert not is_valid_hex_color("#gggggg") def test_phone(self): assert is_valid_phone("123-456-7890") assert is_valid_phone("(123) 456-7890") def test_password(self): assert is_strong_password("Passw0rd!") assert not is_strong_password("weakpass") class TestHtmlTags: def test_open_tags(self): assert "" in find_open_tags("
") def test_close_tags(self): assert "" in find_close_tags("
") def test_strip_tags(self): assert strip_html_tags("

Hello

") == "Hello" class TestLexer: def test_tokenize(self): tokens = tokenize("9 + 5 * 2") assert tokens[0] == Token("NUMBER", "9") assert tokens[1] == Token("PLUS", "+") assert tokens[2] == Token("NUMBER", "5") class TestUtility: def test_hashtags(self): assert find_hashtags("Hello #world #python") == ["#world", "#python"] def test_mentions(self): assert find_mentions("Hello @user @admin") == ["@user", "@admin"] def test_extract_domain(self): assert ( extract_domain("https://www.example.com/path") == "www.example.com" ) ================================================ FILE: src/basic/set.py ================================================ """Python Set Examples Source code for docs/notes/basic/python-set.rst """ import pytest # Create a Set def create_set_literal(): """Create set using literal syntax.""" return {1, 2, 3} def create_set_from_list(items: list) -> set: """Create set from list, removing duplicates.""" return set(items) def create_empty_set() -> set: """Create empty set.""" return set() # Set Comprehension def set_comprehension_basic(items: list) -> set: """Basic set comprehension.""" return {x for x in items} def set_comprehension_filter(items: list, threshold: int) -> set: """Set comprehension with filter.""" return {x for x in items if x > threshold} def set_comprehension_squares(n: int) -> set: """Set of squares.""" return {x**2 for x in range(n)} # Uniquify def uniquify_list(items: list) -> list: """Remove duplicates from list.""" return list(set(items)) def uniquify_preserve_order(items: list) -> list: """Remove duplicates preserving order (Python 3.7+).""" return list(dict.fromkeys(items)) # Add Items def add_single(s: set, item) -> set: """Add single item to set.""" s.add(item) return s def add_multiple(s: set, items) -> set: """Add multiple items to set.""" s.update(items) return s # Remove Items def remove_item(s: set, item) -> set: """Remove item from set (raises KeyError if missing).""" s.remove(item) return s def discard_item(s: set, item) -> set: """Remove item from set (no error if missing).""" s.discard(item) return s def pop_item(s: set): """Remove and return arbitrary item.""" return s.pop() # Set Operations def union(a: set, b: set) -> set: """Union of two sets.""" return a | b def intersection(a: set, b: set) -> set: """Intersection of two sets.""" return a & b def difference(a: set, b: set) -> set: """Difference of two sets (a - b).""" return a - b def symmetric_difference(a: set, b: set) -> set: """Symmetric difference of two sets.""" return a ^ b def is_subset(a: set, b: set) -> bool: """Check if a is subset of b.""" return a <= b def is_proper_subset(a: set, b: set) -> bool: """Check if a is proper subset of b.""" return a < b def is_superset(a: set, b: set) -> bool: """Check if a is superset of b.""" return a >= b def is_disjoint(a: set, b: set) -> bool: """Check if sets have no common elements.""" return a.isdisjoint(b) # Membership def membership_test(s: set, item) -> bool: """Test if item is in set.""" return item in s # Frozenset def create_frozenset(items: list) -> frozenset: """Create immutable frozenset.""" return frozenset(items) def frozenset_as_dict_key(): """Use frozenset as dictionary key.""" return {frozenset([1, 2]): "a", frozenset([3, 4]): "b"} def frozenset_in_set(): """Use frozenset as set element.""" return {frozenset([1, 2]), frozenset([3, 4])} # Tests class TestSetCreation: def test_literal(self): assert create_set_literal() == {1, 2, 3} def test_from_list(self): assert create_set_from_list([1, 2, 2, 3]) == {1, 2, 3} def test_empty(self): assert create_empty_set() == set() assert len(create_empty_set()) == 0 class TestSetComprehension: def test_basic(self): assert set_comprehension_basic([1, 2, 2, 3]) == {1, 2, 3} def test_filter(self): assert set_comprehension_filter([1, 2, 3, 4, 5], 3) == {4, 5} def test_squares(self): assert set_comprehension_squares(5) == {0, 1, 4, 9, 16} class TestUniquify: def test_uniquify(self): result = uniquify_list([1, 2, 2, 3, 3, 3]) assert set(result) == {1, 2, 3} def test_preserve_order(self): assert uniquify_preserve_order([3, 1, 2, 1, 3]) == [3, 1, 2] class TestAddRemove: def test_add_single(self): s = {1, 2} assert add_single(s, 3) == {1, 2, 3} def test_add_multiple(self): s = {1, 2} assert add_multiple(s, [3, 4]) == {1, 2, 3, 4} def test_remove(self): s = {1, 2, 3} assert remove_item(s, 2) == {1, 3} def test_remove_missing(self): s = {1, 2, 3} with pytest.raises(KeyError): remove_item(s, 10) def test_discard(self): s = {1, 2, 3} assert discard_item(s, 2) == {1, 3} assert discard_item(s, 10) == {1, 3} # no error def test_pop(self): s = {1, 2, 3} item = pop_item(s) assert item in {1, 2, 3} assert len(s) == 2 class TestSetOperations: def test_union(self): assert union({1, 2}, {2, 3}) == {1, 2, 3} def test_intersection(self): assert intersection({1, 2, 3}, {2, 3, 4}) == {2, 3} def test_difference(self): assert difference({1, 2, 3}, {2, 3, 4}) == {1} def test_symmetric_difference(self): assert symmetric_difference({1, 2, 3}, {2, 3, 4}) == {1, 4} def test_subset(self): assert is_subset({1, 2}, {1, 2, 3}) assert is_subset({1, 2}, {1, 2}) # equal is subset assert not is_subset({1, 2, 3}, {1, 2}) def test_proper_subset(self): assert is_proper_subset({1, 2}, {1, 2, 3}) assert not is_proper_subset({1, 2}, {1, 2}) # equal is not proper def test_superset(self): assert is_superset({1, 2, 3}, {1, 2}) assert not is_superset({1, 2}, {1, 2, 3}) def test_disjoint(self): assert is_disjoint({1, 2}, {3, 4}) assert not is_disjoint({1, 2}, {2, 3}) class TestMembership: def test_in(self): assert membership_test({1, 2, 3}, 2) assert not membership_test({1, 2, 3}, 10) class TestFrozenset: def test_create(self): fs = create_frozenset([1, 2, 2, 3]) assert fs == frozenset({1, 2, 3}) def test_immutable(self): fs = create_frozenset([1, 2, 3]) assert not hasattr(fs, "add") def test_as_dict_key(self): d = frozenset_as_dict_key() assert d[frozenset([1, 2])] == "a" def test_in_set(self): s = frozenset_in_set() assert frozenset([1, 2]) in s assert frozenset([5, 6]) not in s ================================================ FILE: src/basic/socket_.py ================================================ """Network/socket examples and tests for pysheeet documentation.""" import socket import threading import time import pytest class TestHostname: """Test hostname and DNS resolution.""" def test_gethostname(self): hostname = socket.gethostname() assert isinstance(hostname, str) assert len(hostname) > 0 def test_gethostbyname_localhost(self): ip = socket.gethostbyname("localhost") assert ip == "127.0.0.1" def test_getaddrinfo(self): results = socket.getaddrinfo( "localhost", None, proto=socket.IPPROTO_TCP ) assert len(results) > 0 family, socktype, proto, canonname, sockaddr = results[0] assert family in (socket.AF_INET, socket.AF_INET6) class TestByteOrder: """Test network byte order conversion.""" def test_htons(self): # Host to network short result = socket.htons(1) # On little-endian, 1 becomes 256 assert result in (1, 256) def test_htonl(self): # Host to network long result = socket.htonl(1) assert result in (1, 16777216) def test_ntohs(self): # Network to host short val = socket.htons(1234) assert socket.ntohs(val) == 1234 def test_ntohl(self): # Network to host long val = socket.htonl(12345678) assert socket.ntohl(val) == 12345678 class TestIPConversion: """Test IP address string/binary conversion.""" def test_inet_aton(self): addr = socket.inet_aton("127.0.0.1") assert addr == b"\x7f\x00\x00\x01" def test_inet_ntoa(self): ip = socket.inet_ntoa(b"\x7f\x00\x00\x01") assert ip == "127.0.0.1" def test_inet_pton_ipv4(self): addr = socket.inet_pton(socket.AF_INET, "192.168.1.1") assert addr == b"\xc0\xa8\x01\x01" def test_inet_ntop_ipv4(self): ip = socket.inet_ntop(socket.AF_INET, b"\xc0\xa8\x01\x01") assert ip == "192.168.1.1" def test_inet_pton_ipv6(self): addr = socket.inet_pton(socket.AF_INET6, "::1") assert addr == b"\x00" * 15 + b"\x01" def test_inet_ntop_ipv6(self): ip = socket.inet_ntop(socket.AF_INET6, b"\x00" * 15 + b"\x01") assert ip == "::1" class TestSocketOptions: """Test socket options.""" def test_reuseaddr(self): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) val = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) assert val != 0 # Non-zero means enabled sock.close() def test_timeout(self): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(5.0) assert sock.gettimeout() == 5.0 sock.close() def test_blocking(self): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setblocking(False) assert sock.getblocking() is False sock.setblocking(True) assert sock.getblocking() is True sock.close() class TestTCPEchoServer: """Test TCP echo server functionality.""" def test_echo(self): host, port = "localhost", 15566 # Start server in thread server_ready = threading.Event() def server(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind((host, port)) sock.listen(1) server_ready.set() conn, addr = sock.accept() data = conn.recv(1024) conn.send(data) conn.close() sock.close() t = threading.Thread(target=server) t.daemon = True t.start() server_ready.wait() # Client client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) client.connect((host, port)) client.send(b"Hello") response = client.recv(1024) client.close() assert response == b"Hello" t.join(timeout=1) class TestUDPEchoServer: """Test UDP echo server functionality.""" def test_echo(self): host, port = "localhost", 15567 server_ready = threading.Event() def server(): sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.bind((host, port)) server_ready.set() data, addr = sock.recvfrom(1024) sock.sendto(data, addr) sock.close() t = threading.Thread(target=server) t.daemon = True t.start() server_ready.wait() # Client client = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) client.sendto(b"Hello UDP", (host, port)) response, _ = client.recvfrom(1024) client.close() assert response == b"Hello UDP" t.join(timeout=1) class TestSocketPair: """Test socketpair for IPC.""" def test_socketpair(self): parent, child = socket.socketpair() # Send from parent to child parent.send(b"Hello") assert child.recv(1024) == b"Hello" # Send from child to parent child.send(b"World") assert parent.recv(1024) == b"World" parent.close() child.close() class TestPortCheck: """Test port availability checking.""" def test_port_available(self): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: sock.bind(("", 0)) # Bind to any available port port = sock.getsockname()[1] assert port > 0 finally: sock.close() def test_port_in_use(self): sock1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock1.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock1.bind(("", 15568)) sock1.listen(1) sock2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: sock2.bind(("", 15568)) assert False, "Should have raised OSError" except OSError: pass finally: sock1.close() sock2.close() class TestMACConversion: """Test MAC address conversion.""" def test_mac_to_bytes(self): import binascii mac = "00:11:22:33:44:55" byte = binascii.unhexlify(mac.replace(":", "")) assert byte == b'\x00\x11"3DU' def test_bytes_to_mac(self): import binascii byte = b'\x00\x11"3DU' mac = ":".join(f"{b:02x}" for b in byte) assert mac == "00:11:22:33:44:55" class TestSelectorsEcho: """Test selectors-based async server.""" def test_selectors_echo(self): import selectors host, port = "localhost", 15569 server_ready = threading.Event() stop_server = threading.Event() def server(): sel = selectors.DefaultSelector() sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setblocking(False) sock.bind((host, port)) sock.listen(1) def accept(sock): conn, addr = sock.accept() conn.setblocking(False) sel.register(conn, selectors.EVENT_READ, read) def read(conn): data = conn.recv(1024) if data: conn.send(data) sel.unregister(conn) conn.close() stop_server.set() sel.register(sock, selectors.EVENT_READ, accept) server_ready.set() while not stop_server.is_set(): events = sel.select(timeout=0.1) for key, mask in events: callback = key.data callback(key.fileobj) sel.unregister(sock) sock.close() sel.close() t = threading.Thread(target=server) t.daemon = True t.start() server_ready.wait() # Client client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) client.connect((host, port)) client.send(b"Async Hello") response = client.recv(1024) client.close() assert response == b"Async Hello" t.join(timeout=2) ================================================ FILE: src/basic/sqlalchemy_core.py ================================================ """SQLAlchemy examples and tests for pysheeet documentation.""" from datetime import datetime import pytest from sqlalchemy import ( create_engine, MetaData, Table, Column, Integer, String, ForeignKey, select, insert, update, delete, text, inspect, func, and_, or_, desc, case, distinct, union_all, exists, DateTime, ) from sqlalchemy.orm import ( declarative_base, sessionmaker, relationship, joinedload, aliased, ) from sqlalchemy.ext.hybrid import hybrid_property # ============================================================================ # SQLAlchemy Core Tests # ============================================================================ class TestEngine: """Test engine creation and database URLs.""" def test_create_sqlite_memory(self): engine = create_engine("sqlite:///:memory:") assert engine is not None def test_create_sqlite_file(self, tmp_path): db_path = tmp_path / "test.db" engine = create_engine(f"sqlite:///{db_path}") assert engine is not None class TestRawSQL: """Test raw SQL execution.""" def test_execute_raw_sql(self): engine = create_engine("sqlite:///:memory:") with engine.connect() as conn: conn.execute( text("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") ) conn.execute( text("INSERT INTO test (name) VALUES (:name)"), {"name": "Alice"}, ) conn.commit() result = conn.execute(text("SELECT * FROM test")) rows = result.fetchall() assert len(rows) == 1 assert rows[0][1] == "Alice" class TestTransaction: """Test transaction management.""" def test_begin_commit(self): engine = create_engine("sqlite:///:memory:") with engine.begin() as conn: conn.execute( text("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)") ) conn.execute(text("INSERT INTO users (name) VALUES ('Bob')")) with engine.connect() as conn: result = conn.execute(text("SELECT * FROM users")) assert len(result.fetchall()) == 1 class TestMetadata: """Test metadata and table definitions.""" def test_define_table(self): engine = create_engine("sqlite:///:memory:") metadata = MetaData() users = Table( "users", metadata, Column("id", Integer, primary_key=True), Column("name", String(50)), Column("email", String(100)), ) metadata.create_all(engine) assert [c.name for c in users.columns] == ["id", "name", "email"] def test_reflect_table(self): engine = create_engine("sqlite:///:memory:") with engine.begin() as conn: conn.execute( text( "CREATE TABLE products (id INTEGER PRIMARY KEY, name TEXT)" ) ) metadata = MetaData() metadata.reflect(bind=engine) assert "products" in metadata.tables class TestInspect: """Test database inspection.""" def test_get_table_names(self): engine = create_engine("sqlite:///:memory:") metadata = MetaData() Table("users", metadata, Column("id", Integer, primary_key=True)) metadata.create_all(engine) inspector = inspect(engine) assert "users" in inspector.get_table_names() class TestCoreInsert: """Test Core insert operations.""" def test_single_insert(self): engine = create_engine("sqlite:///:memory:") metadata = MetaData() users = Table( "users", metadata, Column("id", Integer, primary_key=True), Column("name", String(50)), ) metadata.create_all(engine) with engine.begin() as conn: conn.execute(insert(users).values(name="Alice")) with engine.connect() as conn: result = conn.execute(select(users)) rows = result.fetchall() assert len(rows) == 1 assert rows[0][1] == "Alice" def test_bulk_insert(self): engine = create_engine("sqlite:///:memory:") metadata = MetaData() users = Table( "users", metadata, Column("id", Integer, primary_key=True), Column("name", String(50)), ) metadata.create_all(engine) with engine.begin() as conn: conn.execute(insert(users), [{"name": "Bob"}, {"name": "Carol"}]) with engine.connect() as conn: result = conn.execute(select(users)) assert len(result.fetchall()) == 2 class TestCoreSelect: """Test Core select operations.""" def test_select_all(self): engine = create_engine("sqlite:///:memory:") metadata = MetaData() users = Table( "users", metadata, Column("id", Integer, primary_key=True), Column("name", String(50)), Column("age", Integer), ) metadata.create_all(engine) with engine.begin() as conn: conn.execute( insert(users), [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}], ) with engine.connect() as conn: result = conn.execute(select(users)) assert len(result.fetchall()) == 2 def test_select_where(self): engine = create_engine("sqlite:///:memory:") metadata = MetaData() users = Table( "users", metadata, Column("id", Integer, primary_key=True), Column("name", String(50)), Column("age", Integer), ) metadata.create_all(engine) with engine.begin() as conn: conn.execute( insert(users), [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}], ) with engine.connect() as conn: result = conn.execute(select(users).where(users.c.age > 28)) rows = result.fetchall() assert len(rows) == 1 assert rows[0][1] == "Alice" class TestCoreUpdate: """Test Core update operations.""" def test_update(self): engine = create_engine("sqlite:///:memory:") metadata = MetaData() users = Table( "users", metadata, Column("id", Integer, primary_key=True), Column("name", String(50)), ) metadata.create_all(engine) with engine.begin() as conn: conn.execute(insert(users).values(name="Alice")) conn.execute( update(users) .where(users.c.name == "Alice") .values(name="Alicia") ) with engine.connect() as conn: result = conn.execute(select(users)) assert result.fetchone()[1] == "Alicia" class TestCoreDelete: """Test Core delete operations.""" def test_delete(self): engine = create_engine("sqlite:///:memory:") metadata = MetaData() users = Table( "users", metadata, Column("id", Integer, primary_key=True), Column("name", String(50)), ) metadata.create_all(engine) with engine.begin() as conn: conn.execute(insert(users), [{"name": "Alice"}, {"name": "Bob"}]) conn.execute(delete(users).where(users.c.name == "Bob")) with engine.connect() as conn: result = conn.execute(select(users)) rows = result.fetchall() assert len(rows) == 1 assert rows[0][1] == "Alice" class TestExpressionLanguage: """Test SQL expression language.""" def test_and_condition(self): engine = create_engine("sqlite:///:memory:") metadata = MetaData() users = Table( "users", metadata, Column("id", Integer, primary_key=True), Column("name", String(50)), Column("age", Integer), ) metadata.create_all(engine) with engine.begin() as conn: conn.execute( insert(users), [ {"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}, {"name": "Carol", "age": 35}, ], ) with engine.connect() as conn: stmt = select(users).where( and_(users.c.age > 25, users.c.age < 35) ) result = conn.execute(stmt) rows = result.fetchall() assert len(rows) == 1 assert rows[0][1] == "Alice" def test_or_condition(self): engine = create_engine("sqlite:///:memory:") metadata = MetaData() users = Table( "users", metadata, Column("id", Integer, primary_key=True), Column("name", String(50)), ) metadata.create_all(engine) with engine.begin() as conn: conn.execute( insert(users), [{"name": "Alice"}, {"name": "Bob"}, {"name": "Carol"}], ) with engine.connect() as conn: stmt = select(users).where( or_(users.c.name == "Alice", users.c.name == "Bob") ) result = conn.execute(stmt) assert len(result.fetchall()) == 2 def test_in_clause(self): engine = create_engine("sqlite:///:memory:") metadata = MetaData() users = Table( "users", metadata, Column("id", Integer, primary_key=True), Column("name", String(50)), ) metadata.create_all(engine) with engine.begin() as conn: conn.execute( insert(users), [{"name": "Alice"}, {"name": "Bob"}, {"name": "Carol"}], ) with engine.connect() as conn: stmt = select(users).where(users.c.name.in_(["Alice", "Carol"])) result = conn.execute(stmt) assert len(result.fetchall()) == 2 class TestCoreJoin: """Test Core join operations.""" def test_join(self): engine = create_engine("sqlite:///:memory:") metadata = MetaData() users = Table( "users", metadata, Column("id", Integer, primary_key=True), Column("name", String(50)), ) orders = Table( "orders", metadata, Column("id", Integer, primary_key=True), Column("user_id", Integer, ForeignKey("users.id")), Column("product", String(50)), ) metadata.create_all(engine) with engine.begin() as conn: conn.execute(insert(users), [{"name": "Alice"}, {"name": "Bob"}]) conn.execute( insert(orders), [ {"user_id": 1, "product": "Book"}, {"user_id": 1, "product": "Pen"}, ], ) with engine.connect() as conn: stmt = select(users.c.name, orders.c.product).select_from( users.join(orders) ) result = conn.execute(stmt) rows = result.fetchall() assert len(rows) == 2 assert all(row[0] == "Alice" for row in rows) class TestAggregate: """Test aggregate functions.""" def test_count(self): engine = create_engine("sqlite:///:memory:") metadata = MetaData() sales = Table( "sales", metadata, Column("id", Integer, primary_key=True), Column("amount", Integer), ) metadata.create_all(engine) with engine.begin() as conn: conn.execute(insert(sales), [{"amount": 100}, {"amount": 200}]) with engine.connect() as conn: result = conn.execute(select(func.count()).select_from(sales)) assert result.scalar() == 2 def test_sum_group_by(self): engine = create_engine("sqlite:///:memory:") metadata = MetaData() sales = Table( "sales", metadata, Column("id", Integer, primary_key=True), Column("product", String(50)), Column("amount", Integer), ) metadata.create_all(engine) with engine.begin() as conn: conn.execute( insert(sales), [ {"product": "A", "amount": 100}, {"product": "A", "amount": 150}, {"product": "B", "amount": 200}, ], ) with engine.connect() as conn: stmt = select(sales.c.product, func.sum(sales.c.amount)).group_by( sales.c.product ) result = conn.execute(stmt) rows = dict(result.fetchall()) assert rows["A"] == 250 assert rows["B"] == 200 class TestDropTable: """Test dropping tables.""" def test_drop_single(self): engine = create_engine("sqlite:///:memory:") metadata = MetaData() users = Table( "users", metadata, Column("id", Integer, primary_key=True) ) metadata.create_all(engine) assert "users" in inspect(engine).get_table_names() users.drop(engine) assert "users" not in inspect(engine).get_table_names() def test_drop_all(self): engine = create_engine("sqlite:///:memory:") metadata = MetaData() Table("t1", metadata, Column("id", Integer, primary_key=True)) Table("t2", metadata, Column("id", Integer, primary_key=True)) metadata.create_all(engine) assert len(inspect(engine).get_table_names()) == 2 metadata.drop_all(engine) assert len(inspect(engine).get_table_names()) == 0 ================================================ FILE: src/basic/sqlalchemy_orm.py ================================================ """SQLAlchemy ORM examples and tests for pysheeet documentation.""" import pytest from sqlalchemy import ( create_engine, Column, Integer, String, ForeignKey, Table, select, and_, or_, func, DateTime, event, ) from sqlalchemy.orm import ( declarative_base, sessionmaker, relationship, joinedload, ) from sqlalchemy.ext.hybrid import hybrid_property from datetime import datetime # ============================================================================ # SQLAlchemy ORM Tests # ============================================================================ class TestDeclarativeBase: """Test declarative model definitions.""" def test_define_model(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) assert User.__tablename__ == "users" class TestSession: """Test session operations.""" def test_add_commit(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: user = User(name="Alice") session.add(user) session.commit() assert user.id == 1 finally: session.close() def test_add_all(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: users = [User(name="Bob"), User(name="Carol")] session.add_all(users) session.commit() assert all(u.id is not None for u in users) finally: session.close() class TestORMQuery: """Test ORM query operations.""" def test_select_all(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all([User(name="Alice"), User(name="Bob")]) session.commit() users = session.execute(select(User)).scalars().all() assert len(users) == 2 finally: session.close() def test_filter_where(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) age = Column(Integer) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all( [User(name="Alice", age=30), User(name="Bob", age=25)] ) session.commit() user = ( session.execute(select(User).where(User.age > 28)) .scalars() .first() ) assert user.name == "Alice" finally: session.close() class TestORMFilter: """Test ORM filter operations.""" def test_and_filter(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) age = Column(Integer) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all( [ User(name="Alice", age=30), User(name="Bob", age=25), User(name="Amy", age=35), ] ) session.commit() stmt = select(User).where( and_(User.age >= 30, User.name.like("A%")) ) users = session.execute(stmt).scalars().all() assert len(users) == 2 finally: session.close() def test_or_filter(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all( [User(name="Alice"), User(name="Bob"), User(name="Carol")] ) session.commit() stmt = select(User).where( or_(User.name == "Alice", User.name == "Bob") ) users = session.execute(stmt).scalars().all() assert len(users) == 2 finally: session.close() def test_in_filter(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) age = Column(Integer) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all([User(age=25), User(age=30), User(age=35)]) session.commit() stmt = select(User).where(User.age.in_([25, 35])) users = session.execute(stmt).scalars().all() assert len(users) == 2 finally: session.close() class TestORMUpdate: """Test ORM update operations.""" def test_update_object(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add(User(name="Alice")) session.commit() user = session.execute(select(User)).scalars().first() user.name = "Alicia" session.commit() user = session.execute(select(User)).scalars().first() assert user.name == "Alicia" finally: session.close() class TestORMDelete: """Test ORM delete operations.""" def test_delete_object(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all([User(name="Alice"), User(name="Bob")]) session.commit() user = ( session.execute(select(User).where(User.name == "Bob")) .scalars() .first() ) session.delete(user) session.commit() users = session.execute(select(User)).scalars().all() assert len(users) == 1 assert users[0].name == "Alice" finally: session.close() class TestOneToMany: """Test one-to-many relationships.""" def test_relationship(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) posts = relationship("Post", back_populates="author") class Post(Base): __tablename__ = "posts" id = Column(Integer, primary_key=True) title = Column(String(100)) user_id = Column(Integer, ForeignKey("users.id")) author = relationship("User", back_populates="posts") engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: user = User(name="Alice") user.posts.append(Post(title="First")) user.posts.append(Post(title="Second")) session.add(user) session.commit() user = session.execute(select(User)).scalars().first() assert len(user.posts) == 2 finally: session.close() class TestManyToMany: """Test many-to-many relationships.""" def test_relationship(self): Base = declarative_base() student_course = Table( "student_course", Base.metadata, Column( "student_id", Integer, ForeignKey("students.id"), primary_key=True, ), Column( "course_id", Integer, ForeignKey("courses.id"), primary_key=True, ), ) class Student(Base): __tablename__ = "students" id = Column(Integer, primary_key=True) name = Column(String(50)) courses = relationship( "Course", secondary=student_course, back_populates="students" ) class Course(Base): __tablename__ = "courses" id = Column(Integer, primary_key=True) name = Column(String(50)) students = relationship( "Student", secondary=student_course, back_populates="courses" ) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: math = Course(name="Math") physics = Course(name="Physics") alice = Student(name="Alice", courses=[math, physics]) bob = Student(name="Bob", courses=[math]) session.add_all([alice, bob]) session.commit() math = ( session.execute(select(Course).where(Course.name == "Math")) .scalars() .first() ) assert len(math.students) == 2 finally: session.close() class TestSelfReferential: """Test self-referential relationships.""" def test_hierarchy(self): Base = declarative_base() class Employee(Base): __tablename__ = "employees" id = Column(Integer, primary_key=True) name = Column(String(50)) manager_id = Column(Integer, ForeignKey("employees.id")) manager = relationship( "Employee", remote_side=[id], backref="subordinates" ) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: ceo = Employee(name="CEO") session.add(ceo) session.flush() manager = Employee(name="Manager", manager_id=ceo.id) session.add(manager) session.flush() worker = Employee(name="Worker", manager_id=manager.id) session.add(worker) session.commit() mgr = ( session.execute( select(Employee).where(Employee.name == "Manager") ) .scalars() .first() ) assert mgr.manager.name == "CEO" assert len(mgr.subordinates) == 1 finally: session.close() class TestCascade: """Test cascade delete.""" def test_delete_orphan(self): Base = declarative_base() class Parent(Base): __tablename__ = "parents" id = Column(Integer, primary_key=True) children = relationship("Child", cascade="all, delete-orphan") class Child(Base): __tablename__ = "children" id = Column(Integer, primary_key=True) parent_id = Column(Integer, ForeignKey("parents.id")) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: parent = Parent() parent.children = [Child(), Child()] session.add(parent) session.commit() session.delete(parent) session.commit() children = session.execute(select(Child)).scalars().all() assert len(children) == 0 finally: session.close() class TestEagerLoading: """Test eager loading.""" def test_joinedload(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) posts = relationship("Post", back_populates="author") class Post(Base): __tablename__ = "posts" id = Column(Integer, primary_key=True) title = Column(String(100)) user_id = Column(Integer, ForeignKey("users.id")) author = relationship("User", back_populates="posts") engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: user = User(name="Alice") user.posts = [Post(title="Post1"), Post(title="Post2")] session.add(user) session.commit() stmt = select(User).options(joinedload(User.posts)) user = session.execute(stmt).scalars().unique().first() assert len(user.posts) == 2 finally: session.close() class TestHybridProperty: """Test hybrid properties.""" def test_full_name(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) first_name = Column(String(50)) last_name = Column(String(50)) @hybrid_property def full_name(self): return f"{self.first_name} {self.last_name}" engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add(User(first_name="Alice", last_name="Smith")) session.commit() user = session.execute(select(User)).scalars().first() assert user.full_name == "Alice Smith" finally: session.close() class TestEventHooks: """Test event hooks.""" def test_before_insert(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) created_at = Column(DateTime) @event.listens_for(User, "before_insert") def set_created_at(mapper, connection, target): target.created_at = datetime.now() engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: user = User(name="Alice") session.add(user) session.commit() assert user.created_at is not None finally: session.close() ================================================ FILE: src/basic/sqlalchemy_query.py ================================================ """SQLAlchemy query recipe examples and tests for pysheeet documentation.""" import pytest from sqlalchemy import ( create_engine, Column, Integer, String, ForeignKey, select, insert, func, desc, case, distinct, union_all, exists, text, ) from sqlalchemy.orm import ( declarative_base, sessionmaker, relationship, aliased, ) # ============================================================================ # SQLAlchemy Query Recipe Tests # ============================================================================ class TestOrderBy: """Test order by operations.""" def test_ascending(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) age = Column(Integer) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all( [ User(name="Alice", age=30), User(name="Bob", age=25), User(name="Carol", age=35), ] ) session.commit() stmt = select(User).order_by(User.age) users = session.execute(stmt).scalars().all() assert [u.name for u in users] == ["Bob", "Alice", "Carol"] finally: session.close() def test_descending(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) age = Column(Integer) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all([User(age=30), User(age=25), User(age=35)]) session.commit() stmt = select(User).order_by(desc(User.age)) users = session.execute(stmt).scalars().all() assert [u.age for u in users] == [35, 30, 25] finally: session.close() class TestLimitOffset: """Test limit and offset operations.""" def test_limit(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all([User(name=f"User{i}") for i in range(10)]) session.commit() stmt = select(User).order_by(User.id).limit(3) users = session.execute(stmt).scalars().all() assert len(users) == 3 finally: session.close() def test_offset(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all([User(name=f"User{i}") for i in range(10)]) session.commit() stmt = select(User).order_by(User.id).limit(3).offset(3) users = session.execute(stmt).scalars().all() assert [u.name for u in users] == ["User3", "User4", "User5"] finally: session.close() class TestGroupBy: """Test group by and aggregate operations.""" def test_sum_group_by(self): Base = declarative_base() class Sale(Base): __tablename__ = "sales" id = Column(Integer, primary_key=True) product = Column(String(50)) amount = Column(Integer) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all( [ Sale(product="A", amount=100), Sale(product="A", amount=150), Sale(product="B", amount=200), ] ) session.commit() stmt = select(Sale.product, func.sum(Sale.amount)).group_by( Sale.product ) results = dict(session.execute(stmt).fetchall()) assert results["A"] == 250 assert results["B"] == 200 finally: session.close() def test_having(self): Base = declarative_base() class Sale(Base): __tablename__ = "sales" id = Column(Integer, primary_key=True) product = Column(String(50)) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all( [Sale(product="A"), Sale(product="A"), Sale(product="B")] ) session.commit() stmt = ( select(Sale.product, func.count()) .group_by(Sale.product) .having(func.count() > 1) ) results = session.execute(stmt).fetchall() assert len(results) == 1 assert results[0][0] == "A" finally: session.close() class TestJoin: """Test join operations.""" def test_inner_join(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) class Order(Base): __tablename__ = "orders" id = Column(Integer, primary_key=True) product = Column(String(50)) user_id = Column(Integer, ForeignKey("users.id")) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all([User(name="Alice"), User(name="Bob")]) session.commit() alice = ( session.execute(select(User).where(User.name == "Alice")) .scalars() .first() ) session.add_all([Order(product="Book", user_id=alice.id)]) session.commit() stmt = select(User.name, Order.product).join(Order) results = session.execute(stmt).fetchall() assert len(results) == 1 assert results[0] == ("Alice", "Book") finally: session.close() def test_outer_join(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) class Order(Base): __tablename__ = "orders" id = Column(Integer, primary_key=True) user_id = Column(Integer, ForeignKey("users.id")) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all([User(name="Alice"), User(name="Bob")]) session.commit() alice = ( session.execute(select(User).where(User.name == "Alice")) .scalars() .first() ) session.add(Order(user_id=alice.id)) session.commit() stmt = select(User.name, Order.id).outerjoin(Order) results = session.execute(stmt).fetchall() assert len(results) == 2 finally: session.close() class TestSubquery: """Test subquery operations.""" def test_scalar_subquery(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) score = Column(Integer) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all( [ User(name="Alice", score=85), User(name="Bob", score=90), User(name="Carol", score=75), ] ) session.commit() avg_score = select(func.avg(User.score)).scalar_subquery() stmt = select(User).where(User.score > avg_score) users = session.execute(stmt).scalars().all() assert len(users) == 2 finally: session.close() class TestCTE: """Test common table expressions.""" def test_cte(self): Base = declarative_base() class Sale(Base): __tablename__ = "sales" id = Column(Integer, primary_key=True) region = Column(String(50)) amount = Column(Integer) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all( [ Sale(region="East", amount=100), Sale(region="East", amount=200), Sale(region="West", amount=150), ] ) session.commit() regional_totals = ( select(Sale.region, func.sum(Sale.amount).label("total")) .group_by(Sale.region) .cte("regional_totals") ) stmt = select(regional_totals).where(regional_totals.c.total > 200) results = session.execute(stmt).fetchall() assert len(results) == 1 assert results[0][0] == "East" finally: session.close() class TestExists: """Test exists operations.""" def test_exists(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) class Order(Base): __tablename__ = "orders" id = Column(Integer, primary_key=True) user_id = Column(Integer, ForeignKey("users.id")) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all([User(name="Alice"), User(name="Bob")]) session.commit() alice = ( session.execute(select(User).where(User.name == "Alice")) .scalars() .first() ) session.add(Order(user_id=alice.id)) session.commit() has_orders = exists().where(Order.user_id == User.id) stmt = select(User).where(has_orders) users = session.execute(stmt).scalars().all() assert len(users) == 1 assert users[0].name == "Alice" finally: session.close() def test_not_exists(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) class Order(Base): __tablename__ = "orders" id = Column(Integer, primary_key=True) user_id = Column(Integer, ForeignKey("users.id")) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all([User(name="Alice"), User(name="Bob")]) session.commit() alice = ( session.execute(select(User).where(User.name == "Alice")) .scalars() .first() ) session.add(Order(user_id=alice.id)) session.commit() has_orders = exists().where(Order.user_id == User.id) stmt = select(User).where(~has_orders) users = session.execute(stmt).scalars().all() assert len(users) == 1 assert users[0].name == "Bob" finally: session.close() class TestUnion: """Test union operations.""" def test_union_all(self): Base = declarative_base() class Customer(Base): __tablename__ = "customers" id = Column(Integer, primary_key=True) name = Column(String(50)) class Supplier(Base): __tablename__ = "suppliers" id = Column(Integer, primary_key=True) name = Column(String(50)) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all([Customer(name="Alice"), Customer(name="Bob")]) session.add_all([Supplier(name="Acme"), Supplier(name="Bob")]) session.commit() stmt = union_all(select(Customer.name), select(Supplier.name)) results = [row[0] for row in session.execute(stmt)] assert len(results) == 4 assert results.count("Bob") == 2 finally: session.close() class TestCase: """Test case expressions.""" def test_case(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) score = Column(Integer) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all( [ User(name="Alice", score=95), User(name="Bob", score=75), User(name="Carol", score=55), ] ) session.commit() grade = case( (User.score >= 90, "A"), (User.score >= 70, "B"), else_="C" ) stmt = select(User.name, grade.label("grade")) results = dict(session.execute(stmt).fetchall()) assert results["Alice"] == "A" assert results["Bob"] == "B" assert results["Carol"] == "C" finally: session.close() class TestDistinct: """Test distinct operations.""" def test_distinct(self): Base = declarative_base() class Order(Base): __tablename__ = "orders" id = Column(Integer, primary_key=True) customer = Column(String(50)) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all( [ Order(customer="Alice"), Order(customer="Alice"), Order(customer="Bob"), ] ) session.commit() stmt = select(Order.customer).distinct() results = session.execute(stmt).fetchall() assert len(results) == 2 finally: session.close() def test_count_distinct(self): Base = declarative_base() class Order(Base): __tablename__ = "orders" id = Column(Integer, primary_key=True) product = Column(String(50)) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all( [ Order(product="Book"), Order(product="Book"), Order(product="Pen"), ] ) session.commit() stmt = select(func.count(distinct(Order.product))) result = session.execute(stmt).scalar() assert result == 2 finally: session.close() class TestAliased: """Test aliased tables.""" def test_aliased(self): Base = declarative_base() class Employee(Base): __tablename__ = "employees" id = Column(Integer, primary_key=True) name = Column(String(50)) salary = Column(Integer) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all( [ Employee(name="Alice", salary=50000), Employee(name="Bob", salary=60000), Employee(name="Carol", salary=55000), ] ) session.commit() alice_alias = aliased(Employee, name="alice") stmt = ( select(Employee.name) .select_from(Employee) .join(alice_alias, alice_alias.name == "Alice") .where(Employee.salary > alice_alias.salary) ) results = [row[0] for row in session.execute(stmt)] assert "Bob" in results assert "Carol" in results assert "Alice" not in results finally: session.close() class TestRawSQL: """Test raw SQL execution.""" def test_text(self): Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() try: session.add_all([User(name="Alice"), User(name="Bob")]) session.commit() result = session.execute( text("SELECT * FROM users WHERE name = :name"), {"name": "Alice"}, ) rows = result.fetchall() assert len(rows) == 1 assert rows[0][1] == "Alice" finally: session.close() ================================================ FILE: src/basic/typing_.py ================================================ """Python Typing Examples Source code for docs/notes/basic/python-typing.rst """ import pytest from typing import ( Optional, Union, Callable, TypeVar, Generic, Protocol, TypedDict, Literal, Final, ClassVar, ) # Basic Types def greet(name: str) -> str: """Function with type annotations.""" return f"Hello, {name}!" def add(a: int, b: int) -> int: """Add two integers.""" return a + b # Collection Types def sum_list(numbers: list[int]) -> int: """Sum a list of integers.""" return sum(numbers) def get_value(data: dict[str, int], key: str) -> int | None: """Get value from dict.""" return data.get(key) # Optional and Union def find_user(user_id: int) -> Optional[str]: """Return username or None.""" users = {1: "Alice", 2: "Bob"} return users.get(user_id) def process(value: int | str) -> str: """Process int or str.""" return str(value) # Callable def apply(func: Callable[[int, int], int], a: int, b: int) -> int: """Apply function to arguments.""" return func(a, b) double: Callable[[int], int] = lambda x: x * 2 # TypeVar and Generics T = TypeVar("T") def first(items: list[T]) -> T: """Return first item.""" return items[0] Number = TypeVar("Number", int, float) def double_num(x: Number) -> Number: """Double a number.""" return x * 2 # Generic Class class Stack(Generic[T]): """Generic stack class.""" def __init__(self) -> None: self._items: list[T] = [] def push(self, item: T) -> None: self._items.append(item) def pop(self) -> T: return self._items.pop() def is_empty(self) -> bool: return len(self._items) == 0 # Protocol class Drawable(Protocol): """Protocol for drawable objects.""" def draw(self) -> str: ... class Circle: """Circle that implements Drawable.""" def draw(self) -> str: return "Circle" class Square: """Square that implements Drawable.""" def draw(self) -> str: return "Square" def render(shape: Drawable) -> str: """Render any drawable.""" return shape.draw() # TypedDict class UserDict(TypedDict): """Typed dictionary for user data.""" name: str age: int # Literal def set_status(status: Literal["active", "inactive"]) -> str: """Set status with literal type.""" return f"Status: {status}" # Final MAX_SIZE: Final = 100 # ClassVar class Config: """Class with ClassVar.""" debug: ClassVar[bool] = False name: str def __init__(self, name: str) -> None: self.name = name # Tests class TestBasicTypes: def test_greet(self): assert greet("World") == "Hello, World!" def test_add(self): assert add(2, 3) == 5 class TestCollections: def test_sum_list(self): assert sum_list([1, 2, 3, 4, 5]) == 15 def test_get_value(self): assert get_value({"a": 1, "b": 2}, "a") == 1 assert get_value({"a": 1}, "x") is None class TestOptionalUnion: def test_find_user(self): assert find_user(1) == "Alice" assert find_user(999) is None def test_process(self): assert process(42) == "42" assert process("hello") == "hello" class TestCallable: def test_apply(self): assert apply(lambda a, b: a + b, 2, 3) == 5 def test_double(self): assert double(5) == 10 class TestGenerics: def test_first(self): assert first([1, 2, 3]) == 1 assert first(["a", "b"]) == "a" def test_double_num(self): assert double_num(5) == 10 assert double_num(2.5) == 5.0 class TestGenericClass: def test_stack(self): s: Stack[int] = Stack() s.push(1) s.push(2) assert s.pop() == 2 assert s.pop() == 1 assert s.is_empty() class TestProtocol: def test_render(self): assert render(Circle()) == "Circle" assert render(Square()) == "Square" class TestTypedDict: def test_user_dict(self): user: UserDict = {"name": "Alice", "age": 30} assert user["name"] == "Alice" assert user["age"] == 30 class TestLiteral: def test_set_status(self): assert set_status("active") == "Status: active" class TestClassVar: def test_config(self): c = Config("test") assert c.name == "test" assert Config.debug is False ================================================ FILE: src/basic/unicode_.py ================================================ """Python Unicode Examples Source code for docs/notes/basic/python-unicode.rst """ import pytest import unicodedata # Encoding and Decoding def encode_utf8(s: str) -> bytes: """Encode string to UTF-8 bytes.""" return s.encode("utf-8") def decode_utf8(b: bytes) -> str: """Decode UTF-8 bytes to string.""" return b.decode("utf-8") def encode_with_errors(s: str, encoding: str, errors: str) -> bytes: """Encode with error handling.""" return s.encode(encoding, errors=errors) # Code Points def get_code_point(char: str) -> int: """Get Unicode code point of character.""" return ord(char) def get_char(code_point: int) -> str: """Get character from code point.""" return chr(code_point) def format_code_points(s: str) -> list[str]: """Format string as list of code points.""" return [f"U+{ord(c):04X}" for c in s] # Normalization def normalize_nfc(s: str) -> str: """Normalize to NFC (composed) form.""" return unicodedata.normalize("NFC", s) def normalize_nfd(s: str) -> str: """Normalize to NFD (decomposed) form.""" return unicodedata.normalize("NFD", s) # Character Info def get_char_name(char: str) -> str: """Get Unicode name of character.""" return unicodedata.name(char) def get_char_category(char: str) -> str: """Get Unicode category of character.""" return unicodedata.category(char) def lookup_char(name: str) -> str: """Lookup character by Unicode name.""" return unicodedata.lookup(name) # String Operations def case_insensitive_equal(s1: str, s2: str) -> bool: """Case-insensitive comparison using casefold.""" return s1.casefold() == s2.casefold() # Tests class TestEncodingDecoding: def test_encode_utf8(self): assert encode_utf8("Café") == b"Caf\xc3\xa9" def test_decode_utf8(self): assert decode_utf8(b"Caf\xc3\xa9") == "Café" def test_roundtrip(self): s = "Hello, 世界!" assert decode_utf8(encode_utf8(s)) == s def test_encode_errors_ignore(self): assert encode_with_errors("Café", "ascii", "ignore") == b"Caf" def test_encode_errors_replace(self): assert encode_with_errors("Café", "ascii", "replace") == b"Caf?" class TestCodePoints: def test_get_code_point(self): assert get_code_point("A") == 65 assert get_code_point("é") == 233 assert get_code_point("中") == 20013 def test_get_char(self): assert get_char(65) == "A" assert get_char(233) == "é" assert get_char(20013) == "中" def test_format_code_points(self): assert format_code_points("AB") == ["U+0041", "U+0042"] class TestNormalization: def test_nfc(self): # e + combining accent -> single é composed = normalize_nfc("e\u0301") assert composed == "é" assert len(composed) == 1 def test_nfd(self): # single é -> e + combining accent decomposed = normalize_nfd("é") assert len(decomposed) == 2 def test_normalization_equality(self): s1 = "é" s2 = "e\u0301" assert s1 != s2 assert normalize_nfc(s1) == normalize_nfc(s2) class TestCharInfo: def test_get_char_name(self): assert get_char_name("A") == "LATIN CAPITAL LETTER A" assert "ACUTE" in get_char_name("é") def test_get_char_category(self): assert get_char_category("A") == "Lu" # Letter, uppercase assert get_char_category("a") == "Ll" # Letter, lowercase assert get_char_category("1") == "Nd" # Number, digit def test_lookup_char(self): assert lookup_char("LATIN CAPITAL LETTER A") == "A" assert lookup_char("GREEK SMALL LETTER ALPHA") == "α" class TestStringOperations: def test_case_insensitive(self): assert case_insensitive_equal("CAFÉ", "café") assert case_insensitive_equal("Straße", "strasse") def test_unicode_upper_lower(self): assert "café".upper() == "CAFÉ" assert "CAFÉ".lower() == "café" def test_unicode_isalpha(self): assert "中文".isalpha() assert "αβγ".isalpha() ================================================ FILE: src/cext/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.15) project(pysheeet_cext) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) # Find Python and pybind11 find_package(Python3 REQUIRED COMPONENTS Interpreter Development) find_package(pybind11 CONFIG QUIET) # If pybind11 not found via CMake, try to find it via Python if(NOT pybind11_FOUND) execute_process( COMMAND ${Python3_EXECUTABLE} -c "import pybind11; print(pybind11.get_cmake_dir())" OUTPUT_VARIABLE pybind11_DIR OUTPUT_STRIP_TRAILING_WHITESPACE RESULT_VARIABLE pybind11_RESULT ) if(pybind11_RESULT EQUAL 0) find_package(pybind11 CONFIG REQUIRED PATHS ${pybind11_DIR}) else() message(FATAL_ERROR "pybind11 not found. Install with: pip install pybind11") endif() endif() # Example 1: Basic functions (add, fib) pybind11_add_module(example example.cpp) # Example 2: Vector2D class pybind11_add_module(vector vector.cpp) # Example 3: NumPy integration pybind11_add_module(numpy_example numpy_example.cpp) # Example 4: GIL release example pybind11_add_module(gil_example gil_example.cpp) # Pure C library (for ctypes/cffi examples) add_library(fib SHARED fib.c) set_target_properties(fib PROPERTIES PREFIX "lib") # Installation install(TARGETS example vector numpy_example gil_example fib LIBRARY DESTINATION .) ================================================ FILE: src/cext/README.md ================================================ # pybind11 C++ Extension Examples This directory contains C++ source files demonstrating pybind11 bindings. ## Prerequisites ```bash pip install pybind11 numpy ``` ## Building with CMake ```bash mkdir build && cd build cmake .. make ``` The compiled modules will be in the `build/` directory. ## Building with setup.py (Alternative) ```bash pip install . ``` ## Examples ### example.cpp Basic function bindings (add, fibonacci). ```python >>> import example >>> example.add(1, 2) 3 >>> example.fib(10) 55 ``` ### vector.cpp Class binding with operators and properties. ```python >>> from vector import Vector2D >>> v = Vector2D(3, 4) >>> v.length() 5.0 >>> v2 = v + Vector2D(1, 1) >>> v2 Vector2D(4.0, 5.0) ``` ### numpy_example.cpp NumPy array operations (zero-copy). ```python >>> import numpy as np >>> from numpy_example import multiply_inplace >>> arr = np.array([1.0, 2.0, 3.0]) >>> multiply_inplace(arr, 2.0) >>> arr array([2., 4., 6.]) ``` ### gil_example.cpp GIL release for parallel execution. ```python >>> from gil_example import fib_nogil >>> import threading >>> # Runs in parallel because GIL is released >>> threads = [threading.Thread(target=fib_nogil, args=(30,)) for _ in range(4)] ``` ### fib.c Pure C library for ctypes/cffi examples. ```bash # Compile gcc -shared -fPIC -o libfib.so fib.c # Linux clang -shared -fPIC -o libfib.dylib fib.c # macOS ``` ```python >>> import ctypes >>> lib = ctypes.CDLL("./libfib.so") >>> lib.fib(10) 55 ``` ================================================ FILE: src/cext/capi/args.c ================================================ /* Demonstrate argument parsing in Python C API. */ #include /* METH_NOARGS - no arguments */ static PyObject* no_args(PyObject* self) { Py_RETURN_NONE; } /* METH_O - single object argument */ static PyObject* single_arg(PyObject* self, PyObject* arg) { return Py_BuildValue("O", arg); } /* METH_VARARGS - positional arguments */ static PyObject* pos_args(PyObject* self, PyObject* args) { PyObject *x = NULL, *y = NULL; if (!PyArg_ParseTuple(args, "OO", &x, &y)) { return NULL; } return Py_BuildValue("OO", x, y); } /* METH_VARARGS | METH_KEYWORDS - keyword arguments */ static PyObject* kw_args(PyObject* self, PyObject* args, PyObject* kwargs) { static char* keywords[] = {"x", "y", "z", NULL}; PyObject *x = NULL, *y = NULL, *z = Py_None; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OO|O", keywords, &x, &y, &z)) { return NULL; } return Py_BuildValue("OOO", x, y, z); } /* Parse specific types */ static PyObject* typed_args(PyObject* self, PyObject* args) { int i; double d; const char* s; if (!PyArg_ParseTuple(args, "ids", &i, &d, &s)) { return NULL; } return Py_BuildValue("{s:i,s:d,s:s}", "int", i, "double", d, "str", s); } static PyMethodDef methods[] = { {"no_args", (PyCFunction)no_args, METH_NOARGS, "No arguments"}, {"single_arg", (PyCFunction)single_arg, METH_O, "Single argument"}, {"pos_args", (PyCFunction)pos_args, METH_VARARGS, "Positional arguments"}, {"kw_args", (PyCFunction)kw_args, METH_VARARGS | METH_KEYWORDS, "Keyword arguments"}, {"typed_args", (PyCFunction)typed_args, METH_VARARGS, "Typed arguments"}, {NULL, NULL, 0, NULL} }; static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, "args", "Argument parsing examples", -1, methods}; PyMODINIT_FUNC PyInit_args(void) { return PyModule_Create(&module); } ================================================ FILE: src/cext/capi/errors.c ================================================ /* Demonstrate exception handling in Python C API. */ #include static PyObject* FooError; /* Raise built-in exception */ static PyObject* raise_value_error(PyObject* self) { PyErr_SetString(PyExc_ValueError, "This is a ValueError"); return NULL; } /* Raise custom exception */ static PyObject* raise_foo_error(PyObject* self) { PyErr_SetString(FooError, "This is a custom FooError"); return NULL; } /* Raise with format string */ static PyObject* raise_with_format(PyObject* self, PyObject* args) { int code; if (!PyArg_ParseTuple(args, "i", &code)) { return NULL; } PyErr_Format(PyExc_RuntimeError, "Error code: %d", code); return NULL; } /* Check and propagate exception */ static PyObject* divide(PyObject* self, PyObject* args) { double a, b; if (!PyArg_ParseTuple(args, "dd", &a, &b)) { return NULL; } if (b == 0.0) { PyErr_SetString(PyExc_ZeroDivisionError, "division by zero"); return NULL; } return PyFloat_FromDouble(a / b); } static PyMethodDef methods[] = { {"raise_value_error", (PyCFunction)raise_value_error, METH_NOARGS, NULL}, {"raise_foo_error", (PyCFunction)raise_foo_error, METH_NOARGS, NULL}, {"raise_with_format", (PyCFunction)raise_with_format, METH_VARARGS, NULL}, {"divide", (PyCFunction)divide, METH_VARARGS, "Divide a by b"}, {NULL, NULL, 0, NULL} }; static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, "errors", "Exception handling examples", -1, methods}; PyMODINIT_FUNC PyInit_errors(void) { PyObject* m = PyModule_Create(&module); if (!m) return NULL; FooError = PyErr_NewException("errors.FooError", NULL, NULL); Py_INCREF(FooError); PyModule_AddObject(m, "FooError", FooError); return m; } ================================================ FILE: src/cext/capi/gil.c ================================================ /* Demonstrate GIL release and acquire in Python C API. */ #include #ifdef _WIN32 #include #define sleep(x) Sleep((x) * 1000) #else #include #endif /* Sleep WITHOUT releasing GIL - blocks other threads */ static PyObject* sleep_with_gil(PyObject* self, PyObject* args) { int seconds; if (!PyArg_ParseTuple(args, "i", &seconds)) { return NULL; } sleep(seconds); Py_RETURN_NONE; } /* Sleep WITH releasing GIL - allows other threads to run */ static PyObject* sleep_no_gil(PyObject* self, PyObject* args) { int seconds; if (!PyArg_ParseTuple(args, "i", &seconds)) { return NULL; } Py_BEGIN_ALLOW_THREADS sleep(seconds); Py_END_ALLOW_THREADS Py_RETURN_NONE; } /* CPU work without GIL */ static unsigned long fib_impl(unsigned long n) { if (n < 2) return n; return fib_impl(n - 1) + fib_impl(n - 2); } static PyObject* fib_no_gil(PyObject* self, PyObject* args) { unsigned long n, result; if (!PyArg_ParseTuple(args, "k", &n)) { return NULL; } Py_BEGIN_ALLOW_THREADS result = fib_impl(n); Py_END_ALLOW_THREADS return PyLong_FromUnsignedLong(result); } static PyMethodDef methods[] = { {"sleep_with_gil", (PyCFunction)sleep_with_gil, METH_VARARGS, "Sleep holding GIL (blocks threads)"}, {"sleep_no_gil", (PyCFunction)sleep_no_gil, METH_VARARGS, "Sleep releasing GIL (allows threads)"}, {"fib_no_gil", (PyCFunction)fib_no_gil, METH_VARARGS, "Fibonacci with GIL released"}, {NULL, NULL, 0, NULL} }; static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, "gil", "GIL handling examples", -1, methods}; PyMODINIT_FUNC PyInit_gil(void) { return PyModule_Create(&module); } ================================================ FILE: src/cext/capi/setup.py ================================================ from setuptools import setup, Extension extensions = [ Extension("simple", ["simple.c"]), Extension("args", ["args.c"]), Extension("gil", ["gil.c"]), Extension("errors", ["errors.c"]), Extension("types_demo", ["types_demo.c"]), ] setup( name="capi_examples", version="1.0", ext_modules=extensions, ) ================================================ FILE: src/cext/capi/simple.c ================================================ /* Simple C extension module demonstrating Python C API basics. */ #include PyDoc_STRVAR(doc_mod, "Simple example C extension module.\n"); PyDoc_STRVAR(doc_hello, "hello() -> str\n\nReturn a greeting string."); PyDoc_STRVAR(doc_add, "add(a, b) -> int\n\nAdd two integers."); PyDoc_STRVAR(doc_fib, "fib(n) -> int\n\nCompute Fibonacci number."); static PyObject* hello(PyObject* self) { return PyUnicode_FromString("Hello from C!"); } static PyObject* add(PyObject* self, PyObject* args) { long a, b; if (!PyArg_ParseTuple(args, "ll", &a, &b)) { return NULL; } return PyLong_FromLong(a + b); } static unsigned long fib_impl(unsigned long n) { if (n < 2) return n; return fib_impl(n - 1) + fib_impl(n - 2); } static PyObject* fib(PyObject* self, PyObject* args) { unsigned long n; if (!PyArg_ParseTuple(args, "k", &n)) { return NULL; } return PyLong_FromUnsignedLong(fib_impl(n)); } static PyMethodDef methods[] = { {"hello", (PyCFunction)hello, METH_NOARGS, doc_hello}, {"add", (PyCFunction)add, METH_VARARGS, doc_add}, {"fib", (PyCFunction)fib, METH_VARARGS, doc_fib}, {NULL, NULL, 0, NULL} }; static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, "simple", doc_mod, -1, methods}; PyMODINIT_FUNC PyInit_simple(void) { return PyModule_Create(&module); } ================================================ FILE: src/cext/capi/test_capi.py ================================================ """Tests for Python C API extension examples.""" import pytest import sys import os # Add build directory to path build_dir = os.path.join(os.path.dirname(__file__), "build") for d in os.listdir(build_dir) if os.path.exists(build_dir) else []: path = os.path.join(build_dir, d) if os.path.isdir(path) and path not in sys.path: sys.path.insert(0, path) class TestSimple: """Test simple module.""" def test_hello(self): import simple assert simple.hello() == "Hello from C!" def test_add(self): import simple assert simple.add(1, 2) == 3 assert simple.add(-5, 10) == 5 def test_fib(self): import simple assert simple.fib(0) == 0 assert simple.fib(1) == 1 assert simple.fib(10) == 55 assert simple.fib(20) == 6765 class TestArgs: """Test argument parsing module.""" def test_no_args(self): import args assert args.no_args() is None def test_single_arg(self): import args assert args.single_arg(42) == 42 assert args.single_arg("hello") == "hello" def test_pos_args(self): import args assert args.pos_args(1, 2) == (1, 2) assert args.pos_args("a", "b") == ("a", "b") def test_kw_args(self): import args assert args.kw_args(1, 2) == (1, 2, None) assert args.kw_args(1, 2, 3) == (1, 2, 3) assert args.kw_args(x=1, y=2, z=3) == (1, 2, 3) def test_typed_args(self): import args result = args.typed_args(42, 3.14, "hello") assert result == {"int": 42, "double": 3.14, "str": "hello"} class TestGil: """Test GIL handling module.""" def test_fib_no_gil(self): import gil assert gil.fib_no_gil(10) == 55 assert gil.fib_no_gil(20) == 6765 class TestErrors: """Test exception handling module.""" def test_raise_value_error(self): import errors with pytest.raises(ValueError, match="This is a ValueError"): errors.raise_value_error() def test_raise_foo_error(self): import errors with pytest.raises(errors.FooError, match="This is a custom FooError"): errors.raise_foo_error() def test_raise_with_format(self): import errors with pytest.raises(RuntimeError, match="Error code: 42"): errors.raise_with_format(42) def test_divide(self): import errors assert errors.divide(10.0, 2.0) == 5.0 with pytest.raises(ZeroDivisionError): errors.divide(1.0, 0.0) class TestTypesDemo: """Test Python types manipulation.""" def test_list_demo(self): import types_demo assert types_demo.list_demo() == [1, 2, 3] def test_list_sum(self): import types_demo assert types_demo.list_sum([1, 2, 3, 4]) == 10 def test_iter_list(self): import types_demo assert types_demo.iter_list([1, 2, 3]) == [2, 4, 6] def test_dict_demo(self): import types_demo assert types_demo.dict_demo() == {"name": "Python", "version": 3} def test_dict_get(self): import types_demo d = {"a": 1, "b": 2} assert types_demo.dict_get(d, "a") == 1 assert types_demo.dict_get(d, "c") is None def test_iter_dict(self): import types_demo d = {"x": 1, "y": 2} result = types_demo.iter_dict(d) assert set(result) == {("x", 1), ("y", 2)} def test_tuple_demo(self): import types_demo assert types_demo.tuple_demo() == (1, "hello", 3.14) def test_tuple_unpack(self): import types_demo result = types_demo.tuple_unpack((42, "test", 2.5)) assert result == {"int": 42, "str": "test", "float": 2.5} def test_set_demo(self): import types_demo assert types_demo.set_demo() == {1, 2, 3} def test_set_contains(self): import types_demo s = {1, 2, 3} assert types_demo.set_contains(s, 2) is True assert types_demo.set_contains(s, 5) is False def test_str_demo(self): import types_demo assert types_demo.str_demo() == "Hello World" def test_str_format(self): import types_demo assert types_demo.str_format("Alice", 30) == "Alice is 30 years old" def test_bytes_demo(self): import types_demo assert types_demo.bytes_demo() == b"hello bytes" def test_bytes_len(self): import types_demo assert types_demo.bytes_len(b"hello") == 5 ================================================ FILE: src/cext/capi/types_demo.c ================================================ /* Demonstrate Python types manipulation in C API. */ #include /* List operations */ static PyObject* list_demo(PyObject* self) { PyObject* list = PyList_New(0); PyList_Append(list, PyLong_FromLong(1)); PyList_Append(list, PyLong_FromLong(2)); PyList_Append(list, PyLong_FromLong(3)); return list; } static PyObject* list_sum(PyObject* self, PyObject* args) { PyObject* list; if (!PyArg_ParseTuple(args, "O!", &PyList_Type, &list)) { return NULL; } long sum = 0; Py_ssize_t len = PyList_Size(list); for (Py_ssize_t i = 0; i < len; i++) { PyObject* item = PyList_GetItem(list, i); /* borrowed ref */ sum += PyLong_AsLong(item); } return PyLong_FromLong(sum); } /* Iterate list with iterator protocol */ static PyObject* iter_list(PyObject* self, PyObject* args) { PyObject *list, *iter, *item; if (!PyArg_ParseTuple(args, "O", &list)) { return NULL; } iter = PyObject_GetIter(list); if (!iter) return NULL; PyObject* result = PyList_New(0); while ((item = PyIter_Next(iter)) != NULL) { PyObject* doubled = PyLong_FromLong(PyLong_AsLong(item) * 2); PyList_Append(result, doubled); Py_DECREF(doubled); Py_DECREF(item); } Py_DECREF(iter); return result; } /* Dict operations */ static PyObject* dict_demo(PyObject* self) { PyObject* dict = PyDict_New(); PyDict_SetItemString(dict, "name", PyUnicode_FromString("Python")); PyDict_SetItemString(dict, "version", PyLong_FromLong(3)); return dict; } static PyObject* dict_get(PyObject* self, PyObject* args) { PyObject* dict; const char* key; if (!PyArg_ParseTuple(args, "O!s", &PyDict_Type, &dict, &key)) { return NULL; } PyObject* value = PyDict_GetItemString(dict, key); /* borrowed ref */ if (!value) { Py_RETURN_NONE; } Py_INCREF(value); return value; } /* Iterate dict */ static PyObject* iter_dict(PyObject* self, PyObject* args) { PyObject* dict; if (!PyArg_ParseTuple(args, "O!", &PyDict_Type, &dict)) { return NULL; } PyObject* result = PyList_New(0); PyObject *key, *value; Py_ssize_t pos = 0; while (PyDict_Next(dict, &pos, &key, &value)) { PyObject* pair = PyTuple_Pack(2, key, value); PyList_Append(result, pair); Py_DECREF(pair); } return result; } /* Tuple operations */ static PyObject* tuple_demo(PyObject* self) { return Py_BuildValue("(isd)", 1, "hello", 3.14); } static PyObject* tuple_unpack(PyObject* self, PyObject* args) { int a; const char* b; double c; if (!PyArg_ParseTuple(args, "(isd)", &a, &b, &c)) { return NULL; } return Py_BuildValue("{s:i,s:s,s:d}", "int", a, "str", b, "float", c); } /* Set operations */ static PyObject* set_demo(PyObject* self) { PyObject* set = PySet_New(NULL); PySet_Add(set, PyLong_FromLong(1)); PySet_Add(set, PyLong_FromLong(2)); PySet_Add(set, PyLong_FromLong(2)); /* duplicate ignored */ PySet_Add(set, PyLong_FromLong(3)); return set; } static PyObject* set_contains(PyObject* self, PyObject* args) { PyObject *set, *item; if (!PyArg_ParseTuple(args, "OO", &set, &item)) { return NULL; } int result = PySet_Contains(set, item); if (result == -1) return NULL; return PyBool_FromLong(result); } /* String operations */ static PyObject* str_demo(PyObject* self) { PyObject* s1 = PyUnicode_FromString("Hello"); PyObject* s2 = PyUnicode_FromString(" World"); PyObject* result = PyUnicode_Concat(s1, s2); Py_DECREF(s1); Py_DECREF(s2); return result; } static PyObject* str_format(PyObject* self, PyObject* args) { const char* name; int age; if (!PyArg_ParseTuple(args, "si", &name, &age)) { return NULL; } return PyUnicode_FromFormat("%s is %d years old", name, age); } /* Bytes operations */ static PyObject* bytes_demo(PyObject* self) { return PyBytes_FromString("hello bytes"); } static PyObject* bytes_len(PyObject* self, PyObject* args) { PyObject* bytes; if (!PyArg_ParseTuple(args, "S", &bytes)) { return NULL; } return PyLong_FromSsize_t(PyBytes_Size(bytes)); } static PyMethodDef methods[] = { {"list_demo", (PyCFunction)list_demo, METH_NOARGS, "Create a list [1,2,3]"}, {"list_sum", (PyCFunction)list_sum, METH_VARARGS, "Sum list elements"}, {"iter_list", (PyCFunction)iter_list, METH_VARARGS, "Double each element"}, {"dict_demo", (PyCFunction)dict_demo, METH_NOARGS, "Create a dict"}, {"dict_get", (PyCFunction)dict_get, METH_VARARGS, "Get dict value by key"}, {"iter_dict", (PyCFunction)iter_dict, METH_VARARGS, "Get dict items as list"}, {"tuple_demo", (PyCFunction)tuple_demo, METH_NOARGS, "Create a tuple"}, {"tuple_unpack", (PyCFunction)tuple_unpack, METH_VARARGS, "Unpack tuple"}, {"set_demo", (PyCFunction)set_demo, METH_NOARGS, "Create a set"}, {"set_contains", (PyCFunction)set_contains, METH_VARARGS, "Check set membership"}, {"str_demo", (PyCFunction)str_demo, METH_NOARGS, "Concat strings"}, {"str_format", (PyCFunction)str_format, METH_VARARGS, "Format string"}, {"bytes_demo", (PyCFunction)bytes_demo, METH_NOARGS, "Create bytes"}, {"bytes_len", (PyCFunction)bytes_len, METH_VARARGS, "Get bytes length"}, {NULL, NULL, 0, NULL} }; static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, "types_demo", "Python types in C API", -1, methods}; PyMODINIT_FUNC PyInit_types_demo(void) { return PyModule_Create(&module); } ================================================ FILE: src/cext/conftest.py ================================================ """ pytest configuration for C extension tests. Adds build directory to sys.path before tests run. """ import sys from pathlib import Path def pytest_configure(config): """Add build directory to path before collecting tests.""" test_dir = Path(__file__).parent build_dir = test_dir / "build" if build_dir.exists(): sys.path.insert(0, str(build_dir)) ================================================ FILE: src/cext/example.cpp ================================================ /** * example.cpp - Basic pybind11 example * * Build: * mkdir build && cd build * cmake .. && make * * Usage: * >>> import example * >>> example.add(1, 2) * 3 * >>> example.fib(10) * 55 */ #include namespace py = pybind11; int add(int a, int b) { return a + b; } unsigned long fib(unsigned long n) { if (n < 2) return n; return fib(n - 1) + fib(n - 2); } // Iterative version for large n unsigned long fib_iter(unsigned long n) { if (n < 2) return n; unsigned long a = 0, b = 1; for (unsigned long i = 1; i < n; ++i) { unsigned long tmp = a + b; a = b; b = tmp; } return b; } PYBIND11_MODULE(example, m) { m.doc() = "Example pybind11 module with basic functions"; m.def("add", &add, "Add two integers", py::arg("a"), py::arg("b")); m.def("fib", &fib, "Compute Fibonacci number (recursive)", py::arg("n")); m.def("fib_iter", &fib_iter, "Compute Fibonacci number (iterative)", py::arg("n")); } ================================================ FILE: src/cext/fib.c ================================================ /** * fib.c - Pure C library for ctypes/cffi examples * * Compile: * gcc -shared -fPIC -o libfib.so fib.c # Linux * clang -shared -fPIC -o libfib.dylib fib.c # macOS * * Usage with ctypes: * >>> import ctypes * >>> lib = ctypes.CDLL("./libfib.so") * >>> lib.fib.argtypes = [ctypes.c_ulong] * >>> lib.fib.restype = ctypes.c_ulong * >>> lib.fib(10) * 55 * * Usage with cffi: * >>> from cffi import FFI * >>> ffi = FFI() * >>> ffi.cdef("unsigned long fib(unsigned long n);") * >>> lib = ffi.dlopen("./libfib.so") * >>> lib.fib(10) * 55 */ unsigned long fib(unsigned long n) { if (n < 2) return n; return fib(n - 1) + fib(n - 2); } unsigned long fib_iter(unsigned long n) { if (n < 2) return n; unsigned long a = 0, b = 1; for (unsigned long i = 1; i < n; ++i) { unsigned long tmp = a + b; a = b; b = tmp; } return b; } int add(int a, int b) { return a + b; } double multiply(double a, double b) { return a * b; } /* Structure example */ typedef struct { double x; double y; } Point; double point_distance(Point* p1, Point* p2) { double dx = p2->x - p1->x; double dy = p2->y - p1->y; return dx * dx + dy * dy; /* Returns squared distance */ } void point_scale(Point* p, double factor) { p->x *= factor; p->y *= factor; } ================================================ FILE: src/cext/gil_example.cpp ================================================ /** * gil_example.cpp - GIL release example * * Demonstrates: * - Releasing GIL for CPU-intensive work * - Allowing Python threads to run in parallel * - Re-acquiring GIL when needed * * Usage: * >>> from gil_example import slow_operation, fib_nogil * >>> import threading * >>> # These run in parallel because GIL is released * >>> threads = [threading.Thread(target=slow_operation, args=(1,)) for _ in range(3)] */ #include #include #include #include namespace py = pybind11; // Slow operation that releases GIL void slow_operation(int seconds) { // Release GIL while sleeping py::gil_scoped_release release; std::this_thread::sleep_for(std::chrono::seconds(seconds)); } // CPU-intensive Fibonacci without GIL unsigned long fib_nogil(unsigned long n) { // Release GIL for CPU work py::gil_scoped_release release; std::function fib_impl; fib_impl = [&](unsigned long n) -> unsigned long { if (n < 2) return n; return fib_impl(n - 1) + fib_impl(n - 2); }; return fib_impl(n); } // Example showing GIL re-acquisition void call_python_callback(py::function callback, const std::string& msg) { // Release GIL for some work { py::gil_scoped_release release; std::this_thread::sleep_for(std::chrono::milliseconds(100)); } // GIL automatically re-acquired here // Now safe to call Python callback(msg); } PYBIND11_MODULE(gil_example, m) { m.doc() = "GIL release examples for parallel execution"; m.def("slow_operation", &slow_operation, "Sleep for N seconds (releases GIL)", py::arg("seconds")); m.def("fib_nogil", &fib_nogil, "Compute Fibonacci without holding GIL", py::arg("n")); m.def("call_python_callback", &call_python_callback, "Call Python function after releasing GIL", py::arg("callback"), py::arg("msg")); } ================================================ FILE: src/cext/numpy_example.cpp ================================================ /** * numpy_example.cpp - pybind11 NumPy integration * * Demonstrates: * - Accepting NumPy arrays * - Modifying arrays in-place (zero-copy) * - Returning new arrays * - 2D array operations * * Usage: * >>> import numpy as np * >>> from numpy_example import multiply_inplace, add_arrays, matrix_sum * >>> arr = np.array([1.0, 2.0, 3.0]) * >>> multiply_inplace(arr, 2.0) * >>> arr * array([2., 4., 6.]) */ #include #include #include namespace py = pybind11; // Modify array in-place (no copy) void multiply_inplace(py::array_t arr, double factor) { auto buf = arr.mutable_unchecked<1>(); for (py::ssize_t i = 0; i < buf.shape(0); i++) { buf(i) *= factor; } } // Return new array py::array_t add_arrays(py::array_t a, py::array_t b) { auto buf_a = a.unchecked<1>(); auto buf_b = b.unchecked<1>(); if (buf_a.shape(0) != buf_b.shape(0)) { throw std::runtime_error("Arrays must have same length"); } auto result = py::array_t(buf_a.shape(0)); auto buf_r = result.mutable_unchecked<1>(); for (py::ssize_t i = 0; i < buf_a.shape(0); i++) { buf_r(i) = buf_a(i) + buf_b(i); } return result; } // Sum all elements in 2D array double matrix_sum(py::array_t mat) { auto buf = mat.unchecked<2>(); double sum = 0; for (py::ssize_t i = 0; i < buf.shape(0); i++) { for (py::ssize_t j = 0; j < buf.shape(1); j++) { sum += buf(i, j); } } return sum; } // Element-wise square py::array_t square(py::array_t arr) { auto buf = arr.unchecked<1>(); auto result = py::array_t(buf.shape(0)); auto buf_r = result.mutable_unchecked<1>(); for (py::ssize_t i = 0; i < buf.shape(0); i++) { buf_r(i) = buf(i) * buf(i); } return result; } PYBIND11_MODULE(numpy_example, m) { m.doc() = "NumPy integration examples"; m.def("multiply_inplace", &multiply_inplace, "Multiply array elements by factor in-place", py::arg("arr"), py::arg("factor")); m.def("add_arrays", &add_arrays, "Add two arrays element-wise", py::arg("a"), py::arg("b")); m.def("matrix_sum", &matrix_sum, "Sum all elements in 2D array", py::arg("mat")); m.def("square", &square, "Square each element", py::arg("arr")); } ================================================ FILE: src/cext/setup.py ================================================ """ setup.py for pybind11 examples Build: pip install . # or python setup.py build_ext --inplace """ from setuptools import setup, find_packages try: from pybind11.setup_helpers import Pybind11Extension, build_ext ext_modules = [ Pybind11Extension( "example", ["example.cpp"], ), Pybind11Extension( "vector", ["vector.cpp"], ), Pybind11Extension( "numpy_example", ["numpy_example.cpp"], ), Pybind11Extension( "gil_example", ["gil_example.cpp"], ), ] setup( name="pysheeet_cext", version="1.0.0", description="pybind11 extension examples for pysheeet", ext_modules=ext_modules, cmdclass={"build_ext": build_ext}, python_requires=">=3.8", ) except ImportError: # pybind11 not installed, create minimal setup setup( name="pysheeet_cext", version="1.0.0", description="pybind11 extension examples for pysheeet", ) ================================================ FILE: src/cext/test_cext.py ================================================ """ Tests for pybind11 C++ extension modules. Run from src/cext directory: python -m pytest test_cext.py -v Build first: mkdir build && cd build && cmake .. && make """ import sys import threading from datetime import datetime import pytest # Try to import compiled modules (path set by conftest.py) try: import example HAS_EXAMPLE = True except ImportError: HAS_EXAMPLE = False try: import vector HAS_VECTOR = True except ImportError: HAS_VECTOR = False try: import numpy as np import numpy_example HAS_NUMPY = True except ImportError: HAS_NUMPY = False try: import gil_example HAS_GIL = True except ImportError: HAS_GIL = False @pytest.mark.skipif(not HAS_EXAMPLE, reason="example module not built") class TestExample: """Test basic pybind11 functions.""" def test_add(self): assert example.add(1, 2) == 3 assert example.add(-5, 10) == 5 assert example.add(0, 0) == 0 def test_fib(self): assert example.fib(0) == 0 assert example.fib(1) == 1 assert example.fib(10) == 55 assert example.fib(20) == 6765 def test_fib_iter(self): assert example.fib_iter(0) == 0 assert example.fib_iter(1) == 1 assert example.fib_iter(10) == 55 assert example.fib_iter(50) == 12586269025 @pytest.mark.skipif(not HAS_VECTOR, reason="vector module not built") class TestVector: """Test Vector2D class binding.""" def test_constructor(self): v = vector.Vector2D() assert v.x == 0 assert v.y == 0 v = vector.Vector2D(3, 4) assert v.x == 3 assert v.y == 4 def test_length(self): v = vector.Vector2D(3, 4) assert abs(v.length() - 5.0) < 1e-10 v = vector.Vector2D(0, 0) assert v.length() == 0 def test_dot(self): v1 = vector.Vector2D(1, 2) v2 = vector.Vector2D(3, 4) assert v1.dot(v2) == 11 # 1*3 + 2*4 def test_normalized(self): v = vector.Vector2D(3, 4) n = v.normalized() assert abs(n.length() - 1.0) < 1e-10 def test_add(self): v1 = vector.Vector2D(1, 2) v2 = vector.Vector2D(3, 4) v3 = v1 + v2 assert v3.x == 4 assert v3.y == 6 def test_sub(self): v1 = vector.Vector2D(5, 7) v2 = vector.Vector2D(2, 3) v3 = v1 - v2 assert v3.x == 3 assert v3.y == 4 def test_mul(self): v = vector.Vector2D(2, 3) v2 = v * 2.0 assert v2.x == 4 assert v2.y == 6 def test_eq(self): v1 = vector.Vector2D(1, 2) v2 = vector.Vector2D(1, 2) v3 = vector.Vector2D(1, 3) assert v1 == v2 assert not (v1 == v3) def test_repr(self): v = vector.Vector2D(3, 4) assert "Vector2D" in repr(v) assert "3" in repr(v) assert "4" in repr(v) @pytest.mark.skipif(not HAS_NUMPY, reason="numpy_example module not built") class TestNumPy: """Test NumPy integration.""" def test_multiply_inplace(self): arr = np.array([1.0, 2.0, 3.0]) numpy_example.multiply_inplace(arr, 2.0) np.testing.assert_array_equal(arr, [2.0, 4.0, 6.0]) def test_add_arrays(self): a = np.array([1.0, 2.0, 3.0]) b = np.array([4.0, 5.0, 6.0]) result = numpy_example.add_arrays(a, b) np.testing.assert_array_equal(result, [5.0, 7.0, 9.0]) def test_add_arrays_length_mismatch(self): a = np.array([1.0, 2.0]) b = np.array([1.0, 2.0, 3.0]) with pytest.raises(RuntimeError): numpy_example.add_arrays(a, b) def test_matrix_sum(self): mat = np.array([[1.0, 2.0], [3.0, 4.0]]) assert numpy_example.matrix_sum(mat) == 10.0 def test_square(self): arr = np.array([1.0, 2.0, 3.0]) result = numpy_example.square(arr) np.testing.assert_array_equal(result, [1.0, 4.0, 9.0]) @pytest.mark.skipif(not HAS_GIL, reason="gil_example module not built") class TestGIL: """Test GIL release functionality.""" def test_fib_nogil(self): result = gil_example.fib_nogil(20) assert result == 6765 def test_slow_operation_parallel(self): """Test that slow_operation releases GIL allowing parallel execution.""" results = [] start = datetime.now() def worker(n): results.append(n) gil_example.slow_operation(1) threads = [ threading.Thread(target=worker, args=(i,)) for i in range(3) ] for t in threads: t.start() for t in threads: t.join() elapsed = (datetime.now() - start).total_seconds() # If GIL was released, all 3 should complete in ~1 second # If GIL was held, it would take ~3 seconds assert elapsed < 2.0, f"Took {elapsed}s, GIL may not be released" assert len(results) == 3 def test_callback(self): """Test calling Python callback from C++.""" results = [] def callback(msg): results.append(msg) gil_example.call_python_callback(callback, "hello") assert results == ["hello"] ================================================ FILE: src/cext/vector.cpp ================================================ /** * vector.cpp - pybind11 class binding example * * Demonstrates: * - Class binding with constructor * - Read/write properties * - Methods * - Operator overloading * - __repr__ for nice printing * * Usage: * >>> from vector import Vector2D * >>> v1 = Vector2D(3, 4) * >>> v1.length() * 5.0 * >>> v2 = Vector2D(1, 2) * >>> v3 = v1 + v2 * >>> v3 * Vector2D(4.0, 6.0) */ #include #include #include namespace py = pybind11; class Vector2D { public: double x, y; Vector2D(double x = 0, double y = 0) : x(x), y(y) {} double length() const { return std::sqrt(x * x + y * y); } double dot(const Vector2D& other) const { return x * other.x + y * other.y; } Vector2D normalized() const { double len = length(); if (len == 0) return Vector2D(0, 0); return Vector2D(x / len, y / len); } Vector2D operator+(const Vector2D& other) const { return Vector2D(x + other.x, y + other.y); } Vector2D operator-(const Vector2D& other) const { return Vector2D(x - other.x, y - other.y); } Vector2D operator*(double scalar) const { return Vector2D(x * scalar, y * scalar); } bool operator==(const Vector2D& other) const { return x == other.x && y == other.y; } std::string repr() const { std::ostringstream oss; oss << "Vector2D(" << x << ", " << y << ")"; return oss.str(); } }; PYBIND11_MODULE(vector, m) { m.doc() = "2D Vector class example"; py::class_(m, "Vector2D") .def(py::init(), py::arg("x") = 0, py::arg("y") = 0) .def_readwrite("x", &Vector2D::x) .def_readwrite("y", &Vector2D::y) .def("length", &Vector2D::length, "Return vector length") .def("dot", &Vector2D::dot, "Dot product with another vector") .def("normalized", &Vector2D::normalized, "Return unit vector") .def("__add__", &Vector2D::operator+) .def("__sub__", &Vector2D::operator-) .def("__mul__", &Vector2D::operator*) .def("__eq__", &Vector2D::operator==) .def("__repr__", &Vector2D::repr); } ================================================ FILE: src/cpp_from_python/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.15) project(cpp_from_python) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) # Download and configure Google Test include(FetchContent) FetchContent_Declare( googletest GIT_REPOSITORY https://github.com/google/googletest.git GIT_TAG release-1.12.1 ) set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) FetchContent_MakeAvailable(googletest) enable_testing() # Single executable with tests add_executable(cpp_from_py cpp_from_py.cpp) target_link_libraries(cpp_from_py gtest gtest_main) # Add test add_test(NAME CppFromPythonTests COMMAND cpp_from_py) ================================================ FILE: src/cpp_from_python/cpp_from_py.cpp ================================================ /* * Learn C++ from Python - Modern C++ Examples with Tests * * Demonstrates modern C++ syntax with Python equivalents in Doxygen comments. * Build: mkdir build && cd build && cmake .. && make * Test: make test */ #include #include #include #include #include #include #include #include #include #include #include /** * @brief Print hello world message * * Python equivalent: * @code{.py} * print("Hello, World!") * @endcode */ void hello_world() { std::cout << "Hello, World!" << std::endl; } /** * @brief Demonstrate automatic type inference with auto keyword * * Python equivalent: * @code{.py} * x = 10 * y = 3.14 * name = "Alice" * is_valid = True * @endcode */ void variables() { auto x = 10; auto y = 3.14; auto name = "Alice"; auto is_valid = true; } /** * @brief Create and manipulate vectors (dynamic arrays) * * Python equivalent: * @code{.py} * numbers = [1, 2, 3, 4, 5] * numbers.append(6) * print(numbers[0]) * print(len(numbers)) * @endcode */ std::vector lists_and_vectors() { std::vector numbers = {1, 2, 3, 4, 5}; numbers.push_back(6); return numbers; } /** * @brief Demonstrate array slicing and access patterns * * Python equivalent: * @code{.py} * numbers = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] * print(numbers[0]) * print(numbers[-1]) * print(numbers[2:5]) * print(numbers[:3]) * print(numbers[7:]) * print(numbers[::2]) * print(numbers[::-1]) * @endcode */ std::vector array_slicing() { std::vector numbers = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; // Slicing [2:5] std::vector slice(numbers.begin() + 2, numbers.begin() + 5); // Every second element [::2] std::vector every_second; for (size_t i = 0; i < numbers.size(); i += 2) { every_second.push_back(numbers[i]); } // Reversed [::-1] std::vector reversed(numbers.rbegin(), numbers.rend()); return reversed; } /** * @brief Use maps for key-value storage * * Python equivalent: * @code{.py} * ages = {"Alice": 30, "Bob": 25} * ages["Charlie"] = 35 * print(ages["Alice"]) * @endcode */ std::map dictionaries_and_maps() { std::map ages = {{"Alice", 30}, {"Bob", 25}}; ages["Charlie"] = 35; return ages; } /** * @brief Range-based for loops * * Python equivalent: * @code{.py} * for i in range(5): * print(i) * * for item in [1, 2, 3]: * print(item) * @endcode */ void for_loop() { for (int i = 0; i < 5; i++) { // Traditional loop } for (auto item : {1, 2, 3}) { // Range-based loop } } /** * @brief Add two numbers * * Python equivalent: * @code{.py} * def add(a, b): * return a + b * * result = add(3, 5) * @endcode */ auto add(int a, int b) -> int { return a + b; } /** * @brief Lambda function that squares a number * * Python equivalent: * @code{.py} * square = lambda x: x * x * print(square(5)) * * numbers = [1, 2, 3, 4] * squared = list(map(lambda x: x * x, numbers)) * * multiplier = 10 * multiply = lambda x: x * multiplier * print(multiply(5)) * @endcode */ std::function create_square_lambda() { return [](int x) { return x * x; }; } /** * @brief Lambda with variable capture * * Python equivalent: * @code{.py} * multiplier = 10 * multiply = lambda x: x * multiplier * @endcode */ std::function create_multiply_lambda(int multiplier) { return [multiplier](int x) { return x * multiplier; }; } /** * @brief Transform vector using lambda * * Python equivalent: * @code{.py} * numbers = [1, 2, 3, 4] * squared = list(map(lambda x: x * x, numbers)) * @endcode */ std::vector transform_with_lambda(const std::vector& numbers) { std::vector squared; std::transform(numbers.begin(), numbers.end(), std::back_inserter(squared), [](int x) { return x * x; }); return squared; } /** * @brief List comprehension equivalent * * Python equivalent: * @code{.py} * squares = [x * x for x in range(10)] * evens = [x for x in range(10) if x % 2 == 0] * @endcode */ std::vector list_comprehension() { std::vector evens; for (int x = 0; x < 10; x++) { if (x % 2 == 0) { evens.push_back(x); } } return evens; } /** * @brief String concatenation and manipulation * * Python equivalent: * @code{.py} * s = "Hello" * s += " World" * print(len(s)) * print(s[0]) * @endcode */ std::string string_operations() { std::string s = "Hello"; s += " World"; return s; } /** * @brief Person class with constructor and method * * Python equivalent: * @code{.py} * class Person: * def __init__(self, name, age): * self.name = name * self.age = age * * def greet(self): * return f"Hello, I'm {self.name}" * * p = Person("Alice", 30) * print(p.greet()) * @endcode */ class Person { public: std::string name; int age; Person(std::string name, int age); std::string greet() const; }; Person::Person(std::string n, int a) : name(n), age(a) {} std::string Person::greet() const { return "Hello, I'm " + name; } /** * @brief Optional value handling * * Python equivalent: * @code{.py} * def find_value(key): * data = {"a": 1, "b": 2} * return data.get(key) * * result = find_value("a") * if result is not None: * print(result) * @endcode */ std::optional find_value(const std::string& key) { std::map data = {{"a", 1}, {"b", 2}}; auto it = data.find(key); if (it != data.end()) { return it->second; } return std::nullopt; } /** * @brief Tuple unpacking with structured bindings * * Python equivalent: * @code{.py} * point = (10, 20) * x, y = point * print(x, y) * @endcode */ std::tuple create_tuple() { return std::make_tuple(10, 20); } /** * @brief Filter even numbers from vector * * Python equivalent: * @code{.py} * numbers = [1, 2, 3, 4, 5] * evens = list(filter(lambda x: x % 2 == 0, numbers)) * @endcode */ std::vector filter_evens(const std::vector& numbers) { std::vector evens; std::copy_if(numbers.begin(), numbers.end(), std::back_inserter(evens), [](int x) { return x % 2 == 0; }); return evens; } /** * @brief Check if any element satisfies condition * * Python equivalent: * @code{.py} * numbers = [1, 2, 3, 4, 5] * has_even = any(x % 2 == 0 for x in numbers) * @endcode */ bool has_even(const std::vector& numbers) { return std::any_of(numbers.begin(), numbers.end(), [](int x) { return x % 2 == 0; }); } /** * @brief Check if all elements satisfy condition * * Python equivalent: * @code{.py} * numbers = [1, 2, 3, 4, 5] * all_positive = all(x > 0 for x in numbers) * @endcode */ bool all_positive(const std::vector& numbers) { return std::all_of(numbers.begin(), numbers.end(), [](int x) { return x > 0; }); } /** * @brief Sort vector in place * * Python equivalent: * @code{.py} * numbers = [3, 1, 4, 1, 5] * numbers.sort() * @endcode */ std::vector sort_vector(std::vector numbers) { std::sort(numbers.begin(), numbers.end()); return numbers; } /** * @brief Find minimum element * * Python equivalent: * @code{.py} * numbers = [3, 1, 4, 1, 5] * print(min(numbers)) * @endcode */ int find_min(const std::vector& numbers) { return *std::min_element(numbers.begin(), numbers.end()); } /** * @brief Sum all elements * * Python equivalent: * @code{.py} * numbers = [1, 2, 3, 4, 5] * total = sum(numbers) * @endcode */ int sum_vector(const std::vector& numbers) { return std::accumulate(numbers.begin(), numbers.end(), 0); } /** * @brief Function with default argument * * Python equivalent: * @code{.py} * def greet(name, greeting="Hello"): * return f"{greeting}, {name}" * * print(greet("Alice")) * print(greet("Bob", "Hi")) * @endcode */ std::string greet(const std::string& name, const std::string& greeting = "Hello") { return greeting + ", " + name; } TEST(BasicTest, AddFunction) { EXPECT_EQ(add(3, 5), 8); EXPECT_EQ(add(0, 0), 0); EXPECT_EQ(add(-1, 1), 0); } TEST(VectorTest, ListsAndVectors) { auto vec = lists_and_vectors(); EXPECT_EQ(vec.size(), 6); EXPECT_EQ(vec[0], 1); EXPECT_EQ(vec[5], 6); } TEST(VectorTest, ArraySlicing) { auto reversed = array_slicing(); EXPECT_EQ(reversed.size(), 10); EXPECT_EQ(reversed[0], 9); EXPECT_EQ(reversed[9], 0); } TEST(MapTest, DictionariesAndMaps) { auto ages = dictionaries_and_maps(); EXPECT_EQ(ages["Alice"], 30); EXPECT_EQ(ages["Bob"], 25); EXPECT_EQ(ages["Charlie"], 35); } TEST(LambdaTest, SquareLambda) { auto square = create_square_lambda(); EXPECT_EQ(square(5), 25); EXPECT_EQ(square(0), 0); EXPECT_EQ(square(-3), 9); } TEST(LambdaTest, MultiplyLambda) { auto multiply = create_multiply_lambda(10); EXPECT_EQ(multiply(5), 50); EXPECT_EQ(multiply(0), 0); } TEST(LambdaTest, TransformWithLambda) { std::vector numbers = {1, 2, 3, 4}; auto squared = transform_with_lambda(numbers); EXPECT_EQ(squared.size(), 4); EXPECT_EQ(squared[0], 1); EXPECT_EQ(squared[1], 4); EXPECT_EQ(squared[2], 9); EXPECT_EQ(squared[3], 16); } TEST(VectorTest, ListComprehension) { auto evens = list_comprehension(); EXPECT_EQ(evens.size(), 5); EXPECT_EQ(evens[0], 0); EXPECT_EQ(evens[4], 8); } TEST(StringTest, StringOperations) { auto str = string_operations(); EXPECT_EQ(str, "Hello World"); EXPECT_EQ(str.size(), 11); } TEST(ClassTest, PersonClass) { Person p("Alice", 30); EXPECT_EQ(p.name, "Alice"); EXPECT_EQ(p.age, 30); EXPECT_EQ(p.greet(), "Hello, I'm Alice"); } TEST(OptionalTest, FindValue) { auto result = find_value("a"); ASSERT_TRUE(result.has_value()); EXPECT_EQ(result.value(), 1); auto missing = find_value("z"); EXPECT_FALSE(missing.has_value()); } TEST(TupleTest, CreateTuple) { auto [x, y] = create_tuple(); EXPECT_EQ(x, 10); EXPECT_EQ(y, 20); } TEST(AlgorithmTest, FilterEvens) { std::vector numbers = {1, 2, 3, 4, 5}; auto evens = filter_evens(numbers); EXPECT_EQ(evens.size(), 2); EXPECT_EQ(evens[0], 2); EXPECT_EQ(evens[1], 4); } TEST(AlgorithmTest, HasEven) { std::vector with_even = {1, 2, 3}; std::vector without_even = {1, 3, 5}; EXPECT_TRUE(has_even(with_even)); EXPECT_FALSE(has_even(without_even)); } TEST(AlgorithmTest, AllPositive) { std::vector all_pos = {1, 2, 3}; std::vector has_neg = {1, -2, 3}; EXPECT_TRUE(all_positive(all_pos)); EXPECT_FALSE(all_positive(has_neg)); } TEST(AlgorithmTest, SortVector) { std::vector unsorted = {3, 1, 4, 1, 5}; auto sorted = sort_vector(unsorted); EXPECT_EQ(sorted[0], 1); EXPECT_EQ(sorted[4], 5); } TEST(AlgorithmTest, FindMin) { std::vector numbers = {3, 1, 4, 1, 5}; EXPECT_EQ(find_min(numbers), 1); } TEST(AlgorithmTest, SumVector) { std::vector numbers = {1, 2, 3, 4, 5}; EXPECT_EQ(sum_vector(numbers), 15); } TEST(FunctionTest, DefaultArguments) { EXPECT_EQ(greet("Alice"), "Hello, Alice"); EXPECT_EQ(greet("Bob", "Hi"), "Hi, Bob"); } int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } ================================================ FILE: src/gin/Dockerfile ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Modifications copyright (c) 2025 chang-ning # Modifications licensed under the Creative Commons Attribution 4.0 International License (CC BY 4.0). # See LICENSE or https://creativecommons.org/licenses/by/4.0/ # ref: https://github.com/aws-samples/awsome-distributed-training/blob/main/micro-benchmarks/nccl-tests ARG CUDA_VERSION=12.8.1 FROM nvcr.io/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu24.04 ARG GDRCOPY_VERSION=v2.5.1 ARG EFA_INSTALLER_VERSION=1.47.0 ARG AWS_OFI_NCCL_VERSION=5f4202f11db1585d878196db4430aeda0e834a0c ARG NCCL_VERSION=v2.29.3-1 ARG NCCL_TESTS_VERSION=v2.17.9 ARG NVSHMEM_VERSION=v3.5.19-1 ARG TORCH_VERSION=2.9.1 RUN apt-get update -y && apt-get upgrade -y RUN apt-get remove -y --allow-change-held-packages \ ibverbs-utils \ libibverbs-dev \ libibverbs1 \ libmlx5-1 \ libnccl2 \ libnccl-dev RUN rm -rf /opt/hpcx \ && rm -rf /usr/local/mpi \ && rm -f /etc/ld.so.conf.d/hpcx.conf \ && ldconfig ENV OPAL_PREFIX= RUN DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ apt-utils \ autoconf \ automake \ build-essential \ check \ cmake \ ninja-build \ curl \ debhelper \ devscripts \ git \ gcc \ gdb \ kmod \ libsubunit-dev \ libtool \ openssh-client \ openssh-server \ pkg-config \ vim \ hwloc \ libhwloc-dev \ python3-dev \ python3-venv \ libomp-dev RUN apt-get purge -y cuda-compat-* RUN mkdir -p /var/run/sshd RUN sed -i 's/[ #]\(.*StrictHostKeyChecking \).*/ \1no/g' /etc/ssh/ssh_config && \ echo " UserKnownHostsFile /dev/null" >> /etc/ssh/ssh_config && \ sed -i 's/#\(StrictModes \).*/\1no/g' /etc/ssh/sshd_config # Set paths for both aarch64 and x86_64 ENV LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:/opt/amazon/openmpi/lib:/opt/nccl/build/lib:/opt/amazon/efa/lib:/opt/amazon/ofi-nccl/lib:/usr/local/lib:$LD_LIBRARY_PATH ENV PATH=/opt/amazon/openmpi/bin/:/opt/amazon/efa/bin:/usr/bin:/usr/local/bin:$PATH RUN apt-get install -y python3-pip \ && pip3 install --break-system-packages --no-cache-dir awscli nvidia-ml-py Cython ################################################# ## Install NVIDIA GDRCopy ## ## NOTE: if `nccl-tests` or `/opt/gdrcopy/bin/sanity -v` crashes with incompatible version, ensure ## that the cuda-compat-xx-x package is the latest. RUN git clone -b ${GDRCOPY_VERSION} https://github.com/NVIDIA/gdrcopy.git /tmp/gdrcopy \ && cd /tmp/gdrcopy \ && make prefix=/opt/gdrcopy install \ && rm -rf /tmp/gdrcopy ENV LD_LIBRARY_PATH=/opt/gdrcopy/lib:$LD_LIBRARY_PATH ENV LIBRARY_PATH=/opt/gdrcopy/lib:$LIBRARY_PATH ENV CPATH=/opt/gdrcopy/include:$CPATH ENV PATH=/opt/gdrcopy/bin:$PATH ################################################# ## Install EFA installer RUN cd $HOME \ && curl -O https://efa-installer.amazonaws.com/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \ && tar -xf $HOME/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \ && cd aws-efa-installer \ && ./efa_installer.sh -y -g -d --skip-kmod --skip-limit-conf --no-verify --skip-plugin \ && rm -rf $HOME/aws-efa-installer $HOME/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz ################################################### ## Install aws-ofi-nccl from source (pinned commit) RUN git clone https://github.com/aws/aws-ofi-nccl.git /tmp/aws-ofi-nccl \ && cd /tmp/aws-ofi-nccl \ && git checkout ${AWS_OFI_NCCL_VERSION} \ && ./autogen.sh \ && ./configure --prefix=/opt/amazon/ofi-nccl \ --with-libfabric=/opt/amazon/efa \ --with-cuda=/usr/local/cuda \ && make -j$(nproc) \ && make install \ && rm -rf /tmp/aws-ofi-nccl ################################################### ## Install NCCL RUN git clone -b ${NCCL_VERSION} https://github.com/NVIDIA/nccl.git /opt/nccl \ && cd /opt/nccl \ && make -j $(nproc) src.build CUDA_HOME=/usr/local/cuda \ NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_89,code=sm_89 -gencode=arch=compute_90,code=sm_90 -gencode=arch=compute_100,code=sm_100" ################################################### ## Install NCCL-tests RUN git clone -b ${NCCL_TESTS_VERSION} https://github.com/NVIDIA/nccl-tests.git /opt/nccl-tests \ && cd /opt/nccl-tests \ && make -j $(nproc) \ MPI=1 \ MPI_HOME=/opt/amazon/openmpi/ \ CUDA_HOME=/usr/local/cuda \ NCCL_HOME=/opt/nccl/build \ NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_89,code=sm_89 -gencode=arch=compute_90,code=sm_90 -gencode=arch=compute_100,code=sm_100" ################################################### ## Build NCCL Device API examples RUN cd /opt/nccl/examples/06_device_api \ && make -j $(nproc) NCCL_HOME=/opt/nccl/build CUDA_HOME=/usr/local/cuda MPI=1 MPI_HOME=/opt/amazon/openmpi ################################################### ## Install NVSHMEM ENV NVSHMEM_DIR=/opt/nvshmem ENV NVSHMEM_HOME=/opt/nvshmem RUN git clone -b ${NVSHMEM_VERSION} https://github.com/NVIDIA/nvshmem.git \ && cd nvshmem \ && mkdir -p build \ && cd build \ && cmake -DNVSHMEM_PREFIX=/opt/nvshmem \ -DCMAKE_CUDA_ARCHITECTURES="80;90" \ -DNVSHMEM_MPI_SUPPORT=1 \ -DNVSHMEM_PMIX_SUPPORT=1 \ -DNVSHMEM_LIBFABRIC_SUPPORT=1 \ -DNVSHMEM_IBRC_SUPPORT=1 \ -DNVSHMEM_IBGDA_SUPPORT=1 \ -DNVSHMEM_USE_GDRCOPY=1 \ -DNVSHMEM_BUILD_TESTS=1 \ -DNVSHMEM_BUILD_EXAMPLES=1 \ -DNVSHMEM_BUILD_HYDRA_LAUNCHER=1 \ -DNVSHMEM_BUILD_TXZ_PACKAGE=0 \ -DNVSHMEM_BUILD_PYTHON_LIB=0 \ -DMPI_HOME=/opt/amazon/openmpi \ -DPMIX_HOME=/opt/amazon/pmix \ -DGDRCOPY_HOME=/opt/gdrcopy \ -DLIBFABRIC_HOME=/opt/amazon/efa \ -G Ninja .. \ && ninja -j $(nproc) \ && ninja install \ && rm -rf /root/nvshmem RUN pip3 install --break-system-packages --no-cache-dir nvshmem4py-cu12 ENV LD_LIBRARY_PATH=/opt/amazon/pmix/lib:/opt/nvshmem/lib:$LD_LIBRARY_PATH ENV PATH=/opt/nvshmem/bin:$PATH ENV NVSHMEM_REMOTE_TRANSPORT=libfabric ENV NVSHMEM_LIBFABRIC_PROVIDER=efa ################################################### ## Install PyTorch (required for DeepEP) RUN pip3 install --break-system-packages --no-cache-dir torch==${TORCH_VERSION} --index-url https://download.pytorch.org/whl/cu128 ################################################### ## Install DeepEP with NCCL GIN backend (PR #521) RUN unset NVSHMEM_DIR NVSHMEM_HOME \ && export ENABLE_NCCL=1 \ && export NCCL_DIR=/opt/nccl/build \ && export LD_LIBRARY_PATH=/opt/nccl/build/lib:$LD_LIBRARY_PATH \ && export LD_PRELOAD=/opt/nccl/build/lib/libnccl.so.2 \ && git clone -b nccl https://github.com/aamirshafi/DeepEP.git /opt/DeepEP \ && cd /opt/DeepEP \ && git checkout 6d29f34 \ && python3 setup.py build_ext --inplace \ && pip install --break-system-packages --no-build-isolation . RUN rm -rf /var/lib/apt/lists/* ## Set Open MPI variables to exclude network interface and conduit. ENV OMPI_MCA_pml=^ucx \ OMPI_MCA_btl=tcp,self \ OMPI_MCA_btl_tcp_if_exclude=lo,docker0,veth_def_agent\ OPAL_PREFIX=/opt/amazon/openmpi \ NCCL_SOCKET_IFNAME=^docker,lo,veth ENV FI_EFA_USE_DEVICE_RDMA=1 ENV FI_PROVIDER=efa ENV FI_EFA_FORK_SAFE=1 ENV NCCL_BUFFSIZE=8388608 ENV NCCL_P2P_NET_CHUNKSIZE=524288 ENV NCCL_TUNER_PLUGIN=/opt/amazon/ofi-nccl/lib/libnccl-tuner-ofi.so ## Turn off PMIx Error https://github.com/open-mpi/ompi/issues/7516 ENV PMIX_MCA_gds=hash ## Set LD_PRELOAD for NCCL library ENV LD_PRELOAD=/opt/nccl/build/lib/libnccl.so ================================================ FILE: src/gin/Makefile ================================================ .PHONY: help docker save clean .DEFAULT_GOAL := help IMAGE_NAME ?= nccl IMAGE_TAG ?= latest help: @echo "NCCL GIN Test Makefile" @echo "" @echo "Targets:" @echo " docker Build Docker image" @echo " save Save Docker image to tar.gz and sqsh" @echo " clean Remove image, tarball, and sqsh" @echo "" @echo "Usage:" @echo " make docker && make save" docker: docker build -t $(IMAGE_NAME):$(IMAGE_TAG) -f Dockerfile . save: enroot import -o $(IMAGE_NAME)+$(IMAGE_TAG).sqsh dockerd://$(IMAGE_NAME):$(IMAGE_TAG) docker save $(IMAGE_NAME):$(IMAGE_TAG) | pigz > $(IMAGE_NAME)+$(IMAGE_TAG).tar.gz clean: -docker rmi $(IMAGE_NAME):$(IMAGE_TAG) 2>/dev/null || true -rm -f $(IMAGE_NAME)+$(IMAGE_TAG).sqsh -rm -f $(IMAGE_NAME)-$(IMAGE_TAG).tar.gz ================================================ FILE: src/gin/run.enroot ================================================ #!/bin/bash # Launch command inside enroot container via srun + pyxis # Usage: salloc -N 2 ./run.enroot set -exo pipefail DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" SQSH="${SQSH:-${DIR}/nccl+latest.sqsh}" MOUNT="/fsx:/fsx" master_addr=$(scontrol show hostnames $SLURM_JOB_NODELIST 2>/dev/null | head -n1) master_addr=$(getent hosts "${master_addr}" 2>/dev/null | awk '{print $1}' || echo "${master_addr}") master_addr=${master_addr:-127.0.0.1} cmd="$(cat <&2; } IMAGE="" CONTAINER_MOUNT="/fsx" WORKSPACE="$PWD" FORCE_PULL=false ENABLE_NSYS=false VLLM_STARTED=false SERVE_ARGS=() while (("$#")); do case "$1" in --image) IMAGE="$2" shift 2 ;; --container-mount) CONTAINER_MOUNT="$2" shift 2 ;; --workspace | -w) WORKSPACE="$2" shift 2 ;; --force | -f) FORCE_PULL=true shift ;; --nsys) ENABLE_NSYS=true shift ;; --profile) SERVE_ARGS+=(--profiler-config "{\"profiler\": \"torch\", \"torch_profiler_dir\": \"${PWD}/vllm_profile\"}") shift ;; --profiler-config) SERVE_ARGS+=(--profiler-config "$2") shift 2 ;; *) SERVE_ARGS+=("$1") shift ;; esac done # Build nsys command prefix NSYS_CMD="" if [[ "${ENABLE_NSYS}" == "true" ]]; then NSYS_DIR="${WORKSPACE}/nsys-vllm" mkdir -p "${NSYS_DIR}" NSYS_PATH="${NSYS_DIR}/profile-node${SLURM_NODEID:-0}.nsys-rep" NSYS_CMD="nsys profile" NSYS_CMD+=" -t cuda,nvtx,osrt,cudnn,cublas" NSYS_CMD+=" --trace-fork-before-exec=true" NSYS_CMD+=" --cuda-graph-trace=node" NSYS_CMD+=" --capture-range=cudaProfilerApi" NSYS_CMD+=" --capture-range-end=repeat" NSYS_CMD+=" --cuda-memory-usage=true" NSYS_CMD+=" --cudabacktrace=true" NSYS_CMD+=" -o ${NSYS_PATH}" NSYS_CMD+=" --force-overwrite=true" fi IMAGE="${IMAGE:-${WORKSPACE}/vllm-serve-latest.tar.gz}" LOGDIR="${WORKSPACE}/logs" # Build a shell-safe string from SERVE_ARGS for nested bash -c / docker exec SERVE_ARGS_STR=$(printf '%q ' "${SERVE_ARGS[@]+"${SERVE_ARGS[@]}"}") # Peek at SERVE_ARGS to extract values needed for topology computation _peek_arg() { local short="$1" long="$2" default="$3" local i=0 while ((i < ${#SERVE_ARGS[@]})); do if [[ "${SERVE_ARGS[$i]}" == "$short" || "${SERVE_ARGS[$i]}" == "$long" ]]; then echo "${SERVE_ARGS[$((i + 1))]}" return fi ((i++)) done echo "$default" } _has_flag() { for arg in "${SERVE_ARGS[@]}"; do [[ "$arg" == "$1" ]] && return 0; done return 1 } TP=$(_peek_arg "-tp" "--tensor-parallel-size" "1") PP=$(_peek_arg "-pp" "--pipeline-parallel-size" "1") ENABLE_EP=$(_has_flag "--enable-expert-parallel" && echo "true" || echo "false") load_or_pull_image() { if [[ "${FORCE_PULL}" == "true" ]]; then info "Force pull: cleaning up existing images..." srun --ntasks-per-node=1 bash -c ' docker ps -aq | xargs -r docker rm -f 2>/dev/null || true docker images -aq | xargs -r docker rmi -f 2>/dev/null || true ' fi if [[ "${IMAGE}" == *.tar.gz ]]; then info "Loading Docker image from tarball..." CONTAINER_IMAGE=$(pigz -dc "${IMAGE}" | tar -xf - -O manifest.json | python3 -c "import sys,json; print(json.load(sys.stdin)[0]['RepoTags'][0])") info "Image tag: ${CONTAINER_IMAGE}" srun --ntasks-per-node=1 bash -c " if ! docker image inspect '${CONTAINER_IMAGE}' &>/dev/null; then pigz -dc '${IMAGE}' | docker load fi " else info "Pulling Docker image from registry..." local registry="${IMAGE%%/*}" local region=$(echo "${registry}" | sed -n 's/.*\.ecr\.\([^.]*\)\.amazonaws\.com/\1/p') region="${region:-us-west-2}" srun --ntasks-per-node=1 bash -c " if ! docker image inspect '${IMAGE}' &>/dev/null; then aws ecr get-login-password --region '${region}' | docker login --username AWS --password-stdin '${registry}' docker pull '${IMAGE}' fi " CONTAINER_IMAGE="${IMAGE}" fi } launch_container() { local name="${1}" cmd="${2}" local devices=("--device=/dev/gdrdrv") while IFS= read -r -d '' d; do devices+=("--device=${d}") done < <(find "/dev/infiniband" -name "uverbs*" -print0 2>/dev/null) local net_if="${GLOO_SOCKET_IFNAME:-$(ip -o -4 route show to default | awk '{print $5}' | head -1)}" docker run --gpus "${GPUS}" \ --privileged -d \ --name "${name}" \ --uts=host --ipc=host --net=host \ --ulimit stack=67108864 --ulimit memlock=-1 \ --security-opt seccomp=unconfined \ "${devices[@]}" \ -v "${CONTAINER_MOUNT}:${CONTAINER_MOUNT}" \ -e NCCL_SOCKET_IFNAME="${net_if}" \ -e GLOO_SOCKET_IFNAME="${net_if}" \ -e TP_SOCKET_IFNAME="${net_if}" \ --entrypoint bash \ "${CONTAINER_IMAGE:-${IMAGE}}" \ -c "${cmd}" } setup_topology() { NUM_NODES=${SLURM_JOB_NUM_NODES:-1} GPUS_PER_NODE=8 TOTAL_GPUS=$((NUM_NODES * GPUS_PER_NODE)) if [[ "$PP" -gt 1 && "$ENABLE_EP" == "true" ]]; then err "Pipeline parallel (PP=$PP) and expert parallel cannot be enabled simultaneously" exit 1 fi [[ "$PP" -gt 1 ]] && DP_BACKEND="mp" if [[ "$ENABLE_EP" == "true" ]]; then DP=$((TOTAL_GPUS / TP)) if [[ $((DP * TP)) -ne $TOTAL_GPUS ]]; then err "DP($DP) * TP($TP) = $((DP * TP)) != TOTAL_GPUS($TOTAL_GPUS)" exit 1 fi else DP=$((TOTAL_GPUS / (TP * PP))) if [[ $((DP * TP * PP)) -ne $TOTAL_GPUS ]]; then err "DP($DP) * TP($TP) * PP($PP) = $((DP * TP * PP)) != TOTAL_GPUS($TOTAL_GPUS)" exit 1 fi fi DP_LOCAL=$((GPUS_PER_NODE / TP)) readarray -t NODES < <(scontrol show hostnames "$SLURM_JOB_NODELIST") HEAD_NODE=${NODES[0]} HEAD_IP=$(getent ahostsv4 "$HEAD_NODE" | head -1 | awk '{print $1}') RAY_PORT=$((6379 + (SLURM_JOB_ID % 1000))) RPC_PORT=$((13345 + (SLURM_JOB_ID % 1000))) mkdir -p "${LOGDIR}" info "========================================" info "vLLM Server" info "========================================" info "Image: ${IMAGE}" info "Nodes: ${NUM_NODES}, Head: ${HEAD_NODE} (${HEAD_IP}), GPUs: ${TOTAL_GPUS}" info "Parallelism: TP=${TP}, PP=${PP}, DP=${DP}, DP_LOCAL=${DP_LOCAL}, EP=${ENABLE_EP}" info "Backend: ${DP_BACKEND}" info "SERVE_ARGS: ${SERVE_ARGS[*]+"${SERVE_ARGS[*]}"}" info "========================================" } stop_nsys() { [[ "${VLLM_STARTED}" != "true" ]] && return 0 info "Sending SIGINT to nsys processes for graceful shutdown..." srun --ntasks-per-node=1 bash -c ' for cid in $(docker ps -q); do docker exec "$cid" pkill -INT -f "^nsys profile" 2>/dev/null || true done ' 2>/dev/null || true } wait_for_nsys() { [[ "${VLLM_STARTED}" != "true" ]] && return 0 info "Waiting 60s for nsys to finalize profiles..." sleep 60 } cleanup() { info "Cleaning up containers..." if [[ "${ENABLE_NSYS}" == "true" ]]; then stop_nsys wait_for_nsys fi srun --ntasks-per-node=1 bash -c ' docker ps -aq | xargs -r docker stop -t 30 2>/dev/null || true docker ps -aq | xargs -r docker rm -f 2>/dev/null || true ' 2>/dev/null || true rm -f "${LOGDIR}/vllm_server_${SLURM_JOB_ID}.log" } start_ray_head() { info "Starting Ray head on ${HEAD_NODE}..." srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c " $(declare -f launch_container) CONTAINER_IMAGE='${CONTAINER_IMAGE:-}' IMAGE='${IMAGE}' GPUS='${GPUS}' CONTAINER_MOUNT='${CONTAINER_MOUNT}' launch_container ray-head 'sleep infinity' " sleep 5 srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c " docker exec ray-head ray start --head --port=${RAY_PORT} \ --num-gpus=${GPUS_PER_NODE} --num-cpus=96 --disable-usage-stats " } start_ray_workers() { [[ "$NUM_NODES" -le 1 ]] && return local worker_nodes=$(echo "${NODES[@]:1}" | tr ' ' ',') info "Starting Ray workers on ${worker_nodes}..." srun --nodes=$((NUM_NODES - 1)) --nodelist="${worker_nodes}" --ntasks-per-node=1 bash -c " $(declare -f launch_container) CONTAINER_IMAGE='${CONTAINER_IMAGE:-}' IMAGE='${IMAGE}' GPUS='${GPUS}' CONTAINER_MOUNT='${CONTAINER_MOUNT}' launch_container ray-worker 'sleep infinity' " sleep 5 srun --nodes=$((NUM_NODES - 1)) --nodelist="${worker_nodes}" --ntasks-per-node=1 bash -c " docker exec ray-worker ray start --address=${HEAD_IP}:${RAY_PORT} \ --num-gpus=${GPUS_PER_NODE} --num-cpus=96 --disable-usage-stats " } wait_for_gpus() { info "Waiting for ${TOTAL_GPUS} GPUs..." for _ in {1..120}; do local gpu_count gpu_count=$(srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c " docker exec ray-head python3 -c \ 'import ray; ray.init(address=\"auto\"); print(int(ray.cluster_resources().get(\"GPU\",0))); ray.shutdown()' \ 2>/dev/null" || echo 0) [[ "$gpu_count" -ge "$TOTAL_GPUS" ]] && return 0 sleep 5 done err "Timeout waiting for GPUs" return 1 } start_vllm_ray() { info "Launching vllm serve (Ray)..." local logfile="${LOGDIR}/vllm_server_${SLURM_JOB_ID}.log" local extra="--host 0.0.0.0 --port 8000 --data-parallel-backend ray --data-parallel-address ${HEAD_IP} --data-parallel-size ${DP}" srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c \ "docker exec -d ray-head bash -c '${DEEPEP_ENV} ${NSYS_CMD} vllm serve ${SERVE_ARGS_STR} ${extra} 2>&1 | tee ${logfile}'" } start_vllm_mp() { info "Starting vLLM with PP (multiprocessing)..." local logfile="${LOGDIR}/vllm_server_${SLURM_JOB_ID}.log" for i in $(seq 0 $((NUM_NODES - 1))); do srun --nodes=1 --nodelist="${NODES[$i]}" bash -c " $(declare -f launch_container) IMAGE='${CONTAINER_IMAGE:-${IMAGE}}' GPUS='${GPUS}' CONTAINER_MOUNT='${CONTAINER_MOUNT}' launch_container vllm-node-${i} 'sleep infinity' " & done wait local extra="--host 0.0.0.0 --port 8000 --nnodes ${NUM_NODES} --master-addr ${HEAD_IP} --master-port 29500" for i in $(seq 1 $((NUM_NODES - 1))); do srun --nodes=1 --nodelist="${NODES[$i]}" bash -c \ "docker exec -d vllm-node-${i} bash -c '${DEEPEP_ENV} ${NSYS_CMD} vllm serve ${SERVE_ARGS_STR} ${extra} --node-rank ${i} --headless 2>&1 | tee ${logfile}.node${i}'" done srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c \ "docker exec -d vllm-node-0 bash -c '${DEEPEP_ENV} ${NSYS_CMD} vllm serve ${SERVE_ARGS_STR} ${extra} --node-rank 0 2>&1 | tee ${logfile}'" } # RPC backend start_vllm_rpc() { info "Starting vLLM with RPC-based DP..." local logfile="${LOGDIR}/vllm_server_${SLURM_JOB_ID}.log" srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c " $(declare -f launch_container) IMAGE='${CONTAINER_IMAGE:-${IMAGE}}' GPUS='${GPUS}' CONTAINER_MOUNT='${CONTAINER_MOUNT}' launch_container vllm-head 'sleep infinity' " for i in $(seq 1 $((NUM_NODES - 1))); do srun --nodes=1 --nodelist="${NODES[$i]}" bash -c " $(declare -f launch_container) IMAGE='${CONTAINER_IMAGE:-${IMAGE}}' GPUS='${GPUS}' CONTAINER_MOUNT='${CONTAINER_MOUNT}' launch_container vllm-worker 'sleep infinity' " done sleep 3 local extra="--data-parallel-size ${DP} --data-parallel-size-local ${DP_LOCAL} --data-parallel-address ${HEAD_IP} --data-parallel-rpc-port ${RPC_PORT}" for i in $(seq 1 $((NUM_NODES - 1))); do local start_rank=$((i * DP_LOCAL)) info "Starting RPC worker on ${NODES[$i]} (rank ${start_rank})..." srun --nodes=1 --nodelist="${NODES[$i]}" bash -c " docker exec -d vllm-worker bash -c '${DEEPEP_ENV} ${NSYS_CMD} vllm serve ${SERVE_ARGS_STR} ${extra} \ --data-parallel-start-rank ${start_rank} --headless \ 2>&1 | tee ${LOGDIR}/vllm_worker_${SLURM_JOB_ID}_${i}.log' " done srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c " docker exec -d vllm-head bash -c '${DEEPEP_ENV} ${NSYS_CMD} vllm serve ${SERVE_ARGS_STR} ${extra} \ --host 0.0.0.0 --port 8000 \ 2>&1 | tee ${logfile}' " } wait_for_server() { info "Waiting for vLLM server at ${HEAD_IP}:8000..." for _ in {1..360}; do if srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c "curl -sf localhost:8000/health" &>/dev/null && srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c "curl -sf localhost:8000/v1/models | grep -q '\"id\"'" &>/dev/null; then info "Server ready at ${HEAD_IP}:8000" return 0 fi sleep 10 done err "Timeout waiting for server" return 1 } setup_topology trap cleanup EXIT cleanup LOGFILE="${LOGDIR}/vllm_server_${SLURM_JOB_ID}.log" load_or_pull_image case "${DP_BACKEND}" in ray) start_ray_head start_ray_workers wait_for_gpus start_vllm_ray ;; mp) start_vllm_mp ;; rpc) start_vllm_rpc ;; *) err "Unknown backend: ${DP_BACKEND}" exit 1 ;; esac tail -f "${LOGFILE}" 2>/dev/null & wait_for_server || exit 1 VLLM_STARTED=true info "vLLM serving on ${HEAD_IP}:8000 — Ctrl+C or scancel to stop" info "Logs: ${LOGFILE}" sleep infinity ================================================ FILE: src/llm/sglang/Dockerfile ================================================ ARG SGLANG_VERSION=0.5.8 ARG CUDA_VERSION=12.8.1 ARG GDRCOPY_VERSION=v2.5.1 ARG EFA_INSTALLER_VERSION=1.46.0 ARG NCCL_VERSION=v2.29.2-1 FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu24.04 ARG GDRCOPY_VERSION ARG EFA_INSTALLER_VERSION ARG NCCL_VERSION ARG SGLANG_VERSION # Prevent interactive prompts ENV DEBIAN_FRONTEND=noninteractive ENV TZ=UTC # Update and remove conflicting packages RUN apt-get update -y && apt-get upgrade -y RUN apt-get remove -y --allow-change-held-packages \ ibverbs-utils \ libibverbs-dev \ libibverbs1 \ libmlx5-1 \ libnccl2 \ libnccl-dev # Clean up existing MPI installations RUN rm -rf /opt/hpcx \ && rm -rf /usr/local/mpi \ && rm -f /etc/ld.so.conf.d/hpcx.conf \ && ldconfig ENV OPAL_PREFIX= # Install build dependencies RUN DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ apt-utils \ autoconf \ automake \ build-essential \ check \ cmake \ curl \ debhelper \ devscripts \ git \ gcc \ gdb \ kmod \ libnuma-dev \ libsubunit-dev \ libtool \ openssh-client \ openssh-server \ pkg-config \ python3 \ python3-dev \ python3-pip \ vim \ wget \ ninja-build \ && rm -rf /var/lib/apt/lists/* # Remove cuda-compat if present RUN apt-get purge -y cuda-compat-* || true # Configure SSH RUN mkdir -p /var/run/sshd RUN sed -i 's/[ #]\(.*StrictHostKeyChecking \).*/ \1no/g' /etc/ssh/ssh_config && \ echo " UserKnownHostsFile /dev/null" >> /etc/ssh/ssh_config && \ sed -i 's/#\(StrictModes \).*/\1no/g' /etc/ssh/sshd_config # Set library paths ENV LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:/opt/amazon/openmpi/lib:/opt/nccl/build/lib:/opt/amazon/efa/lib:/opt/gdrcopy/lib:/usr/local/lib:$LD_LIBRARY_PATH ENV PATH=/opt/amazon/openmpi/bin:/opt/amazon/efa/bin:/opt/gdrcopy/bin:/usr/bin:/usr/local/bin:$PATH # Remove PEP 668 restriction and install packages RUN rm -f /usr/lib/python*/EXTERNALLY-MANAGED \ && pip3 install --no-cache-dir awscli nvidia-ml-py Cython # Install GDRCopy RUN git clone -b ${GDRCOPY_VERSION} https://github.com/NVIDIA/gdrcopy.git /tmp/gdrcopy \ && cd /tmp/gdrcopy \ && make prefix=/opt/gdrcopy install \ && rm -rf /tmp/gdrcopy ENV LIBRARY_PATH=/opt/gdrcopy/lib:${LIBRARY_PATH:-} ENV CPATH=/opt/gdrcopy/include # Install EFA dependencies RUN apt-get update -y && apt-get install -y --no-install-recommends \ pciutils \ environment-modules \ tcl \ libnl-3-200 \ libnl-3-dev \ libnl-route-3-200 \ libnl-route-3-dev \ udev \ dmidecode \ ethtool \ iproute2 \ libevent-core-2.1-7t64 \ libevent-pthreads-2.1-7t64 \ libhwloc15 \ && rm -rf /var/lib/apt/lists/* # Install EFA RUN cd /tmp \ && curl -O https://efa-installer.amazonaws.com/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \ && tar -xf aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \ && cd aws-efa-installer \ && ./efa_installer.sh -y -g -d --skip-kmod --skip-limit-conf --no-verify \ && rm -rf /tmp/aws-efa-installer* # Install NCCL RUN git clone -b ${NCCL_VERSION} https://github.com/NVIDIA/nccl.git /tmp/nccl \ && cd /tmp/nccl \ && make -j $(nproc) src.build CUDA_HOME=/usr/local/cuda \ NVCC_GENCODE="-gencode=arch=compute_90,code=sm_90" \ && mkdir -p /opt/nccl/build/lib \ && cp -r build/lib/* /opt/nccl/build/lib/ \ && cp -r build/include /opt/nccl/build/ \ && rm -rf /tmp/nccl # OpenMPI settings ENV OMPI_MCA_pml=^ucx ENV OMPI_MCA_btl=tcp,self ENV OMPI_MCA_btl_tcp_if_exclude=lo,docker0,veth_def_agent ENV OPAL_PREFIX=/opt/amazon/openmpi ENV PMIX_MCA_gds=hash # NCCL settings ENV NCCL_DEBUG=INFO ENV NCCL_SOCKET_IFNAME=^docker,lo,veth ENV NCCL_P2P_NET_CHUNKSIZE=524288 ENV NCCL_BUFFSIZE=8388608 ENV NCCL_TUNER_PLUGIN=/opt/amazon/ofi-nccl/lib/libnccl-tuner-ofi.so ENV LD_PRELOAD=/opt/nccl/build/lib/libnccl.so # EFA settings ENV FI_PROVIDER=efa ENV FI_EFA_USE_DEVICE_RDMA=1 ENV FI_EFA_FORK_SAFE=1 ENV RDMAV_FORK_SAFE=1 # Install SGLang (with CUDA 12.8 wheels) RUN pip3 install --no-cache-dir "sglang[all]==${SGLANG_VERSION}" \ --find-links https://docs.sglang.ai/whl/cu128/ # Install Nsight Systems for profiling RUN apt-get update -y && apt-get install -y --no-install-recommends gnupg \ && apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/3bf863cc.pub \ && echo "deb https://developer.download.nvidia.com/devtools/repos/ubuntu2404/$(dpkg --print-architecture) /" \ > /etc/apt/sources.list.d/nvidia-devtools.list \ && apt-key adv --keyserver keyserver.ubuntu.com --recv-keys F60F4B3D7FA2AF80 \ && apt-get update -y \ && apt-get install -y --no-install-recommends nsight-systems-cli \ && rm -rf /var/lib/apt/lists/* WORKDIR /workspace ================================================ FILE: src/llm/sglang/Makefile ================================================ .PHONY: help docker sqush save load serve test bench clean .DEFAULT_GOAL := help help: @echo "SGLang Serving Makefile" @echo "" @echo "Build targets:" @echo " docker Build Docker image" @echo " sqush Build Enroot sqsh file" @echo " save Save Docker image to tar.gz" @echo " load Load Docker image from tar.gz" @echo "" @echo "Run targets:" @echo " serve Launch SGLang server" @echo " test Test SGLang API endpoints" @echo " bench Run benchmarks (HOST=ip)" @echo "" @echo "Variables:" @echo " MODEL=$(MODEL)" @echo " PORT=$(PORT)" @echo " TP=$(TP)" @echo "" @echo "Examples:" @echo " make docker" @echo " make serve MODEL=Qwen/Qwen2.5-14B-Instruct TP=8" @echo " make test HOST=10.0.128.193" @echo " make bench HOST=10.0.128.193" @echo "" @echo "Cleanup:" @echo " clean Remove containers and images" IMAGE_NAME ?= sglang-serve IMAGE_TAG ?= latest CONTAINER_NAME ?= sglang-server MODEL ?= Qwen/Qwen2.5-7B-Instruct PORT ?= 30000 TP ?= 1 HOST ?= localhost BENCH_TYPE ?= DEVICES := --device=/dev/gdrdrv $(shell find /dev/infiniband -name "uverbs*" 2>/dev/null | sed 's/^/--device=/') DOCKER_RUN = docker run --gpus all \ --privileged \ --uts=host \ --ipc=host \ --net=host \ --ulimit stack=67108864 \ --ulimit memlock=-1 \ --security-opt seccomp=unconfined \ $(DEVICES) \ --rm \ --name $(CONTAINER_NAME) \ -v /fsx:/fsx \ --entrypoint bash \ $(IMAGE_NAME):$(IMAGE_TAG) docker: docker build -t $(IMAGE_NAME):$(IMAGE_TAG) -f Dockerfile . sqush: docker enroot import -o $(IMAGE_NAME)-$(IMAGE_TAG).sqsh dockerd://$(IMAGE_NAME):$(IMAGE_TAG) save: docker save $(IMAGE_NAME):$(IMAGE_TAG) | pigz > $(IMAGE_NAME)-$(IMAGE_TAG).tar.gz load: pigz -dc $(IMAGE_NAME)-$(IMAGE_TAG).tar.gz | docker load serve: $(DOCKER_RUN) -c 'python3 -m sglang.launch_server --model-path $(MODEL) --host 0.0.0.0 --port $(PORT) --tp $(TP)' test: @./test.sh -H $(HOST) -p $(PORT) bench: @bash bench.sh -H $(HOST) -p $(PORT) -i $(IMAGE_NAME):$(IMAGE_TAG) $(if $(BENCH_TYPE),--type $(BENCH_TYPE)) clean: -docker rm -f $(CONTAINER_NAME) 2>/dev/null || true -docker rmi $(IMAGE_NAME):$(IMAGE_TAG) 2>/dev/null || true -rm -f $(IMAGE_NAME)-$(IMAGE_TAG).sqsh -rm -f $(IMAGE_NAME)-$(IMAGE_TAG).tar.gz ================================================ FILE: src/llm/sglang/README.rst ================================================ ============= SGLang Serving ============= .. contents:: Table of Contents :backlinks: none This cheat sheet provides quick-reference commands for launching an SGLang server in both local (single-node) and SLURM (multi-node) environments. It covers building the Docker image, running with different parallelism strategies, and testing the server. For more details, see the `SGLang documentation `_ and `GitHub repository `_. For parallelism strategies and benchmark methodology, see the `LLM Serving Guide `_ and `LLM Benchmark Guide `_. Build Docker Image ------------------ The Dockerfile bundles SGLang with EFA drivers, NCCL, and GDRCopy for high-performance multi-node inference on GPU clusters. .. code-block:: bash # Build the Docker image make docker # Save as a compressed tarball for SLURM nodes # Output: sglang-serve-latest.tar.gz make save Local Serving (Single Node) --------------------------- For development or single-node deployments, SGLang can run directly on the host or inside a Docker container. The server exposes an OpenAI-compatible API on port 30000. **Bare metal** — run SGLang directly without Docker: .. code-block:: bash # Single GPU python -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0 --port 30000 # Tensor parallel across 8 GPUs python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --tp 8 # MoE model with expert parallelism (EP subdivides TP) python -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --tp 8 --ep 2 **Using Docker (via Makefile)**: .. code-block:: bash # Single GPU with default model make serve MODEL=Qwen/Qwen2.5-7B-Instruct # Tensor parallel across 8 GPUs make serve MODEL=Qwen/Qwen2.5-14B-Instruct TP=8 **Using Docker directly**: .. code-block:: bash # Single GPU docker run --gpus all --rm --net=host -v /fsx:/fsx \ sglang-serve:latest \ python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0 --port 30000 # Tensor parallel across 8 GPUs docker run --gpus all --rm --net=host -v /fsx:/fsx \ sglang-serve:latest \ python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --tp 8 --host 0.0.0.0 --port 30000 SLURM Serving (Multi-Node) --------------------------- ``run.sbatch`` orchestrates multi-node SGLang serving on SLURM clusters. It handles Docker image distribution, container launch with EFA/GPU passthrough, and health checking. The server runs until you stop it with ``Ctrl+C`` or ``scancel``. **Script flags** — consumed by the script, not passed to SGLang: .. list-table:: :widths: 30 70 :header-rows: 1 * - Flag - Description * - ``--image PATH`` - Docker image tarball or registry path (default: ``$WORKSPACE/sglang-serve-latest.tar.gz``) * - ``--workspace, -w PATH`` - Base directory for default image and logs (default: ``$PWD``) * - ``--container-mount PATH`` - Host path to bind-mount into containers (default: ``/fsx``) * - ``--force, -f`` - Force remove existing containers and images before loading * - ``--profile`` - Enable PyTorch profiler (writes to ``$PWD/sglang_profile``) All other arguments are passed directly to ``python -m sglang.launch_server``. **Basic usage**: .. code-block:: bash # Allocate 2 nodes with 8 GPUs each salloc -N 2 --gpus-per-node=8 --exclusive # MoE with expert parallelism (TP=8, EP=2 across 2 nodes) bash run.sbatch \ --model-path Qwen/Qwen1.5-MoE-A2.7B \ --tp 8 --ep 2 **Data parallelism** — requires ``--enable-dp-attention`` for multi-node: .. code-block:: bash # TP=8, DP=2 (2 replicas across 16 GPUs) bash run.sbatch \ --model-path Qwen/Qwen2.5-14B-Instruct \ --tp 8 --dp 2 --enable-dp-attention **Pipeline parallelism**: .. code-block:: bash # TP=8, PP=2 bash run.sbatch \ --model-path deepseek-ai/DeepSeek-V2-Lite \ --tp 8 --pp 2 **Custom image**: .. code-block:: bash bash run.sbatch \ --image /fsx/images/sglang-serve-latest.tar.gz \ --model-path Qwen/Qwen2.5-72B-Instruct \ --tp 8 **Profiling**: .. code-block:: bash # PyTorch profiler bash run.sbatch --profile \ --model-path Qwen/Qwen2.5-14B-Instruct \ --tp 8 Test the Server --------------- SGLang serves on port 30000 by default: .. code-block:: bash # Health check curl http://:30000/health # List models curl http://:30000/v1/models # Chat completion (OpenAI-compatible) curl -X POST http://:30000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "Qwen/Qwen2.5-14B-Instruct", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50 }' # SGLang native generate endpoint curl -X POST http://:30000/generate \ -H "Content-Type: application/json" \ -d '{"text": "Hello", "sampling_params": {"max_new_tokens": 50}}' # Run the included test script bash test.sh # Test against a remote server bash test.sh -H 10.0.128.193 -p 30000 Benchmark --------- ``bench.sh`` measures serving performance (throughput, TTFT, ITL, latency) by sending requests to a running SGLang server. It handles Docker image loading and container management automatically. .. code-block:: bash # Run all benchmarks bash bench.sh -H 10.0.128.193 -i sglang-serve:latest # Run specific benchmarks bash bench.sh -H 10.0.128.193 -i sglang-serve:latest --type throughput,prefill # Via Makefile make bench HOST=10.0.128.193 make bench HOST=10.0.128.193 BENCH_TYPE=throughput,prefill Available benchmark types: - **throughput** — peak output tokens/sec at max request rate - **prefill** — TTFT scaling with input length (128→16K tokens) - **decode** — ITL as output length grows (128→1024 tokens) - **latency** — end-to-end latency under minimal load - **concurrency** — throughput vs latency at different concurrency levels - **sharegpt** — realistic conversational workload Parallelism ----------- SGLang's parallelism formula: .. code-block:: text Total GPUs = TP × DP × PP **EP is a subdivision of TP**, not a separate multiplier. When using ``--ep N``, the TP GPUs are divided into N expert-parallel groups. .. list-table:: :widths: 20 10 10 10 10 40 :header-rows: 1 * - Config - TP - EP - DP - PP - Use case * - Dense, max throughput - 2 - 1 - 8 - 1 - 8 replicas of TP=2 * - Dense, large model - 8 - 1 - 2 - 1 - 2 replicas of TP=8 * - Dense, very large - 8 - 1 - 1 - 2 - Single replica, 2-stage pipeline * - MoE model - 8 - 2 - 1 - 1 - Experts split into 2 groups * - MoE, more EP - 8 - 4 - 1 - 1 - Experts split into 4 groups **Constraints:** - Multi-node DP requires ``--enable-dp-attention`` - EP only works with MoE models and requires ``--enable-ep`` - ``TP`` must be divisible by ``nnodes`` for multi-node ================================================ FILE: src/llm/sglang/bench.sh ================================================ #!/usr/bin/env bash # SGLang serving benchmark suite # Usage: # salloc -N1 bash bench.sh -H 10.0.128.193 -i /fsx/sglang-serve-latest.tar.gz # bash bench.sh -H 10.0.128.193 -i sglang-serve:latest set -euo pipefail info() { echo "[$(date +'%H:%M:%S')] $*"; } CONTAINER_MOUNT="${CONTAINER_MOUNT:-/fsx}" _run() { if [[ -n "${SLURM_JOB_ID:-}" ]]; then srun -N1 --ntasks-per-node=1 bash -c "$*" else bash -c "$*" fi } load_or_pull_image() { if [[ "${IMAGE}" == *.tar.gz ]]; then CONTAINER_IMAGE=$(pigz -dc "${IMAGE}" | tar -xf - -O manifest.json \ | python3 -c "import sys,json; print(json.load(sys.stdin)[0]['RepoTags'][0])") info "Image tag: ${CONTAINER_IMAGE}" _run " if ! docker image inspect '${CONTAINER_IMAGE}' &>/dev/null; then echo 'Loading Docker image from tarball...' pigz -dc '${IMAGE}' | docker load fi " else CONTAINER_IMAGE="${IMAGE}" _run " if ! docker image inspect '${CONTAINER_IMAGE}' &>/dev/null; then echo 'Pulling ${CONTAINER_IMAGE}...' registry=\"\${CONTAINER_IMAGE%%/*}\" region=\$(echo \"\${registry}\" | sed -n 's/.*\.ecr\.\([^.]*\)\.amazonaws\.com/\1/p') region=\"\${region:-us-west-2}\" aws ecr get-login-password --region \"\${region}\" \ | docker login --username AWS --password-stdin \"\${registry}\" docker pull '${CONTAINER_IMAGE}' fi " fi } launch_container() { local cmd="$1" _run " docker run --rm --net=host \ -v '${PWD}:${PWD}' -w '${PWD}' \ -v '${CONTAINER_MOUNT}:${CONTAINER_MOUNT}' \ --entrypoint bash '${CONTAINER_IMAGE}' \ -c '${cmd}' " } # If sglang is not available, load image and re-exec inside container if ! python3 -c "import sglang" &>/dev/null; then IMAGE="" _args=("$@") for ((i=0; i<${#_args[@]}; i++)); do [[ "${_args[$i]}" == "--image" || "${_args[$i]}" == "-i" ]] \ && { IMAGE="${_args[$((i+1))]}"; break; } done IMAGE="${IMAGE:-${PWD}/sglang-serve-latest.tar.gz}" load_or_pull_image _SCRIPT="$(cd "$(dirname "$0")" && pwd)/$(basename "$0")" launch_container "bash ${_SCRIPT} $*" exit $? fi HOST="localhost" PORT="30000" MODEL="" SEED="42" RESULT_DIR="./results" TYPES="throughput,prefill,decode,latency,concurrency,sharegpt" usage() { cat < $label" python3 -m sglang.bench_serving \ --backend sglang \ --host "$HOST" \ --port "$PORT" \ --model "$MODEL" \ --seed "$SEED" \ --output-file "$outfile" \ "$@" echo "" } bench_throughput() { bench "Throughput (random 512in/256out, max rate)" \ --dataset-name random \ --random-input 512 --random-output 256 \ --num-prompts 100 --request-rate inf } bench_prefill() { for len in 128 512 2048 4096 16384; do bench "Prefill TTFT (input=${len})" \ --dataset-name random \ --random-input "$len" --random-output 1 \ --num-prompts 100 --request-rate 4 done } bench_decode() { for len in 128 256 512 1024; do bench "Decode ITL (output=${len})" \ --dataset-name random \ --random-input 128 --random-output "$len" \ --num-prompts 100 --request-rate 4 done } bench_latency() { bench "Latency (short 128/128, rate=1)" \ --dataset-name random \ --random-input 128 --random-output 128 \ --num-prompts 100 --request-rate 1 bench "Latency (medium 512/256, rate=1)" \ --dataset-name random \ --random-input 512 --random-output 256 \ --num-prompts 100 --request-rate 1 bench "Latency (long 4096/512, rate=1)" \ --dataset-name random \ --random-input 4096 --random-output 512 \ --num-prompts 100 --request-rate 1 } bench_concurrency() { for c in 1 4 16 64 256; do bench "Concurrency=${c} (512in/256out)" \ --dataset-name random \ --random-input 512 --random-output 256 \ --num-prompts 100 --request-rate inf --max-concurrency "$c" done } SHAREGPT_PATH="${SHAREGPT_PATH:-ShareGPT_V3_unfiltered_cleaned_split.json}" bench_sharegpt() { if [[ ! -f "$SHAREGPT_PATH" ]]; then echo "Downloading ShareGPT dataset..." wget -q -O "$SHAREGPT_PATH" \ https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json fi bench "ShareGPT (100 prompts, max rate)" \ --dataset-name sharegpt \ --dataset-path "$SHAREGPT_PATH" \ --num-prompts 100 --request-rate inf bench "ShareGPT (100 prompts, rate=4)" \ --dataset-name sharegpt \ --dataset-path "$SHAREGPT_PATH" \ --num-prompts 100 --request-rate 4 } IFS=',' read -ra TESTS <<< "$TYPES" for t in "${TESTS[@]}"; do t=$(echo "$t" | xargs) echo "========================================" echo "Running: $t" echo "========================================" case "$t" in throughput) bench_throughput ;; prefill) bench_prefill ;; decode) bench_decode ;; latency) bench_latency ;; concurrency) bench_concurrency ;; sharegpt) bench_sharegpt ;; *) echo "Unknown test: $t"; exit 1 ;; esac done ================================================ FILE: src/llm/sglang/run.sbatch ================================================ #!/bin/bash set -euo pipefail GPUS="${GPUS:-all}" info() { echo -e "[$(date +'%Y-%m-%dT%H:%M:%S%z')][info] $*"; } err() { echo -e "[$(date +'%Y-%m-%dT%H:%M:%S%z')][error] $*" >&2; } IMAGE="" CONTAINER_MOUNT="/fsx" WORKSPACE="$PWD" FORCE_PULL=false PROFILE_DIR="" SERVE_ARGS=() while (( "$#" )); do case "$1" in --image) IMAGE="$2"; shift 2 ;; --container-mount) CONTAINER_MOUNT="$2"; shift 2 ;; --workspace|-w) WORKSPACE="$2"; shift 2 ;; --force|-f) FORCE_PULL=true; shift ;; --profile) PROFILE_DIR="${PWD}/sglang_profile"; shift ;; *) SERVE_ARGS+=("$1"); shift ;; esac done IMAGE="${IMAGE:-${WORKSPACE}/sglang-serve-latest.tar.gz}" LOGDIR="${WORKSPACE}/logs" [[ -n "${PROFILE_DIR}" ]] && mkdir -p "${PROFILE_DIR}" SERVE_ARGS_STR=$(printf '%q ' "${SERVE_ARGS[@]+"${SERVE_ARGS[@]}"}") _peek_arg() { local short="$1" long="$2" default="$3" local i=0 while (( i < ${#SERVE_ARGS[@]} )); do if [[ "${SERVE_ARGS[$i]}" == "$short" || "${SERVE_ARGS[$i]}" == "$long" ]]; then echo "${SERVE_ARGS[$((i+1))]}"; return fi ((i++)) done echo "$default" } load_or_pull_image() { if [[ "${FORCE_PULL}" == "true" ]]; then info "Force pull: cleaning up existing images..." srun --ntasks-per-node=1 bash -c ' docker ps -aq | xargs -r docker rm -f 2>/dev/null || true docker images -aq | xargs -r docker rmi -f 2>/dev/null || true ' fi if [[ "${IMAGE}" == *.tar.gz ]]; then info "Loading Docker image from tarball..." CONTAINER_IMAGE=$(pigz -dc "${IMAGE}" | tar -xf - -O manifest.json | python3 -c "import sys,json; print(json.load(sys.stdin)[0]['RepoTags'][0])") info "Image tag: ${CONTAINER_IMAGE}" srun --ntasks-per-node=1 bash -c " if ! docker image inspect '${CONTAINER_IMAGE}' &>/dev/null; then pigz -dc '${IMAGE}' | docker load fi " else info "Pulling Docker image from registry..." local registry="${IMAGE%%/*}" local region=$(echo "${registry}" | sed -n 's/.*\.ecr\.\([^.]*\)\.amazonaws\.com/\1/p') region="${region:-us-west-2}" srun --ntasks-per-node=1 bash -c " if ! docker image inspect '${IMAGE}' &>/dev/null; then aws ecr get-login-password --region '${region}' | docker login --username AWS --password-stdin '${registry}' docker pull '${IMAGE}' fi " CONTAINER_IMAGE="${IMAGE}" fi } launch_container() { local name="${1}" cmd="${2}" local devices=("--device=/dev/gdrdrv") while IFS= read -r -d '' d; do devices+=("--device=${d}") done < <(find "/dev/infiniband" -name "uverbs*" -print0 2>/dev/null) local net_if="${GLOO_SOCKET_IFNAME:-$(ip -o -4 route show to default | awk '{print $5}' | head -1)}" docker run --gpus "${GPUS}" \ --privileged -d \ --name "${name}" \ --uts=host --ipc=host --net=host \ --ulimit stack=67108864 --ulimit memlock=-1 \ --security-opt seccomp=unconfined \ "${devices[@]}" \ -v "${CONTAINER_MOUNT}:${CONTAINER_MOUNT}" \ -e NCCL_SOCKET_IFNAME="${net_if}" \ -e GLOO_SOCKET_IFNAME="${net_if}" \ -e TP_SOCKET_IFNAME="${net_if}" \ ${PROFILE_DIR:+-e SGLANG_TORCH_PROFILER_DIR="${PROFILE_DIR}"} \ --entrypoint bash \ "${CONTAINER_IMAGE:-${IMAGE}}" \ -c "${cmd}" } setup_topology() { NUM_NODES=${SLURM_JOB_NUM_NODES:-1} GPUS_PER_NODE=8 TOTAL_GPUS=$((NUM_NODES * GPUS_PER_NODE)) readarray -t NODES < <(scontrol show hostnames "$SLURM_JOB_NODELIST") HEAD_NODE=${NODES[0]} HEAD_IP=$(getent ahostsv4 "$HEAD_NODE" | head -1 | awk '{print $1}') DIST_PORT=$((25000 + (SLURM_JOB_ID % 1000))) mkdir -p "${LOGDIR}" info "========================================" info "SGLang Server" info "========================================" info "Image: ${IMAGE}" info "Nodes: ${NUM_NODES}, Head: ${HEAD_NODE} (${HEAD_IP}), GPUs: ${TOTAL_GPUS}" info "SERVE_ARGS: ${SERVE_ARGS[*]+"${SERVE_ARGS[*]}"}" info "========================================" } cleanup() { info "Cleaning up containers..." srun --ntasks-per-node=1 bash -c ' docker ps -aq | xargs -r docker rm -f 2>/dev/null || true ' 2>/dev/null || true rm -f "${LOGDIR}/sglang_server_${SLURM_JOB_ID}.log" } start_sglang() { local logfile="${LOGDIR}/sglang_server_${SLURM_JOB_ID}.log" # Launch containers on all nodes for i in $(seq 0 $((NUM_NODES - 1))); do srun --nodes=1 --nodelist="${NODES[$i]}" bash -c " $(declare -f launch_container) CONTAINER_IMAGE='${CONTAINER_IMAGE:-${IMAGE}}' GPUS='${GPUS}' CONTAINER_MOUNT='${CONTAINER_MOUNT}' PROFILE_DIR='${PROFILE_DIR}' launch_container sglang-node-${i} 'sleep infinity' " & done wait sleep 3 # Only inject host/port and multi-node flags; SGLang owns parallelism (TP/PP/DP/EP) local extra="--host 0.0.0.0 --port 30000" if [[ "$NUM_NODES" -gt 1 ]]; then extra+=" --nnodes ${NUM_NODES} --dist-init-addr ${HEAD_IP}:${DIST_PORT}" # Start worker nodes first (rank 1..N-1) for i in $(seq 1 $((NUM_NODES - 1))); do srun --nodes=1 --nodelist="${NODES[$i]}" bash -c " docker exec -d sglang-node-${i} bash -c 'python3 -m sglang.launch_server ${SERVE_ARGS_STR} ${extra} --node-rank ${i} 2>&1 | tee ${logfile}.node${i}' " done fi # Start head node (rank 0) — this one serves the API srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c " docker exec -d sglang-node-0 bash -c 'python3 -m sglang.launch_server ${SERVE_ARGS_STR} ${extra} --node-rank 0 2>&1 | tee ${logfile}' " } wait_for_server() { info "Waiting for SGLang server at ${HEAD_IP}:30000..." for _ in {1..360}; do if srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c "curl -sf localhost:30000/health" &>/dev/null && srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c "curl -sf localhost:30000/v1/models | grep -q '\"id\"'" &>/dev/null; then info "Server ready at ${HEAD_IP}:30000" return 0 fi sleep 10 done err "Timeout waiting for server"; return 1 } setup_topology trap cleanup EXIT cleanup LOGFILE="${LOGDIR}/sglang_server_${SLURM_JOB_ID}.log" load_or_pull_image start_sglang tail -f "${LOGFILE}" 2>/dev/null & wait_for_server || exit 1 info "SGLang serving on ${HEAD_IP}:30000 — Ctrl+C or scancel to stop" info "Logs: ${LOGFILE}" sleep infinity ================================================ FILE: src/llm/sglang/test.sh ================================================ #!/usr/bin/env bash # SGLang API test script set -uo pipefail HOST="localhost" PORT="30000" MODEL="" while (( "$#" )); do case "$1" in -h|--help) echo "Usage: $0 [-H host] [-p port] [-m model]"; exit 0 ;; -H|--host) HOST="$2"; shift 2 ;; -p|--port) PORT="$2"; shift 2 ;; -m|--model) MODEL="$2"; shift 2 ;; *) echo "Unknown option: $1"; exit 1 ;; esac done BASE_URL="http://${HOST}:${PORT}" # Auto-detect model from server if not specified if [[ -z "$MODEL" ]]; then MODEL=$(curl -sf "${BASE_URL}/v1/models" | python3 -c "import sys,json; print(json.load(sys.stdin)['data'][0]['id'])" 2>/dev/null) if [[ -z "$MODEL" ]]; then echo "ERROR: Cannot detect model. Is the server running at ${BASE_URL}?" exit 1 fi fi PASS=0 FAIL=0 echo "Testing SGLang server at ${BASE_URL}" echo "Model: ${MODEL}" echo "========================================" test_endpoint() { local name="$1" cmd="$2" echo -ne "\n${name}... " if eval "$cmd" 2>&1; then ((PASS++)) else ((FAIL++)) fi } test_endpoint "[1/10] List models" \ "curl -sf '${BASE_URL}/v1/models' | jq -e '.data'" test_endpoint "[2/10] Basic completions" \ "curl -sf -X POST '${BASE_URL}/v1/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"prompt\": \"Hello\", \"max_tokens\": 10}' | jq -e '.choices'" test_endpoint "[3/10] Batch completions" \ "curl -sf -X POST '${BASE_URL}/v1/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"prompt\": [\"Once\", \"In\"], \"max_tokens\": 10}' | jq -e '.choices'" test_endpoint "[4/10] Chat completions" \ "curl -sf -X POST '${BASE_URL}/v1/chat/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"messages\": [{\"role\": \"user\", \"content\": \"Hi\"}], \"max_tokens\": 10}' | jq -e '.choices'" test_endpoint "[5/10] Sampling parameters" \ "curl -sf -X POST '${BASE_URL}/v1/chat/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"messages\": [{\"role\": \"user\", \"content\": \"Hi\"}], \"max_tokens\": 10, \"temperature\": 0.9, \"top_p\": 0.95}' | jq -e '.choices'" test_endpoint "[6/10] Streaming" \ "curl -sf -X POST '${BASE_URL}/v1/chat/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"messages\": [{\"role\": \"user\", \"content\": \"Hi\"}], \"stream\": true, \"max_tokens\": 10}' | grep -q 'data:'" test_endpoint "[7/10] Logprobs" \ "curl -sf -X POST '${BASE_URL}/v1/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"prompt\": \"Hello\", \"max_tokens\": 5, \"logprobs\": 5}' | jq -e '.choices[0].logprobs'" test_endpoint "[8/10] Stop sequences" \ "curl -sf -X POST '${BASE_URL}/v1/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"prompt\": \"1.\", \"max_tokens\": 20, \"stop\": [\"3.\"]}' | jq -e '.choices'" test_endpoint "[9/10] Native generate" \ "curl -sf -X POST '${BASE_URL}/generate' \ -H 'Content-Type: application/json' \ -d '{\"text\": \"Hello\", \"sampling_params\": {\"max_new_tokens\": 10}}' | jq -e '.text'" test_endpoint "[10/10] Health check" \ "curl -sf '${BASE_URL}/health' | jq -e '.'" echo -e "\n========================================" echo "Results: ${PASS} passed, ${FAIL} failed" [[ $FAIL -gt 0 ]] && exit 1 ================================================ FILE: src/llm/tensorrt-llm/Dockerfile ================================================ ARG TRTLLM_VERSION=1.1.0 ARG CUDA_VERSION=13.0.0 ARG GDRCOPY_VERSION=v2.5.1 ARG EFA_INSTALLER_VERSION=1.46.0 ARG NCCL_VERSION=v2.29.2-1 FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu24.04 ARG GDRCOPY_VERSION ARG EFA_INSTALLER_VERSION ARG NCCL_VERSION ARG TRTLLM_VERSION # Prevent interactive prompts ENV DEBIAN_FRONTEND=noninteractive ENV TZ=UTC # Update and remove conflicting packages RUN apt-get update -y && apt-get upgrade -y RUN apt-get remove -y --allow-change-held-packages \ ibverbs-utils \ libibverbs-dev \ libibverbs1 \ libmlx5-1 \ libnccl2 \ libnccl-dev # Clean up existing MPI installations RUN rm -rf /opt/hpcx \ && rm -rf /usr/local/mpi \ && rm -f /etc/ld.so.conf.d/hpcx.conf \ && ldconfig ENV OPAL_PREFIX= # Install build dependencies RUN DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ apt-utils \ autoconf \ automake \ build-essential \ check \ cmake \ curl \ debhelper \ devscripts \ git \ gcc \ gdb \ kmod \ libnuma-dev \ libsubunit-dev \ libtool \ openssh-client \ openssh-server \ pkg-config \ python3 \ python3-dev \ python3-pip \ vim \ wget \ ninja-build \ && rm -rf /var/lib/apt/lists/* # Remove cuda-compat if present RUN apt-get purge -y cuda-compat-* || true # Configure SSH RUN mkdir -p /var/run/sshd RUN sed -i 's/[ #]\(.*StrictHostKeyChecking \).*/ \1no/g' /etc/ssh/ssh_config && \ echo " UserKnownHostsFile /dev/null" >> /etc/ssh/ssh_config && \ sed -i 's/#\(StrictModes \).*/\1no/g' /etc/ssh/sshd_config # Set library paths ENV LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:/opt/amazon/openmpi/lib:/opt/nccl/build/lib:/opt/amazon/efa/lib:/opt/gdrcopy/lib:/usr/local/lib:$LD_LIBRARY_PATH ENV PATH=/opt/amazon/openmpi/bin:/opt/amazon/efa/bin:/opt/gdrcopy/bin:/usr/bin:/usr/local/bin:$PATH # Remove PEP 668 restriction and install packages RUN rm -f /usr/lib/python*/EXTERNALLY-MANAGED \ && pip3 install --no-cache-dir awscli nvidia-ml-py Cython # Install GDRCopy RUN git clone -b ${GDRCOPY_VERSION} https://github.com/NVIDIA/gdrcopy.git /tmp/gdrcopy \ && cd /tmp/gdrcopy \ && make prefix=/opt/gdrcopy install \ && rm -rf /tmp/gdrcopy ENV LIBRARY_PATH=/opt/gdrcopy/lib:${LIBRARY_PATH:-} ENV CPATH=/opt/gdrcopy/include # Install EFA dependencies RUN apt-get update -y && apt-get install -y --no-install-recommends \ pciutils \ environment-modules \ tcl \ libnl-3-200 \ libnl-3-dev \ libnl-route-3-200 \ libnl-route-3-dev \ udev \ dmidecode \ ethtool \ iproute2 \ libevent-core-2.1-7t64 \ libevent-pthreads-2.1-7t64 \ libhwloc15 \ && rm -rf /var/lib/apt/lists/* # Install EFA RUN cd /tmp \ && curl -O https://efa-installer.amazonaws.com/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \ && tar -xf aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \ && cd aws-efa-installer \ && ./efa_installer.sh -y -g -d --skip-kmod --skip-limit-conf --no-verify \ && rm -rf /tmp/aws-efa-installer* # Install NCCL RUN git clone -b ${NCCL_VERSION} https://github.com/NVIDIA/nccl.git /tmp/nccl \ && cd /tmp/nccl \ && make -j $(nproc) src.build CUDA_HOME=/usr/local/cuda \ NVCC_GENCODE="-gencode=arch=compute_90,code=sm_90" \ && mkdir -p /opt/nccl/build/lib \ && cp -r build/lib/* /opt/nccl/build/lib/ \ && cp -r build/include /opt/nccl/build/ \ && rm -rf /tmp/nccl # OpenMPI settings ENV OMPI_MCA_pml=^ucx ENV OMPI_MCA_btl=tcp,self ENV OMPI_MCA_btl_tcp_if_exclude=lo,docker0,veth_def_agent ENV OPAL_PREFIX=/opt/amazon/openmpi ENV PMIX_MCA_gds=hash # NCCL settings ENV NCCL_DEBUG=INFO ENV NCCL_SOCKET_IFNAME=^docker,lo,veth ENV NCCL_P2P_NET_CHUNKSIZE=524288 ENV NCCL_BUFFSIZE=8388608 ENV NCCL_TUNER_PLUGIN=/opt/amazon/ofi-nccl/lib/libnccl-tuner-ofi.so ENV LD_PRELOAD=/opt/nccl/build/lib/libnccl.so # EFA settings ENV FI_PROVIDER=efa ENV FI_EFA_USE_DEVICE_RDMA=1 ENV FI_EFA_FORK_SAFE=1 ENV RDMAV_FORK_SAFE=1 # Install TensorRT-LLM # urllib3 is installed by debian so pip cannot uninstall it; ignore it RUN pip3 install --no-cache-dir --ignore-installed urllib3 tensorrt-llm==${TRTLLM_VERSION} \ && pip3 install --no-cache-dir --force-reinstall nvidia-ml-py \ && pip3 uninstall -y pynvml 2>/dev/null || true # Install Nsight Systems for profiling RUN apt-get update -y && apt-get install -y --no-install-recommends gnupg \ && apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/3bf863cc.pub \ && echo "deb https://developer.download.nvidia.com/devtools/repos/ubuntu2404/$(dpkg --print-architecture) /" \ > /etc/apt/sources.list.d/nvidia-devtools.list \ && apt-key adv --keyserver keyserver.ubuntu.com --recv-keys F60F4B3D7FA2AF80 \ && apt-get update -y \ && apt-get install -y --no-install-recommends nsight-systems-cli \ && rm -rf /var/lib/apt/lists/* WORKDIR /workspace ================================================ FILE: src/llm/tensorrt-llm/Makefile ================================================ .PHONY: help docker sqush save load serve test bench clean .DEFAULT_GOAL := help help: @echo "TensorRT-LLM Serving Makefile" @echo "" @echo "Build targets:" @echo " docker Build Docker image" @echo " sqush Build Enroot sqsh file" @echo " save Save Docker image to tar.gz" @echo " load Load Docker image from tar.gz" @echo "" @echo "Run targets:" @echo " serve Launch TensorRT-LLM server" @echo " test Test API endpoints" @echo " bench Run benchmarks (HOST=ip)" @echo "" @echo "Variables:" @echo " MODEL=$(MODEL)" @echo " PORT=$(PORT)" @echo " TP=$(TP)" @echo "" @echo "Examples:" @echo " make docker" @echo " make serve MODEL=Qwen/Qwen2.5-14B-Instruct TP=8" @echo " make test HOST=10.0.128.193" @echo " make bench HOST=10.0.128.193" @echo "" @echo "Cleanup:" @echo " clean Remove containers and images" IMAGE_NAME ?= tensorrt-llm-serve IMAGE_TAG ?= latest CONTAINER_NAME ?= trtllm-server MODEL ?= Qwen/Qwen2.5-7B-Instruct PORT ?= 8000 TP ?= 1 HOST ?= localhost BENCH_TYPE ?= DEVICES := --device=/dev/gdrdrv $(shell find /dev/infiniband -name "uverbs*" 2>/dev/null | sed 's/^/--device=/') DOCKER_RUN = docker run --gpus all \ --privileged \ --uts=host \ --ipc=host \ --net=host \ --ulimit stack=67108864 \ --ulimit memlock=-1 \ --security-opt seccomp=unconfined \ $(DEVICES) \ --rm \ --name $(CONTAINER_NAME) \ -v /fsx:/fsx \ --entrypoint bash \ $(IMAGE_NAME):$(IMAGE_TAG) docker: docker build -t $(IMAGE_NAME):$(IMAGE_TAG) -f Dockerfile . sqush: docker enroot import -o $(IMAGE_NAME)-$(IMAGE_TAG).sqsh dockerd://$(IMAGE_NAME):$(IMAGE_TAG) save: docker save $(IMAGE_NAME):$(IMAGE_TAG) | pigz > $(IMAGE_NAME)-$(IMAGE_TAG).tar.gz load: pigz -dc $(IMAGE_NAME)-$(IMAGE_TAG).tar.gz | docker load serve: $(DOCKER_RUN) -c 'trtllm-serve $(MODEL) --host 0.0.0.0 --port $(PORT) --tp_size $(TP)' test: @./test.sh -H $(HOST) -p $(PORT) bench: @bash bench.sh -H $(HOST) -p $(PORT) -i $(IMAGE_NAME):$(IMAGE_TAG) $(if $(BENCH_TYPE),--type $(BENCH_TYPE)) clean: -docker rm -f $(CONTAINER_NAME) 2>/dev/null || true -docker rmi $(IMAGE_NAME):$(IMAGE_TAG) 2>/dev/null || true -rm -f $(IMAGE_NAME)-$(IMAGE_TAG).sqsh -rm -f $(IMAGE_NAME)-$(IMAGE_TAG).tar.gz ================================================ FILE: src/llm/tensorrt-llm/README.rst ================================================ ==================== TensorRT-LLM Serving ==================== .. contents:: Table of Contents :backlinks: none This cheat sheet provides quick-reference commands for launching a TensorRT-LLM server in both local (single-node) and SLURM environments. It covers building the Docker image, running with different parallelism strategies, and testing the server. TensorRT-LLM v1.1.0 uses ``trtllm-serve`` for OpenAI-compatible online serving and ``trtllm-bench`` for benchmarking. For more details, see the `TensorRT-LLM documentation `_ and `GitHub repository `_. For parallelism strategies and benchmark methodology, see the `LLM Serving Guide `_ and `LLM Benchmark Guide `_. Build Docker Image ------------------ The Dockerfile bundles TensorRT-LLM with EFA drivers, NCCL, and GDRCopy for high-performance inference on GPU clusters. .. code-block:: bash # Build the Docker image make docker # Save as a compressed tarball for SLURM nodes # Output: tensorrt-llm-serve-latest.tar.gz make save Local Serving (Single Node) --------------------------- TensorRT-LLM exposes an OpenAI-compatible API on port 8000 by default via ``trtllm-serve``. **Bare metal** — run directly (requires TensorRT-LLM installed): .. code-block:: bash # Single GPU trtllm-serve Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0 --port 8000 # Tensor parallel across 8 GPUs trtllm-serve Qwen/Qwen2.5-14B-Instruct --tp_size 8 # FP8 quantized model trtllm-serve nvidia/Qwen3-8B-FP8 **Using Docker (via Makefile)**: .. code-block:: bash # Single GPU with default model make serve MODEL=Qwen/Qwen2.5-7B-Instruct # Tensor parallel across 8 GPUs make serve MODEL=Qwen/Qwen2.5-14B-Instruct TP=8 **Using the NGC container directly**: .. code-block:: bash docker run --gpus all --rm --ipc host --net=host \ --ulimit memlock=-1 --ulimit stack=67108864 \ -v /fsx:/fsx \ nvcr.io/nvidia/tensorrt-llm/release:1.1.0 \ trtllm-serve Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0 --port 8000 SLURM Serving ------------- ``run.sbatch`` orchestrates TensorRT-LLM serving on SLURM clusters. It handles Docker image distribution, container launch with EFA/GPU passthrough, and health checking. The server runs until you stop it with ``Ctrl+C`` or ``scancel``. **Script flags** — consumed by the script, not passed to trtllm-serve: .. list-table:: :widths: 30 70 :header-rows: 1 * - Flag - Description * - ``--image PATH`` - Docker image tarball or registry path (default: ``$WORKSPACE/tensorrt-llm-serve-latest.tar.gz``) * - ``--workspace, -w PATH`` - Base directory for default image and logs (default: ``$PWD``) * - ``--container-mount PATH`` - Host path to bind-mount into containers (default: ``/fsx``) * - ``--force, -f`` - Force remove existing containers and images before loading All other arguments are passed directly to ``trtllm-serve``. **Basic usage**: .. code-block:: bash # Allocate 1 node with 8 GPUs salloc -N 1 --gpus-per-node=8 --exclusive # Serve with TP=8 bash run.sbatch \ Qwen/Qwen2.5-14B-Instruct \ --tp_size 8 **Custom image**: .. code-block:: bash bash run.sbatch \ --image /fsx/images/tensorrt-llm-serve-latest.tar.gz \ Qwen/Qwen2.5-72B-Instruct \ --tp_size 8 **FP8 quantized model**: .. code-block:: bash bash run.sbatch nvidia/Qwen3-8B-FP8 Test the Server --------------- TensorRT-LLM serves on port 8000 by default: .. code-block:: bash # Health check curl http://:8000/health # List models curl http://:8000/v1/models # Chat completion (OpenAI-compatible) curl -X POST http://:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "Qwen/Qwen2.5-14B-Instruct", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50 }' # Completions endpoint curl -X POST http://:8000/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "Qwen/Qwen2.5-14B-Instruct", "prompt": "The capital of France is", "max_tokens": 20 }' # Run the included test script bash test.sh # Test against a remote server bash test.sh -H 10.0.128.193 -p 8000 Benchmark --------- ``bench.sh`` measures serving performance (throughput, TTFT, ITL, latency) by sending requests to a running TensorRT-LLM server. It handles Docker image loading and container management automatically. .. code-block:: bash # Run all benchmarks bash bench.sh -H 10.0.128.193 -i tensorrt-llm-serve:latest # Run specific benchmarks bash bench.sh -H 10.0.128.193 -i tensorrt-llm-serve:latest --type throughput,prefill # Via Makefile make bench HOST=10.0.128.193 make bench HOST=10.0.128.193 BENCH_TYPE=throughput,prefill Available benchmark types: - **throughput** — peak output tokens/sec at max request rate - **prefill** — TTFT scaling with input length (128→4096 tokens) - **decode** — ITL as output length grows (128→1024 tokens) - **latency** — end-to-end latency under minimal load - **concurrency** — throughput vs latency at different concurrency levels - **sharegpt** — realistic conversational workload You can also use the built-in ``trtllm-bench`` CLI for more detailed benchmarking: .. code-block:: bash # Static benchmark (no server needed) trtllm-bench --model Qwen/Qwen2.5-7B-Instruct \ --dataset-type synthetic --num-requests 100 Parallelism ----------- TensorRT-LLM supports multiple parallelism strategies, configured via a YAML file (``parallel_config.yaml``) passed with ``--config``: - **Tensor Parallel (TP)** — shards model weights across GPUs - **Pipeline Parallel (PP)** — distributes layers across GPUs - **Data Parallel (DP)** — replicates model across GPUs for different requests - **Expert Parallel (EP)** — distributes MoE experts across GPUs - **Context Parallel (CP)** — distributes long-context processing across GPUs - **Wide-EP** — advanced EP with load balancing for large-scale MoE (DeepSeek-V3/R1, LLaMA4, Qwen3) **Attention module** supports TP (small batches) or DP (large batches) via ``enable_attention_dp``. **MoE FFN** supports TP, EP, or hybrid ETP where ``moe_tensor_parallel_size × moe_expert_parallel_size = tensor_parallel_size``. .. list-table:: :widths: 25 55 20 :header-rows: 1 * - Strategy - Use case - Key config * - TP only - Dense models, small batch, memory-constrained - ``tensor_parallel_size: 8`` * - PP - Very large models that don't fit in single-node GPU memory - ``pipeline_parallel_size: 2`` * - DP (attention) - Large batch, high throughput - ``enable_attention_dp: true`` * - EP only (MoE) - MoE models with high expert count - ``moe_expert_parallel_size: 8`` * - Hybrid ETP (MoE) - Balance workload and kernel efficiency - ``moe_tensor_parallel_size: 4, moe_expert_parallel_size: 2`` * - Wide-EP (MoE) - Large-scale MoE with load balancing (hot expert replication) - See ``examples/wide_ep/`` **Configuration via YAML** (recommended): .. code-block:: yaml # parallel_config.yaml # Dense model: TP=8 tensor_parallel_size: 8 # Dense model: TP=8 with attention DP # tensor_parallel_size: 8 # enable_attention_dp: true # MoE: EP only # tensor_parallel_size: 8 # moe_expert_parallel_size: 8 # MoE: Hybrid TP-4 × EP-2 # tensor_parallel_size: 8 # moe_tensor_parallel_size: 4 # moe_expert_parallel_size: 2 .. code-block:: bash trtllm-serve Qwen/Qwen2.5-14B-Instruct --config parallel_config.yaml **Quick examples via CLI flags**: .. code-block:: bash # TP=4 trtllm-serve Qwen/Qwen2.5-14B-Instruct --tp_size 4 # TP=8, PP=2 trtllm-serve Qwen/Qwen2.5-72B-Instruct --tp_size 8 --pp_size 2 # MoE with EP trtllm-serve Qwen/Qwen1.5-MoE-A2.7B --tp_size 8 --ep_size 4 Key Differences from SGLang / vLLM ----------------------------------- .. list-table:: :widths: 25 25 25 25 :header-rows: 1 * - Feature - TensorRT-LLM - SGLang - vLLM * - Serve command - ``trtllm-serve `` - ``python -m sglang.launch_server`` - ``vllm serve `` * - Default port - 8000 - 30000 - 8000 * - TP flag - ``--tp_size N`` - ``--tp N`` - ``--tensor-parallel-size N`` * - Bench tool - ``trtllm-bench`` - ``sglang.bench_serving`` - ``vllm bench`` * - Container - ``nvcr.io/nvidia/tensorrt-llm/release`` - Custom build - Custom build * - Quantization - FP8, FP4, INT4 AWQ, INT8 SQ - FP8, AWQ, GPTQ - FP8, AWQ, GPTQ, BitsAndBytes ================================================ FILE: src/llm/tensorrt-llm/bench.sh ================================================ #!/usr/bin/env bash # TensorRT-LLM serving benchmark suite # Uses the official benchmark_serving.py for proper TTFT/TPOT/ITL/E2EL metrics # Usage: # salloc -N1 bash bench.sh -H 10.0.128.193 -i /fsx/tensorrt-llm-serve-latest.tar.gz # bash bench.sh -H 10.0.128.193 -i tensorrt-llm-serve:latest set -euo pipefail info() { echo "[$(date +'%H:%M:%S')] $*"; } CONTAINER_MOUNT="${CONTAINER_MOUNT:-/fsx}" _run() { if [[ -n "${SLURM_JOB_ID:-}" ]]; then srun -N1 --ntasks-per-node=1 bash -c "$*" else bash -c "$*" fi } load_or_pull_image() { if [[ "${IMAGE}" == *.tar.gz ]]; then CONTAINER_IMAGE=$(pigz -dc "${IMAGE}" | tar -xf - -O manifest.json \ | python3 -c "import sys,json; print(json.load(sys.stdin)[0]['RepoTags'][0])") info "Image tag: ${CONTAINER_IMAGE}" _run " if ! docker image inspect '${CONTAINER_IMAGE}' &>/dev/null; then echo 'Loading Docker image from tarball...' pigz -dc '${IMAGE}' | docker load fi " else CONTAINER_IMAGE="${IMAGE}" _run " if ! docker image inspect '${CONTAINER_IMAGE}' &>/dev/null; then echo 'Pulling ${CONTAINER_IMAGE}...' registry=\"\${CONTAINER_IMAGE%%/*}\" region=\$(echo \"\${registry}\" | sed -n 's/.*\.ecr\.\([^.]*\)\.amazonaws\.com/\1/p') region=\"\${region:-us-west-2}\" aws ecr get-login-password --region \"\${region}\" \ | docker login --username AWS --password-stdin \"\${registry}\" docker pull '${CONTAINER_IMAGE}' fi " fi } launch_container() { local cmd="$1" _run " docker run --rm --net=host --gpus all \ -v '${PWD}:${PWD}' -w '${PWD}' \ -v '${CONTAINER_MOUNT}:${CONTAINER_MOUNT}' \ --entrypoint bash '${CONTAINER_IMAGE}' \ -c '${cmd}' " } # If tensorrt_llm package is not installed, load image and re-exec inside container if ! python3 -c "import importlib.metadata; importlib.metadata.version('tensorrt-llm')" &>/dev/null 2>&1; then info "tensorrt_llm not found, bootstrapping into container..." info "SLURM_JOB_ID=${SLURM_JOB_ID:-} SLURM_NODELIST=${SLURM_NODELIST:-}" IMAGE="" _args=("$@") for ((i=0; i<${#_args[@]}; i++)); do [[ "${_args[$i]}" == "--image" || "${_args[$i]}" == "-i" ]] \ && { IMAGE="${_args[$((i+1))]}"; break; } done IMAGE="${IMAGE:-${PWD}/tensorrt-llm-serve-latest.tar.gz}" info "IMAGE=${IMAGE}" load_or_pull_image info "CONTAINER_IMAGE=${CONTAINER_IMAGE}" _SCRIPT="$(cd "$(dirname "$0")" && pwd)/$(basename "$0")" info "Launching container on: $(hostname)" launch_container "bash ${_SCRIPT} $*" exit $? fi HOST="localhost" PORT="8000" MODEL="" SEED="42" RESULT_DIR="./results" TYPES="throughput,prefill,decode,latency,concurrency,sharegpt" usage() { cat < $label" python3 -m tensorrt_llm.serve.scripts.benchmark_serving \ --backend openai \ --base-url "${BASE_URL}" \ --model "$MODEL" \ --seed "$SEED" \ --save-result \ --result-dir "$RESULT_DIR" \ --result-filename "$(basename "$outfile")" \ "$@" echo "" } bench_throughput() { bench "Throughput (random 512in/256out, max rate)" \ --dataset-name random --random-ids --random-prefix-len 0 \ --random-input-len 512 --random-output-len 256 \ --num-prompts 100 --ignore-eos } bench_prefill() { for len in 128 512 1024 2048; do bench "Prefill TTFT (input=${len})" \ --dataset-name random --random-ids --random-prefix-len 0 \ --random-input-len "$len" --random-output-len 1 \ --num-prompts 100 --max-concurrency 4 --ignore-eos done } bench_decode() { for len in 128 256 512 1024; do bench "Decode ITL (output=${len})" \ --dataset-name random --random-ids --random-prefix-len 0 \ --random-input-len 128 --random-output-len "$len" \ --num-prompts 100 --max-concurrency 4 --ignore-eos done } bench_latency() { bench "Latency (short 128/128, concurrency=1)" \ --dataset-name random --random-ids --random-prefix-len 0 \ --random-input-len 128 --random-output-len 128 \ --num-prompts 100 --max-concurrency 1 --ignore-eos bench "Latency (medium 512/256, concurrency=1)" \ --dataset-name random --random-ids --random-prefix-len 0 \ --random-input-len 512 --random-output-len 256 \ --num-prompts 100 --max-concurrency 1 --ignore-eos } bench_concurrency() { for c in 1 4 16 64 256; do bench "Concurrency=${c} (512in/256out)" \ --dataset-name random --random-ids --random-prefix-len 0 \ --random-input-len 512 --random-output-len 256 \ --num-prompts $((c * 5)) --max-concurrency "$c" --ignore-eos done } SHAREGPT_PATH="${SHAREGPT_PATH:-ShareGPT_V3_unfiltered_cleaned_split.json}" bench_sharegpt() { if [[ ! -f "$SHAREGPT_PATH" ]]; then echo "Downloading ShareGPT dataset..." wget -q -O "$SHAREGPT_PATH" \ https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json fi bench "ShareGPT (100 prompts, max rate)" \ --dataset-name sharegpt \ --dataset-path "$SHAREGPT_PATH" \ --num-prompts 100 --ignore-eos bench "ShareGPT (100 prompts, concurrency=4)" \ --dataset-name sharegpt \ --dataset-path "$SHAREGPT_PATH" \ --num-prompts 100 --max-concurrency 4 --ignore-eos } IFS=',' read -ra TESTS <<< "$TYPES" for t in "${TESTS[@]}"; do t=$(echo "$t" | xargs) echo "========================================" echo "Running: $t" echo "========================================" case "$t" in throughput) bench_throughput ;; prefill) bench_prefill ;; decode) bench_decode ;; latency) bench_latency ;; concurrency) bench_concurrency ;; sharegpt) bench_sharegpt ;; *) echo "Unknown test: $t"; exit 1 ;; esac done ================================================ FILE: src/llm/tensorrt-llm/run.sbatch ================================================ #!/bin/bash set -euo pipefail GPUS="${GPUS:-all}" info() { echo -e "[$(date +'%Y-%m-%dT%H:%M:%S%z')][info] $*"; } err() { echo -e "[$(date +'%Y-%m-%dT%H:%M:%S%z')][error] $*" >&2; } IMAGE="" CONTAINER_MOUNT="/fsx" WORKSPACE="$PWD" FORCE_PULL=false SERVE_ARGS=() while (( "$#" )); do case "$1" in --image) IMAGE="$2"; shift 2 ;; --container-mount) CONTAINER_MOUNT="$2"; shift 2 ;; --workspace|-w) WORKSPACE="$2"; shift 2 ;; --force|-f) FORCE_PULL=true; shift ;; *) SERVE_ARGS+=("$1"); shift ;; esac done IMAGE="${IMAGE:-${WORKSPACE}/tensorrt-llm-serve-latest.tar.gz}" LOGDIR="${WORKSPACE}/logs" SERVE_ARGS_STR=$(printf '%q ' "${SERVE_ARGS[@]+"${SERVE_ARGS[@]}"}") _peek_arg() { local short="$1" long="$2" default="$3" local i=0 while (( i < ${#SERVE_ARGS[@]} )); do if [[ "${SERVE_ARGS[$i]}" == "$short" || "${SERVE_ARGS[$i]}" == "$long" ]]; then echo "${SERVE_ARGS[$((i+1))]}"; return fi ((i++)) done echo "$default" } load_or_pull_image() { if [[ "${FORCE_PULL}" == "true" ]]; then info "Force pull: cleaning up existing images..." srun --ntasks-per-node=1 bash -c ' docker ps -aq | xargs -r docker rm -f 2>/dev/null || true docker images -aq | xargs -r docker rmi -f 2>/dev/null || true ' fi if [[ "${IMAGE}" == *.tar.gz ]]; then info "Loading Docker image from tarball..." CONTAINER_IMAGE=$(pigz -dc "${IMAGE}" | tar -xf - -O manifest.json | python3 -c "import sys,json; print(json.load(sys.stdin)[0]['RepoTags'][0])") info "Image tag: ${CONTAINER_IMAGE}" srun --ntasks-per-node=1 bash -c " if ! docker image inspect '${CONTAINER_IMAGE}' &>/dev/null; then pigz -dc '${IMAGE}' | docker load fi " else info "Pulling Docker image from registry..." local registry="${IMAGE%%/*}" local region=$(echo "${registry}" | sed -n 's/.*\.ecr\.\([^.]*\)\.amazonaws\.com/\1/p') region="${region:-us-west-2}" srun --ntasks-per-node=1 bash -c " if ! docker image inspect '${IMAGE}' &>/dev/null; then aws ecr get-login-password --region '${region}' | docker login --username AWS --password-stdin '${registry}' docker pull '${IMAGE}' fi " CONTAINER_IMAGE="${IMAGE}" fi } launch_container() { local name="${1}" cmd="${2}" local devices=("--device=/dev/gdrdrv") while IFS= read -r -d '' d; do devices+=("--device=${d}") done < <(find "/dev/infiniband" -name "uverbs*" -print0 2>/dev/null) local net_if="${GLOO_SOCKET_IFNAME:-$(ip -o -4 route show to default | awk '{print $5}' | head -1)}" docker run --gpus "${GPUS}" \ --privileged -d \ --name "${name}" \ --uts=host --ipc=host --net=host \ --ulimit stack=67108864 --ulimit memlock=-1 \ --security-opt seccomp=unconfined \ "${devices[@]}" \ -v "${CONTAINER_MOUNT}:${CONTAINER_MOUNT}" \ -e NCCL_SOCKET_IFNAME="${net_if}" \ -e GLOO_SOCKET_IFNAME="${net_if}" \ -e TP_SOCKET_IFNAME="${net_if}" \ --entrypoint bash \ "${CONTAINER_IMAGE:-${IMAGE}}" \ -c "${cmd}" } setup_topology() { NUM_NODES=${SLURM_JOB_NUM_NODES:-1} GPUS_PER_NODE=8 TOTAL_GPUS=$((NUM_NODES * GPUS_PER_NODE)) readarray -t NODES < <(scontrol show hostnames "$SLURM_JOB_NODELIST") HEAD_NODE=${NODES[0]} HEAD_IP=$(getent ahostsv4 "$HEAD_NODE" | head -1 | awk '{print $1}') mkdir -p "${LOGDIR}" info "========================================" info "TensorRT-LLM Server" info "========================================" info "Image: ${IMAGE}" info "Nodes: ${NUM_NODES}, Head: ${HEAD_NODE} (${HEAD_IP}), GPUs: ${TOTAL_GPUS}" info "SERVE_ARGS: ${SERVE_ARGS[*]+"${SERVE_ARGS[*]}"}" info "========================================" } cleanup() { info "Cleaning up containers..." srun --ntasks-per-node=1 bash -c ' docker ps -aq | xargs -r docker rm -f 2>/dev/null || true ' 2>/dev/null || true rm -f "${LOGDIR}/trtllm_server_${SLURM_JOB_ID}.log" } start_trtllm() { local logfile="${LOGDIR}/trtllm_server_${SLURM_JOB_ID}.log" # Launch container on head node srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c " $(declare -f launch_container) CONTAINER_IMAGE='${CONTAINER_IMAGE:-${IMAGE}}' GPUS='${GPUS}' CONTAINER_MOUNT='${CONTAINER_MOUNT}' launch_container trtllm-node-0 'sleep infinity' " sleep 3 # trtllm-serve syntax: trtllm-serve [OPTIONS] srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c " docker exec -d trtllm-node-0 bash -c 'trtllm-serve ${SERVE_ARGS_STR} --host 0.0.0.0 --port 8000 2>&1 | tee ${logfile}' " } wait_for_server() { info "Waiting for TensorRT-LLM server at ${HEAD_IP}:8000..." for _ in {1..360}; do if srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c "curl -sf localhost:8000/health" &>/dev/null && srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c "curl -sf localhost:8000/v1/models | grep -q '\"id\"'" &>/dev/null; then info "Server ready at ${HEAD_IP}:8000" return 0 fi sleep 10 done err "Timeout waiting for server"; return 1 } setup_topology trap cleanup EXIT cleanup LOGFILE="${LOGDIR}/trtllm_server_${SLURM_JOB_ID}.log" load_or_pull_image start_trtllm tail -f "${LOGFILE}" 2>/dev/null & wait_for_server || exit 1 info "TensorRT-LLM serving on ${HEAD_IP}:8000 — Ctrl+C or scancel to stop" info "Logs: ${LOGFILE}" sleep infinity ================================================ FILE: src/llm/tensorrt-llm/test.sh ================================================ #!/usr/bin/env bash # TensorRT-LLM API test script set -uo pipefail HOST="localhost" PORT="8000" MODEL="" while (( "$#" )); do case "$1" in -h|--help) echo "Usage: $0 [-H host] [-p port] [-m model]"; exit 0 ;; -H|--host) HOST="$2"; shift 2 ;; -p|--port) PORT="$2"; shift 2 ;; -m|--model) MODEL="$2"; shift 2 ;; *) echo "Unknown option: $1"; exit 1 ;; esac done BASE_URL="http://${HOST}:${PORT}" # Auto-detect model from server if not specified if [[ -z "$MODEL" ]]; then MODEL=$(curl -sf "${BASE_URL}/v1/models" | python3 -c "import sys,json; print(json.load(sys.stdin)['data'][0]['id'])" 2>/dev/null) if [[ -z "$MODEL" ]]; then echo "ERROR: Cannot detect model. Is the server running at ${BASE_URL}?" exit 1 fi fi PASS=0 FAIL=0 echo "Testing TensorRT-LLM server at ${BASE_URL}" echo "Model: ${MODEL}" echo "========================================" test_endpoint() { local name="$1" cmd="$2" echo -ne "\n${name}... " if eval "$cmd" 2>&1; then ((PASS++)) else ((FAIL++)) fi } test_endpoint "[1/8] Health check" \ "curl -sf '${BASE_URL}/health' | jq -e '.'" test_endpoint "[2/8] List models" \ "curl -sf '${BASE_URL}/v1/models' | jq -e '.data'" test_endpoint "[3/8] Basic completions" \ "curl -sf -X POST '${BASE_URL}/v1/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"prompt\": \"Hello\", \"max_tokens\": 10}' | jq -e '.choices'" test_endpoint "[4/8] Chat completions" \ "curl -sf -X POST '${BASE_URL}/v1/chat/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"messages\": [{\"role\": \"user\", \"content\": \"Hi\"}], \"max_tokens\": 10}' | jq -e '.choices'" test_endpoint "[5/8] Sampling parameters" \ "curl -sf -X POST '${BASE_URL}/v1/chat/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"messages\": [{\"role\": \"user\", \"content\": \"Hi\"}], \"max_tokens\": 10, \"temperature\": 0.9, \"top_p\": 0.95}' | jq -e '.choices'" test_endpoint "[6/8] Streaming" \ "curl -sf -X POST '${BASE_URL}/v1/chat/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"messages\": [{\"role\": \"user\", \"content\": \"Hi\"}], \"stream\": true, \"max_tokens\": 10}' | grep -q 'data:'" test_endpoint "[7/8] Stop sequences" \ "curl -sf -X POST '${BASE_URL}/v1/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"prompt\": \"1.\", \"max_tokens\": 20, \"stop\": [\"3.\"]}' | jq -e '.choices'" test_endpoint "[8/8] Batch completions" \ "curl -sf -X POST '${BASE_URL}/v1/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"prompt\": [\"Once\", \"In\"], \"max_tokens\": 10}' | jq -e '.choices'" echo -e "\n========================================" echo "Results: ${PASS} passed, ${FAIL} failed" [[ $FAIL -gt 0 ]] && exit 1 ================================================ FILE: src/llm/vllm/Dockerfile ================================================ ARG VLLM_VERSION=0.15.1 ARG CUDA_VERSION=12.8.1 ARG GDRCOPY_VERSION=v2.5.1 ARG EFA_INSTALLER_VERSION=1.46.0 ARG NCCL_VERSION=v2.29.2-1 ARG NVSHMEM_VERSION=v3.5.19-1 FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu24.04 ARG GDRCOPY_VERSION ARG EFA_INSTALLER_VERSION ARG NCCL_VERSION ARG NVSHMEM_VERSION ARG VLLM_VERSION # Prevent interactive prompts ENV DEBIAN_FRONTEND=noninteractive ENV TZ=UTC # Update and remove conflicting packages RUN apt-get update -y && apt-get upgrade -y RUN apt-get remove -y --allow-change-held-packages \ ibverbs-utils \ libibverbs-dev \ libibverbs1 \ libmlx5-1 \ libnccl2 \ libnccl-dev # Clean up existing MPI installations RUN rm -rf /opt/hpcx \ && rm -rf /usr/local/mpi \ && rm -f /etc/ld.so.conf.d/hpcx.conf \ && ldconfig ENV OPAL_PREFIX= # Install build dependencies RUN DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ apt-utils \ autoconf \ automake \ build-essential \ check \ cmake \ curl \ debhelper \ devscripts \ git \ gcc \ gdb \ kmod \ libsubunit-dev \ libtool \ openssh-client \ openssh-server \ pkg-config \ python3 \ python3-dev \ python3-pip \ vim \ wget \ ninja-build \ && rm -rf /var/lib/apt/lists/* # Remove cuda-compat if present RUN apt-get purge -y cuda-compat-* || true # Configure SSH RUN mkdir -p /var/run/sshd RUN sed -i 's/[ #]\(.*StrictHostKeyChecking \).*/ \1no/g' /etc/ssh/ssh_config && \ echo " UserKnownHostsFile /dev/null" >> /etc/ssh/ssh_config && \ sed -i 's/#\(StrictModes \).*/\1no/g' /etc/ssh/sshd_config # Set library paths ENV LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:/opt/amazon/openmpi/lib:/opt/nccl/build/lib:/opt/amazon/efa/lib:/opt/gdrcopy/lib:/opt/nvshmem/lib:/usr/local/lib:$LD_LIBRARY_PATH ENV PATH=/opt/amazon/openmpi/bin:/opt/amazon/efa/bin:/opt/gdrcopy/bin:/usr/bin:/usr/local/bin:$PATH # Remove PEP 668 restriction and install packages RUN rm -f /usr/lib/python*/EXTERNALLY-MANAGED \ && pip3 install --no-cache-dir awscli nvidia-ml-py Cython # Install GDRCopy RUN git clone -b ${GDRCOPY_VERSION} https://github.com/NVIDIA/gdrcopy.git /tmp/gdrcopy \ && cd /tmp/gdrcopy \ && make prefix=/opt/gdrcopy install \ && rm -rf /tmp/gdrcopy ENV LIBRARY_PATH=/opt/gdrcopy/lib:${LIBRARY_PATH:-} ENV CPATH=/opt/gdrcopy/include # Install EFA dependencies RUN apt-get update -y && apt-get install -y --no-install-recommends \ pciutils \ environment-modules \ tcl \ libnl-3-200 \ libnl-3-dev \ libnl-route-3-200 \ libnl-route-3-dev \ udev \ dmidecode \ ethtool \ iproute2 \ libevent-core-2.1-7t64 \ libevent-pthreads-2.1-7t64 \ libhwloc15 \ && rm -rf /var/lib/apt/lists/* # Install EFA RUN cd /tmp \ && curl -O https://efa-installer.amazonaws.com/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \ && tar -xf aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \ && cd aws-efa-installer \ && ./efa_installer.sh -y -g -d --skip-kmod --skip-limit-conf --no-verify \ && rm -rf /tmp/aws-efa-installer* # Install NCCL RUN git clone -b ${NCCL_VERSION} https://github.com/NVIDIA/nccl.git /tmp/nccl \ && cd /tmp/nccl \ && make -j $(nproc) src.build CUDA_HOME=/usr/local/cuda \ NVCC_GENCODE="-gencode=arch=compute_90,code=sm_90" \ && mkdir -p /opt/nccl/build/lib \ && cp -r build/lib/* /opt/nccl/build/lib/ \ && cp -r build/include /opt/nccl/build/ \ && rm -rf /tmp/nccl # Install NVSHMEM ENV NVSHMEM_DIR=/opt/nvshmem ENV NVSHMEM_HOME=/opt/nvshmem RUN git clone https://github.com/NVIDIA/nvshmem.git /tmp/nvshmem \ && cd /tmp/nvshmem \ && git checkout ${NVSHMEM_VERSION} \ && mkdir -p build \ && cd build \ && cmake -DNVSHMEM_PREFIX=/opt/nvshmem \ -DCMAKE_CUDA_ARCHITECTURES="90" \ -DNVSHMEM_MPI_SUPPORT=1 \ -DNVSHMEM_PMIX_SUPPORT=1 \ -DNVSHMEM_LIBFABRIC_SUPPORT=1 \ -DNVSHMEM_IBRC_SUPPORT=1 \ -DNVSHMEM_IBGDA_SUPPORT=1 \ -DNVSHMEM_USE_GDRCOPY=1 \ -DNVSHMEM_BUILD_TESTS=0 \ -DNVSHMEM_BUILD_EXAMPLES=0 \ -DNVSHMEM_BUILD_HYDRA_LAUNCHER=0 \ -DNVSHMEM_BUILD_TXZ_PACKAGE=0 \ -DNVSHMEM_BUILD_PYTHON_LIB=0 \ -DMPI_HOME=/opt/amazon/openmpi \ -DPMIX_HOME=/opt/amazon/pmix \ -DGDRCOPY_HOME=/opt/gdrcopy \ -DLIBFABRIC_HOME=/opt/amazon/efa \ -G Ninja .. \ && ninja -j $(nproc) \ && ninja install \ && rm -rf /tmp/nvshmem # OpenMPI settings ENV OMPI_MCA_pml=^ucx ENV OMPI_MCA_btl=tcp,self ENV OMPI_MCA_btl_tcp_if_exclude=lo,docker0,veth_def_agent ENV OPAL_PREFIX=/opt/amazon/openmpi ENV PMIX_MCA_gds=hash # NCCL settings ENV NCCL_DEBUG=INFO ENV NCCL_SOCKET_IFNAME=^docker,lo,veth ENV NCCL_P2P_NET_CHUNKSIZE=524288 ENV NCCL_BUFFSIZE=8388608 ENV NCCL_TUNER_PLUGIN=/opt/amazon/ofi-nccl/lib/libnccl-tuner-ofi.so ENV LD_PRELOAD=/opt/nccl/build/lib/libnccl.so # EFA settings ENV FI_PROVIDER=efa ENV FI_EFA_USE_DEVICE_RDMA=1 ENV FI_EFA_FORK_SAFE=1 ENV RDMAV_FORK_SAFE=1 # vLLM settings ENV VLLM_RPC_TIMEOUT=3600000 ENV VLLM_ENGINE_READY_TIMEOUT_S=3600 ENV VLLM_USE_DEEP_GEMM=1 ENV DG_JIT_CACHE_DIR=/tmp # NVSHMEM settings ENV NVSHMEM_REMOTE_TRANSPORT=libfabric ENV NVSHMEM_LIBFABRIC_PROVIDER=efa ENV NVSHMEM_DISABLE_CUDA_VMM=1 # Install vLLM RUN pip3 install --no-cache-dir vllm==${VLLM_VERSION} # Install Nsight Systems for profiling RUN apt-get update -y && apt-get install -y --no-install-recommends gnupg \ && apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/3bf863cc.pub \ && echo "deb https://developer.download.nvidia.com/devtools/repos/ubuntu2404/$(dpkg --print-architecture) /" \ > /etc/apt/sources.list.d/nvidia-devtools.list \ && apt-key adv --keyserver keyserver.ubuntu.com --recv-keys F60F4B3D7FA2AF80 \ && apt-get update -y \ && apt-get install -y --no-install-recommends nsight-systems-cli \ && rm -rf /var/lib/apt/lists/* # Install DeepGEMM (requires torch from vLLM) RUN git clone --recursive -b v2.1.1.post3 https://github.com/deepseek-ai/DeepGEMM.git /tmp/deepgemm \ && cd /tmp/deepgemm \ && python3 setup.py bdist_wheel \ && pip3 install dist/*.whl \ && rm -rf /tmp/deepgemm # Copy run.sh script COPY run.sh /workspace/run.sh RUN chmod +x /workspace/run.sh WORKDIR /workspace ================================================ FILE: src/llm/vllm/Makefile ================================================ .PHONY: help docker save clean format .DEFAULT_GOAL := help IMAGE_NAME ?= vllm-serve IMAGE_TAG ?= latest help: @echo "vLLM Makefile" @echo "" @echo "Targets:" @echo " docker Build Docker image" @echo " save Save Docker image to tar.gz" @echo " format Format shell scripts with shfmt (indent=2)" @echo " clean Remove image and tarball" @echo "" @echo "Usage:" @echo " make docker && make save" @echo " make format" docker: docker build -t $(IMAGE_NAME):$(IMAGE_TAG) -f Dockerfile . save: docker save $(IMAGE_NAME):$(IMAGE_TAG) | pigz > $(IMAGE_NAME)-$(IMAGE_TAG).tar.gz format: @command -v shfmt >/dev/null 2>&1 || { echo "shfmt not found. Install: brew install shfmt"; exit 1; } shfmt -i 2 -w *.sh *.sbatch clean: -docker rmi $(IMAGE_NAME):$(IMAGE_TAG) 2>/dev/null || true -rm -f $(IMAGE_NAME)-$(IMAGE_TAG).tar.gz ================================================ FILE: src/llm/vllm/README.rst ================================================ ============ vLLM Serving ============ .. contents:: Table of Contents :backlinks: none This cheat sheet provides quick-reference commands for launching a vLLM server in both local (single-node) and SLURM (multi-node) environments. It covers building the Docker image, running with different parallelism strategies, and testing the server endpoints. For more details, see the `vLLM documentation `_ and `GitHub repository `_. For detailed explanations of tensor parallelism, pipeline parallelism, data parallelism, and expert parallelism, see the For parallelism strategies and benchmark methodology, see the `LLM Serving Guide `_ and `LLM Benchmark Guide `_. Build Docker Image ------------------ The Dockerfile bundles vLLM with EFA drivers, NCCL, NVSHMEM, and GDRCopy for high-performance multi-node inference on GPU clusters. Build the image and save it as a compressed tarball for distribution to SLURM nodes via a shared filesystem. .. code-block:: bash # Build the Docker image with all dependencies make docker # Save as a compressed tarball for SLURM nodes # Output: vllm-serve-latest.tar.gz make save Local Serving (Single Node) --------------------------- For development or single-node deployments, vLLM can run directly on the host or inside a Docker container. The server exposes an OpenAI-compatible API on port 8000. **Bare metal** — run vLLM directly without Docker: .. code-block:: bash # Single GPU — simplest way to serve a model vllm serve Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0 --port 8000 # Tensor parallel across 8 GPUs — for models too large for a single GPU vllm serve Qwen/Qwen2.5-14B-Instruct --tensor-parallel-size 8 **Using Docker (via Makefile)** — convenient targets for common configurations: .. code-block:: bash # Single GPU with default model make single # Tensor parallel across 8 GPUs make mp TP=8 # Pipeline parallel — split model into 2 stages, each with TP=4 make mp TP=4 PP=2 # Ray backend with data parallelism — 2 replicas, each using 4 GPUs make ray TP=4 DP=2 **Using Docker (via run.sh)** — the entrypoint script supports multiple serving modes (single, ray, mp, rpc) with explicit control over parallelism settings: .. code-block:: bash # Single GPU mode docker run --gpus all --rm --net=host -v /fsx:/fsx \ vllm-serve:latest ./run.sh single --model Qwen/Qwen2.5-7B-Instruct # Multiprocessing mode with tensor parallelism docker run --gpus all --rm --net=host -v /fsx:/fsx \ vllm-serve:latest ./run.sh mp --model Qwen/Qwen2.5-14B-Instruct --tp 8 SLURM Serving (Multi-Node) --------------------------- ``run.sbatch`` orchestrates multi-node vLLM serving on SLURM clusters. It handles Docker image distribution, container launch with EFA/GPU passthrough, parallelism computation, and health checking. The server runs until you stop it with ``Ctrl+C`` or ``scancel``. **Script flags** — these are consumed by the script and not passed to vLLM: .. list-table:: :widths: 30 70 :header-rows: 1 * - Flag - Description * - ``--image PATH`` - Docker image tarball or registry path (default: ``$WORKSPACE/vllm-serve-latest.tar.gz``) * - ``--workspace, -w PATH`` - Base directory for default image and logs (default: ``$PWD``) * - ``--container-mount PATH`` - Host path to bind-mount into containers (default: ``/fsx``) * - ``--force, -f`` - Force remove existing containers and images before loading * - ``--nsys`` - Enable Nsight Systems profiling (writes to ``$WORKSPACE/nsys-vllm/``) All other arguments are passed directly to ``vllm serve`` as-is. **Basic usage** — allocate nodes with ``salloc``, then run the script. The script auto-detects the SLURM allocation and computes DP based on total GPUs and TP/PP: .. code-block:: bash # Allocate 2 nodes with 8 GPUs each salloc -N 2 --gpus-per-node=8 --exclusive # Expert parallel for MoE models (TP=8, DP=2, EP=16 auto-computed) bash run.sbatch \ Qwen/Qwen3-30B-A3B-FP8 \ --tensor-parallel-size 8 \ --enable-expert-parallel **Custom image** — specify a different Docker image tarball or registry path: .. code-block:: bash bash run.sbatch \ --image /fsx/images/vllm-serve-latest.tar.gz \ Qwen/Qwen3-30B-A3B-FP8 \ --tensor-parallel-size 8 \ --enable-expert-parallel **Pipeline parallel** — for large dense models that need to be split across nodes. The script auto-selects the multiprocessing backend when PP > 1: .. code-block:: bash bash run.sbatch \ deepseek-ai/DeepSeek-V2-Lite \ --tensor-parallel-size 8 \ --pipeline-parallel-size 2 **Force reload** — remove cached containers and images before loading: .. code-block:: bash bash run.sbatch -f \ Qwen/Qwen3-30B-A3B-FP8 \ --tensor-parallel-size 8 \ --enable-expert-parallel **Backend selection** — the script supports three distributed backends. RPC is the default and works best for most cases. Ray is useful when you need its scheduling features. Multiprocessing is auto-selected for pipeline parallelism: .. code-block:: bash # RPC backend (default) — lightweight, best for DP + EP bash run.sbatch Qwen/Qwen3-30B-A3B-FP8 --tensor-parallel-size 8 # Ray backend — uses Ray cluster for worker management DP_BACKEND=ray bash run.sbatch \ Qwen/Qwen3-30B-A3B-FP8 \ --tensor-parallel-size 8 \ --enable-expert-parallel # Multiprocessing backend — auto-selected when PP > 1 bash run.sbatch \ deepseek-ai/DeepSeek-V2-Lite \ --tensor-parallel-size 8 \ --pipeline-parallel-size 2 Test the Server --------------- Once the server is ready, it prints the head node IP address. Use standard HTTP requests to interact with the OpenAI-compatible API: .. code-block:: bash # Health check — returns 200 when server is ready curl http://:8000/health # List available models curl http://:8000/v1/models # Chat completion request curl -X POST http://:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "Qwen/Qwen3-30B-A3B-FP8", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50 }' # Or run the included test script bash test.sh # Test against a remote server bash test.sh -H 10.0.128.193 # Test with specific port and model bash test.sh -H 10.0.128.193 -p 8000 -m Qwen/Qwen3-30B-A3B-FP8 Benchmark --------- ``bench.sh`` measures serving performance (throughput, TTFT, ITL, latency) by sending requests to a running vLLM server. It handles Docker image loading and container management automatically. .. code-block:: bash # Run all benchmarks bash bench.sh -H 10.0.128.193 -i vllm-serve:latest # Run specific benchmarks bash bench.sh -H 10.0.128.193 -i vllm-serve:latest --type throughput,prefill # Via Makefile make bench HOST=10.0.128.193 make bench HOST=10.0.128.193 BENCH_TYPE=throughput,prefill Sweep ----- ``sweep.sh`` runs predefined sweep suites by calling ``sweep.sbatch`` for each configuration. Each suite launches its own ``vllm serve``, sweeps a parameter, and collects results. Requires GPU access. .. code-block:: bash # All suites (rate, concurrency, input, output) bash sweep.sh -m Qwen/Qwen3-0.6B # Select specific suites bash sweep.sh -m Qwen/Qwen3-30B-A3B-FP8 \ -i vllm-serve:latest \ --serve-cmd "vllm serve Qwen/Qwen3-30B-A3B-FP8 -tp 8 --enable-expert-parallel" \ --type rate,input # Via Makefile make sweep make sweep SWEEP_MODEL=Qwen/Qwen3-30B-A3B-FP8 SWEEP_TYPE=rate,concurrency # Show vllm serve stdout (model loading, request logs) — very useful for debugging bash sweep.sh -m Qwen/Qwen3-0.6B --show-stdout # Custom serve command — TP=2 with expert parallel, rate sweep only bash sweep.sh -m Qwen/Qwen1.5-MoE-A2.7B \ --serve-cmd "vllm serve Qwen/Qwen1.5-MoE-A2.7B -tp 2 --enable-expert-parallel" \ --type rate --show-stdout Available suites: - **rate** — Sweeps request rate (1, 2, 4, 8, 16, 32, inf) to find saturation point - **concurrency** — Sweeps concurrent requests (1–128) to find optimal batch size - **input** — Sweeps input length (128–16K) to measure TTFT scaling with context - **output** — Sweeps output length (64–2048) to measure ITL as KV cache grows For direct control over sweep parameters, use ``sweep.sbatch`` which passes all args through to ``vllm bench sweep serve``: .. code-block:: bash # Custom serve + bench commands bash sweep.sbatch -m Qwen/Qwen3-30B-A3B-FP8 \ --serve-cmd "vllm serve Qwen/Qwen3-30B-A3B-FP8 -tp 8 --enable-expert-parallel" \ --bench-cmd "vllm bench serve --model Qwen/Qwen3-30B-A3B-FP8 --dataset-name sharegpt" \ --bench-params results/bench_params.json --show-stdout See the `LLM Benchmark Guide `_ for detailed explanations of each benchmark type and metric. Notes and Limitations --------------------- **Profiling:** vLLM supports PyTorch profiler tracing via ``--profiler-config``. The server must be started with profiling enabled, then the benchmark client triggers ``/start_profile`` and ``/stop_profile`` endpoints via ``--profile``. .. code-block:: bash # Server — start with profiling enabled (default writes to $PWD/vllm_profile) bash run.sbatch --profile \ Qwen/Qwen3-30B-A3B-FP8 \ --tensor-parallel-size 8 # Server — custom profiler config bash run.sbatch \ --profiler-config '{"profiler": "torch", "torch_profiler_dir": "/fsx/traces"}' \ Qwen/Qwen3-30B-A3B-FP8 \ --tensor-parallel-size 8 # Client — run benchmark with profiling bash bench.sh -H 10.0.128.193 --type throughput --profile View traces at https://ui.perfetto.dev/ (supports ``.gz`` files directly). See the `vLLM Profiling Guide `_ for more details. **Nsight Systems profiling:** ``--nsys`` wraps the ``vllm serve`` command with ``nsys profile`` for GPU-level tracing (CUDA kernels, NVTX ranges, memory usage). Profiles are saved per-node to ``$WORKSPACE/nsys-vllm/``. The script sends ``SIGINT`` to nsys on cleanup for graceful finalization. .. code-block:: bash # Server — enable nsys + vLLM's CUDA profiler (terminal 0) bash run.sbatch --nsys \ Qwen/Qwen3-30B-A3B-FP8 \ --tensor-parallel-size 8 \ --enable-expert-parallel \ --profiler-config '{"profiler": "cuda"}' # Client — run benchmark with profiling (terminal 1) bash bench.sh -H --type throughput --profile # Stop server with Ctrl+C (terminal 0) # Nsys will finalize profiles (~30s) # Profile files: nsys-vllm/profile-node*.nsys-rep Open ``.nsys-rep`` files with `Nsight Systems `_ or export to JSON for custom analysis. **Parallelism constraints:** - vLLM does not support combining PP, TP, and DP simultaneously. When using PP mode, DP is not available (``TOTAL_GPUS = TP × PP``). - EP and PP are mutually exclusive. Use EP for MoE models and PP for large dense models. - EP is computed automatically: ``EP = TP × DP = world_size``. As TP increases, DP decreases proportionally to maintain the same total parallelism. **Formulas:** - EP mode: ``TOTAL_GPUS = DP × TP``, EP auto-computed by vLLM - PP mode: ``TOTAL_GPUS = TP × PP`` (no DP) - DP mode: ``TOTAL_GPUS = DP × TP`` Offline Benchmarking --------------------- ``offline_bench.sh`` measures raw inference performance without API server overhead. Uses ``torchrun`` for multi-node coordination and supports profiling with Nsight Systems. **Single GPU:** .. code-block:: bash bash offline_bench.sh \ --model meta-llama/Llama-3.1-8B \ --input-len 512 --output-len 128 \ --num-prompts 100 **Multi-GPU with tensor parallelism:** .. code-block:: bash salloc -N 1 bash offline_bench.sh \ --model Qwen/Qwen2-57B-A14B \ --tensor-parallel-size 4 --enable-expert-parallel \ --input-len 1024 --output-len 256 \ --num-prompts 100 # Multi-node with custom docker image salloc -N 4 bash offline_bench.sh \ --image "$PWD/vllm-serve-latest.tar.gz" \ --model Qwen/Qwen2-57B-A14B \ --all2all-backend allgather_reducescatter \ --tensor-parallel-size 4 --enable-expert-parallel \ --gpu-memory-utilization 0.8 \ --input-len 2048 --output-len 512 \ --num-prompts 50 # ShareGPT dataset wget -O ShareGPT_V3_unfiltered_cleaned_split.json \ https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json bash offline_bench.sh \ --model meta-llama/Llama-3.1-8B \ --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \ --num-prompts 100 **Nsight Systems profiling:** Similar to the server workflow above, ``--nsys`` wraps the ``torchrun`` command (instead of ``vllm serve``) with ``nsys profile``. Profiles are saved per-node to ``$WORKSPACE/nsys-offline/``. .. code-block:: bash salloc -N 4 bash offline_bench.sh --nsys \ --model Qwen/Qwen2-57B-A14B \ --tensor-parallel-size 4 --enable-expert-parallel \ --all2all-backend allgather_reducescatter \ --input-len 2048 --output-len 512 \ --num-prompts 50 # Profile files: nsys-offline/profile-node*.nsys-rep **VizTracer profiling (Python-level tracing):** VizTracer is a lightweight Python profiler that traces function calls without the memory overhead of PyTorch profiler. It works reliably with multi-node/high data-parallelism setups where PyTorch profiler may cause OOM. Use VizTracer to understand Python-level execution flow and identify bottlenecks in application logic. .. code-block:: bash salloc -N 2 bash offline_bench.sh \ --model Qwen/Qwen2-57B-A14B \ --tensor-parallel-size 4 --enable-expert-parallel \ --viztracer ./vllm-trace.json \ --num-prompts 50 # View trace at https://ui.perfetto.dev/ **Profiling comparison:** - **nsys**: GPU kernel-level profiling (CUDA operations, memory transfers, NCCL) - **VizTracer**: Python function-level profiling (application logic, scheduling) - Use nsys for GPU performance analysis, VizTracer for Python code analysis ================================================ FILE: src/llm/vllm/bench.sh ================================================ #!/usr/bin/env bash # vLLM serving benchmark suite # Usage: # salloc -N1 bash bench.sh -H 10.0.128.193 -i /fsx/vllm-serve-latest.tar.gz # bash bench.sh -H 10.0.128.193 -i vllm-serve:latest set -euo pipefail info() { echo "[$(date +'%H:%M:%S')] $*"; } # Docker image helpers (mirrors run.sbatch) CONTAINER_MOUNT="${CONTAINER_MOUNT:-/fsx}" # Wrap a command with srun if inside a SLURM allocation, otherwise run directly _run() { if [[ -n "${SLURM_JOB_ID:-}" ]]; then srun -N1 --ntasks-per-node=1 bash -c "$*" else bash -c "$*" fi } load_or_pull_image() { if [[ "${IMAGE}" == *.tar.gz ]]; then CONTAINER_IMAGE=$(pigz -dc "${IMAGE}" | tar -xf - -O manifest.json \ | python3 -c "import sys,json; print(json.load(sys.stdin)[0]['RepoTags'][0])") info "Image tag: ${CONTAINER_IMAGE}" _run " if ! docker image inspect '${CONTAINER_IMAGE}' &>/dev/null; then echo 'Loading Docker image from tarball...' pigz -dc '${IMAGE}' | docker load fi " else CONTAINER_IMAGE="${IMAGE}" _run " if ! docker image inspect '${CONTAINER_IMAGE}' &>/dev/null; then echo 'Pulling ${CONTAINER_IMAGE}...' registry=\"\${CONTAINER_IMAGE%%/*}\" region=\$(echo \"\${registry}\" | sed -n 's/.*\.ecr\.\([^.]*\)\.amazonaws\.com/\1/p') region=\"\${region:-us-west-2}\" aws ecr get-login-password --region \"\${region}\" \ | docker login --username AWS --password-stdin \"\${registry}\" docker pull '${CONTAINER_IMAGE}' fi " fi } launch_container() { local cmd="$1" _run " docker run --rm --net=host \ -v '${PWD}:${PWD}' -w '${PWD}' \ -v '${CONTAINER_MOUNT}:${CONTAINER_MOUNT}' \ --entrypoint bash '${CONTAINER_IMAGE}' \ -c '${cmd}' " } # If vllm CLI is not available, load image and re-exec inside container if ! command -v vllm &>/dev/null; then # Pre-parse --image/-i before full arg parsing IMAGE="" _args=("$@") for ((i=0; i<${#_args[@]}; i++)); do [[ "${_args[$i]}" == "--image" || "${_args[$i]}" == "-i" ]] \ && { IMAGE="${_args[$((i+1))]}"; break; } done IMAGE="${IMAGE:-${PWD}/vllm-serve-latest.tar.gz}" load_or_pull_image _SCRIPT="$(cd "$(dirname "$0")" && pwd)/$(basename "$0")" launch_container "bash ${_SCRIPT} $*" exit $? fi HOST="localhost" PORT="8000" MODEL="" SEED="42" RESULT_DIR="./results" PROFILE="" TYPES="throughput,prefill,decode,latency,concurrency,longctx,sharegpt,sonnet" usage() { cat < $label" vllm bench serve \ --model "$MODEL" \ --base-url "$BASE_URL" \ --backend openai-chat \ --endpoint /v1/chat/completions \ --seed "$SEED" \ --save-result \ --result-dir "$RESULT_DIR" \ $PROFILE \ "$@" echo "" } # Throughput: measures peak output tokens/sec and request throughput. # Uses request-rate=inf to saturate the server — all 1000 prompts are sent as fast as # possible, forcing the scheduler to batch aggressively. 512in/256out is a moderate # workload that exercises both prefill and decode phases. 1000 prompts follows the # GPUStack methodology for statistically stable throughput numbers. bench_throughput() { bench "Throughput (random 512in/256out, max rate)" \ --dataset-name random \ --random-input-len 512 --random-output-len 256 \ --num-prompts 100 --request-rate inf } # Prefill (TTFT): measures Time to First Token, which reflects prompt processing speed. # output-len=1 isolates prefill from decode — we only care how fast the model processes # the input. Sweeps 128→16K tokens to show how TTFT scales with context length (should # grow roughly linearly due to attention's O(n) compute per layer during prefill). # rate=4 keeps the server lightly loaded so TTFT reflects compute, not queueing. bench_prefill() { for len in 128 512 2048 4096 16384; do bench "Prefill TTFT (input=${len})" \ --dataset-name random \ --random-input-len "$len" --random-output-len 1 \ --num-prompts 100 --request-rate 4 done } # Decode (ITL): measures Inter-Token Latency and Time Per Output Token during generation. # input=128 keeps prefill minimal so the benchmark focuses on autoregressive decode. # Sweeps 128→1024 output tokens to reveal how ITL changes as KV cache grows — longer # sequences increase memory pressure and may trigger preemption/swapping. # rate=4 avoids batching interference so ITL reflects per-request decode speed. bench_decode() { for len in 128 256 512 1024; do bench "Decode ITL (output=${len})" \ --dataset-name random \ --random-input-len 128 --random-output-len "$len" \ --num-prompts 100 --request-rate 4 done } # Latency (E2E): measures end-to-end request latency under minimal load. # rate=1 ensures requests are mostly processed alone (no batching), giving a baseline # for best-case latency. Tests short/medium/long to show how total latency scales. # These numbers represent the "single user" experience (similar to ChatGPT-style usage # where one user waits for a complete response). bench_latency() { bench "Latency (short 128/128, rate=1)" \ --dataset-name random \ --random-input-len 128 --random-output-len 128 \ --num-prompts 100 --request-rate 1 bench "Latency (medium 512/256, rate=1)" \ --dataset-name random \ --random-input-len 512 --random-output-len 256 \ --num-prompts 100 --request-rate 1 bench "Latency (long 4096/512, rate=1)" \ --dataset-name random \ --random-input-len 4096 --random-output-len 512 \ --num-prompts 100 --request-rate 1 } # Concurrency: finds the server's saturation point by sweeping concurrent requests. # request-rate=inf with max-concurrency=N caps how many requests run in parallel. # At low concurrency (1-4), latency is good but throughput is low (GPU underutilized). # At high concurrency (64-256), throughput plateaus and latency degrades (queueing). # The "knee" where throughput stops improving is the optimal operating point. # 500 prompts per level gives enough samples for stable percentile metrics. bench_concurrency() { for c in 1 4 16 64 256; do bench "Concurrency=${c} (512in/256out)" \ --dataset-name random \ --random-input-len 512 --random-output-len 256 \ --num-prompts 100 --request-rate inf --max-concurrency "$c" done } # Long context: tests behavior with very long inputs (4K→32K tokens). # Inspired by GPUStack's "very long prompt" config (32000in/100out). Long inputs stress # KV cache memory, attention compute, and may trigger chunked prefill. output=100 keeps # decode short so we focus on prefill scaling. rate=1 and fewer prompts (50) because # each request is expensive and we want to avoid OOM under memory pressure. bench_longctx() { for len in 4096 16384 32000; do bench "Long context (input=${len})" \ --dataset-name random \ --random-input-len "$len" --random-output-len 100 \ --num-prompts 50 --request-rate 1 done } # ShareGPT: realistic conversational workload from real user conversations. Variable # input/output lengths reflecting actual usage patterns. This is the standard dataset # used by vLLM CI, GPUStack perf lab, and most published benchmarks. Unlike random # datasets, ShareGPT captures the natural distribution of short/long prompts and # responses, making it the best proxy for production chat traffic. # Ref: github.com/vllm-project/vllm/blob/main/benchmarks/README.md # Ref: GPUStack perf lab uses ShareGPT with 1000 prompts as primary benchmark. # # Requires: wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json SHAREGPT_PATH="${SHAREGPT_PATH:-ShareGPT_V3_unfiltered_cleaned_split.json}" bench_sharegpt() { if [[ ! -f "$SHAREGPT_PATH" ]]; then echo "Downloading ShareGPT dataset..." wget -q -O "$SHAREGPT_PATH" \ https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json fi bench "ShareGPT (1000 prompts, max rate)" \ --dataset-name sharegpt \ --dataset-path "$SHAREGPT_PATH" \ --num-prompts 100 --request-rate inf bench "ShareGPT (1000 prompts, rate=4)" \ --dataset-name sharegpt \ --dataset-path "$SHAREGPT_PATH" \ --num-prompts 100 --request-rate 4 } # Sonnet: uses Shakespeare's sonnets with a shared prefix to test prefix caching. # All prompts share a common prefix (--sonnet-prefix-len=200 tokens of sonnet text), # then each request gets a unique suffix. This exercises vLLM's automatic prefix # caching — if enabled, the shared prefix KV cache is computed once and reused across # requests, dramatically reducing TTFT. Comparing sonnet results with prefix caching # on vs off shows the caching speedup. # Ref: vllm/benchmarks/datasets.py SonnetDataset # Ref: Default params: input=550, output=150, prefix=200 bench_sonnet() { SONNET_PATH="${SONNET_PATH:-sonnet.txt}" if [[ ! -f "$SONNET_PATH" ]]; then echo "Downloading Shakespeare sonnets..." wget -q -O "$SONNET_PATH" \ https://raw.githubusercontent.com/vllm-project/vllm/main/benchmarks/sonnet.txt fi bench "Sonnet (prefix caching, max rate)" \ --dataset-name sonnet --dataset-path "$SONNET_PATH" \ --sonnet-input-len 550 --sonnet-output-len 150 --sonnet-prefix-len 200 \ --num-prompts 100 --request-rate inf bench "Sonnet (prefix caching, rate=4)" \ --dataset-name sonnet --dataset-path "$SONNET_PATH" \ --sonnet-input-len 550 --sonnet-output-len 150 --sonnet-prefix-len 200 \ --num-prompts 100 --request-rate 4 } IFS=',' read -ra TESTS <<< "$TYPES" for t in "${TESTS[@]}"; do t=$(echo "$t" | xargs) # trim whitespace echo "========================================" echo "Running: $t" echo "========================================" case "$t" in throughput) bench_throughput ;; prefill) bench_prefill ;; decode) bench_decode ;; latency) bench_latency ;; concurrency) bench_concurrency ;; longctx) bench_longctx ;; sharegpt) bench_sharegpt ;; sonnet) bench_sonnet ;; *) echo "Unknown test: $t"; exit 1 ;; esac done ================================================ FILE: src/llm/vllm/offline_bench.py ================================================ #!/usr/bin/env python3 """ Offline vLLM benchmark without API server overhead. Based on vllm torchrun_dp_example.py for distributed inference. Usage: # Single GPU python offline_bench.py --model meta-llama/Llama-3.1-8B --num-prompts 50 # Multi-GPU with tensor parallelism torchrun --nproc-per-node=4 offline_bench.py \ --model Qwen/Qwen2-57B-A14B \ --tp-size 4 --enable-ep \ --num-prompts 100 # Data parallelism torchrun --nproc-per-node=8 offline_bench.py \ --model meta-llama/Llama-3.1-8B \ --tp-size 2 --dp-size 4 \ --num-prompts 200 """ import argparse import json import os import time from contextlib import contextmanager from pathlib import Path from typing import List import numpy as np from vllm import LLM, SamplingParams @contextmanager def viztracer_profiler(output_file, world_rank): """VizTracer profiler context manager.""" viztracer = None if output_file: try: from viztracer import VizTracer viztracer = VizTracer(output_file=output_file, verbose=0, log_torch=True) if world_rank == 0: print(f"VizTracer profiling enabled. Output: {output_file}") except ImportError: if world_rank == 0: print("Warning: viztracer not installed (pip install viztracer)") try: if viztracer: viztracer.start() if world_rank == 0: print("VizTracer started") yield finally: if viztracer: viztracer.stop() viztracer.save() if world_rank == 0: print(f"VizTracer trace saved to {output_file}") @contextmanager def cuda_profiler(enabled, world_rank): """CUDA profiler context manager for nsys.""" profiler = None if enabled: import torch.cuda.profiler as profiler try: if enabled: profiler.start() if world_rank == 0: print("CUDA profiler started for nsys") yield finally: if enabled and profiler: profiler.stop() if world_rank == 0: print("CUDA profiler stopped") def load_sharegpt_prompts(dataset_path: str, num_prompts: int) -> List[str]: """Load prompts from ShareGPT dataset.""" with open(dataset_path) as f: data = json.load(f) prompts = [] for item in data[:num_prompts]: if "conversations" in item and len(item["conversations"]) > 0: prompts.append(item["conversations"][0]["value"]) return prompts[:num_prompts] def generate_random_prompts(num_prompts: int, input_len: int, tokenizer) -> List[str]: """Generate random prompts with specified input length.""" # Use a fixed vocabulary for reproducibility vocab = list(range(1000, 10000)) # Use token IDs from vocab prompts = [] for _ in range(num_prompts): # Generate random token IDs token_ids = [vocab[i % len(vocab)] for i in range(input_len)] # Decode to text prompt = tokenizer.decode(token_ids) prompts.append(prompt) return prompts def generate_dummy_prompts(num_prompts: int) -> List[str]: """Generate dummy prompts for testing.""" base_prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", "Explain quantum computing in simple terms:", "Write a short story about a robot:", ] return [base_prompts[i % len(base_prompts)] for i in range(num_prompts)] def parse_args(): parser = argparse.ArgumentParser(description="Offline vLLM benchmark") # Model args parser.add_argument("--model", type=str, required=True, help="Model name or path") parser.add_argument( "--max-model-len", type=int, default=None, help="Max model length" ) parser.add_argument( "--gpu-memory-utilization", type=float, default=0.9, help="GPU memory utilization", ) # Parallelism args parser.add_argument( "--tensor-parallel-size", "--tp-size", type=int, default=1, dest="tp_size", help="Tensor parallel size", ) parser.add_argument( "--pipeline-parallel-size", "--pp-size", type=int, default=1, dest="pp_size", help="Pipeline parallel size", ) parser.add_argument( "--data-parallel-size", "--dp-size", type=int, default=1, dest="dp_size", help="Data parallel size", ) parser.add_argument( "--enable-expert-parallel", "--enable-ep", action="store_true", dest="enable_ep", help="Enable expert parallel", ) # Benchmark args parser.add_argument("--num-prompts", type=int, default=50, help="Number of prompts") parser.add_argument( "--dataset-path", type=str, default=None, help="Path to ShareGPT dataset" ) parser.add_argument( "--input-len", type=int, default=None, help="Input length for random prompts (default: use dataset or dummy prompts)", ) parser.add_argument( "--output-len", type=int, default=None, help="Output length (maps to --max-tokens)", ) parser.add_argument("--max-tokens", type=int, default=128, help="Max output tokens") parser.add_argument( "--temperature", type=float, default=0.0, help="Sampling temperature" ) parser.add_argument("--seed", type=int, default=0, help="Random seed") # Profiling args parser.add_argument( "--nsys", type=str, default=None, help="Enable nsys profiling and save to file (e.g., ./profile.nsys-rep)", ) parser.add_argument( "--viztracer", type=str, default=None, help="Enable VizTracer profiling and save to file (e.g., ./vllm-trace.json)", ) # vLLM args parser.add_argument( "--enforce-eager", action="store_true", help="Disable CUDA graph" ) parser.add_argument( "--all2all-backend", type=str, default=None, help="All-to-all backend (allgather_reducescatter, nccl)", ) return parser.parse_args() def main(): args = parse_args() # Map --output-len to --max-tokens if args.output_len: args.max_tokens = args.output_len # Sampling params sampling_params = SamplingParams( temperature=args.temperature, max_tokens=args.max_tokens, seed=args.seed, ignore_eos=args.input_len is not None, # Ignore EOS for random prompts ) # Initialize LLM first to get tokenizer llm_kwargs = { "model": args.model, "tensor_parallel_size": args.tp_size, "pipeline_parallel_size": args.pp_size, "data_parallel_size": args.dp_size, "enable_expert_parallel": args.enable_ep, "gpu_memory_utilization": args.gpu_memory_utilization, "seed": args.seed, "enforce_eager": args.enforce_eager, "disable_log_stats": False, } if args.max_model_len: llm_kwargs["max_model_len"] = args.max_model_len if args.all2all_backend: llm_kwargs["all2all_backend"] = args.all2all_backend # Use external launcher for multi-process if args.dp_size > 1: llm_kwargs["distributed_executor_backend"] = "external_launcher" print(f"Initializing LLM with config: {json.dumps(llm_kwargs, indent=2)}") llm = LLM(**llm_kwargs) # Load prompts if args.input_len: prompts = generate_random_prompts( args.num_prompts, args.input_len, llm.get_tokenizer() ) elif args.dataset_path: prompts = load_sharegpt_prompts(args.dataset_path, args.num_prompts) else: prompts = generate_dummy_prompts(args.num_prompts) # Get data parallel rank/size dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank dp_size = llm.llm_engine.vllm_config.parallel_config.data_parallel_size # Get world rank for printing import torch.distributed as dist world_rank = dist.get_rank() if dist.is_initialized() else 0 # Distribute prompts across DP ranks local_prompts = [p for i, p in enumerate(prompts) if i % dp_size == dp_rank] # Warmup if world_rank == 0: print("Warming up...") _ = llm.generate(local_prompts[:1], sampling_params) # Benchmark with profiling if world_rank == 0: print(f"Running benchmark with {len(prompts)} total prompts...") with cuda_profiler(args.nsys, world_rank), viztracer_profiler( args.viztracer, world_rank ): start = time.perf_counter() outputs = llm.generate(local_prompts, sampling_params) elapsed = time.perf_counter() - start # Collect per-request metrics metrics = [] for output in outputs: num_output_tokens = len(output.outputs[0].token_ids) input_tokens = len(output.prompt_token_ids) # Check if detailed metrics are available req_metrics = output.metrics if req_metrics and hasattr(req_metrics, "first_token_ts"): ttft = (req_metrics.first_token_ts - req_metrics.scheduled_ts) * 1000 if num_output_tokens > 1: total_gen_time = ( req_metrics.last_token_ts - req_metrics.first_token_ts ) * 1000 tpot = total_gen_time / (num_output_tokens - 1) else: tpot = 0 else: ttft = 0 tpot = 0 metrics.append( { "ttft": ttft, "tpot": tpot, "input_tokens": input_tokens, "output_tokens": num_output_tokens, } ) # Aggregate stats total_input = sum(m["input_tokens"] for m in metrics) total_output = sum(m["output_tokens"] for m in metrics) ttfts = [m["ttft"] for m in metrics] tpots = [m["tpot"] for m in metrics if m["tpot"] > 0] # Aggregate across DP ranks if dp_size > 1: import torch.distributed as dist # Collect stats from all ranks local_stats = [ len(metrics), elapsed, total_input, total_output, np.mean(ttfts) if ttfts else 0, np.median(ttfts) if ttfts else 0, np.percentile(ttfts, 99) if ttfts else 0, np.mean(tpots) if tpots else 0, np.median(tpots) if tpots else 0, np.percentile(tpots, 99) if tpots else 0, ] # Gather from all world ranks world_size = dist.get_world_size() world_rank = dist.get_rank() all_stats = [None] * world_size dist.all_gather_object(all_stats, local_stats) # Only print from world rank 0 if world_rank == 0: # Filter to only DP ranks (every TP_SIZE-th rank) tp_size = llm.llm_engine.vllm_config.parallel_config.tensor_parallel_size dp_stats = [all_stats[i] for i in range(0, world_size, tp_size)] # Aggregate total_reqs = sum(s[0] for s in dp_stats) max_time = max(s[1] for s in dp_stats) total_in = sum(s[2] for s in dp_stats) total_out = sum(s[3] for s in dp_stats) # Average latency metrics mean_ttft = np.mean([s[4] for s in dp_stats]) median_ttft = np.mean([s[5] for s in dp_stats]) p99_ttft = max(s[6] for s in dp_stats) mean_tpot = np.mean([s[7] for s in dp_stats]) median_tpot = np.mean([s[8] for s in dp_stats]) p99_tpot = max(s[9] for s in dp_stats) print_results( total_reqs, 0, max_time, total_in, total_out, mean_ttft, median_ttft, p99_ttft, mean_tpot, median_tpot, p99_tpot, ) # Barrier to ensure all ranks finish before cleanup dist.barrier() else: # Single process print_results( len(metrics), 0, elapsed, total_input, total_output, np.mean(ttfts) if ttfts else 0, np.median(ttfts) if ttfts else 0, np.percentile(ttfts, 99) if ttfts else 0, np.mean(tpots) if tpots else 0, np.median(tpots) if tpots else 0, np.percentile(tpots, 99) if tpots else 0, ) def print_results( successful, failed, duration, total_input, total_output, mean_ttft, median_ttft, p99_ttft, mean_tpot, median_tpot, p99_tpot, ): """Print results in vLLM bench serve format.""" print("\n============ Serving Benchmark Result ============") print(f"Successful requests: {successful:<10}") print(f"Failed requests: {failed:<10}") print(f"Benchmark duration (s): {duration:<10.2f}") print(f"Total input tokens: {total_input:<10}") print(f"Total generated tokens: {total_output:<10}") print(f"Request throughput (req/s): {successful/duration:<10.2f}") print(f"Output token throughput (tok/s): {total_output/duration:<10.2f}") print( f"Total token throughput (tok/s): {(total_input+total_output)/duration:<10.2f}" ) print("---------------Time to First Token----------------") print(f"Mean TTFT (ms): {mean_ttft:<10.2f}") print(f"Median TTFT (ms): {median_ttft:<10.2f}") print(f"P99 TTFT (ms): {p99_ttft:<10.2f}") print("-----Time per Output Token (excl. 1st token)------") print(f"Mean TPOT (ms): {mean_tpot:<10.2f}") print(f"Median TPOT (ms): {median_tpot:<10.2f}") print(f"P99 TPOT (ms): {p99_tpot:<10.2f}") print("---------------Inter-token Latency----------------") print(f"Mean ITL (ms): {mean_tpot:<10.2f}") print(f"Median ITL (ms): {median_tpot:<10.2f}") print(f"P99 ITL (ms): {p99_tpot:<10.2f}") print("==================================================\n") if __name__ == "__main__": main() ================================================ FILE: src/llm/vllm/offline_bench.sh ================================================ #!/usr/bin/env bash # Offline vLLM benchmark wrapper (no API server) # Usage: # salloc -N1 bash offline_bench.sh --model meta-llama/Llama-3.1-8B --num-prompts 50 # salloc -N2 bash offline_bench.sh --model Qwen/Qwen2-57B-A14B --tp-size 4 --enable-ep set -euo pipefail info() { echo "[$(date +'%H:%M:%S')] $*"; } CONTAINER_MOUNT="${CONTAINER_MOUNT:-/fsx}" IMAGE="${IMAGE:-${PWD}/vllm-serve-latest.tar.gz}" NPROC="${NPROC:-}" # Setup multi-node coordination early (before container launch) if [[ -n "${SLURM_JOB_ID:-}" ]]; then NUM_NODES=${SLURM_JOB_NUM_NODES:-1} readarray -t NODES < <(scontrol show hostnames "$SLURM_JOB_NODELIST") HEAD_NODE=${NODES[0]} HEAD_IP=$(getent ahostsv4 "$HEAD_NODE" | head -1 | awk '{print $1}') MASTER_PORT=$((29500 + (SLURM_JOB_ID % 1000))) else NUM_NODES=1 HEAD_IP="127.0.0.1" MASTER_PORT=29500 fi _run() { if [[ -n "${SLURM_JOB_ID:-}" ]]; then srun bash -c "$*" else bash -c "$*" fi } load_or_pull_image() { if [[ "${IMAGE}" == *.tar.gz ]]; then CONTAINER_IMAGE=$(pigz -dc "${IMAGE}" | tar -xf - -O manifest.json | python3 -c "import sys,json; print(json.load(sys.stdin)[0]['RepoTags'][0])") info "Image tag: ${CONTAINER_IMAGE}" _run " if ! docker image inspect '${CONTAINER_IMAGE}' &>/dev/null; then echo 'Loading Docker image from tarball...' pigz -dc '${IMAGE}' | docker load fi " else CONTAINER_IMAGE="${IMAGE}" _run " if ! docker image inspect '${CONTAINER_IMAGE}' &>/dev/null; then echo 'Pulling ${CONTAINER_IMAGE}...' registry=\"\${CONTAINER_IMAGE%%/*}\" region=\$(echo \"\${registry}\" | sed -n 's/.*\.ecr\.\([^.]*\)\.amazonaws\.com/\1/p') region=\"\${region:-us-west-2}\" aws ecr get-login-password --region \"\${region}\" \ | docker login --username AWS --password-stdin \"\${registry}\" docker pull '${CONTAINER_IMAGE}' fi " fi } launch_container() { local cmd="$1" _run " docker run --rm --gpus all --privileged --ipc=host --net=host \ -v '${PWD}:${PWD}' -w '${PWD}' \ -v '${CONTAINER_MOUNT}:${CONTAINER_MOUNT}' \ --entrypoint bash '${CONTAINER_IMAGE}' \ -c '${cmd}' " } usage() { cat </dev/null || ! python3 -c "import vllm" &>/dev/null 2>&1; then load_or_pull_image SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" # Build nsys command prefix NSYS_CMD="" NSYS_ARG="" if [[ "${ENABLE_NSYS}" == "true" ]]; then NSYS_DIR="${PWD}/nsys-offline" mkdir -p "${NSYS_DIR}" NSYS_PATH="${NSYS_DIR}/profile-node${SLURM_NODEID:-0}.nsys-rep" NSYS_CMD="nsys profile" NSYS_CMD+=" -t cuda,nvtx,osrt,cudnn,cublas" NSYS_CMD+=" --trace-fork-before-exec=true" NSYS_CMD+=" --cuda-graph-trace=node" NSYS_CMD+=" --capture-range=cudaProfilerApi" NSYS_CMD+=" --capture-range-end=repeat" NSYS_CMD+=" --cuda-memory-usage=true" NSYS_CMD+=" --cudabacktrace=true" NSYS_CMD+=" -o ${NSYS_PATH}" NSYS_CMD+=" --force-overwrite=true" NSYS_ARG="--nsys ${NSYS_PATH}" fi TORCHRUN_CMD="${NSYS_CMD:+${NSYS_CMD} }torchrun \ --nnodes=${NUM_NODES} \ --nproc-per-node=${NPROC} \ --rdzv-backend=c10d \ --rdzv-endpoint=${HEAD_IP}:${MASTER_PORT} \ --rdzv-id=${SLURM_JOB_ID:-12345} \ ${SCRIPT_DIR}/offline_bench.py ${BENCH_ARGS[*]}" info "========================================" info "Offline vLLM Benchmark" info "========================================" info "Nodes: ${NUM_NODES}, Processes per node: ${NPROC}" info "Rendezvous: ${HEAD_IP}:${MASTER_PORT}" info "Command: ${TORCHRUN_CMD}" info "========================================" launch_container "${TORCHRUN_CMD}" exit $? fi # This code only runs if already inside container with vllm available SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" info "========================================" info "Offline vLLM Benchmark (inside container)" info "========================================" info "Nodes: ${NUM_NODES}, Processes per node: ${NPROC}" info "Rendezvous: ${HEAD_IP}:${MASTER_PORT}" info "========================================" ${NSYS_CMD} torchrun \ --nnodes="${NUM_NODES}" \ --nproc-per-node="${NPROC}" \ --rdzv-backend=c10d \ --rdzv-endpoint="${HEAD_IP}:${MASTER_PORT}" \ --rdzv-id="${SLURM_JOB_ID:-12345}" \ "${SCRIPT_DIR}/offline_bench.py" \ ${NSYS_ARG} \ "${BENCH_ARGS[@]}" ================================================ FILE: src/llm/vllm/run.sbatch ================================================ #!/bin/bash set -euo pipefail GPUS="${GPUS:-all}" DP_BACKEND="${DP_BACKEND:-rpc}" info() { echo -e "[$(date +'%Y-%m-%dT%H:%M:%S%z')][info] $*"; } err() { echo -e "[$(date +'%Y-%m-%dT%H:%M:%S%z')][error] $*" >&2; } IMAGE="" CONTAINER_MOUNT="/fsx" WORKSPACE="$PWD" FORCE_PULL=false ENABLE_NSYS=false VLLM_STARTED=false SERVE_ARGS=() while (( "$#" )); do case "$1" in --image) IMAGE="$2"; shift 2 ;; --container-mount) CONTAINER_MOUNT="$2"; shift 2 ;; --workspace|-w) WORKSPACE="$2"; shift 2 ;; --force|-f) FORCE_PULL=true; shift ;; --nsys) ENABLE_NSYS=true; shift ;; --profile) SERVE_ARGS+=(--profiler-config "{\"profiler\": \"torch\", \"torch_profiler_dir\": \"${PWD}/vllm_profile\"}"); shift ;; --profiler-config) SERVE_ARGS+=(--profiler-config "$2"); shift 2 ;; *) SERVE_ARGS+=("$1"); shift ;; esac done # Build nsys command prefix NSYS_CMD="" if [[ "${ENABLE_NSYS}" == "true" ]]; then NSYS_DIR="${WORKSPACE}/nsys-vllm" mkdir -p "${NSYS_DIR}" NSYS_PATH="${NSYS_DIR}/profile-node${SLURM_NODEID:-0}.nsys-rep" NSYS_CMD="nsys profile" NSYS_CMD+=" -t cuda,nvtx,osrt,cudnn,cublas" NSYS_CMD+=" --trace-fork-before-exec=true" NSYS_CMD+=" --cuda-graph-trace=node" NSYS_CMD+=" --capture-range=cudaProfilerApi" NSYS_CMD+=" --capture-range-end=repeat" NSYS_CMD+=" --cuda-memory-usage=true" NSYS_CMD+=" --cudabacktrace=true" NSYS_CMD+=" -o ${NSYS_PATH}" NSYS_CMD+=" --force-overwrite=true" fi IMAGE="${IMAGE:-${WORKSPACE}/vllm-serve-latest.tar.gz}" LOGDIR="${WORKSPACE}/logs" # Build a shell-safe string from SERVE_ARGS for nested bash -c / docker exec SERVE_ARGS_STR=$(printf '%q ' "${SERVE_ARGS[@]+"${SERVE_ARGS[@]}"}") # Peek at SERVE_ARGS to extract values needed for topology computation _peek_arg() { local short="$1" long="$2" default="$3" local i=0 while (( i < ${#SERVE_ARGS[@]} )); do if [[ "${SERVE_ARGS[$i]}" == "$short" || "${SERVE_ARGS[$i]}" == "$long" ]]; then echo "${SERVE_ARGS[$((i+1))]}"; return fi ((i++)) done echo "$default" } _has_flag() { for arg in "${SERVE_ARGS[@]}"; do [[ "$arg" == "$1" ]] && return 0; done return 1 } TP=$(_peek_arg "-tp" "--tensor-parallel-size" "1") PP=$(_peek_arg "-pp" "--pipeline-parallel-size" "1") ENABLE_EP=$(_has_flag "--enable-expert-parallel" && echo "true" || echo "false") load_or_pull_image() { if [[ "${FORCE_PULL}" == "true" ]]; then info "Force pull: cleaning up existing images..." srun --ntasks-per-node=1 bash -c ' docker ps -aq | xargs -r docker rm -f 2>/dev/null || true docker images -aq | xargs -r docker rmi -f 2>/dev/null || true ' fi if [[ "${IMAGE}" == *.tar.gz ]]; then info "Loading Docker image from tarball..." CONTAINER_IMAGE=$(pigz -dc "${IMAGE}" | tar -xf - -O manifest.json | python3 -c "import sys,json; print(json.load(sys.stdin)[0]['RepoTags'][0])") info "Image tag: ${CONTAINER_IMAGE}" srun --ntasks-per-node=1 bash -c " if ! docker image inspect '${CONTAINER_IMAGE}' &>/dev/null; then pigz -dc '${IMAGE}' | docker load fi " else info "Pulling Docker image from registry..." local registry="${IMAGE%%/*}" local region=$(echo "${registry}" | sed -n 's/.*\.ecr\.\([^.]*\)\.amazonaws\.com/\1/p') region="${region:-us-west-2}" srun --ntasks-per-node=1 bash -c " if ! docker image inspect '${IMAGE}' &>/dev/null; then aws ecr get-login-password --region '${region}' | docker login --username AWS --password-stdin '${registry}' docker pull '${IMAGE}' fi " CONTAINER_IMAGE="${IMAGE}" fi } launch_container() { local name="${1}" cmd="${2}" local devices=("--device=/dev/gdrdrv") while IFS= read -r -d '' d; do devices+=("--device=${d}") done < <(find "/dev/infiniband" -name "uverbs*" -print0 2>/dev/null) local net_if="${GLOO_SOCKET_IFNAME:-$(ip -o -4 route show to default | awk '{print $5}' | head -1)}" docker run --gpus "${GPUS}" \ --privileged -d \ --name "${name}" \ --uts=host --ipc=host --net=host \ --ulimit stack=67108864 --ulimit memlock=-1 \ --security-opt seccomp=unconfined \ "${devices[@]}" \ -v "${CONTAINER_MOUNT}:${CONTAINER_MOUNT}" \ -e NCCL_SOCKET_IFNAME="${net_if}" \ -e GLOO_SOCKET_IFNAME="${net_if}" \ -e TP_SOCKET_IFNAME="${net_if}" \ --entrypoint bash \ "${CONTAINER_IMAGE:-${IMAGE}}" \ -c "${cmd}" } setup_topology() { NUM_NODES=${SLURM_JOB_NUM_NODES:-1} GPUS_PER_NODE=8 TOTAL_GPUS=$((NUM_NODES * GPUS_PER_NODE)) if [[ "$PP" -gt 1 && "$ENABLE_EP" == "true" ]]; then err "Pipeline parallel (PP=$PP) and expert parallel cannot be enabled simultaneously" exit 1 fi [[ "$PP" -gt 1 ]] && DP_BACKEND="mp" if [[ "$ENABLE_EP" == "true" ]]; then DP=$((TOTAL_GPUS / TP)) if [[ $((DP * TP)) -ne $TOTAL_GPUS ]]; then err "DP($DP) * TP($TP) = $((DP * TP)) != TOTAL_GPUS($TOTAL_GPUS)"; exit 1 fi else DP=$((TOTAL_GPUS / (TP * PP))) if [[ $((DP * TP * PP)) -ne $TOTAL_GPUS ]]; then err "DP($DP) * TP($TP) * PP($PP) = $((DP * TP * PP)) != TOTAL_GPUS($TOTAL_GPUS)"; exit 1 fi fi DP_LOCAL=$((GPUS_PER_NODE / TP)) readarray -t NODES < <(scontrol show hostnames "$SLURM_JOB_NODELIST") HEAD_NODE=${NODES[0]} HEAD_IP=$(getent ahostsv4 "$HEAD_NODE" | head -1 | awk '{print $1}') RAY_PORT=$((6379 + (SLURM_JOB_ID % 1000))) RPC_PORT=$((13345 + (SLURM_JOB_ID % 1000))) mkdir -p "${LOGDIR}" info "========================================" info "vLLM Server" info "========================================" info "Image: ${IMAGE}" info "Nodes: ${NUM_NODES}, Head: ${HEAD_NODE} (${HEAD_IP}), GPUs: ${TOTAL_GPUS}" info "Parallelism: TP=${TP}, PP=${PP}, DP=${DP}, DP_LOCAL=${DP_LOCAL}, EP=${ENABLE_EP}" info "Backend: ${DP_BACKEND}" info "SERVE_ARGS: ${SERVE_ARGS[*]+"${SERVE_ARGS[*]}"}" info "========================================" } stop_nsys() { [[ "${VLLM_STARTED}" != "true" ]] && return 0 info "Sending SIGINT to nsys processes for graceful shutdown..." srun --ntasks-per-node=1 bash -c ' for cid in $(docker ps -q); do docker exec "$cid" pkill -INT -f "^nsys profile" 2>/dev/null || true done ' 2>/dev/null || true } wait_for_nsys() { [[ "${VLLM_STARTED}" != "true" ]] && return 0 info "Waiting 30s for nsys to finalize profiles..." sleep 30 } cleanup() { info "Cleaning up containers..." if [[ "${ENABLE_NSYS}" == "true" ]]; then stop_nsys wait_for_nsys fi srun --ntasks-per-node=1 bash -c ' docker ps -aq | xargs -r docker stop -t 30 2>/dev/null || true docker ps -aq | xargs -r docker rm -f 2>/dev/null || true ' 2>/dev/null || true rm -f "${LOGDIR}/vllm_server_${SLURM_JOB_ID}.log" } start_ray_head() { info "Starting Ray head on ${HEAD_NODE}..." srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c " $(declare -f launch_container) CONTAINER_IMAGE='${CONTAINER_IMAGE:-}' IMAGE='${IMAGE}' GPUS='${GPUS}' CONTAINER_MOUNT='${CONTAINER_MOUNT}' launch_container ray-head 'sleep infinity' " sleep 5 srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c " docker exec ray-head ray start --head --port=${RAY_PORT} \ --num-gpus=${GPUS_PER_NODE} --num-cpus=96 --disable-usage-stats " } start_ray_workers() { [[ "$NUM_NODES" -le 1 ]] && return local worker_nodes=$(echo "${NODES[@]:1}" | tr ' ' ',') info "Starting Ray workers on ${worker_nodes}..." srun --nodes=$((NUM_NODES - 1)) --nodelist="${worker_nodes}" --ntasks-per-node=1 bash -c " $(declare -f launch_container) CONTAINER_IMAGE='${CONTAINER_IMAGE:-}' IMAGE='${IMAGE}' GPUS='${GPUS}' CONTAINER_MOUNT='${CONTAINER_MOUNT}' launch_container ray-worker 'sleep infinity' " sleep 5 srun --nodes=$((NUM_NODES - 1)) --nodelist="${worker_nodes}" --ntasks-per-node=1 bash -c " docker exec ray-worker ray start --address=${HEAD_IP}:${RAY_PORT} \ --num-gpus=${GPUS_PER_NODE} --num-cpus=96 --disable-usage-stats " } wait_for_gpus() { info "Waiting for ${TOTAL_GPUS} GPUs..." for _ in {1..120}; do local gpu_count gpu_count=$(srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c " docker exec ray-head python3 -c \ 'import ray; ray.init(address=\"auto\"); print(int(ray.cluster_resources().get(\"GPU\",0))); ray.shutdown()' \ 2>/dev/null" || echo 0) [[ "$gpu_count" -ge "$TOTAL_GPUS" ]] && return 0 sleep 5 done err "Timeout waiting for GPUs"; return 1 } start_vllm_ray() { info "Launching vllm serve (Ray)..." local logfile="${LOGDIR}/vllm_server_${SLURM_JOB_ID}.log" local extra="--host 0.0.0.0 --port 8000 --data-parallel-backend ray --data-parallel-address ${HEAD_IP} --data-parallel-size ${DP}" srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c \ "docker exec -d ray-head bash -c '${NSYS_CMD} vllm serve ${SERVE_ARGS_STR} ${extra} 2>&1 | tee ${logfile}'" } start_vllm_mp() { info "Starting vLLM with PP (multiprocessing)..." local logfile="${LOGDIR}/vllm_server_${SLURM_JOB_ID}.log" for i in $(seq 0 $((NUM_NODES - 1))); do srun --nodes=1 --nodelist="${NODES[$i]}" bash -c " $(declare -f launch_container) IMAGE='${CONTAINER_IMAGE:-${IMAGE}}' GPUS='${GPUS}' CONTAINER_MOUNT='${CONTAINER_MOUNT}' launch_container vllm-node-${i} 'sleep infinity' " & done wait local extra="--host 0.0.0.0 --port 8000 --nnodes ${NUM_NODES} --master-addr ${HEAD_IP} --master-port 29500" for i in $(seq 1 $((NUM_NODES - 1))); do srun --nodes=1 --nodelist="${NODES[$i]}" bash -c \ "docker exec -d vllm-node-${i} bash -c '${NSYS_CMD} vllm serve ${SERVE_ARGS_STR} ${extra} --node-rank ${i} --headless 2>&1 | tee ${logfile}.node${i}'" done srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c \ "docker exec -d vllm-node-0 bash -c '${NSYS_CMD} vllm serve ${SERVE_ARGS_STR} ${extra} --node-rank 0 2>&1 | tee ${logfile}'" } # RPC backend start_vllm_rpc() { info "Starting vLLM with RPC-based DP..." local logfile="${LOGDIR}/vllm_server_${SLURM_JOB_ID}.log" srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c " $(declare -f launch_container) IMAGE='${CONTAINER_IMAGE:-${IMAGE}}' GPUS='${GPUS}' CONTAINER_MOUNT='${CONTAINER_MOUNT}' launch_container vllm-head 'sleep infinity' " for i in $(seq 1 $((NUM_NODES - 1))); do srun --nodes=1 --nodelist="${NODES[$i]}" bash -c " $(declare -f launch_container) IMAGE='${CONTAINER_IMAGE:-${IMAGE}}' GPUS='${GPUS}' CONTAINER_MOUNT='${CONTAINER_MOUNT}' launch_container vllm-worker 'sleep infinity' " done sleep 3 local extra="--data-parallel-size ${DP} --data-parallel-size-local ${DP_LOCAL} --data-parallel-address ${HEAD_IP} --data-parallel-rpc-port ${RPC_PORT}" for i in $(seq 1 $((NUM_NODES - 1))); do local start_rank=$((i * DP_LOCAL)) info "Starting RPC worker on ${NODES[$i]} (rank ${start_rank})..." srun --nodes=1 --nodelist="${NODES[$i]}" bash -c " docker exec -d vllm-worker bash -c '${NSYS_CMD} vllm serve ${SERVE_ARGS_STR} ${extra} \ --data-parallel-start-rank ${start_rank} --headless \ 2>&1 | tee ${LOGDIR}/vllm_worker_${SLURM_JOB_ID}_${i}.log' " done srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c " docker exec -d vllm-head bash -c '${NSYS_CMD} vllm serve ${SERVE_ARGS_STR} ${extra} \ --host 0.0.0.0 --port 8000 \ 2>&1 | tee ${logfile}' " } wait_for_server() { info "Waiting for vLLM server at ${HEAD_IP}:8000..." for _ in {1..360}; do if srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c "curl -sf localhost:8000/health" &>/dev/null && srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c "curl -sf localhost:8000/v1/models | grep -q '\"id\"'" &>/dev/null; then info "Server ready at ${HEAD_IP}:8000" return 0 fi sleep 10 done err "Timeout waiting for server"; return 1 } setup_topology trap cleanup EXIT cleanup LOGFILE="${LOGDIR}/vllm_server_${SLURM_JOB_ID}.log" load_or_pull_image case "${DP_BACKEND}" in ray) start_ray_head start_ray_workers wait_for_gpus start_vllm_ray ;; mp) start_vllm_mp ;; rpc) start_vllm_rpc ;; *) err "Unknown backend: ${DP_BACKEND}"; exit 1 ;; esac tail -f "${LOGFILE}" 2>/dev/null & wait_for_server || exit 1 VLLM_STARTED=true info "vLLM serving on ${HEAD_IP}:8000 — Ctrl+C or scancel to stop" info "Logs: ${LOGFILE}" sleep infinity ================================================ FILE: src/llm/vllm/run.sh ================================================ #!/bin/bash # vLLM server launcher with multiple modes set -uo pipefail PROGRAM="$0" MODE="single" MODEL="Qwen/Qwen2.5-7B-Instruct" PORT="8000" TP="1" PP="1" DP="1" HEAD_IP="127.0.0.1" NODE_RANK="0" # Log info message with timestamp info() { echo -e "[$(date +'%Y-%m-%dT%H:%M:%S%z')][info] $*" } # Log error message with timestamp to stderr err() { echo -e "[$(date +'%Y-%m-%dT%H:%M:%S%z')][error] $*" >&2 } # Display usage information usage() { cat </dev/null; then echo 'Loading Docker image from tarball...' pigz -dc '${IMAGE}' | docker load fi " else CONTAINER_IMAGE="${IMAGE}" _run " if ! docker image inspect '${CONTAINER_IMAGE}' &>/dev/null; then echo 'Pulling ${CONTAINER_IMAGE}...' registry=\"\${CONTAINER_IMAGE%%/*}\" region=\$(echo \"\${registry}\" | sed -n 's/.*\.ecr\.\([^.]*\)\.amazonaws\.com/\1/p') region=\"\${region:-us-west-2}\" aws ecr get-login-password --region \"\${region}\" \ | docker login --username AWS --password-stdin \"\${registry}\" docker pull '${CONTAINER_IMAGE}' fi " fi } # Passes args as an array directly to docker so --serve-cmd/--bench-cmd quoting # is preserved (no bash -c needed). launch_container() { local devices=() [[ -e /dev/gdrdrv ]] && devices+=("--device=/dev/gdrdrv") while IFS= read -r -d '' d; do devices+=("--device=${d}") done < <(find "/dev/infiniband" -name "uverbs*" -print0 2>/dev/null) local docker_cmd=( docker run --rm --gpus "${GPUS}" --ipc=host --net=host --uts=host --ulimit stack=67108864 --ulimit memlock=-1 "${devices[@]}" -v "${PWD}:${PWD}" -w "${PWD}" -v "${CONTAINER_MOUNT}:${CONTAINER_MOUNT}" -e VLLM_LOGGING_LEVEL="${VLLM_LOGGING_LEVEL:-INFO}" "${CONTAINER_IMAGE}" "$@" ) if [[ -n "${SLURM_JOB_ID:-}" ]]; then srun -N1 --ntasks-per-node=1 "${docker_cmd[@]}" else "${docker_cmd[@]}" fi } # Auto-generate --serve-cmd and --bench-cmd if not in SWEEP_ARGS has_arg() { for a in "${SWEEP_ARGS[@]+"${SWEEP_ARGS[@]}"}"; do [[ "$a" == "$1" ]] && return 0; done; return 1; } defaults=() if ! has_arg "--serve-cmd"; then defaults+=(--serve-cmd "vllm serve ${MODEL}") fi if ! has_arg "--bench-cmd"; then defaults+=(--bench-cmd "vllm bench serve --model ${MODEL}") fi info "Model: ${MODEL}" load_or_pull_image launch_container \ vllm bench sweep serve \ "${defaults[@]+"${defaults[@]}"}" \ "${SWEEP_ARGS[@]+"${SWEEP_ARGS[@]}"}" info "Sweep complete" ================================================ FILE: src/llm/vllm/sweep.sh ================================================ #!/usr/bin/env bash # vLLM sweep benchmark suite # # Runs predefined sweep configurations via sweep.sbatch. Each suite writes # a different bench_params.json and invokes sweep.sbatch with the appropriate # --bench-cmd for that workload. # # Usage: # bash sweep.sh -m Qwen/Qwen3-0.6B # bash sweep.sh -m Qwen/Qwen3-30B-A3B-FP8 -i $PWD/vllm-serve-latest.tar.gz \ # --serve-cmd "vllm serve Qwen/Qwen3-30B-A3B-FP8 -tp 8 --enable-expert-parallel" \ # --type rate,concurrency set -euo pipefail info() { echo "[$(date +'%H:%M:%S')] $*"; } MODEL="Qwen/Qwen3-0.6B" IMAGE="" RESULT_DIR="results" NUM_PROMPTS=100 SEED=42 TYPES="rate,concurrency,input,output" SBATCH_ARGS=() # passthrough to sweep.sbatch usage() { cat < "${out}/bench_params.json" info "=== Sweep: ${name} ===" bash "${SCRIPT_DIR}/sweep.sbatch" \ -m "$MODEL" \ ${IMAGE:+-i "$IMAGE"} \ --bench-cmd "vllm bench serve --model ${MODEL} ${bench_extra} --seed ${SEED}" \ --bench-params "${out}/bench_params.json" \ -o "$out" \ "${SBATCH_ARGS[@]+"${SBATCH_ARGS[@]}"}" } # Request rate: find saturation point with fixed workload sweep_rate() { run_sweep "rate" \ "--dataset-name random --random-input-len 512 --random-output-len 256 --num-prompts ${NUM_PROMPTS}" \ '[{"request-rate":1},{"request-rate":2},{"request-rate":4},{"request-rate":8},{"request-rate":16},{"request-rate":32},{"request-rate":"inf"}]' } # Concurrency: find optimal batch size sweep_concurrency() { run_sweep "concurrency" \ "--dataset-name random --random-input-len 512 --random-output-len 256 --num-prompts ${NUM_PROMPTS}" \ '[{"max-concurrency":1},{"max-concurrency":2},{"max-concurrency":4},{"max-concurrency":8},{"max-concurrency":16},{"max-concurrency":32},{"max-concurrency":64},{"max-concurrency":128}]' } # Input length: measure TTFT scaling with context size sweep_input() { run_sweep "input" \ "--dataset-name random --random-output-len 128 --num-prompts ${NUM_PROMPTS} --request-rate inf" \ '[{"random-input-len":128},{"random-input-len":256},{"random-input-len":512},{"random-input-len":1024},{"random-input-len":2048},{"random-input-len":4096},{"random-input-len":8192},{"random-input-len":16384}]' } # Output length: measure ITL as KV cache grows sweep_output() { run_sweep "output" \ "--dataset-name random --random-input-len 512 --num-prompts ${NUM_PROMPTS} --request-rate inf" \ '[{"random-output-len":64},{"random-output-len":128},{"random-output-len":256},{"random-output-len":512},{"random-output-len":1024},{"random-output-len":2048}]' } IFS=',' read -ra TESTS <<< "$TYPES" for t in "${TESTS[@]}"; do t=$(echo "$t" | xargs) case "$t" in rate) sweep_rate ;; concurrency) sweep_concurrency ;; input) sweep_input ;; output) sweep_output ;; *) echo "Unknown sweep: $t"; exit 1 ;; esac done info "All sweeps complete — results in ${RESULT_DIR}/" ================================================ FILE: src/llm/vllm/test.sh ================================================ #!/usr/bin/env bash # vLLM API test script set -uo pipefail HOST="localhost" PORT="8000" MODEL="" while (( "$#" )); do case "$1" in -h|--help) echo "Usage: $0 [-H host] [-p port] [-m model]"; exit 0 ;; -H|--host) HOST="$2"; shift 2 ;; -p|--port) PORT="$2"; shift 2 ;; -m|--model) MODEL="$2"; shift 2 ;; *) echo "Unknown option: $1"; exit 1 ;; esac done BASE_URL="http://${HOST}:${PORT}" # Auto-detect model from server if not specified if [[ -z "$MODEL" ]]; then MODEL=$(curl -sf "${BASE_URL}/v1/models" | python3 -c "import sys,json; print(json.load(sys.stdin)['data'][0]['id'])" 2>/dev/null) if [[ -z "$MODEL" ]]; then echo "ERROR: Cannot detect model. Is the server running at ${BASE_URL}?" exit 1 fi fi PASS=0 FAIL=0 echo "Testing vLLM server at ${BASE_URL}" echo "Model: ${MODEL}" echo "========================================" test_endpoint() { local name="$1" cmd="$2" echo -ne "\n${name}... " if eval "$cmd" 2>&1; then ((PASS++)) else ((FAIL++)) fi } test_endpoint "[1/10] List models" \ "curl -sf '${BASE_URL}/v1/models' | jq -e '.data'" test_endpoint "[2/10] Basic completions" \ "curl -sf -X POST '${BASE_URL}/v1/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"prompt\": \"Hello\", \"max_tokens\": 10}' | jq -e '.choices'" test_endpoint "[3/10] Batch completions" \ "curl -sf -X POST '${BASE_URL}/v1/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"prompt\": [\"Once\", \"In\"], \"max_tokens\": 10}' | jq -e '.choices'" test_endpoint "[4/10] Chat completions" \ "curl -sf -X POST '${BASE_URL}/v1/chat/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"messages\": [{\"role\": \"user\", \"content\": \"Hi\"}], \"max_tokens\": 10}' | jq -e '.choices'" test_endpoint "[5/10] Sampling parameters" \ "curl -sf -X POST '${BASE_URL}/v1/chat/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"messages\": [{\"role\": \"user\", \"content\": \"Hi\"}], \"max_tokens\": 10, \"temperature\": 0.9, \"top_p\": 0.95}' | jq -e '.choices'" test_endpoint "[6/10] Streaming" \ "curl -sf -X POST '${BASE_URL}/v1/chat/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"messages\": [{\"role\": \"user\", \"content\": \"Hi\"}], \"stream\": true, \"max_tokens\": 10}' | grep -q 'data:'" test_endpoint "[7/10] Logprobs" \ "curl -sf -X POST '${BASE_URL}/v1/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"prompt\": \"Hello\", \"max_tokens\": 5, \"logprobs\": 5}' | jq -e '.choices[0].logprobs'" test_endpoint "[8/10] Stop sequences" \ "curl -sf -X POST '${BASE_URL}/v1/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"prompt\": \"1.\", \"max_tokens\": 20, \"stop\": [\"3.\"]}' | jq -e '.choices'" test_endpoint "[9/10] Echo" \ "curl -sf -X POST '${BASE_URL}/v1/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"prompt\": \"Hello\", \"max_tokens\": 5, \"echo\": true}' | jq -e '.choices'" test_endpoint "[10/10] Multiple completions (n=2)" \ "curl -sf -X POST '${BASE_URL}/v1/completions' \ -H 'Content-Type: application/json' \ -d '{\"model\": \"${MODEL}\", \"prompt\": \"Hi\", \"max_tokens\": 10, \"n\": 2}' | jq -e '.choices | length == 2'" echo -e "\n========================================" echo "Results: ${PASS} passed, ${FAIL} failed" [[ $FAIL -gt 0 ]] && exit 1 ================================================ FILE: src/megatron/Dockerfile ================================================ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Modifications copyright (c) 2025 chang-ning # Modifications licensed under the Creative Commons Attribution 4.0 International License (CC BY 4.0). # See LICENSE or https://creativecommons.org/licenses/by/4.0/ ARG CUDA_VERSION=12.8.1 FROM nvcr.io/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu24.04 ARG GDRCOPY_VERSION=v2.5.1 ARG EFA_INSTALLER_VERSION=1.47.0 ARG AWS_OFI_NCCL_VERSION=5f4202f11db1585d878196db4430aeda0e834a0c ARG NCCL_VERSION=v2.29.3-1 ARG NCCL_TESTS_VERSION=v2.17.9 ARG NVSHMEM_VERSION=v3.5.19-1 ARG TORCH_VERSION=2.9.1 ARG MEGATRON_BRIDGE_VERSION=v0.2.2 RUN apt-get update -y && apt-get upgrade -y RUN apt-get remove -y --allow-change-held-packages \ ibverbs-utils \ libibverbs-dev \ libibverbs1 \ libmlx5-1 \ libnccl2 \ libnccl-dev RUN rm -rf /opt/hpcx \ && rm -rf /usr/local/mpi \ && rm -f /etc/ld.so.conf.d/hpcx.conf \ && ldconfig ENV OPAL_PREFIX= RUN DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ apt-utils \ autoconf \ automake \ build-essential \ check \ cmake \ ninja-build \ curl \ debhelper \ devscripts \ git \ gcc \ gdb \ kmod \ libsubunit-dev \ libtool \ openssh-client \ openssh-server \ pkg-config \ vim \ hwloc \ libhwloc-dev \ python3-dev \ python3-venv \ libomp-dev RUN apt-get purge -y cuda-compat-* RUN mkdir -p /var/run/sshd RUN sed -i 's/[ #]\(.*StrictHostKeyChecking \).*/ \1no/g' /etc/ssh/ssh_config && \ echo " UserKnownHostsFile /dev/null" >> /etc/ssh/ssh_config && \ sed -i 's/#\(StrictModes \).*/\1no/g' /etc/ssh/sshd_config # Set paths for both aarch64 and x86_64 ENV LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:/opt/amazon/openmpi/lib:/opt/nccl/build/lib:/opt/amazon/efa/lib:/opt/amazon/ofi-nccl/lib:/usr/local/lib:$LD_LIBRARY_PATH ENV PATH=/opt/amazon/openmpi/bin/:/opt/amazon/efa/bin:/usr/bin:/usr/local/bin:$PATH RUN apt-get install -y python3-pip \ && pip3 install --break-system-packages --no-cache-dir awscli nvidia-ml-py Cython ################################################# ## Install NVIDIA GDRCopy RUN git clone -b ${GDRCOPY_VERSION} https://github.com/NVIDIA/gdrcopy.git /tmp/gdrcopy \ && cd /tmp/gdrcopy \ && make prefix=/opt/gdrcopy install \ && rm -rf /tmp/gdrcopy ENV LD_LIBRARY_PATH=/opt/gdrcopy/lib:$LD_LIBRARY_PATH ENV LIBRARY_PATH=/opt/gdrcopy/lib:$LIBRARY_PATH ENV CPATH=/opt/gdrcopy/include:$CPATH ENV PATH=/opt/gdrcopy/bin:$PATH ################################################# ## Install EFA installer RUN cd $HOME \ && curl -O https://efa-installer.amazonaws.com/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \ && tar -xf $HOME/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \ && cd aws-efa-installer \ && ./efa_installer.sh -y -g -d --skip-kmod --skip-limit-conf --no-verify --skip-plugin \ && rm -rf $HOME/aws-efa-installer $HOME/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz ################################################### ## Install aws-ofi-nccl from source (pinned commit) RUN git clone https://github.com/aws/aws-ofi-nccl.git /tmp/aws-ofi-nccl \ && cd /tmp/aws-ofi-nccl \ && git checkout ${AWS_OFI_NCCL_VERSION} \ && ./autogen.sh \ && ./configure --prefix=/opt/amazon/ofi-nccl \ --with-libfabric=/opt/amazon/efa \ --with-cuda=/usr/local/cuda \ --with-nvtx=/usr/local/cuda \ && make -j$(nproc) \ && make install \ && rm -rf /tmp/aws-ofi-nccl ################################################### ## Install NCCL RUN git clone -b ${NCCL_VERSION} https://github.com/NVIDIA/nccl.git /opt/nccl \ && cd /opt/nccl \ && make -j $(nproc) src.build CUDA_HOME=/usr/local/cuda \ NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_89,code=sm_89 -gencode=arch=compute_90,code=sm_90 -gencode=arch=compute_100,code=sm_100" ################################################### ## Install NCCL-tests RUN git clone -b ${NCCL_TESTS_VERSION} https://github.com/NVIDIA/nccl-tests.git /opt/nccl-tests \ && cd /opt/nccl-tests \ && make -j $(nproc) \ MPI=1 \ MPI_HOME=/opt/amazon/openmpi/ \ CUDA_HOME=/usr/local/cuda \ NCCL_HOME=/opt/nccl/build \ NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_89,code=sm_89 -gencode=arch=compute_90,code=sm_90 -gencode=arch=compute_100,code=sm_100" ################################################### ## Build NCCL Device API examples RUN cd /opt/nccl/examples/06_device_api \ && make -j $(nproc) NCCL_HOME=/opt/nccl/build CUDA_HOME=/usr/local/cuda MPI=1 MPI_HOME=/opt/amazon/openmpi ################################################### ## Install NVSHMEM ENV NVSHMEM_DIR=/opt/nvshmem ENV NVSHMEM_HOME=/opt/nvshmem RUN git clone -b ${NVSHMEM_VERSION} https://github.com/NVIDIA/nvshmem.git \ && cd nvshmem \ && mkdir -p build \ && cd build \ && cmake -DNVSHMEM_PREFIX=/opt/nvshmem \ -DCMAKE_CUDA_ARCHITECTURES="80;90" \ -DNVSHMEM_MPI_SUPPORT=1 \ -DNVSHMEM_PMIX_SUPPORT=1 \ -DNVSHMEM_LIBFABRIC_SUPPORT=1 \ -DNVSHMEM_IBRC_SUPPORT=1 \ -DNVSHMEM_IBGDA_SUPPORT=1 \ -DNVSHMEM_USE_GDRCOPY=1 \ -DNVSHMEM_BUILD_TESTS=1 \ -DNVSHMEM_BUILD_EXAMPLES=1 \ -DNVSHMEM_BUILD_HYDRA_LAUNCHER=1 \ -DNVSHMEM_BUILD_TXZ_PACKAGE=0 \ -DNVSHMEM_BUILD_PYTHON_LIB=0 \ -DMPI_HOME=/opt/amazon/openmpi \ -DPMIX_HOME=/opt/amazon/pmix \ -DGDRCOPY_HOME=/opt/gdrcopy \ -DLIBFABRIC_HOME=/opt/amazon/efa \ -G Ninja .. \ && ninja -j $(nproc) \ && ninja install \ && rm -rf /root/nvshmem RUN pip3 install --break-system-packages --no-cache-dir nvshmem4py-cu12 ENV LD_LIBRARY_PATH=/opt/amazon/pmix/lib:/opt/nvshmem/lib:$LD_LIBRARY_PATH ENV PATH=/opt/nvshmem/bin:$PATH ENV NVSHMEM_REMOTE_TRANSPORT=libfabric ENV NVSHMEM_LIBFABRIC_PROVIDER=efa ################################################### ## Install PyTorch RUN pip3 install --break-system-packages --no-cache-dir torch==${TORCH_VERSION} --index-url https://download.pytorch.org/whl/cu128 ################################################### ## Install DeepEP with NCCL GIN backend RUN unset NVSHMEM_DIR NVSHMEM_HOME \ && export ENABLE_NCCL=1 \ && export NCCL_DIR=/opt/nccl/build \ && export LD_LIBRARY_PATH=/opt/nccl/build/lib:$LD_LIBRARY_PATH \ && export LD_PRELOAD=/opt/nccl/build/lib/libnccl.so.2 \ && git clone -b nccl https://github.com/aamirshafi/DeepEP.git /opt/DeepEP \ && cd /opt/DeepEP \ && git checkout 6d29f34 \ && python3 setup.py build_ext --inplace \ && pip install --break-system-packages --no-build-isolation . RUN rm -rf /var/lib/apt/lists/* ## Set Open MPI variables ENV OMPI_MCA_pml=^ucx \ OMPI_MCA_btl=tcp,self \ OMPI_MCA_btl_tcp_if_exclude=lo,docker0,veth_def_agent\ OPAL_PREFIX=/opt/amazon/openmpi \ NCCL_SOCKET_IFNAME=^docker,lo,veth ENV FI_EFA_USE_DEVICE_RDMA=1 ENV FI_PROVIDER=efa ENV FI_EFA_FORK_SAFE=1 ENV NCCL_BUFFSIZE=8388608 ENV NCCL_P2P_NET_CHUNKSIZE=524288 ENV NCCL_TUNER_PLUGIN=/opt/amazon/ofi-nccl/lib/libnccl-tuner-ofi.so ## Turn off PMIx Error ENV PMIX_MCA_gds=hash ## Set LD_PRELOAD for NCCL library ENV LD_PRELOAD=/opt/nccl/build/lib/libnccl.so # NVSHMEM additional settings ENV NVSHMEM_DISABLE_CUDA_VMM=1 # Install Nsight Systems for profiling RUN apt-get update -y && apt-get install -y --no-install-recommends gnupg \ && apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/3bf863cc.pub \ && echo "deb https://developer.download.nvidia.com/devtools/repos/ubuntu2404/$(dpkg --print-architecture) /" \ > /etc/apt/sources.list.d/nvidia-devtools.list \ && apt-key adv --keyserver keyserver.ubuntu.com --recv-keys F60F4B3D7FA2AF80 \ && apt-get update -y \ && apt-get install -y --no-install-recommends nsight-systems-cli \ && rm -rf /var/lib/apt/lists/* ################################################### ## Install cuDNN (headers + libs needed for TransformerEngine build) RUN apt-get update -y && apt-get install -y --no-install-recommends libcudnn9-dev-cuda-12 \ && rm -rf /var/lib/apt/lists/* ################################################### ## Build tools and pinned numpy (before TE build) ENV TORCH_CUDA_ARCH_LIST="9.0;10.0" ENV NVTE_FRAMEWORK=pytorch ENV NVTE_CUDA_ARCHS="90;100" ENV NCCL_DIR=/opt/nccl/build ENV NCCL_INCLUDE_DIR=/opt/nccl/build/include ENV NCCL_LIB_DIR=/opt/nccl/build/lib ENV CPATH=/opt/nccl/build/include:${CPATH} RUN pip3 install --break-system-packages --no-cache-dir \ numpy ninja pybind11 packaging "setuptools<80" wheel ################################################### ## Pre-install CUDA extensions against current torch RUN pip3 install --break-system-packages --no-cache-dir --no-build-isolation causal-conv1d==1.6.0 mamba-ssm==2.3.0 ################################################### ## Install Megatron-Bridge RUN apt-get update && apt-get remove -y python3-blinker && rm -rf /var/lib/apt/lists/* RUN git clone -b ${MEGATRON_BRIDGE_VERSION} --depth 1 https://github.com/NVIDIA-NeMo/Megatron-Bridge.git /tmp/Megatron-Bridge \ && cd /tmp/Megatron-Bridge \ && pip install --break-system-packages --no-cache-dir --no-build-isolation "torch==${TORCH_VERSION}" . \ && rm -rf /tmp/Megatron-Bridge RUN pip install --break-system-packages --no-cache-dir viztracer # Sanity check RUN python3 -c "import torch; print(f'torch={torch.__version__}, CUDA={torch.version.cuda}')" \ && python3 -c "import transformer_engine; print(f'TE={transformer_engine.__version__}')" \ && python3 -c "import megatron.core; print('megatron-core OK')" \ && python3 -c "from megatron.bridge import AutoBridge; print('megatron-bridge OK')" WORKDIR /workspace ================================================ FILE: src/megatron/Makefile ================================================ IMAGE_NAME ?= megatron-lm .PHONY: build sqsh clean fmt build: ./enroot.sh -n "$(IMAGE_NAME)" -f Dockerfile sqsh: build clean: docker rmi -f "$(IMAGE_NAME)" 2>/dev/null || true rm -f *.sqsh fmt: shfmt -i 2 -w *.sh black recipes/ ================================================ FILE: src/megatron/README.md ================================================ # Megatron ```bash make build # Launch a 2-node DeepSeek V2 Lite pretrain job: salloc -N 2 ./srun.sh recipes/deepseek_v2_lite_pretrain.py \ hf_path=/fsx/models/deepseek-ai/DeepSeek-V2-Lite # Launch DeepSeek V2 pretrain (4 nodes): salloc -N 4 ./srun.sh recipes/deepseek_v2_pretrain.py \ hf_path=/fsx/models/deepseek-ai/DeepSeek-V2 # Override config with Hydra-style args: ./srun.sh recipes/deepseek_v2_lite_pretrain.py \ train.train_iters=1000 # Launch DeepSeek-V2-Lite using DeepEP with NCCL GIN ./srun.sh recipes/deepseek_v2_lite_pretrain.py \ hf_path=/fsx/models/deepseek-ai/DeepSeek-V2-Lite \ moe_token_dispatcher_type=deepep \ model.tensor_model_parallel_size=1 \ model.expert_model_parallel_size=64 \ model.sequence_parallel=false # Nsys profile ./srun.sh --nsys recipes/deepseek_v2_lite_pretrain.py \ hf_path=/fsx/hf_pretrained_models/deepseek-ai/DeepSeek-V2-Lite \ moe_token_dispatcher_type=deepep \ model.tensor_model_parallel_size=1 \ model.expert_model_parallel_size=64 \ model.sequence_parallel=false \ profiling.use_nsys_profiler=true \ profiling.profile_step_start=10 \ profiling.profile_step_end=15 \ profiling.profile_ranks=[0] # Viztracer profile ./srun.sh recipes/deepseek_v2_lite_pretrain.py \ hf_path=/fsx/models/deepseek-ai/DeepSeek-V2-Lite \ profiling.use_viztracer=true \ profiling.profile_step_start=10 \ profiling.profile_step_end=15 \ profiling.profile_ranks=[0] ``` ## Environment Variables | Variable | Default | Description | |---|---|---| | `SQSH` | `./megatron-lm+latest.sqsh` | Path to enroot image | | `MOUNT` | `.:/workspace/megatron,/fsx:/fsx` | Container mounts | | `GPUS_PER_NODE` | `8` | GPUs per node | ================================================ FILE: src/megatron/enroot.sh ================================================ #!/bin/bash FILE="${PWD}/Dockerfile" IMAGE="" INPUT="${PWD}" PROGRAM="$0" OUTPUT="${PWD}" err() { echo -e "[$(date +'%Y-%m-%dT%H:%M:%S%z')][error] $*" >&2 } usage() { set +x cat </dev/null; then err "please install 'enroot' cli" return 1 fi # split string IFS=$'\n' read -d "" -ra arr <<<"${image//\//$'\n'}" url="${arr[0]}" tag="${arr[1]//\:/-}" if [ -z "${url}" ]; then err "parse ${image} fail" return 1 fi if [ -z "${tag}" ]; then tag="latest" fi if ! enroot import -o "${output}/${url}+${tag}.sqsh" "dockerd://${image}"; then err "enroot import ${image} fail" return 1 fi } build() { local file="${1}" local image="${2}" local input="${3}" local arr=() local url if ! command -v docker &>/dev/null; then err "please install 'docker' cli" return 1 fi if [ -z "${image}" ]; then err "please specify a image name" return 1 fi IFS=$'\n' read -d "" -ra arr <<<"${image//\//$'\n'}" url="${arr[0]}" if [ -z "${url}" ]; then err "parse ${image} url fail" return 1 fi docker images "${url}" -q | xargs -I{} docker rmi -f {} if ! docker build -f "${file}" -t "${image}" "${input}"; then err "docker bukld -f $file -t $image $input fail" return 1 fi } while (("$#")); do case "$1" in -h | -\? | --help) usage exit 0 ;; -f | --file) FILE="$2" shift 2 ;; -n | --image) IMAGE="${2}" shift 2 ;; -i | --input) INPUT="${2}" shift 2 ;; -o | --output) OUTPUT="${2}" shift 2 ;; *) break ;; esac done set -exo pipefail if ! build "${FILE}" "${IMAGE}" "${INPUT}"; then err "build image fail" exit 1 fi if ! run "${IMAGE}" "${OUTPUT}"; then err "create enroot image fail" exit 1 fi ================================================ FILE: src/megatron/entrypoint.py ================================================ #!/usr/bin/env python3 """Generic entrypoint for Megatron Bridge recipes. Usage: ./srun.sh recipes/deepseek_v2_lite_pretrain.py hf_path=/fsx/models/deepseek-ai/DeepSeek-V2-Lite ./srun.sh recipes/qwen3_30b_a3b_pretrain.py hf_path=/fsx/models/Qwen/Qwen3-30B-A3B-FP8 """ import importlib.util import sys import megatron.core.jit as _jit if not hasattr(_jit, "disable_jit_fuser"): _jit.disable_jit_fuser = lambda: None import viztracer_plugin viztracer_plugin.install() from omegaconf import OmegaConf from megatron.bridge.training.gpt_step import forward_step from megatron.bridge.training.pretrain import pretrain from megatron.bridge.training.utils.omegaconf_utils import ( apply_overrides, create_omegaconf_dict_config, parse_hydra_overrides, ) def load_recipe(path, **kwargs): """Load a recipe module and call its `configure()` function.""" spec = importlib.util.spec_from_file_location("recipe", path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) return mod.configure(**kwargs) def parse_cli_overrides(cfg, args): """Apply Hydra-style CLI overrides (key=value) to a ConfigContainer.""" if args: omega_conf, excluded = create_omegaconf_dict_config(cfg) omega_conf = parse_hydra_overrides(omega_conf, args) apply_overrides(cfg, OmegaConf.to_container(omega_conf, resolve=True), excluded) def main() -> None: recipe_path = sys.argv[1] overrides = sys.argv[2:] recipe_kwargs = {} remaining = [] for o in overrides: if o.startswith("hf_path="): recipe_kwargs["hf_path"] = o.split("=", 1)[1] elif o.startswith("moe_token_dispatcher_type="): recipe_kwargs["moe_token_dispatcher_type"] = o.split("=", 1)[1] else: remaining.append(o) cfg = load_recipe(recipe_path, **recipe_kwargs) parse_cli_overrides(cfg, remaining) pretrain(config=cfg, forward_step_func=forward_step) if __name__ == "__main__": main() ================================================ FILE: src/megatron/recipes/deepseek_v2_lite_pretrain.py ================================================ from megatron.bridge.recipes.deepseek.deepseek_v2 import ( deepseek_v2_lite_pretrain_config, ) def configure(hf_path=None, moe_token_dispatcher_type=None): cfg = deepseek_v2_lite_pretrain_config( **({"hf_path": hf_path} if hf_path else {}), tensor_model_parallel_size=8, pipeline_model_parallel_size=1, expert_model_parallel_size=2, sequence_parallel=True, seq_length=4096, train_iters=500, global_batch_size=64, micro_batch_size=1, eval_interval=100, lr_warmup_iters=50, save_interval=0, ) cfg.model.moe_permute_fusion = False if moe_token_dispatcher_type == "deepep": cfg.model.moe_token_dispatcher_type = "flex" cfg.model.moe_flex_dispatcher_backend = "deepep" cfg.model.moe_enable_deepep = True cfg.model.moe_shared_expert_overlap = False return cfg ================================================ FILE: src/megatron/srun.sh ================================================ #!/bin/bash # Launch Megatron Bridge recipe inside enroot container via srun + pyxis # Usage: salloc -N 2 ./srun.sh recipes/deepseek_v2_lite_pretrain.py [overrides...] # Example: salloc -N 2 ./srun.sh recipes/deepseek_v2_lite_pretrain.py model.tensor_model_parallel_size=4 set -exo pipefail DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" SQSH="${SQSH:-${DIR}/megatron-lm+latest.sqsh}" MOUNT="${MOUNT:-${DIR}:/workspace/megatron,/fsx:/fsx}" GPUS_PER_NODE="${GPUS_PER_NODE:-8}" ENABLE_NSYS=false # Parse flags before recipe arg while [[ "${1:-}" == --* ]]; do case "$1" in --nsys) ENABLE_NSYS=true shift ;; *) break ;; esac done RECIPE="${1:?Usage: srun.sh [--nsys] [overrides...]}" RECIPE_PATH="/workspace/megatron/${RECIPE}" ENTRYPOINT="/workspace/megatron/entrypoint.py" shift OVERRIDES="$*" master_host=$(scontrol show hostnames "$SLURM_JOB_NODELIST" 2>/dev/null | head -n1) master_addr=$(getent hosts "${master_host}" 2>/dev/null | awk '{print $1}' || echo "${master_host}") master_addr=${master_addr:-127.0.0.1} cmd="$( cat < dict: """Merge dicts with | operator.""" return a | b def dict_update(a: dict, b: dict) -> dict: """Update dict in place with |= operator.""" a |= b return a # Python 3.8 - Positional-only parameters (PEP 570) def positional_only(a, b, /, c, d): """a and b must be positional, c and d can be keyword.""" return a + b + c + d # Python 3.8 - Walrus operator (PEP 572) def walrus_example(data: list): """Assignment expression in condition.""" if (n := len(data)) > 3: return n return None def walrus_fib(count: int) -> list: """Fibonacci using walrus operator.""" f = (0, 1) return [(f := (f[1], sum(f)))[0] for _ in range(count)] # Python 3.7 - Data Classes (PEP 557) @dataclass class Point: """Mutable dataclass.""" x: int y: int @dataclass(frozen=True) class FrozenPoint: """Immutable dataclass.""" x: int y: int # Python 3.6 - f-string (PEP 498) def fstring_basic(name: str) -> str: """Basic f-string interpolation.""" return f"Hello, {name}!" def fstring_format(value: float) -> str: """F-string with format spec.""" return f"{value:1.3}" # Python 3.5 - Async/Await (PEP 492) async def async_greet() -> str: """Native coroutine syntax.""" await asyncio.sleep(0.01) return "Hello" # Python 3.5 - General unpacking (PEP 448) def general_unpacking() -> list: """Multiple unpacking in literals.""" return [*range(3), 3, *range(4, 6)] # Python 3.3 - yield from (PEP 380) def fib_gen(n: int): """Generator for fibonacci.""" a, b = 0, 1 for _ in range(n): yield a b, a = a + b, b def delegate_fib(n: int): """Delegate to subgenerator.""" yield from fib_gen(n) # Python 3.0 - Extended unpacking (PEP 3132) def extended_unpacking() -> tuple: """Star operator captures remaining items.""" a, *b, c = range(5) return a, b, c # Python 3.0 - Keyword-only arguments (PEP 3102) def keyword_only(a, b, *, kw): """kw must be passed as keyword argument.""" return a, b, kw # Python 3.0 - nonlocal keyword (PEP 3104) def nonlocal_example() -> str: """Modify variable in enclosing scope.""" outer = "original" def inner(): nonlocal outer outer = "modified" inner() return outer # Tests @pytest.mark.skipif(PY_VERSION < (3, 12), reason="Requires Python 3.12+") class TestPython312: def test_box_int(self): exec( "class Box[T]:\n def __init__(self, item: T): self.item = item\nassert Box(42).item == 42" ) def test_first(self): exec( "def first[T](items: list[T]) -> T: return items[0]\nassert first([1, 2, 3]) == 1" ) def test_fstring_nested(self): songs = ["Take me back to Eden", "&", "Satellite"] result = eval('f"Playlist: {", ".join(songs)}"') assert "Playlist:" in result @pytest.mark.skipif(PY_VERSION < (3, 11), reason="Requires Python 3.11+") class TestPython311: def test_exception_group(self): code = """ caught_value = caught_type = False try: raise ExceptionGroup("errors", [ValueError("invalid"), TypeError("wrong")]) except* ValueError: caught_value = True except* TypeError: caught_type = True assert caught_value and caught_type """ exec(code) @pytest.mark.skipif(PY_VERSION < (3, 10), reason="Requires Python 3.10+") class TestPython310: def test_http_status(self): code = """ def http_status(status: int) -> str: match status: case 200: return "OK" case 404: return "Not Found" case 500: return "Internal Server Error" case _: return "Unknown" assert http_status(200) == "OK" assert http_status(404) == "Not Found" assert http_status(999) == "Unknown" """ exec(code) def test_describe_point(self): code = """ def describe_point(point: tuple) -> str: match point: case (0, 0): return "Origin" case (x, 0): return f"On x-axis at {x}" case (0, y): return f"On y-axis at {y}" case (x, y): return f"Point at ({x}, {y})" assert describe_point((0, 0)) == "Origin" assert describe_point((5, 0)) == "On x-axis at 5" """ exec(code) class TestPython39: def test_dict_merge(self): assert dict_merge({"a": 1}, {"b": 2}) == {"a": 1, "b": 2} def test_dict_update(self): assert dict_update({"a": 1}, {"b": 2}) == {"a": 1, "b": 2} class TestPython38: def test_positional_only(self): assert positional_only(1, 2, 3, 4) == 10 assert positional_only(1, 2, c=3, d=4) == 10 def test_walrus(self): assert walrus_example([1, 2, 3, 4, 5]) == 5 assert walrus_example([1, 2]) is None def test_walrus_fib(self): assert walrus_fib(10) == [1, 1, 2, 3, 5, 8, 13, 21, 34, 55] class TestPython37: def test_dataclass(self): assert Point(1, 2) == Point(1, 2) def test_frozen_dataclass(self): with pytest.raises(FrozenInstanceError): FrozenPoint(1, 2).x = 3 class TestPython36: def test_fstring_basic(self): assert fstring_basic("World") == "Hello, World!" def test_fstring_format(self): assert fstring_format(123.567) == "1.24e+02" class TestPython35: def test_async_greet(self): assert asyncio.run(async_greet()) == "Hello" def test_general_unpacking(self): assert general_unpacking() == [0, 1, 2, 3, 4, 5] class TestPython33: def test_delegate_fib(self): assert list(delegate_fib(10)) == [0, 1, 1, 2, 3, 5, 8, 13, 21, 34] class TestPython30: def test_extended_unpacking(self): assert extended_unpacking() == (0, [1, 2, 3], 4) def test_keyword_only(self): assert keyword_only(1, 2, kw=3) == (1, 2, 3) def test_nonlocal(self): assert nonlocal_example() == "modified" ================================================ FILE: src/nixl/Dockerfile ================================================ ARG CUDA_VERSION=12.8.1 ARG GDRCOPY_VERSION=v2.5.1 ARG EFA_INSTALLER_VERSION=1.47.0 ARG UCX_VERSION=v1.20.0 ARG NIXL_VERSION=0.10.1 ARG CUDA_ARCH=90 FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu24.04 ARG GDRCOPY_VERSION ARG EFA_INSTALLER_VERSION ARG UCX_VERSION ARG NIXL_VERSION ARG CUDA_ARCH ENV DEBIAN_FRONTEND=noninteractive ENV TZ=UTC RUN apt-get update -y && apt-get upgrade -y RUN apt-get remove -y --allow-change-held-packages \ ibverbs-utils libibverbs-dev libibverbs1 libmlx5-1 RUN rm -rf /opt/hpcx /usr/local/mpi /etc/ld.so.conf.d/hpcx.conf && ldconfig ENV OPAL_PREFIX= RUN DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ autoconf \ automake \ build-essential \ ca-certificates \ cmake \ curl \ dmidecode \ environment-modules \ etcd-server \ ethtool \ git \ ibverbs-providers \ ibverbs-utils \ iproute2 \ jq \ libcpprest-dev \ libcurl4-openssl-dev \ libelf-dev \ libevent-core-2.1-7t64 \ libevent-pthreads-2.1-7t64 \ libgflags-dev \ libgrpc++-dev \ libgrpc-dev \ libhwloc-dev \ libhwloc15 \ libibumad-dev \ libibverbs-dev \ libnl-3-200 \ libnl-3-dev \ libnl-route-3-200 \ libnl-route-3-dev \ libnuma-dev \ libprotobuf-dev \ librdmacm-dev \ libssl-dev \ libtool \ meson \ ninja-build \ openssh-client \ openssh-server \ pciutils \ pkg-config \ protobuf-compiler-grpc \ pybind11-dev \ python3 \ python3-dev \ python3-pip \ rdma-core \ tcl \ udev \ uuid-dev \ vim \ wget \ zlib1g-dev \ && rm -rf /var/lib/apt/lists/* RUN apt-get purge -y cuda-compat-* || true RUN mkdir -p /var/run/sshd RUN sed -i 's/[ #]\(.*StrictHostKeyChecking \).*/ \1no/g' /etc/ssh/ssh_config && \ echo " UserKnownHostsFile /dev/null" >> /etc/ssh/ssh_config && \ sed -i 's/#\(StrictModes \).*/\1no/g' /etc/ssh/sshd_config ENV LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:/opt/amazon/openmpi/lib:/opt/amazon/efa/lib:/opt/gdrcopy/lib:/usr/local/lib:$LD_LIBRARY_PATH ENV PATH=/opt/amazon/openmpi/bin:/opt/amazon/efa/bin:/usr/bin:/usr/local/bin:$PATH RUN rm -f /usr/lib/python*/EXTERNALLY-MANAGED \ && pip3 install --no-cache-dir awscli nvidia-ml-py # --- GDRCopy --- RUN git clone -b ${GDRCOPY_VERSION} https://github.com/NVIDIA/gdrcopy.git /tmp/gdrcopy \ && cd /tmp/gdrcopy \ && make prefix=/opt/gdrcopy install \ && rm -rf /tmp/gdrcopy ENV LD_LIBRARY_PATH=/opt/gdrcopy/lib:$LD_LIBRARY_PATH ENV LIBRARY_PATH=/opt/gdrcopy/lib:$LIBRARY_PATH ENV CPATH=/opt/gdrcopy/include ENV PATH=/opt/gdrcopy/bin:$PATH # --- EFA --- RUN cd $HOME \ && curl -O https://efa-installer.amazonaws.com/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \ && tar -xf $HOME/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \ && cd aws-efa-installer \ && ./efa_installer.sh -y -g -d --skip-kmod --skip-limit-conf --no-verify \ && rm -rf $HOME/aws-efa-installer $HOME/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz # --- UCX (verbs + rdmacm + dm + efa) --- ENV UCX_PREFIX=/usr/local/ucx RUN git clone --depth 1 --branch ${UCX_VERSION} https://github.com/openucx/ucx.git /tmp/ucx \ && cd /tmp/ucx && ./autogen.sh \ && ./contrib/configure-release-mt \ --prefix=${UCX_PREFIX} \ --enable-shared --disable-static \ --enable-optimizations --enable-cma --enable-mt \ --enable-devel-headers \ --with-cuda=/usr/local/cuda \ --with-gdrcopy=/opt/gdrcopy \ --with-verbs --with-rdmacm --with-dm --with-efa \ && make -j$(nproc) && make install \ && echo "${UCX_PREFIX}/lib" > /etc/ld.so.conf.d/ucx.conf \ && echo "${UCX_PREFIX}/lib/ucx" >> /etc/ld.so.conf.d/ucx.conf \ && ldconfig && rm -rf /tmp/ucx ENV PATH="${UCX_PREFIX}/bin:${PATH}" ENV LD_LIBRARY_PATH="${UCX_PREFIX}/lib:${UCX_PREFIX}/lib/ucx:${LD_LIBRARY_PATH}" # --- PyTorch (install before NIXL, needed for Python bindings) --- RUN pip3 install --no-cache-dir torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128 # --- NIXL --- COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ RUN git clone --depth 1 --branch ${NIXL_VERSION} \ https://github.com/ai-dynamo/nixl.git /opt/nixl \ && cd /opt/nixl \ && pip3 install --no-cache-dir tomlkit \ && export PKG_CONFIG_PATH="/opt/amazon/efa/lib/pkgconfig:$PKG_CONFIG_PATH" \ && export CPATH="/opt/amazon/efa/include:$CPATH" \ && export LIBRARY_PATH="/opt/amazon/efa/lib:/usr/local/cuda/lib64/stubs" \ && meson setup build \ --prefix=/usr/local \ --buildtype=release \ -Ducx_path=${UCX_PREFIX} \ -Dlibfabric_path=/opt/amazon/efa \ && ninja -C build -j$(nproc) \ && ninja -C build install \ && pip3 install --no-cache-dir build/src/bindings/python/nixl-meta/nixl-${NIXL_VERSION}-py3-none-any.whl \ && ldconfig # --- nixlbench --- # etcd-cpp-api (nixlbench dep) RUN git clone --depth 1 https://github.com/etcd-cpp-apiv3/etcd-cpp-apiv3.git /tmp/etcd-cpp \ && cd /tmp/etcd-cpp && mkdir build && cd build \ && cmake .. -DCMAKE_INSTALL_PREFIX=/usr/local -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=Release \ && make -j$(nproc) && make install && ldconfig \ && rm -rf /tmp/etcd-cpp RUN cd /opt/nixl/benchmark/nixlbench \ && meson setup build --prefix=/usr/local --buildtype=release -Dnixl_path=/usr/local \ && ninja -C build -j$(nproc) && ninja -C build install # --- kvbench (Python-based, lives in /opt/nixl/benchmark/kvbench) --- RUN pip3 install --no-cache-dir pyYAML click tabulate # --- Environment --- ENV OPAL_PREFIX=/opt/amazon/openmpi \ OMPI_MCA_pml=ob1 \ OMPI_MCA_btl="tcp,self" \ FI_PROVIDER=efa \ FI_EFA_FORK_SAFE=1 \ FI_EFA_USE_DEVICE_RDMA=1 \ FI_LOG_LEVEL=warn \ UCX_TLS="^cuda_ipc" \ UCX_NET_DEVICES=all \ PYTHONPATH="/opt/nixl/benchmark" # --- Validation --- RUN echo "=== NIXL Image Validation ===" \ && ucx_info -v | head -3 \ && fi_info --version \ && ${OPAL_PREFIX}/bin/mpirun --version | head -1 \ && which ucx_perftest \ && which nixlbench \ && python3 -c "import nixl; print('nixl OK')" # --- flash-attn + DeepGEMM --- RUN apt-get update && apt-get remove -y python3-blinker && rm -rf /var/lib/apt/lists/* RUN pip3 install --no-cache-dir packaging numpy ninja pybind11 "setuptools<80" wheel RUN pip3 install --no-cache-dir flash-attn==2.8.1 --no-build-isolation RUN git clone --recursive -b v2.1.1.post3 https://github.com/deepseek-ai/DeepGEMM.git /tmp/deepgemm \ && cd /tmp/deepgemm \ && python3 setup.py bdist_wheel \ && pip3 install dist/*.whl \ && rm -rf /tmp/deepgemm # --- vLLM + SGLang + Ray --- ENV VLLM_USE_DEEP_GEMM=1 ENV DG_JIT_CACHE_DIR=/tmp RUN pip3 install --no-cache-dir vllm==0.15.1 RUN pip3 install --no-cache-dir "sglang[all]==0.5.9" RUN pip3 install --no-cache-dir ray # --- vllm-router (PD disaggregation proxy) --- RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y ENV PATH="/root/.cargo/bin:${PATH}" RUN pip3 install --no-cache-dir vllm-router WORKDIR /workspace CMD ["/bin/bash"] ================================================ FILE: src/nixl/Makefile ================================================ .PHONY: help docker save clean .DEFAULT_GOAL := help IMAGE_NAME ?= nixl IMAGE_TAG ?= latest help: @echo "NIXL Makefile" @echo "" @echo "Targets:" @echo " docker Build Docker image" @echo " save Save Docker image to tar.gz" @echo " clean Remove image and tarball" @echo "" @echo "Usage:" @echo " make docker && make save" docker: docker build -t $(IMAGE_NAME):$(IMAGE_TAG) -f Dockerfile . save: docker save $(IMAGE_NAME):$(IMAGE_TAG) | pigz > $(IMAGE_NAME)-$(IMAGE_TAG).tar.gz clean: -docker rmi $(IMAGE_NAME):$(IMAGE_TAG) 2>/dev/null || true -rm -f $(IMAGE_NAME)-$(IMAGE_TAG).tar.gz ================================================ FILE: src/nixl/bench.sh ================================================ #!/usr/bin/env bash # vLLM serving benchmark suite # Usage: # salloc -N1 bash bench.sh -H 10.0.128.193 -i /fsx/nixl-latest.tar.gz # bash bench.sh -H 10.0.128.193 -i vllm-serve:latest set -euo pipefail info() { echo "[$(date +'%H:%M:%S')] $*"; } # Docker image helpers (mirrors run.sbatch) CONTAINER_MOUNT="${CONTAINER_MOUNT:-/fsx}" # Wrap a command with srun if inside a SLURM allocation, otherwise run directly _run() { if [[ -n "${SLURM_JOB_ID:-}" ]]; then srun -N1 --ntasks-per-node=1 bash -c "$*" else bash -c "$*" fi } load_or_pull_image() { if [[ "${IMAGE}" == *.tar.gz ]]; then CONTAINER_IMAGE=$(pigz -dc "${IMAGE}" | tar -xf - -O manifest.json | python3 -c "import sys,json; print(json.load(sys.stdin)[0]['RepoTags'][0])") info "Image tag: ${CONTAINER_IMAGE}" _run " if ! docker image inspect '${CONTAINER_IMAGE}' &>/dev/null; then echo 'Loading Docker image from tarball...' pigz -dc '${IMAGE}' | docker load fi " else CONTAINER_IMAGE="${IMAGE}" _run " if ! docker image inspect '${CONTAINER_IMAGE}' &>/dev/null; then echo 'Pulling ${CONTAINER_IMAGE}...' registry=\"\${CONTAINER_IMAGE%%/*}\" region=\$(echo \"\${registry}\" | sed -n 's/.*\.ecr\.\([^.]*\)\.amazonaws\.com/\1/p') region=\"\${region:-us-west-2}\" aws ecr get-login-password --region \"\${region}\" \ | docker login --username AWS --password-stdin \"\${registry}\" docker pull '${CONTAINER_IMAGE}' fi " fi } launch_container() { local cmd="$1" _run " docker run --rm --net=host \ -v '${PWD}:${PWD}' -w '${PWD}' \ -v '${CONTAINER_MOUNT}:${CONTAINER_MOUNT}' \ --entrypoint bash '${CONTAINER_IMAGE}' \ -c '${cmd}' " } # If vllm CLI is not available, load image and re-exec inside container if ! command -v vllm &>/dev/null; then # Pre-parse --image/-i before full arg parsing IMAGE="" _args=("$@") for ((i = 0; i < ${#_args[@]}; i++)); do [[ "${_args[$i]}" == "--image" || "${_args[$i]}" == "-i" ]] && { IMAGE="${_args[$((i + 1))]}" break } done IMAGE="${IMAGE:-${PWD}/nixl-latest.tar.gz}" load_or_pull_image _SCRIPT="$(cd "$(dirname "$0")" && pwd)/$(basename "$0")" launch_container "bash ${_SCRIPT} $*" exit $? fi HOST="localhost" PORT="8000" BENCH_ARGS=() usage() { cat <] [nixlbench args...] # # Examples: # salloc -N 2 bash nixl.sbatch --backend UCX --initiator_seg_type VRAM --target_seg_type VRAM # salloc -N 2 bash nixl.sbatch --image ./nixl-test-latest.tar.gz --backend UCX set -euo pipefail info() { echo -e "[$(date +'%Y-%m-%dT%H:%M:%S%z')][info] $*"; } err() { echo -e "[$(date +'%Y-%m-%dT%H:%M:%S%z')][error] $*" >&2; } IMAGE="" WORKSPACE="$PWD" CONTAINER_MOUNT="/fsx" ETCD_PORT=$((2379 + (SLURM_JOB_ID % 1000))) BENCH_ARGS=() DOCKER_ENVS=() while (("$#")); do case "$1" in --image) IMAGE="$2"; shift 2 ;; --container-mount) CONTAINER_MOUNT="$2"; shift 2 ;; --env) DOCKER_ENVS+=("-e" "$2"); shift 2 ;; *) BENCH_ARGS+=("$1"); shift ;; esac done IMAGE="${IMAGE:-${WORKSPACE}/nixl-latest.tar.gz}" NUM_NODES=${SLURM_JOB_NUM_NODES:-1} readarray -t NODES < <(scontrol show hostnames "$SLURM_JOB_NODELIST") HEAD_NODE=${NODES[0]} HEAD_IP=$(getent ahostsv4 "$HEAD_NODE" | head -1 | awk '{print $1}') info "========================================" info "nixlbench" info "========================================" info "Image: ${IMAGE}" info "Nodes: ${NUM_NODES}, Head: ${HEAD_NODE} (${HEAD_IP})" info "Args: ${BENCH_ARGS[*]+"${BENCH_ARGS[*]}"}" info "========================================" load_image() { if [[ "${IMAGE}" == *.tar.gz ]]; then CONTAINER_IMAGE=$(pigz -dc "${IMAGE}" | tar -xf - -O manifest.json | python3 -c "import sys,json; print(json.load(sys.stdin)[0]['RepoTags'][0])") info "Image tag: ${CONTAINER_IMAGE}" srun --ntasks-per-node=1 bash -c " if ! docker image inspect '${CONTAINER_IMAGE}' &>/dev/null; then pigz -dc '${IMAGE}' | docker load fi " else CONTAINER_IMAGE="${IMAGE}" srun --ntasks-per-node=1 bash -c " docker image inspect '${CONTAINER_IMAGE}' &>/dev/null || docker pull '${CONTAINER_IMAGE}' " fi } launch_container() { local name="$1" node="$2" cmd="$3" local devices=("--device=/dev/gdrdrv") while IFS= read -r -d '' d; do devices+=("--device=${d}") done < <(find "/dev/infiniband" -name "uverbs*" -print0 2>/dev/null) srun --nodes=1 --nodelist="${node}" bash -c " docker run --gpus all \ --privileged -d \ --name '${name}' \ --uts=host --ipc=host --net=host \ --ulimit stack=67108864 --ulimit memlock=-1 \ --security-opt seccomp=unconfined \ ${devices[*]+"${devices[*]}"} \ ${DOCKER_ENVS[*]+"${DOCKER_ENVS[*]}"} \ -v '${CONTAINER_MOUNT}:${CONTAINER_MOUNT}' \ --entrypoint bash \ '${CONTAINER_IMAGE}' \ -c '${cmd}' " } LOGDIR="${WORKSPACE}/logs" cleanup() { info "Cleaning up..." kill $(jobs -p) 2>/dev/null || true srun --ntasks-per-node=1 bash -c ' docker ps -aq | xargs -r docker stop -t 10 2>/dev/null || true docker ps -aq | xargs -r docker rm -f 2>/dev/null || true ' 2>/dev/null || true rm -f "${LOGDIR}"/nixlbench-*.log } trap cleanup EXIT cleanup load_image # Start etcd on head node info "Starting etcd on ${HEAD_NODE}:${ETCD_PORT}..." launch_container "etcd" "${HEAD_NODE}" \ "etcd --data-dir=/tmp/etcd-data \ --listen-client-urls http://0.0.0.0:${ETCD_PORT} \ --advertise-client-urls http://${HEAD_IP}:${ETCD_PORT} \ --listen-peer-urls http://0.0.0.0:$((ETCD_PORT+1)) \ --initial-advertise-peer-urls http://${HEAD_IP}:$((ETCD_PORT+1)) \ --initial-cluster default=http://${HEAD_IP}:$((ETCD_PORT+1))" ETCD_URL="http://${HEAD_IP}:${ETCD_PORT}" # Wait for etcd to be ready info "Waiting for etcd at ${ETCD_URL}..." for i in $(seq 1 60); do if curl -sf "${ETCD_URL}/health" >/dev/null 2>&1; then info "etcd is ready" break fi if [ "$i" -eq 60 ]; then err "etcd failed to start within 60s" exit 1 fi sleep 2 done # Launch nixlbench containers for i in $(seq 0 $((NUM_NODES - 1))); do launch_container "nixl-${i}" "${NODES[$i]}" "sleep infinity" done sleep 3 BENCH_ARGS_STR="${BENCH_ARGS[*]+"${BENCH_ARGS[*]}"}" mkdir -p "${LOGDIR}" # Start nixlbench on worker nodes for i in $(seq 1 $((NUM_NODES - 1))); do info "Starting nixlbench on ${NODES[$i]}..." srun --nodes=1 --nodelist="${NODES[$i]}" bash -c " docker exec -d nixl-${i} bash -c 'nixlbench --etcd_endpoints ${ETCD_URL} ${BENCH_ARGS_STR} 2>&1 | tee ${LOGDIR}/nixlbench-${NODES[$i]}.log' " tail -f "${LOGDIR}/nixlbench-${NODES[$i]}.log" 2>/dev/null | sed "s/^/[${NODES[$i]}] /" & sleep 1 done # Start nixlbench on head node (foreground) info "Starting nixlbench on ${HEAD_NODE}..." srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c " docker exec nixl-0 bash -c 'nixlbench --etcd_endpoints ${ETCD_URL} ${BENCH_ARGS_STR} 2>&1 | tee ${LOGDIR}/nixlbench-${HEAD_NODE}.log' " ================================================ FILE: src/nixl/vllm.sbatch ================================================ #!/bin/bash # Usage: # salloc -N bash vllm.sbatch [--route R] [--prefill P] [--image ] [vllm serve args...] # # Modes: # --prefill 0 (default): pure DP, no disaggregation # --prefill P: P prefill + (nodes_per_group - P) decode per group # --route R: split nodes into R identical groups, vllm-router round-robins # # Examples: # # Exp 1: Pure DP — 4 nodes, TP=8 # salloc -N 4 bash vllm.sbatch --model /fsx/models/deepseek-ai/DeepSeek-V2-Lite # # # Exp 2: 2 groups of 2 nodes, pure DP each, router round-robins # salloc -N 4 bash vllm.sbatch --route 2 --model /fsx/models/deepseek-ai/DeepSeek-V2-Lite # # # Exp 3: 2 groups of 2 nodes, each 1P+1D, router round-robins # salloc -N 4 bash vllm.sbatch --route 2 --prefill 1 --model /fsx/models/deepseek-ai/DeepSeek-V2-Lite set -euo pipefail info() { echo -e "[$(date +'%Y-%m-%dT%H:%M:%S%z')][info] $*"; } err() { echo -e "[$(date +'%Y-%m-%dT%H:%M:%S%z')][error] $*" >&2; } IMAGE="" WORKSPACE="$PWD" CONTAINER_MOUNT="/fsx" NUM_PREFILL=0 ROUTE_SIZE=1 SERVE_ARGS=() while (("$#")); do case "$1" in --image) IMAGE="$2"; shift 2 ;; --container-mount) CONTAINER_MOUNT="$2"; shift 2 ;; --prefill) NUM_PREFILL="$2"; shift 2 ;; --route) ROUTE_SIZE="$2"; shift 2 ;; *) SERVE_ARGS+=("$1"); shift ;; esac done IMAGE="${IMAGE:-${WORKSPACE}/nixl-latest.tar.gz}" NUM_NODES=${SLURM_JOB_NUM_NODES:-1} readarray -t NODES < <(scontrol show hostnames "$SLURM_JOB_NODELIST") HEAD_NODE=${NODES[0]} HEAD_IP=$(getent ahostsv4 "$HEAD_NODE" | head -1 | awk '{print $1}') LOGDIR="${WORKSPACE}/logs" _peek_arg() { local short="$1" long="$2" default="$3" i=0 while ((i < ${#SERVE_ARGS[@]})); do if [[ "${SERVE_ARGS[$i]}" == "$short" || "${SERVE_ARGS[$i]}" == "$long" ]]; then echo "${SERVE_ARGS[$((i + 1))]}" return fi ((i++)) done echo "$default" } GPUS_PER_NODE=8 TP=$(_peek_arg "-tp" "--tensor-parallel-size" "${GPUS_PER_NODE}") MODEL=$(_peek_arg "" "--model" "deepseek-ai/DeepSeek-V2-Lite") SERVE_ARGS_STR=$(printf '%q ' "${SERVE_ARGS[@]+"${SERVE_ARGS[@]}"}") DP_LOCAL=$((GPUS_PER_NODE / TP)) # Validate if (( NUM_NODES % ROUTE_SIZE != 0 )); then err "NUM_NODES(${NUM_NODES}) must be divisible by --route ${ROUTE_SIZE}" exit 1 fi NODES_PER_GROUP=$((NUM_NODES / ROUTE_SIZE)) NUM_DECODE=$((NODES_PER_GROUP - NUM_PREFILL)) if (( NUM_PREFILL > 0 && NUM_DECODE < 1 )); then err "Each group needs at least 1 decode node. Got ${NODES_PER_GROUP} nodes/group, --prefill ${NUM_PREFILL}" exit 1 fi ROUTER_PORT=$((8000 + ROUTE_SIZE)) mkdir -p "${LOGDIR}" if (( NUM_PREFILL == 0 )); then MODE="pure-dp" else MODE="${NUM_PREFILL}P+${NUM_DECODE}D" fi info "========================================" info "vLLM Server" info "========================================" info "Image: ${IMAGE}" info "Model: ${MODEL}" info "Nodes: ${NUM_NODES}, TP=${TP}, DP_LOCAL=${DP_LOCAL}" info "Route: ${ROUTE_SIZE} group(s), ${NODES_PER_GROUP} node(s)/group, mode=${MODE}" info "Args: ${SERVE_ARGS[*]+"${SERVE_ARGS[*]}"}" info "========================================" load_image() { if [[ "${IMAGE}" == *.tar.gz ]]; then CONTAINER_IMAGE=$(pigz -dc "${IMAGE}" | tar -xf - -O manifest.json | python3 -c "import sys,json; print(json.load(sys.stdin)[0]['RepoTags'][0])") info "Image tag: ${CONTAINER_IMAGE}" srun --ntasks-per-node=1 bash -c " if ! docker image inspect '${CONTAINER_IMAGE}' &>/dev/null; then pigz -dc '${IMAGE}' | docker load fi " else CONTAINER_IMAGE="${IMAGE}" srun --ntasks-per-node=1 bash -c " docker image inspect '${CONTAINER_IMAGE}' &>/dev/null || docker pull '${CONTAINER_IMAGE}' " fi } launch_container() { local name="$1" cmd="$2" local devices=("--device=/dev/gdrdrv") while IFS= read -r -d '' d; do devices+=("--device=${d}") done < <(find "/dev/infiniband" -name "uverbs*" -print0 2>/dev/null) local net_if="${GLOO_SOCKET_IFNAME:-$(ip -o -4 route show to default | awk '{print $5}' | head -1)}" local host_ip; host_ip=$(ip -o -4 addr show "${net_if}" | awk '{print $4}' | cut -d/ -f1 | head -1) docker run --gpus all \ --privileged -d \ --name "${name}" \ --uts=host --ipc=host --net=host \ --ulimit stack=67108864 --ulimit memlock=-1 \ --security-opt seccomp=unconfined \ "${devices[@]}" \ -e NCCL_SOCKET_IFNAME="${net_if}" \ -e GLOO_SOCKET_IFNAME="${net_if}" \ -e VLLM_NIXL_SIDE_CHANNEL_HOST="${host_ip}" \ -e VLLM_NIXL_SIDE_CHANNEL_PORT="${SIDE_CHANNEL_PORT:-5600}" \ -e VLLM_RPC_TIMEOUT="${VLLM_RPC_TIMEOUT:-3600000}" \ -e VLLM_ENGINE_READY_TIMEOUT_S="${VLLM_ENGINE_READY_TIMEOUT_S:-3600}" \ -v "${CONTAINER_MOUNT}:${CONTAINER_MOUNT}" \ --entrypoint bash \ "${CONTAINER_IMAGE}" \ -c "${cmd}" } cleanup() { info "Cleaning up..." srun --ntasks-per-node=1 bash -c ' docker ps -aq | xargs -r docker rm -f 2>/dev/null || true ' 2>/dev/null || true rm -f "${LOGDIR}"/vllm-*.log } wait_for_server() { local ip="$1" port="$2" label="$3" info "Waiting for ${label} at ${ip}:${port} ..." for _ in $(seq 1 360); do if curl -sf "http://${ip}:${port}/health" >/dev/null 2>&1; then info "${label} is ready" return 0 fi sleep 5 done err "${label} failed to start within 1800s" return 1 } # --- Launch a pure-DP group on a set of nodes --- # Args: group_id, node_list (space-separated), port start_dp_group() { local gid="$1" port="$2" shift 2 local gnodes=("$@") local ghead="${gnodes[0]}" local ghead_ip; ghead_ip=$(getent ahostsv4 "$ghead" | head -1 | awk '{print $1}') local gsize=${#gnodes[@]} local dp=$((gsize * DP_LOCAL)) local rpc_port=$((13345 + (SLURM_JOB_ID % 1000) + gid)) local logfile="${LOGDIR}/vllm-g${gid}.log" info "Group ${gid}: ${gsize} nodes, DP=${dp}, head=${ghead}:${port}" # Launch containers for i in $(seq 0 $((gsize - 1))); do srun --nodes=1 --nodelist="${gnodes[$i]}" bash -c " $(declare -f launch_container) CONTAINER_IMAGE='${CONTAINER_IMAGE}' CONTAINER_MOUNT='${CONTAINER_MOUNT}' \ launch_container 'vllm-g${gid}-n${i}' 'sleep infinity' " done sleep 3 local extra="--data-parallel-size ${dp} --data-parallel-size-local ${DP_LOCAL} --data-parallel-address ${ghead_ip} --data-parallel-rpc-port ${rpc_port}" # Start workers for i in $(seq 1 $((gsize - 1))); do local start_rank=$((i * DP_LOCAL)) srun --nodes=1 --nodelist="${gnodes[$i]}" bash -c " docker exec -d vllm-g${gid}-n${i} bash -c 'vllm serve ${SERVE_ARGS_STR} ${extra} \ --data-parallel-start-rank ${start_rank} --headless \ 2>&1 | tee ${logfile}.node${i}' " done # Start head srun --nodes=1 --nodelist="${ghead}" bash -c " docker exec -d vllm-g${gid}-n0 bash -c 'vllm serve ${SERVE_ARGS_STR} ${extra} \ --host 0.0.0.0 --port ${port} \ 2>&1 | tee ${logfile}' " tail -f "${logfile}" 2>/dev/null | sed "s/^/[g${gid}] /" & } # --- Launch a PD-disaggregated group on a set of nodes --- # Args: group_id, base_port, node_list (space-separated) start_pd_group() { local gid="$1" base_port="$2" shift 2 local gnodes=("$@") local gsize=${#gnodes[@]} info "Group ${gid}: ${NUM_PREFILL}P+${NUM_DECODE}D, nodes=${gnodes[*]}" # Launch prefill nodes for i in $(seq 0 $((NUM_PREFILL - 1))); do local node="${gnodes[$i]}" local port=$((base_port + i)) srun --nodes=1 --nodelist="${node}" bash -c " $(declare -f launch_container) CONTAINER_IMAGE='${CONTAINER_IMAGE}' CONTAINER_MOUNT='${CONTAINER_MOUNT}' \ launch_container 'vllm-g${gid}-p${i}' \ 'vllm serve ${SERVE_ARGS_STR} \ --host 0.0.0.0 --port ${port} \ --data-parallel-size ${DP_LOCAL} \ --kv-transfer-config.kv_connector NixlConnector \ --kv-transfer-config.kv_role kv_producer \ --kv-transfer-config.kv_load_failure_policy fail \ --kv-transfer-config.kv_connector_extra_config.backends+ LIBFABRIC \ 2>&1 | tee ${LOGDIR}/vllm-g${gid}-p${i}.log' " tail -f "${LOGDIR}/vllm-g${gid}-p${i}.log" 2>/dev/null | sed "s/^/[g${gid}-p${i}] /" & done # Launch decode nodes for i in $(seq 0 $((NUM_DECODE - 1))); do local node="${gnodes[$((NUM_PREFILL + i))]}" local port=$((base_port + NUM_PREFILL + i)) srun --nodes=1 --nodelist="${node}" bash -c " $(declare -f launch_container) CONTAINER_IMAGE='${CONTAINER_IMAGE}' CONTAINER_MOUNT='${CONTAINER_MOUNT}' \ launch_container 'vllm-g${gid}-d${i}' \ 'vllm serve ${SERVE_ARGS_STR} \ --host 0.0.0.0 --port ${port} \ --data-parallel-size ${DP_LOCAL} \ --kv-transfer-config.kv_connector NixlConnector \ --kv-transfer-config.kv_role kv_consumer \ --kv-transfer-config.kv_load_failure_policy fail \ --kv-transfer-config.kv_connector_extra_config.backends+ LIBFABRIC \ 2>&1 | tee ${LOGDIR}/vllm-g${gid}-d${i}.log' " tail -f "${LOGDIR}/vllm-g${gid}-d${i}.log" 2>/dev/null | sed "s/^/[g${gid}-d${i}] /" & done } start_router() { local router_args="" if (( NUM_PREFILL == 0 )); then # Pure DP: round-robin across group endpoints router_args="--worker-urls" for g in $(seq 0 $((ROUTE_SIZE - 1))); do router_args="${router_args} http://${GROUP_IPS[$g]}:${GROUP_PORTS[$g]}" done info "Starting vllm-router (DP) on ${HEAD_NODE}:${ROUTER_PORT} ..." srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c " docker exec -d vllm-g0-n0 bash -c \ 'vllm-router \ --policy round_robin \ ${router_args} \ --host 0.0.0.0 \ --port ${ROUTER_PORT} \ --intra-node-data-parallel-size ${DP_LOCAL} \ 2>&1 | tee ${LOGDIR}/vllm-router.log' " else # PD disaggregation: collect all prefill/decode endpoints across groups local prefill_args="" decode_args="" for g in $(seq 0 $((ROUTE_SIZE - 1))); do local goffset=$((g * NODES_PER_GROUP)) local base_port=$((8000 + g * NODES_PER_GROUP)) for i in $(seq 0 $((NUM_PREFILL - 1))); do local ip; ip=$(getent ahostsv4 "${NODES[$((goffset + i))]}" | head -1 | awk '{print $1}') prefill_args="${prefill_args} --prefill http://${ip}:$((base_port + i))" done for i in $(seq 0 $((NUM_DECODE - 1))); do local ip; ip=$(getent ahostsv4 "${NODES[$((goffset + NUM_PREFILL + i))]}" | head -1 | awk '{print $1}') decode_args="${decode_args} --decode http://${ip}:$((base_port + NUM_PREFILL + i))" done done info "Starting vllm-router (PD) on ${HEAD_NODE}:${ROUTER_PORT} ..." srun --nodes=1 --nodelist="${HEAD_NODE}" bash -c " docker exec -d vllm-g0-p0 bash -c \ 'vllm-router \ --policy consistent_hash \ --vllm-pd-disaggregation \ ${prefill_args} \ ${decode_args} \ --host 0.0.0.0 \ --port ${ROUTER_PORT} \ --intra-node-data-parallel-size ${DP_LOCAL} \ 2>&1 | tee ${LOGDIR}/vllm-router.log' " fi tail -f "${LOGDIR}/vllm-router.log" 2>/dev/null | sed "s/^/[router] /" & } start_groups() { GROUP_IPS=() GROUP_PORTS=() for g in $(seq 0 $((ROUTE_SIZE - 1))); do local goffset=$((g * NODES_PER_GROUP)) local gnodes=("${NODES[@]:$goffset:$NODES_PER_GROUP}") local ghead_ip; ghead_ip=$(getent ahostsv4 "${gnodes[0]}" | head -1 | awk '{print $1}') if (( NUM_PREFILL == 0 )); then local port=$((8000 + g)) start_dp_group "$g" "$port" "${gnodes[@]}" GROUP_IPS+=("${ghead_ip}") GROUP_PORTS+=("${port}") else local base_port=$((8000 + g * NODES_PER_GROUP)) start_pd_group "$g" "$base_port" "${gnodes[@]}" fi done } wait_all_groups() { if (( NUM_PREFILL == 0 )); then for g in $(seq 0 $((ROUTE_SIZE - 1))); do wait_for_server "${GROUP_IPS[$g]}" "${GROUP_PORTS[$g]}" "group-${g}" || exit 1 done else for g in $(seq 0 $((ROUTE_SIZE - 1))); do local goffset=$((g * NODES_PER_GROUP)) local base_port=$((8000 + g * NODES_PER_GROUP)) for i in $(seq 0 $((NUM_PREFILL - 1))); do local ip; ip=$(getent ahostsv4 "${NODES[$((goffset + i))]}" | head -1 | awk '{print $1}') wait_for_server "${ip}" "$((base_port + i))" "g${g}-prefill-${i}" || exit 1 done for i in $(seq 0 $((NUM_DECODE - 1))); do local ip; ip=$(getent ahostsv4 "${NODES[$((goffset + NUM_PREFILL + i))]}" | head -1 | awk '{print $1}') wait_for_server "${ip}" "$((base_port + NUM_PREFILL + i))" "g${g}-decode-${i}" || exit 1 done done fi } print_summary() { info "========================================" info "All instances ready — mode=${MODE}" if (( ROUTE_SIZE > 1 || NUM_PREFILL > 0 )); then info " Router: http://${HEAD_IP}:${ROUTER_PORT}" info "" info " Test (via router):" info " curl http://${HEAD_IP}:${ROUTER_PORT}/v1/completions -H 'Content-Type: application/json' \\" info " -d '{\"model\":\"${MODEL}\",\"prompt\":\"Hello\",\"max_tokens\":64}'" else info " Endpoint: http://${HEAD_IP}:8000" info "" info " Test:" info " curl http://${HEAD_IP}:8000/v1/completions -H 'Content-Type: application/json' \\" info " -d '{\"model\":\"${MODEL}\",\"prompt\":\"Hello\",\"max_tokens\":64}'" fi info "========================================" } main() { trap cleanup EXIT cleanup load_image start_groups wait_all_groups if (( ROUTE_SIZE > 1 || NUM_PREFILL > 0 )); then start_router fi print_summary sleep infinity } main ================================================ FILE: src/security/vulnerability_.py ================================================ """Tests demonstrating security vulnerabilities and secure alternatives.""" import pytest import hmac import secrets import os class TestTimingAttack: """Demonstrate timing attack vulnerability in string comparison.""" def test_insecure_comparison(self): """Insecure comparison - vulnerable to timing attack.""" secret = b"correct_secret_token" def insecure_compare(a, b): """INSECURE: Returns early on mismatch.""" if len(a) != len(b): return False for x, y in zip(a, b): if x != y: return False return True # These comparisons take different amounts of time assert insecure_compare(secret, secret) == True assert insecure_compare(secret, b"wrong_secret_token!") == False assert insecure_compare(secret, b"c") == False def test_secure_comparison(self): """Secure comparison using hmac.compare_digest.""" secret = b"correct_secret_token" # Constant-time comparison - safe from timing attacks assert hmac.compare_digest(secret, secret) == True assert hmac.compare_digest(secret, b"wrong_secret_token!") == False class TestWeakRandom: """Demonstrate weak vs strong random number generation.""" def test_weak_random(self): """INSECURE: random module is predictable.""" import random # Mersenne Twister can be predicted after observing outputs random.seed(12345) values = [random.randint(0, 100) for _ in range(10)] # Same seed produces same sequence - predictable! random.seed(12345) values2 = [random.randint(0, 100) for _ in range(10)] assert values == values2 def test_secure_random(self): """SECURE: secrets module uses OS entropy.""" # Cryptographically secure random token1 = secrets.token_hex(16) token2 = secrets.token_hex(16) # Each call produces unique, unpredictable value assert token1 != token2 assert len(token1) == 32 class TestSQLInjection: """Demonstrate SQL injection vulnerability.""" def test_vulnerable_query_building(self): """INSECURE: String formatting in SQL.""" def build_query_insecure(username): # VULNERABLE: User input directly in query return f"SELECT * FROM users WHERE username = '{username}'" # Normal input query = build_query_insecure("alice") assert query == "SELECT * FROM users WHERE username = 'alice'" # Malicious input - SQL injection! malicious = "admin' OR '1'='1" query = build_query_insecure(malicious) # This would return ALL users! assert "OR '1'='1'" in query def test_parameterized_query(self): """SECURE: Parameterized queries prevent injection.""" def build_query_secure(username): # Parameters are escaped by the database driver return ("SELECT * FROM users WHERE username = ?", (username,)) query, params = build_query_secure("admin' OR '1'='1") # The malicious input is treated as a literal string assert params == ("admin' OR '1'='1",) class TestCommandInjection: """Demonstrate command injection vulnerability.""" def test_vulnerable_shell_command(self): """INSECURE: User input in shell command.""" def build_command_insecure(filename): # VULNERABLE: Shell injection possible return f"cat {filename}" # Malicious input malicious = "file.txt; rm -rf /" cmd = build_command_insecure(malicious) # Would execute: cat file.txt; rm -rf / assert "; rm -rf /" in cmd def test_secure_command(self): """SECURE: Use argument list, not shell string.""" import shlex def build_command_secure(filename): # Validate and use list of arguments if not filename.replace(".", "").replace("_", "").isalnum(): raise ValueError("Invalid filename") return ["cat", filename] # Malicious input is rejected with pytest.raises(ValueError): build_command_secure("file.txt; rm -rf /") # Valid input works cmd = build_command_secure("file.txt") assert cmd == ["cat", "file.txt"] class TestPickleVulnerability: """Demonstrate pickle deserialization vulnerability.""" def test_pickle_code_execution(self): """INSECURE: Pickle can execute arbitrary code.""" import pickle class MaliciousPayload: def __reduce__(self): # This would execute when unpickled! # return (os.system, ("echo HACKED",)) # For safety, we just return a harmless function return (str, ("PAYLOAD_EXECUTED",)) # Attacker creates malicious pickle malicious_data = pickle.dumps(MaliciousPayload()) # Victim unpickles - code executes! result = pickle.loads(malicious_data) assert result == "PAYLOAD_EXECUTED" def test_safe_json(self): """SECURE: JSON cannot execute code.""" import json data = {"user": "alice", "role": "admin"} serialized = json.dumps(data) deserialized = json.loads(serialized) assert deserialized == data # JSON only supports basic types - no code execution class TestHardcodedSecrets: """Demonstrate hardcoded secrets vulnerability.""" def test_environment_variables(self): """SECURE: Use environment variables for secrets.""" # Set a test secret os.environ["TEST_API_KEY"] = "secret_key_123" # Read from environment api_key = os.environ.get("TEST_API_KEY") assert api_key == "secret_key_123" # Clean up del os.environ["TEST_API_KEY"] # Missing secret should be handled missing = os.environ.get("NONEXISTENT_KEY") assert missing is None class TestAESModes: """Demonstrate AES mode vulnerabilities.""" def test_ecb_mode_pattern_leak(self): """INSECURE: ECB mode leaks patterns in plaintext.""" # ECB encrypts identical blocks to identical ciphertext # This reveals patterns in the data # Simulated ECB behavior (without actual crypto) def fake_ecb_encrypt(blocks): # Same input block -> same output (pattern leak!) block_map = {} result = [] for i, block in enumerate(blocks): if block not in block_map: block_map[block] = f"cipher_{i}" result.append(block_map[block]) return result # Repeated plaintext blocks blocks = ["AAAA", "BBBB", "AAAA", "CCCC", "AAAA"] encrypted = fake_ecb_encrypt(blocks) # Pattern is visible! Same plaintext -> same ciphertext assert encrypted[0] == encrypted[2] == encrypted[4] def test_cbc_needs_authentication(self): """AES-CBC without authentication is vulnerable.""" # CBC provides confidentiality but not integrity # Attacker can modify ciphertext without detection # This enables padding oracle attacks # AES-GCM provides both confidentiality AND integrity # Tampering is detected and rejected pass class TestPasswordStorage: """Demonstrate password storage vulnerabilities.""" def test_weak_hash(self): """INSECURE: MD5/SHA1 for passwords.""" import hashlib password = "password123" # INSECURE: Fast hashes are easily brute-forced md5_hash = hashlib.md5(password.encode()).hexdigest() sha1_hash = hashlib.sha1(password.encode()).hexdigest() # These can be cracked in seconds with rainbow tables assert len(md5_hash) == 32 assert len(sha1_hash) == 40 def test_secure_password_hash(self): """SECURE: Use slow, salted password hashing.""" # In production, use argon2-cffi or bcrypt # This is a simplified demonstration import hashlib password = "password123" salt = secrets.token_bytes(16) # PBKDF2 with high iteration count (still not ideal, use Argon2) hash_value = hashlib.pbkdf2_hmac( "sha256", password.encode(), salt, iterations=100000 ) assert len(hash_value) == 32 assert len(salt) == 16