Repository: Amaindex/asyncio-socks-server Branch: main Commit: b83d575ecdbd Files: 85 Total size: 314.2 KB Directory structure: gitextract_ufvktzax/ ├── .dockerignore ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ └── feature_request.md │ └── workflows/ │ ├── docker.yml │ ├── release.yml │ └── tests.yml ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── README.zh-CN.md ├── docs/ │ ├── addon-model.md │ ├── addon-model.zh-CN.md │ ├── addon-recipes.md │ ├── addon-recipes.zh-CN.md │ ├── architecture.md │ ├── architecture.zh-CN.md │ ├── public-api.md │ └── public-api.zh-CN.md ├── pyproject.toml ├── src/ │ └── asyncio_socks_server/ │ ├── __init__.py │ ├── __main__.py │ ├── addons/ │ │ ├── __init__.py │ │ ├── auth.py │ │ ├── base.py │ │ ├── chain.py │ │ ├── ip_filter.py │ │ ├── logger.py │ │ ├── manager.py │ │ ├── stats.py │ │ ├── traffic.py │ │ └── udp_over_tcp_entry.py │ ├── cli.py │ ├── client/ │ │ ├── __init__.py │ │ └── client.py │ ├── core/ │ │ ├── __init__.py │ │ ├── address.py │ │ ├── logging.py │ │ ├── protocol.py │ │ ├── socket.py │ │ └── types.py │ ├── py.typed │ └── server/ │ ├── __init__.py │ ├── connection.py │ ├── server.py │ ├── tcp_relay.py │ ├── udp_over_tcp.py │ ├── udp_over_tcp_exit.py │ └── udp_relay.py └── tests/ ├── __init__.py ├── conftest.py ├── e2e_helpers.py ├── test_addon_builtins.py ├── test_addon_builtins_extended.py ├── test_addon_chain.py ├── test_addon_edge_cases.py ├── test_addon_manager.py ├── test_addon_stats.py ├── test_cli.py ├── test_client.py ├── test_client_edge_cases.py ├── test_concurrent.py ├── test_connection.py ├── test_core_address.py ├── test_core_protocol.py ├── test_core_socket.py ├── test_core_types.py ├── test_e2e.py ├── test_e2e_auth_chain.py ├── test_e2e_data_paths.py ├── test_e2e_lifecycle.py ├── test_e2e_policy_errors.py ├── test_flow.py ├── test_ipv6.py ├── test_logging.py ├── test_protocol_robustness.py ├── test_server.py ├── test_server_errors.py ├── test_server_lifecycle.py ├── test_tcp_relay.py ├── test_udp_associate_hook.py ├── test_udp_over_tcp.py ├── test_udp_over_tcp_e2e.py ├── test_udp_over_tcp_exit.py └── test_udp_relay.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .dockerignore ================================================ .git .github .pytest_cache .ruff_cache .venv __pycache__ *.pyc build dist *.egg-info docs tests ax-spec ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve title: "[BUG]" labels: bug assignees: Amaindex --- **Describe the bug** A clear and concise description of what the bug is. **To Reproduce** Steps to reproduce the behavior: 1. 2. 3. 4. **Expected behavior** A clear and concise description of what you expected to happen. **Screenshots** If applicable, add screenshots to help explain your problem. **Desktop (please complete the information):** - OS: [e.g. Win10] - Version: [e.g. 1.0.0] **Additional context** Add any other context about the problem here. ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest an idea for this project title: "[FEATURE]" labels: enhancement assignees: Amaindex --- **Is your feature request related to a problem? Please describe.** A clear and concise description of what the problem is. **Describe the solution you'd like** A clear and concise description of what you want to happen. **Describe alternatives you've considered** A clear and concise description of any alternative solutions or features you've considered. **Additional context** Add any other context or screenshots about the feature request here. ================================================ FILE: .github/workflows/docker.yml ================================================ name: Docker on: push: branches: [main] pull_request: branches: [main] release: types: [published] workflow_dispatch: permissions: contents: read concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true env: IMAGE_NAME: amaindex/asyncio-socks-server DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} DOCKERHUB_PASSWORD: ${{ secrets.DOCKERHUB_PASSWORD }} jobs: docker: name: Build Docker image runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: docker/setup-buildx-action@v3 - uses: docker/metadata-action@v5 id: meta with: images: ${{ env.IMAGE_NAME }} tags: | type=ref,event=branch type=ref,event=pr type=semver,pattern={{version}} type=semver,pattern={{major}}.{{minor}} - uses: docker/login-action@v3 if: github.event_name != 'pull_request' && env.DOCKERHUB_USERNAME != '' && env.DOCKERHUB_PASSWORD != '' with: username: ${{ env.DOCKERHUB_USERNAME }} password: ${{ env.DOCKERHUB_PASSWORD }} - uses: docker/build-push-action@v6 with: context: . platforms: linux/amd64,linux/arm64 push: ${{ github.event_name != 'pull_request' && env.DOCKERHUB_USERNAME != '' && env.DOCKERHUB_PASSWORD != '' }} tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} cache-from: type=gha cache-to: type=gha,mode=max ================================================ FILE: .github/workflows/release.yml ================================================ name: Release on: workflow_dispatch: release: types: [published] permissions: contents: read jobs: publish: name: Publish Python package runs-on: ubuntu-latest permissions: contents: read id-token: write steps: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v4 with: enable-cache: true cache-dependency-glob: pyproject.toml - name: Set up Python run: uv python install 3.12 - name: Build run: uv build --python 3.12 - name: Publish to PyPI run: uv publish ================================================ FILE: .github/workflows/tests.yml ================================================ name: Tests on: push: branches: [main] pull_request: branches: [main] workflow_dispatch: permissions: contents: read concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true jobs: quality: name: Quality runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v4 with: enable-cache: true cache-dependency-glob: pyproject.toml - name: Set up Python run: uv python install 3.12 - name: Install dependencies run: uv sync --group dev --python 3.12 - name: Lint run: uv run ruff check . - name: Format check run: uv run ruff format --check . - name: Type check run: uv run pyright test: name: Test Python ${{ matrix.python-version }} runs-on: ubuntu-latest strategy: fail-fast: false matrix: python-version: ["3.12", "3.13"] steps: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v4 with: enable-cache: true cache-dependency-glob: pyproject.toml - name: Set up Python ${{ matrix.python-version }} run: uv python install ${{ matrix.python-version }} - name: Install dependencies run: uv sync --group dev --python ${{ matrix.python-version }} - name: Test run: uv run pytest -q build: name: Build package runs-on: ubuntu-latest needs: [quality, test] steps: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v4 with: enable-cache: true cache-dependency-glob: pyproject.toml - name: Set up Python run: uv python install 3.12 - name: Build sdist and wheel run: uv build --python 3.12 - uses: actions/upload-artifact@v4 with: name: python-package path: dist/* ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # Distribution / packaging build/ dist/ *.egg-info/ # Environments .env .venv/ venv/ # Testing .pytest_cache/ htmlcov/ .coverage # Agent-local tooling .claude/ .codex/ # Development specs ax-spec/ # Type checkers .mypy_cache/ .pyre/ # uv uv.lock .uv-cache/ # IDE .idea/ .vscode/ .DS_Store ================================================ FILE: Dockerfile ================================================ FROM python:3.12-slim WORKDIR /app COPY pyproject.toml README.md LICENSE ./ COPY src ./src RUN pip install --no-cache-dir --root-user-action=ignore . \ && useradd --create-home --shell /usr/sbin/nologin appuser USER appuser EXPOSE 1080 ENTRYPOINT ["asyncio_socks_server"] CMD ["--host", "0.0.0.0", "--port", "1080"] ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2021 Amaindex Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # asyncio-socks-server [![Tests](https://github.com/Amaindex/asyncio-socks-server/actions/workflows/tests.yml/badge.svg)](https://github.com/Amaindex/asyncio-socks-server/actions/workflows/tests.yml) [![Docker](https://github.com/Amaindex/asyncio-socks-server/actions/workflows/docker.yml/badge.svg)](https://github.com/Amaindex/asyncio-socks-server/actions/workflows/docker.yml) [![Python](https://img.shields.io/badge/python-3.12%2B-blue)](pyproject.toml) [![License](https://img.shields.io/badge/license-MIT-green)](LICENSE) SOCKS5 server with async Python addon hooks. [Docs](#docs) · [Architecture](docs/architecture.md) · [Addon recipes](docs/addon-recipes.md) · [Addon model](docs/addon-model.md) · [Public API](docs/public-api.md) · [简体中文](README.zh-CN.md) ## Install ```shell pip install asyncio-socks-server ``` Docker images are versioned: ```shell docker run --rm -p 1080:1080 amaindex/asyncio-socks-server:1.3.1 ``` ## Run ```shell asyncio_socks_server asyncio_socks_server --host 127.0.0.1 --port 9050 asyncio_socks_server --auth user:pass ``` CLI flags: | Flag | Default | Meaning | |------|---------|---------| | `--host` | `::` | Bind address | | `--port` | `1080` | Bind port | | `--auth` | None | `username:password` | | `--log-level` | `INFO` | `DEBUG`, `INFO`, `WARNING`, `ERROR` | ## Use from Python ```python from asyncio_socks_server import Server Server(host="::", port=1080).run() ``` Addons are optional. Add only the behavior you need: | Goal | Addons | |------|--------| | Runtime counters and active flows | `FlowStats` + `StatsAPI` | | Closed-flow usage audit | `FlowAudit` + `StatsAPI` | | TCP chain proxying | `ChainRouter` | | UDP chain proxying | `UdpOverTcpEntry` + `UdpOverTcpExitServer` | | Auth, source policy, logs | `FileAuth`, `IPFilter`, `Logger` | Runtime counters and audit API: ```python from asyncio_socks_server import FlowAudit, FlowStats, Server, StatsAPI audit = FlowAudit() stats = FlowStats() server = Server( addons=[ audit, stats, StatsAPI(stats=stats, audit=audit, host="127.0.0.1", port=9900), ], ) server.run() ``` Addon order is execution order. Built-in addons are opt-in; adding `StatsAPI` is what starts an HTTP listener. `FlowStats` has no network side effects. Use its `snapshot()` and `flows()` methods directly, or pair it with `StatsAPI` for a small local HTTP API. `FlowAudit` records closed-flow usage in memory and can be exposed through `StatsAPI` for Kafra-like usage audit summaries. For task-oriented examples, see [Addon recipes](docs/addon-recipes.md). ## Model The core handles SOCKS5 parsing, relay, and hook dispatch. Addons handle policy. Hook dispatch has three models: | Model | Hooks | Contract | |-------|-------|----------| | Competitive | `on_auth`, `on_connect`, `on_udp_associate` | First non-`None` result wins | | Pipeline | `on_data` | Output from one addon becomes input to the next | | Observational | `on_start`, `on_stop`, `on_flow_close`, `on_error` | All applicable addons run | Built-ins: - `ChainRouter` for TCP chain proxying - `UdpOverTcpEntry` and `UdpOverTcpExitServer` for UDP chain proxying - `FlowStats` for in-memory flow statistics - `FlowAudit` for closed-flow usage audit summaries - `StatsAPI` as an opt-in HTTP API around `FlowStats` - `StatsServer` as a backward-compatible name for `StatsAPI` - `TrafficCounter`, `FileAuth`, `IPFilter`, `Logger` ## Architecture sketch ```text Client ── SOCKS5 ──▶ Server ──▶ Target │ ├─ auth / route hooks ├─ data pipeline hooks └─ flow close hooks ChainRouter: Client ──▶ A ──▶ B ──▶ C ──▶ Target ``` ## Chain proxying Each node only knows its next hop: ```python # A ─▶ B ─▶ C ─▶ target Server(addons=[ChainRouter("B:1080")]) # A Server(addons=[ChainRouter("C:1080")]) # B Server() # C ``` UDP chain proxying uses TCP between proxy nodes: ```python from asyncio_socks_server import Server, UdpOverTcpEntry, UdpOverTcpExitServer entry = Server(addons=[UdpOverTcpEntry("exit-host:9020")]) exit_server = UdpOverTcpExitServer(host="::", port=9020) ``` ## Client ```python from asyncio_socks_server import Address, connect conn = await connect( proxy_addr=Address("127.0.0.1", 1080), target_addr=Address("93.184.216.34", 443), ) conn.writer.write(b"hello") await conn.writer.drain() data = await conn.reader.read(4096) ``` ## API surface Stable imports live at the package root: ```python from asyncio_socks_server import ( Addon, Address, ChainRouter, Flow, FlowAudit, FlowStats, Server, StatsAPI, StatsServer, UdpOverTcpEntry, UdpOverTcpExitServer, connect, ) ``` Root exports are the 1.x compatibility contract. Submodules remain importable. ## Docs | Document | Scope | |----------|-------| | [Architecture](docs/architecture.md) | Core flow, relay design, UDP-over-TCP, Flow context | | [Addon recipes](docs/addon-recipes.md) | Goal-oriented addon combinations | | [Addon model](docs/addon-model.md) | Hook contracts, dispatch semantics, built-in addons | | [Public API](docs/public-api.md) | 1.x compatibility surface | ## Development ```shell git clone https://github.com/Amaindex/asyncio-socks-server.git cd asyncio-socks-server uv sync uv run ruff check . uv run ruff format --check . uv run pyright uv run pytest uv build ``` ## Release GitHub Actions tests Python 3.12 and 3.13, builds the Python package, and builds Docker images. Create a GitHub Release from a tag such as `v1.3.1`. The release workflow publishes the Python package. The Docker workflow publishes semver image tags. ## License MIT ================================================ FILE: README.zh-CN.md ================================================ # asyncio-socks-server [![Tests](https://github.com/Amaindex/asyncio-socks-server/actions/workflows/tests.yml/badge.svg)](https://github.com/Amaindex/asyncio-socks-server/actions/workflows/tests.yml) [![Docker](https://github.com/Amaindex/asyncio-socks-server/actions/workflows/docker.yml/badge.svg)](https://github.com/Amaindex/asyncio-socks-server/actions/workflows/docker.yml) [![Python](https://img.shields.io/badge/python-3.12%2B-blue)](pyproject.toml) [![License](https://img.shields.io/badge/license-MIT-green)](LICENSE) 带 async Python addon hooks 的 SOCKS5 server。 [文档](#文档) · [架构](docs/architecture.zh-CN.md) · [Addon recipes](docs/addon-recipes.zh-CN.md) · [Addon 模型](docs/addon-model.zh-CN.md) · [公共 API](docs/public-api.zh-CN.md) · [English](README.md) ## 安装 ```shell pip install asyncio-socks-server ``` Docker image 使用明确版本: ```shell docker run --rm -p 1080:1080 amaindex/asyncio-socks-server:1.3.1 ``` ## 运行 ```shell asyncio_socks_server asyncio_socks_server --host 127.0.0.1 --port 9050 asyncio_socks_server --auth user:pass ``` CLI 参数: | 参数 | 默认值 | 含义 | |------|--------|------| | `--host` | `::` | 监听地址 | | `--port` | `1080` | 监听端口 | | `--auth` | 无 | `username:password` | | `--log-level` | `INFO` | `DEBUG`、`INFO`、`WARNING`、`ERROR` | ## Python API ```python from asyncio_socks_server import Server Server(host="::", port=1080).run() ``` Addon 是可选的。只添加你需要的行为: | 目标 | Addons | |------|--------| | 运行计数和活跃 flows | `FlowStats` + `StatsAPI` | | 已关闭 flow 的用量审计 | `FlowAudit` + `StatsAPI` | | TCP 链式代理 | `ChainRouter` | | UDP 链式代理 | `UdpOverTcpEntry` + `UdpOverTcpExitServer` | | 认证、来源策略、日志 | `FileAuth`、`IPFilter`、`Logger` | 运行计数和审计 API: ```python from asyncio_socks_server import FlowAudit, FlowStats, Server, StatsAPI audit = FlowAudit() stats = FlowStats() server = Server( addons=[ audit, stats, StatsAPI(stats=stats, audit=audit, host="127.0.0.1", port=9900), ], ) server.run() ``` Addon 顺序就是执行顺序。内置 addon 都是显式 opt-in;只有加入 `StatsAPI` 才会启动 HTTP listener。 `FlowStats` 没有网络副作用。使用它的 `snapshot()` 和 `flows()` 方法, 可以自行搭建 HTTP API、metrics exporter 或日志管道,也可以搭配 `StatsAPI` 使用一个小型本地 HTTP API。 `FlowAudit` 在内存中记录已关闭 flow 的用量,可通过 `StatsAPI` 暴露类似 Kafra 的用量审计摘要。 按目标组合 addon 的例子见 [Addon recipes](docs/addon-recipes.zh-CN.md)。 ## 模型 核心处理 SOCKS5 解析、中继和 hook 调度。策略由 addon 处理。 Hook 调度有三种模型: | 模型 | Hooks | 契约 | |------|-------|------| | 竞争型 | `on_auth`、`on_connect`、`on_udp_associate` | 第一个非 `None` 结果获胜 | | 管道型 | `on_data` | 前一个 addon 的输出成为下一个 addon 的输入 | | 观察型 | `on_start`、`on_stop`、`on_flow_close`、`on_error` | 所有适用 addon 都会执行 | 内置 addon: - `ChainRouter`:TCP 链式代理 - `UdpOverTcpEntry` 和 `UdpOverTcpExitServer`:UDP 链式代理 - `FlowStats`:内存 flow 统计 - `FlowAudit`:已关闭 flow 的用量审计摘要 - `StatsAPI`:基于 `FlowStats` 的显式 opt-in HTTP API - `StatsServer`:`StatsAPI` 的向后兼容名称 - `TrafficCounter`、`FileAuth`、`IPFilter`、`Logger` ## 架构简图 ```text Client ── SOCKS5 ──▶ Server ──▶ Target │ ├─ auth / route hooks ├─ data pipeline hooks └─ flow close hooks ChainRouter: Client ──▶ A ──▶ B ──▶ C ──▶ Target ``` ## 链式代理 每个节点只知道自己的下一跳: ```python # A ─▶ B ─▶ C ─▶ target Server(addons=[ChainRouter("B:1080")]) # A Server(addons=[ChainRouter("C:1080")]) # B Server() # C ``` UDP 链式代理在代理节点之间使用 TCP: ```python from asyncio_socks_server import Server, UdpOverTcpEntry, UdpOverTcpExitServer entry = Server(addons=[UdpOverTcpEntry("exit-host:9020")]) exit_server = UdpOverTcpExitServer(host="::", port=9020) ``` ## Client ```python from asyncio_socks_server import Address, connect conn = await connect( proxy_addr=Address("127.0.0.1", 1080), target_addr=Address("93.184.216.34", 443), ) conn.writer.write(b"hello") await conn.writer.drain() data = await conn.reader.read(4096) ``` ## API 面 稳定导入面在包根: ```python from asyncio_socks_server import ( Addon, Address, ChainRouter, Flow, FlowAudit, FlowStats, Server, StatsAPI, StatsServer, UdpOverTcpEntry, UdpOverTcpExitServer, connect, ) ``` 包根导出是 1.x 兼容性契约。子模块仍可导入。 ## 文档 | 文档 | 范围 | |------|------| | [架构](docs/architecture.zh-CN.md) | 核心流程、relay 设计、UDP-over-TCP、Flow context | | [Addon recipes](docs/addon-recipes.zh-CN.md) | 按目标组合 addon 的示例 | | [Addon 模型](docs/addon-model.zh-CN.md) | Hook 契约、调度语义、内置 addon | | [公共 API](docs/public-api.zh-CN.md) | 1.x 兼容面 | ## 开发 ```shell git clone https://github.com/Amaindex/asyncio-socks-server.git cd asyncio-socks-server uv sync uv run ruff check . uv run ruff format --check . uv run pyright uv run pytest uv build ``` ## 发布 GitHub Actions 测试 Python 3.12 和 3.13,构建 Python package,并构建 Docker images。 从 `v1.3.1` 这样的 tag 创建 GitHub Release。Release workflow 发布 Python package。Docker workflow 发布 semver image tags。 ## License MIT ================================================ FILE: docs/addon-model.md ================================================ # Addon Model [README](../README.md) · [Architecture](architecture.md) · [Addon recipes](addon-recipes.md) · [Public API](public-api.md) · [简体中文](addon-model.zh-CN.md) Addons are Python classes with optional async methods. The server calls them at defined points in the SOCKS5 flow. This document explains dispatch semantics. If you already know what you want to build, start with [Addon recipes](addon-recipes.md). ## Execution Models A single dispatch rule is not enough: - Authentication and routing need first-match-wins. - Data processing needs output-to-input chaining. - Lifecycle events need all applicable addons to run. The manager uses three models: | Model | Semantics | When to use | Hooks | |-------|-----------|-------------|-------| | Competitive | First non-`None` wins, rest skipped | Mutually exclusive decisions | `on_auth`, `on_connect`, `on_udp_associate` | | Pipeline | Sequential, output→input chaining | Data transformation chains | `on_data` | | Observational | All called where applicable; flow-close/error exceptions are caught | Logging, monitoring, cleanup | `on_start`, `on_stop`, `on_flow_close`, `on_error` | ## Hook API All methods are optional — unimplemented hooks have no effect. ```python class Addon: # Lifecycle (observational) async def on_start(self) -> None: """Server started.""" async def on_stop(self) -> None: """Server stopped. Flush buffers, write stats.""" # Authentication (competitive) async def on_auth(self, username: str, password: str) -> bool | None: """True = allow, False = deny, None = abstain.""" # Connection interception (competitive) async def on_connect(self, flow: Flow) -> Connection | None: """Return Connection to intercept, None to abstain, raise to deny.""" async def on_udp_associate(self, flow: Flow) -> UdpRelayBase | None: """Return UdpRelayBase to intercept, None to abstain.""" # Data transformation (pipeline) async def on_data(self, direction: Direction, data: bytes, flow: Flow) -> bytes | None: """Return bytes to write, None to drop this chunk, raise to abort.""" # Teardown (observational) async def on_flow_close(self, flow: Flow) -> None: """Connection closed. Final stats available in flow.""" async def on_error(self, error: Exception) -> None: """Error occurred. For logging/monitoring only.""" ``` ### Return Value Contract Competitive and pipeline hooks use different `None` semantics: | Hook kind | Return | Meaning | |-----------|--------|---------| | Competitive | `None` | Abstain — let the next addon or default behavior decide | | Competitive | non-`None` | Win — use the returned value as the result | | Pipeline `on_data` | `bytes` | Write these bytes and pass them to the next addon | | Pipeline `on_data` | `None` | Drop this chunk and stop the pipeline | | Any | raise exception | Deny/reject/abort the current operation | Addons can share a list without coordinating if they use different hooks. ## Competitive Dispatch First non-`None` wins. Remaining addons are skipped. ``` on_auth("admin", "secret"): FileAuth → True ← wins, stops here IPFilter → (not called) Logger → (not called) ``` ``` on_auth("unknown", "pass"): FileAuth → False ← explicit deny IPFilter → (not called) ``` ``` on_auth("guest", "pass"): FileAuth → None ← abstain (user not in file) IPFilter → None ← abstain (IP not relevant for auth) → kernel uses default: no auth required → allow ``` Raising an exception rejects the operation. The client receives a SOCKS5 error reply. ## Pipeline Dispatch Sequential, output-chained. Returning `None` breaks the pipeline (data is dropped, subsequent addons are not called). ``` on_data(up, b"hello", flow): UpperAddon → b"HELLO" ← transform TrafficLogger → b"HELLO" ← pass through by returning input unchanged AppendNull → b"HELLO\x00" ← transform → write b"HELLO\x00" to target ``` ``` on_data(down, response, flow): DropAddon → None ← drops data, pipeline breaks UpperAddon → (not called) → nothing written to client ``` Pipeline order is addon list order. ## Observational Dispatch All addons called. Exceptions caught and not propagated. ``` on_flow_close(flow): TrafficCounter → aggregates bytes (may raise on write error) Logger → logs connection stats → all called, any exceptions logged but suppressed ``` This keeps teardown and monitoring isolated from individual addon failures. ## Built-in Addons | Addon | Primary role | Starts network listeners | |-------|--------------|--------------------------| | `ChainRouter` | TCP next-hop routing | No | | `UdpOverTcpEntry` | UDP-over-TCP entry routing | No | | `UdpOverTcpExitServer` | UDP-over-TCP exit service | Yes, as a separate server | | `FlowStats` | Runtime counters and active flow snapshots | No | | `FlowAudit` | Closed-flow usage audit window | No | | `StatsAPI` | Optional HTTP presentation for stats and audit | Yes, only when added | | `StatsServer` | Backward-compatible name for `StatsAPI` | Yes, only when added | | `TrafficCounter` | Minimal closed-flow byte totals | No | | `FileAuth` | Username/password auth from JSON | No | | `IPFilter` | Source IP allow/block policy | No | | `Logger` | Connection and data logging | No | All built-in addons are opt-in. CLI mode starts a direct SOCKS5 server; addon composition is configured from Python. ### ChainRouter — TCP Chain Proxying ```python class ChainRouter(Addon): def __init__(self, next_hop: str): ... async def on_connect(self, flow): conn = await client.connect(self.next_hop, flow.dst) return conn ``` `ChainRouter` returns a `Connection` to the next-hop SOCKS5 server. The server relays through the returned connection. Each node only knows its next hop: ``` User → [A: ChainRouter("B:1080")] → [B: ChainRouter("C:1080")] → [C: direct] → Target ``` ### UdpOverTcpEntry — UDP Chain Proxying UDP chain proxying reuses the same competitive hook (`on_udp_associate`), but returns a bridge that encapsulates UDP datagrams as TCP frames instead of a `Connection`. ``` Client UDP → Entry addon (encapsulate) → TCP chain → Exit server (decapsulate) → UDP → Target ``` Middle nodes see TCP bytes. ### TrafficCounter — Stats Aggregation ```python class TrafficCounter(Addon): async def on_connect(self, flow): self.connections += 1 async def on_flow_close(self, flow): self.bytes_up += flow.bytes_up self.bytes_down += flow.bytes_down ``` `TrafficCounter` aggregates in `on_flow_close`. `Flow` already has cumulative byte counters, and UDP does not pass through `on_data`. ### FlowStats — Flow Statistics Infrastructure ```python from asyncio_socks_server import FlowStats, Server stats = FlowStats() server = Server(addons=[stats]) ``` `FlowStats` has no network side effects. It records flow lifecycle data through addon hooks and exposes Python methods for application-specific presentation: | Method | Content | |--------|---------| | `snapshot()` | Aggregate counters, rates, errors, and active flows | | `flows()` | Active flows and recent closed flow snapshots | | `errors()` | Error counters and recent errors | Use `FlowStats` as infrastructure for your own HTTP API, Prometheus exporter, file audit stream, or control-plane integration. ### FlowAudit — Usage Audit Infrastructure ```python from asyncio_socks_server import FlowAudit, Server audit = FlowAudit() server = Server(addons=[audit]) ``` `FlowAudit` has no network side effects. It records closed flows in memory and aggregates usage by source host and target host: | Method | Content | |--------|---------| | `snapshot()` | Kafra-like audit summary with period, records, totals, devices, and traffic | | `reset()` | Clear the in-memory audit window | The audit window resets when the process restarts. Use an application-specific sink if you need durable long-term audit storage. ### StatsAPI — Opt-in HTTP API ```python from asyncio_socks_server import FlowAudit, FlowStats, Server, StatsAPI audit = FlowAudit() stats = FlowStats() api = StatsAPI(stats=stats, audit=audit, host="127.0.0.1", port=9900) server = Server(addons=[audit, stats, api]) ``` `StatsAPI` is a simple stdlib HTTP wrapper around `FlowStats` and optional `FlowAudit`. It starts a listener only when explicitly added to the addon list: | Endpoint | Content | |----------|---------| | `GET /health` | Liveness response | | `GET /stats` | `FlowStats.snapshot()` | | `GET /flows` | `FlowStats.flows()` | | `GET /errors` | `FlowStats.errors()` | | `GET /audit?top=25&device=` | `FlowAudit.snapshot()` | | `POST /audit/refresh?top=25&device=` | Current `FlowAudit.snapshot()` for Kafra-like refresh flows | When constructed without a `FlowStats` instance, `StatsAPI` creates and owns one: ```python server = Server(addons=[StatsAPI(host="127.0.0.1", port=9900)]) ``` `StatsServer` remains as a backward-compatible name for `StatsAPI`. Put `FlowStats` or owning `StatsAPI` early in the addon list. It observes flow starts through competitive hooks. An earlier winning addon can prevent it from seeing a start event. `on_flow_close` still receives the final Flow snapshot. ### FileAuth — Multi-user Auth Reads a JSON file mapping usernames to passwords. Caches after first load. `FileAuth` is consulted only when the server negotiates username/password auth, so configure `Server(auth=...)` when using it. ### IPFilter — Source IP Access Control ```python IPFilter(allowed=["10.0.0.0/24"]) # or IPFilter(blocked=["10.0.0.5"]) ``` Reads `flow.src.host` in `on_connect`. Denied connections receive SOCKS5 `CONNECTION_NOT_ALLOWED` reply. ### Logger — Connection Logging Logs connection details and flow stats. It does not change proxy behavior. ## Custom Addon Patterns ### Selective Content Inspection ```python class ContentFilter(Addon): async def on_connect(self, flow): if flow.dst.port != 80: return # only inspect HTTP async def on_data(self, direction, data, flow): if direction == Direction.UP and b"forbidden-keyword" in data: raise Exception("blocked content") return data # pass through ``` ### Per-connection Rate Limiting ```python class RateLimiter(Addon): def __init__(self, max_bytes=1024 * 1024): # 1MB per connection self.max_bytes = max_bytes async def on_data(self, direction, data, flow): if flow.bytes_up + flow.bytes_down > self.max_bytes: raise Exception("rate limit exceeded") return data ``` ### Dynamic Next-hop Routing ```python class DynamicRouter(Addon): def __init__(self): self.routes = {} # domain pattern → next hop async def on_connect(self, flow): for pattern, hop in self.routes.items(): if pattern in flow.dst.host: return await client.connect(hop, flow.dst) ``` ## Dispatch Internals `AddonManager` skips unimplemented hooks by checking `type(addon).method is not Addon.method`. This avoids creating coroutines for base-class methods that do nothing — significant when processing thousands of chunks through `on_data`. Addon list order is execution order. There is no priority system or dependency resolution — if order matters, arrange the list accordingly. For hook signature and Flow compatibility, see [`public-api.md`](public-api.md). ================================================ FILE: docs/addon-model.zh-CN.md ================================================ # Addon 模型 [README](../README.zh-CN.md) · [架构](architecture.zh-CN.md) · [Addon recipes](addon-recipes.zh-CN.md) · [公共 API](public-api.zh-CN.md) · [English](addon-model.md) Addon 是包含可选 async 方法的 Python 类。Server 在 SOCKS5 流程的固定位置调用它们。 本文解释派发语义。如果你已经知道自己想搭建什么,先看 [Addon recipes](addon-recipes.zh-CN.md)。 ## 执行模型 单一派发规则不够: - 认证和路由需要第一个结果胜出。 - 数据处理需要输出到输入的链式传递。 - 生命周期事件需要调用所有适用 addon。 Manager 使用三种模型: | 模型 | 语义 | 何时使用 | Hook | |------|------|----------|------| | 竞争型 | 第一个非 `None` 胜出,后续跳过 | 互斥决策 | `on_auth`、`on_connect`、`on_udp_associate` | | 管道型 | 顺序执行,输出→输入链式传递 | 数据转换链 | `on_data` | | 观察型 | 按场景全部调用;flow-close/error 异常被捕获 | 日志、监控、清理 | `on_start`、`on_stop`、`on_flow_close`、`on_error` | ## Hook API 所有方法可选——未实现的 hook 不影响流程。 ```python class Addon: # 生命周期(观察型) async def on_start(self) -> None: """服务器启动。""" async def on_stop(self) -> None: """服务器停止。刷新缓冲、写入统计。""" # 认证(竞争型) async def on_auth(self, username: str, password: str) -> bool | None: """True = 放行,False = 拒绝,None = 不干预。""" # 连接拦截(竞争型) async def on_connect(self, flow: Flow) -> Connection | None: """返回 Connection 拦截,None 不干预,抛异常拒绝。""" async def on_udp_associate(self, flow: Flow) -> UdpRelayBase | None: """返回 UdpRelayBase 拦截,None 不干预。""" # 数据转换(管道型) async def on_data(self, direction: Direction, data: bytes, flow: Flow) -> bytes | None: """返回 bytes 写出,None 丢弃当前 chunk,抛异常中止。""" # 拆解(观察型) async def on_flow_close(self, flow: Flow) -> None: """连接关闭。最终统计在 flow 中。""" async def on_error(self, error: Exception) -> None: """发生异常。仅用于日志/监控。""" ``` ### 返回值契约 竞争型和管道型 hook 的 `None` 语义不同: | Hook 类型 | 返回 | 含义 | |----------|------|------| | 竞争型 | `None` | 弃权——让下一个 addon 或默认行为决定 | | 竞争型 | 非 `None` | 胜出——将返回值作为结果 | | 管道型 `on_data` | `bytes` | 写出这些字节,并继续传给下一个 addon | | 管道型 `on_data` | `None` | 丢弃当前 chunk,并停止管道 | | 任意 | 抛异常 | 拒绝/中止当前操作 | 如果 addon 使用不同 hook,可以共存而不需要互相协调。 ## 竞争型派发 第一个非 `None` 胜出。剩余 addon 跳过。 ``` on_auth("admin", "secret"): FileAuth → True ← 胜出,在此停止 IPFilter → (不调用) Logger → (不调用) ``` ``` on_auth("unknown", "pass"): FileAuth → False ← 显式拒绝 IPFilter → (不调用) ``` ``` on_auth("guest", "pass"): FileAuth → None ← 不干预(用户不在文件中) IPFilter → None ← 不干预(IP 与认证无关) → 内核使用默认行为:无需认证 → 放行 ``` 抛异常会拒绝当前操作。客户端收到 SOCKS5 错误回复。 ## 管道型派发 顺序执行,输出链式传递。返回 `None` 中断管道(数据丢弃,后续 addon 不调用)。 ``` on_data(up, b"hello", flow): UpperAddon → b"HELLO" ← 转换 TrafficLogger → b"HELLO" ← 通过返回原输入来放行 AppendNull → b"HELLO\x00" ← 转换 → 写入目标: b"HELLO\x00" ``` ``` on_data(down, response, flow): DropAddon → None ← 丢弃数据,管道中断 UpperAddon → (不调用) → 不向客户端写入任何内容 ``` 管道顺序即 addon 列表顺序。 ## 观察型派发 所有 addon 调用。异常被捕获不传播。 ``` on_flow_close(flow): TrafficCounter → 聚合字节(写入时可能抛异常) Logger → 记录连接统计 → 全部调用,任何异常被记录但被抑制 ``` 这把 teardown 和监控从单个 addon 的失败中隔离出来。 ## 内置 Addon | Addon | 主要角色 | 是否启动网络 listener | |-------|----------|-----------------------| | `ChainRouter` | TCP 下一跳路由 | 否 | | `UdpOverTcpEntry` | UDP-over-TCP 入口路由 | 否 | | `UdpOverTcpExitServer` | UDP-over-TCP 出口服务 | 是,作为独立 server | | `FlowStats` | 运行计数和活跃 flow 快照 | 否 | | `FlowAudit` | 已关闭 flow 的用量审计窗口 | 否 | | `StatsAPI` | stats/audit 的可选 HTTP 展示层 | 是,只有加入 addon 列表才启动 | | `StatsServer` | `StatsAPI` 的向后兼容名称 | 是,只有加入 addon 列表才启动 | | `TrafficCounter` | 最小的已关闭 flow 字节汇总 | 否 | | `FileAuth` | JSON 用户名/密码认证 | 否 | | `IPFilter` | 来源 IP allow/block 策略 | 否 | | `Logger` | 连接和数据日志 | 否 | 所有内置 addon 都是显式 opt-in。CLI 模式启动直连 SOCKS5 server;addon 组合通过 Python 配置。 ### ChainRouter — TCP 链式代理 ```python class ChainRouter(Addon): def __init__(self, next_hop: str): ... async def on_connect(self, flow): conn = await client.connect(self.next_hop, flow.dst) return conn ``` `ChainRouter` 返回到下一跳 SOCKS5 server 的 `Connection`。Server 通过返回的连接中继。 每个节点只知道自己下一跳: ``` 用户 → [A: ChainRouter("B:1080")] → [B: ChainRouter("C:1080")] → [C: 直连] → 目标 ``` ### UdpOverTcpEntry — UDP 链式代理 UDP 链式代理复用同一个竞争型 hook(`on_udp_associate`),但返回一个将 UDP 数据报封装为 TCP 帧的 bridge,而非 `Connection`。 ``` 客户端 UDP → 入口 addon(封装)→ TCP 链式 → 出口服务(拆封)→ UDP → 目标 ``` 中间节点只看到 TCP bytes。 ### TrafficCounter — 统计聚合 ```python class TrafficCounter(Addon): async def on_connect(self, flow): self.connections += 1 async def on_flow_close(self, flow): self.bytes_up += flow.bytes_up self.bytes_down += flow.bytes_down ``` `TrafficCounter` 在 `on_flow_close` 中聚合。`Flow` 已经有累计字节计数,且 UDP 不经过 `on_data`。 ### FlowStats — Flow 统计基础设施 ```python from asyncio_socks_server import FlowStats, Server stats = FlowStats() server = Server(addons=[stats]) ``` `FlowStats` 没有网络副作用。它通过 addon hooks 记录 flow 生命周期数据, 并暴露 Python 方法供应用自行决定展示方式: | 方法 | 内容 | |------|------| | `snapshot()` | 聚合计数、速率、错误和活跃 flow | | `flows()` | 活跃 flow 和最近关闭 flow 快照 | | `errors()` | 错误计数和最近错误 | 用 `FlowStats` 搭建自己的 HTTP API、Prometheus exporter、文件审计流或控制面集成。 ### FlowAudit — 用量审计基础设施 ```python from asyncio_socks_server import FlowAudit, Server audit = FlowAudit() server = Server(addons=[audit]) ``` `FlowAudit` 没有网络副作用。它在内存中记录已关闭 flow,并按 source host 和 target host 聚合用量: | 方法 | 内容 | |------|------| | `snapshot()` | 类似 Kafra audit 的摘要,包含 period、records、total、devices 和 traffic | | `reset()` | 清空当前内存审计窗口 | 进程重启后审计窗口会重置。如果需要长期留痕,应在应用层接入持久化 sink。 ### StatsAPI — 显式 opt-in HTTP API ```python from asyncio_socks_server import FlowAudit, FlowStats, Server, StatsAPI audit = FlowAudit() stats = FlowStats() api = StatsAPI(stats=stats, audit=audit, host="127.0.0.1", port=9900) server = Server(addons=[audit, stats, api]) ``` `StatsAPI` 是基于 `FlowStats` 和可选 `FlowAudit` 的简单标准库 HTTP wrapper。只有显式加入 addon 列表时才会启动 listener: | Endpoint | 内容 | |----------|------| | `GET /health` | 存活响应 | | `GET /stats` | `FlowStats.snapshot()` | | `GET /flows` | `FlowStats.flows()` | | `GET /errors` | `FlowStats.errors()` | | `GET /audit?top=25&device=` | `FlowAudit.snapshot()` | | `POST /audit/refresh?top=25&device=` | 返回当前 `FlowAudit.snapshot()`,用于类似 Kafra 的刷新流程 | 如果不传入 `FlowStats`,`StatsAPI` 会自己创建并托管一个: ```python server = Server(addons=[StatsAPI(host="127.0.0.1", port=9900)]) ``` `StatsServer` 作为 `StatsAPI` 的向后兼容名称保留。 建议把 `FlowStats` 或托管自身 stats 的 `StatsAPI` 放在 addon 列表靠前位置。它通过竞争型 hook 观察 flow start。更早胜出的 addon 会让它看不到 start 事件。`on_flow_close` 仍会收到最终 Flow 快照。 ### FileAuth — 多用户认证 从 JSON 文件读取用户名/密码映射。首次加载后缓存。只有 server 协商 username/password auth 时才会调用 `FileAuth`,因此使用它时需要配置 `Server(auth=...)`。 ### IPFilter — 源 IP 访问控制 ```python IPFilter(allowed=["10.0.0.0/24"]) # 或 IPFilter(blocked=["10.0.0.5"]) ``` 在 `on_connect` 中读取 `flow.src.host`。被拒绝的连接收到 SOCKS5 `CONNECTION_NOT_ALLOWED` 回复。 ### Logger — 连接日志 记录连接详情和流量统计。不改变代理行为。 ## 自定义 Addon 模式 ### 选择性内容检查 ```python class ContentFilter(Addon): async def on_connect(self, flow): if flow.dst.port != 80: return # 只检查 HTTP async def on_data(self, direction, data, flow): if direction == Direction.UP and b"forbidden-keyword" in data: raise Exception("blocked content") return data # 放行 ``` ### 每连接速率限制 ```python class RateLimiter(Addon): def __init__(self, max_bytes=1024 * 1024): # 每连接 1MB self.max_bytes = max_bytes async def on_data(self, direction, data, flow): if flow.bytes_up + flow.bytes_down > self.max_bytes: raise Exception("rate limit exceeded") return data ``` ### 动态下一跳路由 ```python class DynamicRouter(Addon): def __init__(self): self.routes = {} # 域名模式 → 下一跳 async def on_connect(self, flow): for pattern, hop in self.routes.items(): if pattern in flow.dst.host: return await client.connect(hop, flow.dst) ``` ## 派发内部机制 `AddonManager` 通过 `type(addon).method is not Addon.method` 检测子类是否重写了方法,跳过未重写的。这避免为基类的空方法创建协程——在处理数千个 chunk 经过 `on_data` 时影响显著。 Addon 列表顺序即执行顺序。没有优先级系统或依赖解析——如果顺序重要,自行安排列表。 Hook 签名和 Flow 语义的兼容性承诺见 [`public-api.zh-CN.md`](public-api.zh-CN.md)。 ================================================ FILE: docs/addon-recipes.md ================================================ # Addon Recipes [README](../README.md) · [Architecture](architecture.md) · [Addon model](addon-model.md) · [Public API](public-api.md) · [简体中文](addon-recipes.zh-CN.md) Use this page when choosing which addons to combine. Addons are opt-in and run in the order listed in `Server(addons=[...])`. ## Direct SOCKS5 Server No addons are required for a plain SOCKS5 server: ```python from asyncio_socks_server import Server Server(host="::", port=1080).run() ``` CLI mode is equivalent to this direct shape plus optional single-user auth. ## Runtime Counters Use `FlowStats` for counters and `StatsAPI` only if you want an HTTP endpoint: ```python from asyncio_socks_server import FlowStats, Server, StatsAPI stats = FlowStats() server = Server( addons=[ stats, StatsAPI(stats=stats, host="127.0.0.1", port=9900), ], ) server.run() ``` Endpoints: | Endpoint | Use | |----------|-----| | `GET /health` | Liveness | | `GET /stats` | Totals, rates, errors, active flows | | `GET /flows` | Active and recent closed flows | | `GET /errors` | Error counters | `FlowStats` should appear before competitive routing addons if you need flow start visibility. ## Usage Audit Use `FlowAudit` for closed-flow usage grouped by source host and target host: ```python from asyncio_socks_server import FlowAudit, Server, StatsAPI audit = FlowAudit() server = Server( addons=[ audit, StatsAPI(audit=audit, host="127.0.0.1", port=9900), ], ) server.run() ``` Endpoints: | Endpoint | Use | |----------|-----| | `GET /audit?top=25&device=` | Current in-memory audit window | | `POST /audit/refresh?top=25&device=` | Same snapshot, useful for control-plane refresh flows | The audit window is in-memory and resets when the process restarts. Add a custom sink if you need durable records. ## Runtime Counters Plus Audit This is the normal observability stack: ```python from asyncio_socks_server import FlowAudit, FlowStats, Server, StatsAPI audit = FlowAudit() stats = FlowStats() server = Server( addons=[ audit, stats, StatsAPI(stats=stats, audit=audit, host="127.0.0.1", port=9900), ], ) server.run() ``` `StatsAPI` is a presentation layer. It does not collect stats or audit data by itself unless it owns an internal `FlowStats`; pass explicit `FlowStats` and `FlowAudit` instances when other code also needs direct Python access. ## TCP Chain Proxy Use `ChainRouter` when this server should forward TCP CONNECT traffic through a downstream SOCKS5 server: ```python from asyncio_socks_server import ChainRouter, Server Server(addons=[ChainRouter("10.0.0.5:1080")]).run() ``` Each node only knows its next hop: ```python Server(addons=[ChainRouter("B:1080")]) # A Server(addons=[ChainRouter("C:1080")]) # B Server() # C ``` ## UDP Over TCP Chain Use `UdpOverTcpEntry` at the SOCKS-facing node and `UdpOverTcpExitServer` at the exit: ```python from asyncio_socks_server import Server, UdpOverTcpEntry, UdpOverTcpExitServer entry = Server(addons=[UdpOverTcpEntry("exit-host:9020")]) exit_server = UdpOverTcpExitServer(host="::", port=9020) ``` Middle chain nodes see TCP bytes. ## Auth, Source Policy, And Logs Use these independently or together: ```python from asyncio_socks_server import FileAuth, IPFilter, Logger, Server server = Server( auth=("_fallback_disabled_", "_fallback_disabled_"), addons=[ FileAuth("/etc/asyncio-socks-users.json"), IPFilter(allowed=["10.0.0.0/24"]), Logger(), ], ) server.run() ``` `FileAuth` is consulted only when server auth is enabled; the `auth` tuple forces username/password negotiation and remains a valid fallback credential, so set it deliberately. `IPFilter` accepts either `allowed` or `blocked`, not both. `Logger` observes traffic without changing routing. ## Compatibility Names `StatsServer` is a backward-compatible name for `StatsAPI`. New code should use `StatsAPI` because it describes the role more precisely. ================================================ FILE: docs/addon-recipes.zh-CN.md ================================================ # Addon Recipes [README](../README.zh-CN.md) · [架构](architecture.zh-CN.md) · [Addon 模型](addon-model.zh-CN.md) · [公共 API](public-api.zh-CN.md) · [English](addon-recipes.md) 当你需要选择 addon 组合时,从这里开始。Addon 都是显式 opt-in,并按 `Server(addons=[...])` 中的顺序执行。 ## 直连 SOCKS5 Server 普通 SOCKS5 server 不需要任何 addon: ```python from asyncio_socks_server import Server Server(host="::", port=1080).run() ``` CLI 模式等价于这个直连形态,加上可选的单用户认证。 ## 运行计数 用 `FlowStats` 收集计数;只有需要 HTTP endpoint 时才加 `StatsAPI`: ```python from asyncio_socks_server import FlowStats, Server, StatsAPI stats = FlowStats() server = Server( addons=[ stats, StatsAPI(stats=stats, host="127.0.0.1", port=9900), ], ) server.run() ``` Endpoints: | Endpoint | 用途 | |----------|------| | `GET /health` | 存活检查 | | `GET /stats` | 总量、速率、错误、活跃 flows | | `GET /flows` | 活跃和最近关闭 flows | | `GET /errors` | 错误计数 | 如果需要看到 flow start,`FlowStats` 应放在竞争型路由 addon 之前。 ## 用量审计 用 `FlowAudit` 按 source host 和 target host 聚合已关闭 flow 的用量: ```python from asyncio_socks_server import FlowAudit, Server, StatsAPI audit = FlowAudit() server = Server( addons=[ audit, StatsAPI(audit=audit, host="127.0.0.1", port=9900), ], ) server.run() ``` Endpoints: | Endpoint | 用途 | |----------|------| | `GET /audit?top=25&device=` | 当前内存审计窗口 | | `POST /audit/refresh?top=25&device=` | 同一份 snapshot,便于控制面做刷新流程 | 审计窗口在内存中,进程重启后会清空。如果需要长期留痕,应增加自定义 sink。 ## 运行计数加审计 这是常见的观测组合: ```python from asyncio_socks_server import FlowAudit, FlowStats, Server, StatsAPI audit = FlowAudit() stats = FlowStats() server = Server( addons=[ audit, stats, StatsAPI(stats=stats, audit=audit, host="127.0.0.1", port=9900), ], ) server.run() ``` `StatsAPI` 是展示层。除非它自己托管内部 `FlowStats`,否则它不直接收集 stats 或 audit 数据。当其他代码也需要 Python API 时,显式传入 `FlowStats` 和 `FlowAudit` 实例。 ## TCP 链式代理 当这个 server 需要把 TCP CONNECT 流量转发到下游 SOCKS5 server 时,使用 `ChainRouter`: ```python from asyncio_socks_server import ChainRouter, Server Server(addons=[ChainRouter("10.0.0.5:1080")]).run() ``` 每个节点只知道自己的下一跳: ```python Server(addons=[ChainRouter("B:1080")]) # A Server(addons=[ChainRouter("C:1080")]) # B Server() # C ``` ## UDP Over TCP 链式代理 在面向 SOCKS 的入口节点使用 `UdpOverTcpEntry`,在出口节点使用 `UdpOverTcpExitServer`: ```python from asyncio_socks_server import Server, UdpOverTcpEntry, UdpOverTcpExitServer entry = Server(addons=[UdpOverTcpEntry("exit-host:9020")]) exit_server = UdpOverTcpExitServer(host="::", port=9020) ``` 中间链路节点只看到 TCP bytes。 ## 认证、来源策略和日志 这些 addon 可以独立使用,也可以组合: ```python from asyncio_socks_server import FileAuth, IPFilter, Logger, Server server = Server( auth=("_fallback_disabled_", "_fallback_disabled_"), addons=[ FileAuth("/etc/asyncio-socks-users.json"), IPFilter(allowed=["10.0.0.0/24"]), Logger(), ], ) server.run() ``` 只有启用 server auth 时,`FileAuth` 才会被调用;`auth` tuple 用于强制 username/password 协商,且它本身仍是有效的 fallback 凭证,因此需要明确设置。 `IPFilter` 接受 `allowed` 或 `blocked`,不要同时传入。`Logger` 只观察流量, 不改变路由。 ## 兼容名称 `StatsServer` 是 `StatsAPI` 的向后兼容名称。新代码建议使用 `StatsAPI`, 因为这个名字更准确地表达它是展示层。 ================================================ FILE: docs/architecture.md ================================================ # Architecture [README](../README.md) · [Addon recipes](addon-recipes.md) · [Addon model](addon-model.md) · [Public API](public-api.md) · [简体中文](architecture.zh-CN.md) The core handles protocol parsing, relay, and hook dispatch. Addons handle policy and routing. Chain proxying, traffic counting, and access control are addon behavior. ## System Overview ```text SOCKS5 Client Server Remote ───────────── ────── ────── auth negotiation ────▶ parse_method_selection parse_username_password dispatch_auth (competitive) │ CONNECT/UDP request ─▶ parse_request → create Flow dispatch_connect / dispatch_udp_associate ├─ no addon ─────────────────▶ direct connect └─ ChainRouter ──────────────▶ client.connect(next_hop) │ bidirectional relay ─▶ dispatch_data (pipeline) per chunk flow.bytes_up/down per chunk or datagram │ connection close ────▶ dispatch_flow_close (observational) log stats from Flow ``` Chain proxying uses the same path. `dispatch_connect` returns a `Connection` to the next-hop SOCKS5 server instead of a direct TCP connection. ## Request Lifecycle Every request has three stages: | Stage | Entry | Core Action | Output | |-------|-------|-------------|--------| | Handshake | SOCKS5 client | Parse method selection + auth + request, create Flow | Flow with src/dst/protocol | | Relay | Flow + addon decision | Bidirectional data pump with addon pipeline, Flow tracks bytes | Data forwarded, bytes counted | | Teardown | Connection close | Log stats, dispatch `on_flow_close` | Addons get final stats | TCP and UDP share the hook lifecycle. TCP uses paired `_copy()` coroutines. UDP uses a shared socket and routing table. ### TCP Relay Data Flow ```text Client Server Target ────── ────── ────── │ │ │ │── handshake ─▶│ │ │ │ parse + auth + Flow │ │ │ │ │ │ dispatch_connect(flow) │ │ │ ├─ no addon ──▶ direct │ │ │ └─ ChainRouter ──▶ next hop │ │ │ │ │── data ──────▶│── _copy(client→target)───────────────▶│ │ │ dispatch_data(up, data, flow) │ │ │ flow.bytes_up += len(data) │ │ │ │ │◀─ response ───│◀── _copy(target→client) │ │ │ dispatch_data(down, data) │ │ │ flow.bytes_down += len(data) │ │ │ │ │── close ─────▶│── dispatch_flow_close(flow)──────────▶│ │ │ log: ↑1.2KB ↓45.6KB │ │ │ │ ``` ### UDP Relay Architecture Some SOCKS5 implementations create one outbound UDP socket per client source port. Long-running servers can accumulate sockets. This implementation uses one outbound socket and a bidirectional routing table: ```text Outbound: Client datagram ──▶ shared_socket.sendto(payload, target) route_map[("93.184.216.34", 443)] = ("10.0.0.1", 54321) flow.bytes_up += len(payload) Inbound: shared_socket.recvfrom() ──▶ lookup route_map ──▶ sendto(client, response) flow.bytes_down += len(response) ``` Routes expire by TTL. ## UDP-over-TCP Chaining UDP chain proxying does not use UDP between proxy nodes. Inter-node transport is TCP. Entry nodes encapsulate UDP datagrams as TCP frames. Exit nodes decapsulate them back to UDP. ```text Request: Client UDP ──▶ UdpOverTcpEntry ──▶ middle nodes ──▶ Exit server ──▶ raw UDP ──▶ Target encapsulate (TCP bytes) decapsulate UDP → TCP TCP → UDP Response: Target ──▶ raw UDP ──▶ Exit server ──▶ middle nodes ──▶ UdpOverTcpEntry ──▶ Client UDP encapsulate (TCP bytes) decapsulate UDP → TCP TCP → UDP ``` Frame format (4-byte length prefix + SOCKS5 encoded address + payload): ```text ┌──────────┬──────────────────┬─────────┐ │ Length │ Encoded Address │ Payload │ │ 4 bytes │ variable │ N bytes │ └──────────┴──────────────────┴─────────┘ ``` Properties: - Middle nodes only forward TCP CONNECT traffic. - `on_data` sees TCP bytes in both TCP and UDP-over-TCP cases. - UDP semantics remain at the client-entry and exit-target edges. - No per-hop UDP ASSOCIATE state is needed. ## Flow Context `Flow` is the per-connection context passed through hooks. ```python @dataclass class Flow: id: int # Monotonically increasing src: Address # Client address dst: Address # Target address protocol: Literal["tcp", "udp"] started_at: float # time.monotonic() bytes_up: int = 0 # Client → target (TCP: post-addon; UDP: raw payload) bytes_down: int = 0 # Target → client ``` Without `Flow`, data hooks have no connection identity. Byte counters also become easy to duplicate across relay and addon code. With `Flow`: - Bytes are counted once in relay code. - Hooks receive the same object for the connection lifecycle. - `on_flow_close` receives the final counters. Lifecycle: ``` on_connect / on_udp_associate(flow) → addon registers connection, gets identity on_data(direction, data, flow) → addon knows whose data, can read running stats └─ relay updates flow.bytes_* directly on_flow_close(flow) → addon gets final snapshot, can log/aggregate ``` ## IPv6 Dual-Stack Server listens on `::` with one `AF_INET6` socket (`IPV6_V6ONLY=0`), handling IPv4 and IPv6. Client connection uses Happy Eyeballs-style fallback. It resolves IPv6 and IPv4 candidates, starts one candidate, then staggers subsequent candidates every 250ms. UDP relay normalizes IPv4-mapped IPv6 addresses (`::ffff:x.x.x.x`) in routing tables. ## Async Hooks The data path uses `StreamReader` and `StreamWriter`. It is already async: `await reader.read()` -> process -> `await writer.drain()`. Async hooks allow: - `ChainRouter.on_connect` to `await client.connect()` directly. - `on_auth` to use async I/O. - No sync-to-async bridge in the relay path. The extra `await` is outside the main cost path. ## Design Decisions | Decision | Choice | Rationale | |----------|--------|-----------| | SOCKS version | SOCKS5 | Covers CONNECT and UDP ASSOCIATE | | Runtime deps | Zero | Stdlib only | | Addon model | Class-based, async | One class with multiple hooks gives natural state management; async matches the data path | | Config method | Python scripts | Addons are regular Python objects | | Hot reload | Not in kernel | Use an external watcher if needed | | Resource limits | Not in kernel | Use system-level limits | ## Topic Docs | Doc | Content | |-----|---------| | [`addon-model.md`](addon-model.md) | Hook API, execution models, built-in addons, chain proxying | | [`public-api.md`](public-api.md) | 1.x compatibility surface, root exports, hook contracts | ================================================ FILE: docs/architecture.zh-CN.md ================================================ # 架构与数据流 [README](../README.zh-CN.md) · [Addon recipes](addon-recipes.zh-CN.md) · [Addon 模型](addon-model.zh-CN.md) · [公共 API](public-api.zh-CN.md) · [English](architecture.md) 核心处理协议解析、中继和 hook 调度。Addon 处理策略和路由。链式代理、流量统计、访问控制都是 addon 行为。 ## 系统总览 ```text SOCKS5 Client Server Remote ───────────── ────── ────── auth negotiation ────▶ parse_method_selection parse_username_password dispatch_auth (competitive) │ CONNECT/UDP request ─▶ parse_request → create Flow dispatch_connect / dispatch_udp_associate ├─ no addon ─────────────────▶ direct connect └─ ChainRouter ──────────────▶ client.connect(next_hop) │ bidirectional relay ─▶ dispatch_data (pipeline) per chunk flow.bytes_up/down per chunk or datagram │ connection close ────▶ dispatch_flow_close (observational) log stats from Flow ``` 链式代理使用同一条路径。区别是 `dispatch_connect` 返回到下一跳 SOCKS5 server 的 `Connection`,而不是直连 TCP 连接。 ## 请求生命周期 每个请求经过三个阶段: | 阶段 | 入口 | 核心动作 | 输出 | |------|------|----------|------| | 握手 | SOCKS5 客户端 | 解析方法选择 + 认证 + 请求,创建 Flow | 含 src/dst/protocol 的 Flow | | 中继 | Flow + addon 决策 | 双向数据泵 + addon 管道,Flow 追踪字节 | 数据转发,字节计数 | | 拆解 | 连接关闭 | 记录统计,派发 `on_flow_close` | Addon 获得最终统计 | TCP 和 UDP 共享 hook 生命周期。TCP 使用配对的 `_copy()` 协程。UDP 使用共享 socket 和路由表。 ### TCP 中继数据流 ```text Client Server Target ────── ────── ────── │ │ │ │── handshake ─▶│ │ │ │ parse + auth + Flow │ │ │ │ │ │ dispatch_connect(flow) │ │ │ ├─ no addon ──▶ direct │ │ │ └─ ChainRouter ──▶ next hop │ │ │ │ │── data ──────▶│── _copy(client→target)───────────────▶│ │ │ dispatch_data(up, data, flow) │ │ │ flow.bytes_up += len(data) │ │ │ │ │◀─ response ───│◀── _copy(target→client) │ │ │ dispatch_data(down, data) │ │ │ flow.bytes_down += len(data) │ │ │ │ │── close ─────▶│── dispatch_flow_close(flow)──────────▶│ │ │ log: ↑1.2KB ↓45.6KB │ │ │ │ ``` ### UDP Relay 架构 一些 SOCKS5 实现为每个客户端源端口创建独立的出向 UDP socket。长时间运行时容易积累 socket。 本实现使用一个出向 socket 和双向路由表: ```text Outbound: Client datagram ──▶ shared_socket.sendto(payload, target) route_map[("93.184.216.34", 443)] = ("10.0.0.1", 54321) flow.bytes_up += len(payload) Inbound: shared_socket.recvfrom() ──▶ lookup route_map ──▶ sendto(client, response) flow.bytes_down += len(response) ``` 路由通过 TTL 过期淘汰。 ## UDP-over-TCP 链式 UDP 链式代理不在代理节点之间使用 UDP。节点间传输走 TCP。 入口节点把 UDP datagram 封装为 TCP frame。出口节点拆封后发出 UDP。 ```text Request: Client UDP ──▶ UdpOverTcpEntry ──▶ middle nodes ──▶ Exit server ──▶ raw UDP ──▶ Target encapsulate (TCP bytes) decapsulate UDP → TCP TCP → UDP Response: Target ──▶ raw UDP ──▶ Exit server ──▶ middle nodes ──▶ UdpOverTcpEntry ──▶ Client UDP encapsulate (TCP bytes) decapsulate UDP → TCP TCP → UDP ``` 性质: - 中间节点只转发 TCP CONNECT 流量。 - `on_data` 在 TCP 和 UDP-over-TCP 场景下都只看到 TCP bytes。 - UDP 语义只存在于 client-entry 和 exit-target 两段。 - 不需要逐跳维护 UDP ASSOCIATE 状态。 ## Flow Context `Flow` 是贯穿 hooks 的每连接上下文。 ```python @dataclass class Flow: id: int # 全局递增 ID src: Address # 客户端地址 dst: Address # 目标地址 protocol: Literal["tcp", "udp"] started_at: float # time.monotonic() bytes_up: int = 0 # 客户端→目标(TCP: post-addon; UDP: 原始载荷) bytes_down: int = 0 # 目标→客户端 ``` 没有 `Flow` 时,data hook 没有连接身份。字节计数也容易在 relay 和 addon 中重复。 有了 `Flow`: - 字节只在 relay 中计数一次。 - hook 在连接生命周期内收到同一个对象。 - `on_flow_close` 收到最终计数。 生命周期: ```text on_connect / on_udp_associate(flow) → addon registers connection, gets identity on_data(direction, data, flow) → addon knows whose data, can read live stats └─ relay updates flow.bytes_* directly on_flow_close(flow) → addon gets final snapshot, can log/aggregate ``` ## IPv6 双栈 服务端用一个 `AF_INET6` socket(`IPV6_V6ONLY=0`)监听 `::`,同时处理 IPv4 和 IPv6。 客户端连接使用 Happy Eyeballs 风格 fallback。解析 IPv6 和 IPv4 候选,启动一个候选,然后每 250ms 启动后续候选。快速失败不会终止后续候选。 UDP relay 在路由表中归一化 IPv4-mapped IPv6 地址(`::ffff:x.x.x.x`)。 ## Async Hooks 数据路径使用 `StreamReader` 和 `StreamWriter`。它本来就是 async:`await reader.read()` -> 处理 -> `await writer.drain()`。 Async hooks 允许: - `ChainRouter.on_connect` 直接 `await client.connect()`。 - `on_auth` 使用 async I/O。 - relay 路径不需要 sync-to-async 桥接。 额外一次 `await` 相比网络 I/O 不构成主要成本。 ## 设计决策 | 决策 | 选择 | 理由 | |------|------|------| | SOCKS 版本 | SOCKS5 | 覆盖 CONNECT 和 UDP ASSOCIATE | | 运行时依赖 | 零 | 仅标准库 | | Addon 模型 | 类式 + async | 一个类实现多个 hook,状态管理自然;async 匹配数据路径 | | 配置方式 | Python 脚本 | Addon 是普通 Python 对象 | | 热加载 | 内核不支持 | 需要时使用外部 watcher | | 资源限制 | 内核不处理 | 使用系统级限制 | ## 专题文档 | 文档 | 内容 | |------|------| | [`addon-model.md`](addon-model.md) | Hook API、执行模型、内置 addon、链式代理 | | [`public-api.zh-CN.md`](public-api.zh-CN.md) | 1.x 兼容面、包根导出、hook 契约 | ================================================ FILE: docs/public-api.md ================================================ # Public API [README](../README.md) · [Architecture](architecture.md) · [Addon recipes](addon-recipes.md) · [Addon model](addon-model.md) · [简体中文](public-api.zh-CN.md) This document defines the asyncio-socks-server 1.x compatibility surface. Stable imports live at the package root. Submodules remain importable. ## Compatibility Policy The package root is stable: ```python from asyncio_socks_server import Server, Addon, Address, connect ``` Within the 1.x series: - Root exports keep their names and broad behavior. - Addon hook signatures remain compatible. - `Flow` byte counters and address fields keep their meaning. - CLI flags remain backward-compatible. Modules under `asyncio_socks_server.core`, `asyncio_socks_server.server`, `asyncio_socks_server.client`, and `asyncio_socks_server.addons` are importable. Root exports are the compatibility contract. ## Root Exports | Name | Category | Purpose | |------|----------|---------| | `Server` | Server | SOCKS5 server entry point | | `connect` | Client | Open a TCP connection through a SOCKS5 proxy | | `Addon` | Addon base | Base class for optional async hooks | | `ChainRouter` | Addon | Route TCP CONNECT through a downstream SOCKS5 proxy | | `UdpOverTcpEntry` | Addon | Tunnel UDP ASSOCIATE traffic through a TCP exit service | | `UdpOverTcpExitServer` | Server | Exit service for UDP-over-TCP chaining | | `FlowAudit` | Addon | In-memory closed-flow usage audit collector | | `FlowStats` | Addon | In-memory flow statistics collector | | `StatsAPI` | Addon | Opt-in HTTP API backed by FlowStats | | `StatsServer` | Addon | Backward-compatible name for StatsAPI | | `TrafficCounter` | Addon | Aggregate closed-flow byte counters | | `FileAuth` | Addon | Username/password auth from JSON | | `IPFilter` | Addon | Source IP allow/block rules | | `Logger` | Addon | Connection and data logging | | `Address` | Type | Host/port pair | | `Flow` | Type | Per-connection context and byte counters | | `Direction` | Type | Data direction enum | | `Connection` | Type | Reader/writer pair returned by connection hooks | | `UdpRelayBase` | Type | Base interface for custom UDP relay addons | ## Server Contract ```python server = Server( host="::", port=1080, addons=[], auth=None, log_level="INFO", shutdown_timeout=30.0, ) server.run() ``` `Server.run()` owns the event loop and installs SIGINT/SIGTERM handlers. Internal coroutines are not part of the stable public API. Shutdown stops accepting new clients, waits for active client tasks, then calls addon `on_stop`. If `shutdown_timeout` is `None`, shutdown waits indefinitely for active clients. Otherwise unfinished tasks are cancelled after the timeout. ## Addon Contract All addon hooks are optional async methods. The hook models are: | Model | Hooks | Return contract | |-------|-------|-----------------| | Competitive | `on_auth`, `on_connect`, `on_udp_associate` | `None` abstains; non-`None` wins | | Pipeline | `on_data` | `bytes` continues; `None` drops the chunk | | Observational | `on_start`, `on_stop`, `on_flow_close`, `on_error` | Return value ignored | Exceptions in competitive hooks reject the current SOCKS operation. Exceptions in `on_flow_close` and `on_error` are suppressed. ## Flow Semantics `Flow` is shared across hooks for one TCP CONNECT or UDP ASSOCIATE lifecycle. ```python Flow( id=1, src=Address("127.0.0.1", 54321), dst=Address("example.com", 443), protocol="tcp", started_at=..., bytes_up=0, bytes_down=0, ) ``` Byte counters are maintained by the relay path, not by addons: - `bytes_up`: client to target, after TCP data pipeline processing - `bytes_down`: target to client - UDP counters count SOCKS5 UDP payload bytes, not UDP header bytes Addons should treat `Flow` as readable context. Mutating byte counters or addresses is unsupported. ## Stats API `FlowStats` is the stats infrastructure. It has no network side effects and exposes plain Python methods: | Method | Meaning | |--------|---------| | `snapshot()` | Aggregate counters and active flow snapshots | | `flows()` | Active flows and recent closed flow snapshots | | `active_flows()` | Active flow snapshots | | `recent_closed_flows()` | Retained closed flow snapshots | | `errors()` | Error counters observed through `on_error` | Use `FlowStats` to build an application-specific HTTP API, metrics exporter, or logging pipeline. Put it early in the addon list so it can observe flow starts before another competitive addon wins. `FlowAudit` is the usage audit infrastructure. It has no network side effects and aggregates closed-flow usage by source host and target host: | Method | Meaning | |--------|---------| | `snapshot()` | Kafra-like audit summary with period, records, totals, devices, traffic, and recent records | | `reset()` | Clear the in-memory audit window | The audit window is in-memory and process-local. Use a custom addon or sink if you need durable long-term storage. `StatsAPI` is the built-in opt-in HTTP presentation addon. It can either own its own `FlowStats` instance, or expose `FlowStats` and `FlowAudit` instances supplied by the application: ```python from asyncio_socks_server import FlowAudit, FlowStats, Server, StatsAPI audit = FlowAudit() stats = FlowStats() server = Server( addons=[ audit, stats, StatsAPI(stats=stats, audit=audit, host="127.0.0.1", port=9900), ], ) ``` | Endpoint | Meaning | |----------|---------| | `GET /health` | Liveness response | | `GET /stats` | `FlowStats.snapshot()` | | `GET /flows` | `FlowStats.flows()` | | `GET /errors` | `FlowStats.errors()` | | `GET /audit?top=25&device=` | `FlowAudit.snapshot()` | | `POST /audit/refresh?top=25&device=` | Current `FlowAudit.snapshot()` for Kafra-like refresh flows | `StatsServer` remains available as a backward-compatible name for `StatsAPI`. ## CLI Contract ```shell asyncio_socks_server --host :: --port 1080 --auth user:pass --log-level INFO ``` CLI mode starts a direct SOCKS5 server with optional single-user auth. Addons and advanced routing are configured from Python. ================================================ FILE: docs/public-api.zh-CN.md ================================================ # 公共 API [README](../README.zh-CN.md) · [架构](architecture.zh-CN.md) · [Addon recipes](addon-recipes.zh-CN.md) · [Addon 模型](addon-model.zh-CN.md) · [English](public-api.md) 本文定义 asyncio-socks-server 1.x 的兼容性边界。稳定导入面是包根。子模块仍可导入。 ## 兼容性策略 包根稳定: ```python from asyncio_socks_server import Server, Addon, Address, connect ``` 在 1.x 系列内: - 包根导出的名称和主要行为保持兼容。 - Addon hook 签名保持兼容。 - `Flow` 的字节计数和地址字段语义保持稳定。 - CLI 参数保持向后兼容。 `asyncio_socks_server.core`、`asyncio_socks_server.server`、 `asyncio_socks_server.client`、`asyncio_socks_server.addons` 下的模块可以导入。兼容性契约以包根导出为准。 ## 包根导出 | 名称 | 类别 | 用途 | |------|------|------| | `Server` | 服务端 | SOCKS5 server 入口 | | `connect` | 客户端 | 通过 SOCKS5 proxy 打开 TCP 连接 | | `Addon` | Addon 基类 | 可选 async hooks 的基类 | | `ChainRouter` | Addon | 将 TCP CONNECT 路由到下游 SOCKS5 proxy | | `UdpOverTcpEntry` | Addon | 将 UDP ASSOCIATE 流量封装到 TCP exit service | | `UdpOverTcpExitServer` | 服务端 | UDP-over-TCP 链式代理的出口服务 | | `FlowAudit` | Addon | 内存中的已关闭 flow 用量审计 collector | | `FlowStats` | Addon | 内存 flow 统计 collector | | `StatsAPI` | Addon | 基于 FlowStats 的显式 opt-in HTTP API | | `StatsServer` | Addon | `StatsAPI` 的向后兼容名称 | | `TrafficCounter` | Addon | 聚合已关闭 flow 的字节计数 | | `FileAuth` | Addon | 从 JSON 文件读取用户名/密码 | | `IPFilter` | Addon | 源 IP allow/block 规则 | | `Logger` | Addon | 连接和数据日志 | | `Address` | 类型 | host/port 二元组 | | `Flow` | 类型 | 每连接上下文和字节计数 | | `Direction` | 类型 | 数据方向枚举 | | `Connection` | 类型 | connection hook 返回的 reader/writer | | `UdpRelayBase` | 类型 | 自定义 UDP relay addon 的基础接口 | ## Server 契约 ```python server = Server( host="::", port=1080, addons=[], auth=None, log_level="INFO", shutdown_timeout=30.0, ) server.run() ``` `Server.run()` 接管 event loop,并安装 SIGINT/SIGTERM handler。内部 coroutine 不属于稳定公共 API。 Shutdown 会先停止接收新客户端,等待活跃 client task,再调用 addon `on_stop`。如果 `shutdown_timeout` 为 `None`,会无限等待活跃客户端;否则超时后 取消未完成 task。 ## Addon 契约 所有 addon hook 都是可选 async 方法。Hook 模型如下: | 模型 | Hooks | 返回值契约 | |------|-------|------------| | 竞争型 | `on_auth`, `on_connect`, `on_udp_associate` | `None` 表示弃权;非 `None` 获胜 | | 管道型 | `on_data` | `bytes` 继续;`None` 丢弃当前 chunk | | 观察型 | `on_start`, `on_stop`, `on_flow_close`, `on_error` | 返回值忽略 | 竞争型 hook 抛异常会拒绝当前 SOCKS 操作。`on_flow_close` 和 `on_error` 中的异常会被抑制。 ## Flow 语义 `Flow` 在一次 TCP CONNECT 或 UDP ASSOCIATE 生命周期内被所有 hook 共享。 ```python Flow( id=1, src=Address("127.0.0.1", 54321), dst=Address("example.com", 443), protocol="tcp", started_at=..., bytes_up=0, bytes_down=0, ) ``` 字节计数由 relay 路径维护,而不是由 addon 维护: - `bytes_up`:client 到 target,TCP 场景下为经过 data pipeline 后的字节 - `bytes_down`:target 到 client - UDP 计数统计 SOCKS5 UDP payload,不包含 UDP header Addon 应把 `Flow` 视为可读上下文。修改字节计数或地址字段不受支持。 ## Stats API `FlowStats` 是统计基础设施。它没有网络副作用,只暴露 Python 方法: | 方法 | 含义 | |------|------| | `snapshot()` | 聚合计数和活跃 flow 快照 | | `flows()` | 活跃 flow 和最近关闭 flow 快照 | | `active_flows()` | 活跃 flow 快照 | | `recent_closed_flows()` | 保留的关闭 flow 快照 | | `errors()` | 通过 `on_error` 观察到的错误计数 | 用 `FlowStats` 自行搭建应用需要的 HTTP API、metrics exporter 或日志管道。 建议把它放在 addon 列表靠前位置,这样它能在其他竞争型 addon 获胜前观察 flow start。 `FlowAudit` 是用量审计基础设施。它没有网络副作用,按 source host 和 target host 聚合已关闭 flow 的用量: | 方法 | 含义 | |------|------| | `snapshot()` | 类似 Kafra audit 的摘要,包含 period、records、total、devices、traffic 和 recent records | | `reset()` | 清空当前内存审计窗口 | 审计窗口是内存级、进程级的。如果需要长期留存,应使用自定义 addon 或 sink 做持久化。 `StatsAPI` 是内置的显式 opt-in HTTP 展示 addon。它可以自己托管 `FlowStats`,也可以暴露应用传入的 `FlowStats` 和 `FlowAudit`: ```python from asyncio_socks_server import FlowAudit, FlowStats, Server, StatsAPI audit = FlowAudit() stats = FlowStats() server = Server( addons=[ audit, stats, StatsAPI(stats=stats, audit=audit, host="127.0.0.1", port=9900), ], ) ``` | Endpoint | 含义 | |----------|------| | `GET /health` | 存活检查 | | `GET /stats` | `FlowStats.snapshot()` | | `GET /flows` | `FlowStats.flows()` | | `GET /errors` | `FlowStats.errors()` | | `GET /audit?top=25&device=` | `FlowAudit.snapshot()` | | `POST /audit/refresh?top=25&device=` | 返回当前 `FlowAudit.snapshot()`,用于类似 Kafra 的刷新流程 | `StatsServer` 作为 `StatsAPI` 的向后兼容名称保留。 ## CLI 契约 ```shell asyncio_socks_server --host :: --port 1080 --auth user:pass --log-level INFO ``` CLI 模式启动一个直连 SOCKS5 server,可选单用户认证。Addon 和高级路由通过 Python 配置。 ================================================ FILE: pyproject.toml ================================================ [project] name = "asyncio-socks-server" version = "1.3.1" description = "A SOCKS5 toolchain/framework with programmable addons" readme = "README.md" requires-python = ">=3.12" license = "MIT" authors = [ { name = "Amaindex", email = "amaindex@outlook.com" }, ] keywords = ["asyncio", "socks5", "proxy", "addon"] classifiers = [ "Development Status :: 5 - Production/Stable", "Framework :: AsyncIO", "Intended Audience :: System Administrators", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Topic :: Internet :: Proxy Servers", ] [project.urls] Homepage = "https://github.com/Amaindex/asyncio-socks-server" Repository = "https://github.com/Amaindex/asyncio-socks-server" Issues = "https://github.com/Amaindex/asyncio-socks-server/issues" [project.scripts] asyncio_socks_server = "asyncio_socks_server.cli:main" [build-system] requires = ["hatchling"] build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] packages = ["src/asyncio_socks_server"] [tool.ruff] target-version = "py312" line-length = 88 [tool.ruff.lint] select = ["E", "F", "I", "W"] [tool.pyright] pythonVersion = "3.12" typeCheckingMode = "basic" include = ["src"] [tool.pytest.ini_options] asyncio_mode = "auto" testpaths = ["tests"] [dependency-groups] dev = ["ruff>=0.11", "pytest>=8", "pytest-asyncio>=0.26", "pyright>=1.1"] ================================================ FILE: src/asyncio_socks_server/__init__.py ================================================ """asyncio-socks-server: A SOCKS5 toolchain/framework with programmable addons.""" from asyncio_socks_server.addons import ( Addon, ChainRouter, FileAuth, FlowAudit, FlowStats, IPFilter, Logger, StatsAPI, StatsServer, TrafficCounter, UdpOverTcpEntry, ) from asyncio_socks_server.client.client import connect from asyncio_socks_server.core.types import Address, Direction, Flow from asyncio_socks_server.server.connection import Connection from asyncio_socks_server.server.server import Server from asyncio_socks_server.server.udp_over_tcp_exit import UdpOverTcpExitServer from asyncio_socks_server.server.udp_relay import UdpRelayBase __all__ = [ "Addon", "Address", "ChainRouter", "Connection", "Direction", "FileAuth", "Flow", "FlowAudit", "FlowStats", "IPFilter", "Logger", "Server", "StatsAPI", "StatsServer", "TrafficCounter", "UdpOverTcpEntry", "UdpOverTcpExitServer", "UdpRelayBase", "connect", ] ================================================ FILE: src/asyncio_socks_server/__main__.py ================================================ from asyncio_socks_server.cli import main main() ================================================ FILE: src/asyncio_socks_server/addons/__init__.py ================================================ from asyncio_socks_server.addons.auth import FileAuth from asyncio_socks_server.addons.base import Addon from asyncio_socks_server.addons.chain import ChainRouter from asyncio_socks_server.addons.ip_filter import IPFilter from asyncio_socks_server.addons.logger import Logger from asyncio_socks_server.addons.stats import ( FlowAudit, FlowStats, StatsAPI, StatsServer, ) from asyncio_socks_server.addons.traffic import TrafficCounter from asyncio_socks_server.addons.udp_over_tcp_entry import UdpOverTcpEntry __all__ = [ "Addon", "ChainRouter", "FileAuth", "FlowAudit", "FlowStats", "IPFilter", "Logger", "StatsAPI", "StatsServer", "TrafficCounter", "UdpOverTcpEntry", ] ================================================ FILE: src/asyncio_socks_server/addons/auth.py ================================================ from __future__ import annotations import json from pathlib import Path from asyncio_socks_server.addons.base import Addon class FileAuth(Addon): """File-based username/password authentication. Reads a JSON file mapping usernames to passwords: {"user1": "pass1", "user2": "pass2"} """ def __init__(self, path: str | Path): self._path = Path(path) self._credentials: dict[str, str] = {} def _load(self) -> dict[str, str]: try: text = self._path.read_text(encoding="utf-8") return json.loads(text) except (OSError, json.JSONDecodeError): return {} async def on_auth(self, username: str, password: str) -> bool | None: if not self._credentials: self._credentials = self._load() if username in self._credentials: return self._credentials[username] == password return None ================================================ FILE: src/asyncio_socks_server/addons/base.py ================================================ from __future__ import annotations from typing import TYPE_CHECKING if TYPE_CHECKING: from asyncio_socks_server.core.types import Direction, Flow from asyncio_socks_server.server.connection import Connection from asyncio_socks_server.server.udp_relay import UdpRelayBase class Addon: """Base addon class. All hook methods are optional async methods. Competitive hooks use None to abstain and non-None values to take over. The on_data pipeline uses returned bytes as the outgoing payload and None to drop the current chunk. Exceptions reject or abort the current operation. """ async def on_start(self) -> None: """Called when the server starts.""" async def on_stop(self) -> None: """Called when the server stops.""" async def on_auth(self, username: str, password: str) -> bool | None: """Competitive: True=allow, False=deny, None=don't interfere.""" async def on_connect(self, flow: Flow) -> Connection | None: """Competitive: return Connection to intercept, None=don't interfere.""" async def on_udp_associate(self, flow: Flow) -> UdpRelayBase | None: """Competitive: return UdpRelayBase to intercept, None=don't interfere.""" async def on_data( self, direction: Direction, data: bytes, flow: Flow ) -> bytes | None: """Pipeline: return bytes to write, None=drop this chunk.""" async def on_flow_close(self, flow: Flow) -> None: """Observational: called when a flow (TCP or UDP) closes.""" async def on_error(self, error: Exception) -> None: """Observational: just notify, doesn't affect flow.""" ================================================ FILE: src/asyncio_socks_server/addons/chain.py ================================================ from __future__ import annotations from asyncio_socks_server.addons.base import Addon from asyncio_socks_server.client import client from asyncio_socks_server.core.types import Address, Flow from asyncio_socks_server.server.connection import Connection class ChainRouter(Addon): """Route connections through a SOCKS5 proxy chain. Each instance represents one hop. The addon connects to next_hop via SOCKS5 and tunnels the connection through it. """ def __init__( self, next_hop: str, username: str | None = None, password: str | None = None, ): host, _, port_str = next_hop.rpartition(":") self._proxy_addr = Address(host, int(port_str)) self._username = username self._password = password async def on_connect(self, flow: Flow) -> Connection | None: conn = await client.connect( proxy_addr=self._proxy_addr, target_addr=flow.dst, username=self._username, password=self._password, ) return conn ================================================ FILE: src/asyncio_socks_server/addons/ip_filter.py ================================================ from __future__ import annotations import ipaddress from asyncio_socks_server.addons.base import Addon from asyncio_socks_server.core.types import Flow class IPFilter(Addon): """Allow or deny connections based on source IP ranges. Either `allowed` or `blocked` can be provided (not both). If `allowed` is set, only listed IPs/ranges can connect. If `blocked` is set, listed IPs/ranges are denied. """ def __init__( self, allowed: list[str] | None = None, blocked: list[str] | None = None, ): self._allowed = [ipaddress.ip_network(n) for n in (allowed or [])] self._blocked = [ipaddress.ip_network(n) for n in (blocked or [])] def _is_allowed(self, host: str) -> bool: try: addr = ipaddress.ip_address(host) except ValueError: return False if self._blocked: return not any(addr in net for net in self._blocked) if self._allowed: return any(addr in net for net in self._allowed) return True async def on_connect(self, flow: Flow) -> None: if not self._is_allowed(flow.src.host): raise ConnectionRefusedError(f"IP blocked: {flow.src.host}") return None ================================================ FILE: src/asyncio_socks_server/addons/logger.py ================================================ from __future__ import annotations import logging from asyncio_socks_server.addons.base import Addon from asyncio_socks_server.core.logging import fmt_connection from asyncio_socks_server.core.types import Direction, Flow class Logger(Addon): """Detailed connection logging addon.""" def __init__(self): self._logger = logging.getLogger("asyncio_socks_server.addon.logger") async def on_connect(self, flow: Flow) -> None: self._logger.info(f"{fmt_connection(flow.src, flow.dst)} | on_connect") async def on_data( self, direction: Direction, data: bytes, flow: Flow ) -> bytes | None: self._logger.debug(f"{direction} | {len(data)} bytes") return data async def on_error(self, error: Exception) -> None: self._logger.warning(f"error: {error}") ================================================ FILE: src/asyncio_socks_server/addons/manager.py ================================================ from __future__ import annotations from typing import TYPE_CHECKING from .base import Addon if TYPE_CHECKING: from asyncio_socks_server.core.types import Direction, Flow from asyncio_socks_server.server.connection import Connection from asyncio_socks_server.server.udp_relay import UdpRelayBase def _is_overridden(addon: Addon, method_name: str) -> bool: base_method = getattr(Addon, method_name, None) return getattr(type(addon), method_name, None) is not base_method class AddonManager: def __init__(self, addons: list[Addon] | None = None): self._addons: list[Addon] = addons or [] # lifecycle async def dispatch_start(self) -> None: for addon in self._addons: if _is_overridden(addon, "on_start"): await addon.on_start() async def dispatch_stop(self) -> None: for addon in self._addons: if _is_overridden(addon, "on_stop"): await addon.on_stop() # competitive: first non-None wins async def dispatch_auth(self, username: str, password: str) -> bool | None: for addon in self._addons: if _is_overridden(addon, "on_auth"): result = await addon.on_auth(username, password) if result is not None: return result return None async def dispatch_connect(self, flow: Flow) -> Connection | None: for addon in self._addons: if _is_overridden(addon, "on_connect"): result = await addon.on_connect(flow) if result is not None: return result return None async def dispatch_udp_associate(self, flow: Flow) -> UdpRelayBase | None: for addon in self._addons: if _is_overridden(addon, "on_udp_associate"): result = await addon.on_udp_associate(flow) if result is not None: return result return None # pipeline: chain outputs async def dispatch_data( self, direction: Direction, data: bytes, flow: Flow ) -> bytes | None: current: bytes | None = data for addon in self._addons: if _is_overridden(addon, "on_data"): if current is None: break current = await addon.on_data(direction, current, flow) return current # observational: call all async def dispatch_flow_close(self, flow: Flow) -> None: for addon in self._addons: if _is_overridden(addon, "on_flow_close"): try: await addon.on_flow_close(flow) except Exception: pass async def dispatch_error(self, error: Exception) -> None: for addon in self._addons: if _is_overridden(addon, "on_error"): try: await addon.on_error(error) except Exception: pass # observational hooks must not disrupt ================================================ FILE: src/asyncio_socks_server/addons/stats.py ================================================ from __future__ import annotations import asyncio import json import time from collections import deque from dataclasses import asdict from datetime import UTC, datetime from typing import Any from urllib.parse import parse_qs, urlsplit from asyncio_socks_server.addons.base import Addon from asyncio_socks_server.core.types import Flow class FlowStats(Addon): """Flow statistics collector with no network side effects. FlowStats is the reusable stats infrastructure. It implements addon hooks, keeps in-memory flow counters, and exposes plain Python snapshot methods. Applications can attach their own HTTP API, metrics exporter, file writer, or any other presentation layer around it. """ def __init__( self, max_closed_flows: int = 100, max_recent_errors: int = 50, ) -> None: self.max_closed_flows = max_closed_flows self.max_recent_errors = max_recent_errors self._started_at = time.monotonic() self._started_wall_at = time.time() self._active: dict[int, Flow] = {} self._seen_flow_ids: set[int] = set() self._closed: deque[dict[str, Any]] = deque(maxlen=max_closed_flows) self._recent_errors: deque[dict[str, Any]] = deque(maxlen=max_recent_errors) self.total_flows = 0 self.total_tcp_flows = 0 self.total_udp_flows = 0 self.total_closed_flows = 0 self.closed_bytes_up = 0 self.closed_bytes_down = 0 self.total_errors = 0 self.errors_by_type: dict[str, int] = {} self._last_total_sample_at = self._started_at self._last_total_bytes_up = 0 self._last_total_bytes_down = 0 self._upload_rate = 0.0 self._download_rate = 0.0 self._flow_rates: dict[int, dict[str, float]] = {} async def on_connect(self, flow: Flow) -> None: self._track_flow(flow) async def on_udp_associate(self, flow: Flow) -> None: self._track_flow(flow) async def on_flow_close(self, flow: Flow) -> None: if flow.id not in self._seen_flow_ids: self._track_flow(flow) self._sample_flow_rate(flow) self._active.pop(flow.id, None) self._closed.append(self._flow_snapshot(flow, state="closed")) self.total_closed_flows += 1 self.closed_bytes_up += flow.bytes_up self.closed_bytes_down += flow.bytes_down self._flow_rates.pop(flow.id, None) async def on_error(self, error: Exception) -> None: name = type(error).__name__ self.total_errors += 1 self.errors_by_type[name] = self.errors_by_type.get(name, 0) + 1 self._recent_errors.append( { "type": name, "message": str(error), "at": self._format_wall_time(time.time()), } ) def snapshot(self) -> dict[str, Any]: """Return aggregate counters plus active flow snapshots.""" self._sample_rates() active_bytes_up = sum(flow.bytes_up for flow in self._active.values()) active_bytes_down = sum(flow.bytes_down for flow in self._active.values()) return { "started_at": self._format_wall_time(self._started_wall_at), "uptime_seconds": self._duration(self._started_at), "active_flows": len(self._active), "closed_flows": len(self._closed), "recent_closed_flows": len(self._closed), "total_closed_flows": self.total_closed_flows, "total_flows": self.total_flows, "total_tcp_flows": self.total_tcp_flows, "total_udp_flows": self.total_udp_flows, "active_bytes_up": active_bytes_up, "active_bytes_down": active_bytes_down, "closed_bytes_up": self.closed_bytes_up, "closed_bytes_down": self.closed_bytes_down, "total_bytes_up": self.closed_bytes_up + active_bytes_up, "total_bytes_down": self.closed_bytes_down + active_bytes_down, "upload_rate": self._upload_rate, "download_rate": self._download_rate, "errors": self.errors(), "active": self._active_flow_snapshots(), } def active_flows(self) -> list[dict[str, Any]]: """Return active flow snapshots.""" self._sample_rates() return self._active_flow_snapshots() def _active_flow_snapshots(self) -> list[dict[str, Any]]: return [ self._flow_snapshot(flow, state="active") for flow in self._active.values() ] def recent_closed_flows(self) -> list[dict[str, Any]]: """Return retained closed flow snapshots.""" return list(self._closed) def flows(self) -> dict[str, Any]: """Return active and retained closed flow snapshots.""" return { "active": self.active_flows(), "recent_closed": self.recent_closed_flows(), } def errors(self) -> dict[str, Any]: """Return error counters observed through on_error.""" return { "total": self.total_errors, "by_type": dict(sorted(self.errors_by_type.items())), "recent": list(self._recent_errors), } def _track_flow(self, flow: Flow) -> None: self._active[flow.id] = flow if flow.id in self._seen_flow_ids: return self._seen_flow_ids.add(flow.id) self._flow_rates[flow.id] = { "sample_at": time.monotonic(), "bytes_up": float(flow.bytes_up), "bytes_down": float(flow.bytes_down), "upload_rate": 0.0, "download_rate": 0.0, } self.total_flows += 1 if flow.protocol == "tcp": self.total_tcp_flows += 1 else: self.total_udp_flows += 1 def _flow_snapshot(self, flow: Flow, state: str) -> dict[str, Any]: rates = self._flow_rates.get(flow.id, {}) return { "id": flow.id, "state": state, "src": asdict(flow.src), "dst": asdict(flow.dst), "protocol": flow.protocol, "started_at": self._format_wall_time(flow.started_wall_at), "age_seconds": self._duration(flow.started_at), "bytes_up": flow.bytes_up, "bytes_down": flow.bytes_down, "upload_rate": rates.get("upload_rate", 0.0), "download_rate": rates.get("download_rate", 0.0), } def _sample_rates(self) -> None: for flow in self._active.values(): self._sample_flow_rate(flow) now = time.monotonic() active_bytes_up = sum(flow.bytes_up for flow in self._active.values()) active_bytes_down = sum(flow.bytes_down for flow in self._active.values()) total_bytes_up = self.closed_bytes_up + active_bytes_up total_bytes_down = self.closed_bytes_down + active_bytes_down elapsed = now - self._last_total_sample_at if elapsed > 0: self._upload_rate = (total_bytes_up - self._last_total_bytes_up) / elapsed self._download_rate = ( total_bytes_down - self._last_total_bytes_down ) / elapsed self._last_total_sample_at = now self._last_total_bytes_up = total_bytes_up self._last_total_bytes_down = total_bytes_down def _sample_flow_rate(self, flow: Flow) -> None: now = time.monotonic() sample = self._flow_rates.setdefault( flow.id, { "sample_at": now, "bytes_up": float(flow.bytes_up), "bytes_down": float(flow.bytes_down), "upload_rate": 0.0, "download_rate": 0.0, }, ) elapsed = now - sample["sample_at"] if elapsed > 0: sample["upload_rate"] = (flow.bytes_up - sample["bytes_up"]) / elapsed sample["download_rate"] = (flow.bytes_down - sample["bytes_down"]) / elapsed sample["sample_at"] = now sample["bytes_up"] = float(flow.bytes_up) sample["bytes_down"] = float(flow.bytes_down) @staticmethod def _format_wall_time(timestamp: float) -> str: return datetime.fromtimestamp(timestamp, UTC).isoformat().replace("+00:00", "Z") @staticmethod def _duration(started_at: float) -> float: return round(time.monotonic() - started_at, 6) class FlowAudit(Addon): """Closed-flow traffic audit collector with no network side effects.""" def __init__(self, max_recent_records: int = 100) -> None: self.max_recent_records = max_recent_records self._recent: deque[dict[str, Any]] = deque(maxlen=max_recent_records) self._devices: dict[str, dict[str, Any]] = {} self._traffic: dict[str, dict[str, Any]] = {} self._period_start: float | None = None self._period_end: float | None = None self.records = 0 self.skipped = 0 self.total_upload = 0 self.total_download = 0 async def on_flow_close(self, flow: Flow) -> None: self._record(flow) def snapshot( self, top: int = 25, device: str | None = None, ) -> dict[str, Any]: """Return a Kafra-like traffic audit summary.""" top = max(1, min(top, 100)) devices = self._sorted_totals(self._devices.values(), top, device) traffic = self._sorted_totals(self._traffic.values(), top) generated_at = self._format_wall_time(time.time()) return { "status": "ready" if self.records else "empty", "generated_at": generated_at, "period_start": self._format_optional_time(self._period_start), "period_end": self._format_optional_time(self._period_end), "duration_ms": 0, "records": self.records, "skipped": self.skipped, "total": { "upload": self.total_upload, "download": self.total_download, "total": self.total_upload + self.total_download, }, "devices": devices, "traffic": traffic, "recent": list(self._recent), } def reset(self) -> None: """Clear in-memory audit state.""" self._recent.clear() self._devices.clear() self._traffic.clear() self._period_start = None self._period_end = None self.records = 0 self.skipped = 0 self.total_upload = 0 self.total_download = 0 def _record(self, flow: Flow) -> None: upload = flow.bytes_up download = flow.bytes_down total = upload + download started_at = flow.started_wall_at ended_at = time.time() self._period_start = ( started_at if self._period_start is None else min(self._period_start, started_at) ) self._period_end = ended_at self.records += 1 self.total_upload += upload self.total_download += download device = flow.src.host destination = flow.dst.host self._add_total(self._devices, device, "device", upload, download) self._add_total(self._traffic, destination, "domain", upload, download) self._recent.append( { "id": flow.id, "src": asdict(flow.src), "dst": asdict(flow.dst), "protocol": flow.protocol, "started_at": self._format_wall_time(started_at), "ended_at": self._format_wall_time(ended_at), "upload": upload, "download": download, "total": total, } ) @staticmethod def _add_total( totals: dict[str, dict[str, Any]], key: str, label: str, upload: int, download: int, ) -> None: item = totals.setdefault( key, { label: key, "upload": 0, "download": 0, "total": 0, }, ) item["upload"] += upload item["download"] += download item["total"] += upload + download @staticmethod def _sorted_totals( items: Any, top: int, device: str | None = None, ) -> list[dict[str, Any]]: out = [dict(item) for item in items] if device: out = [item for item in out if item.get("device") == device] return sorted(out, key=lambda item: item["total"], reverse=True)[:top] @classmethod def _format_optional_time(cls, timestamp: float | None) -> str: if timestamp is None: return "" return cls._format_wall_time(timestamp) @staticmethod def _format_wall_time(timestamp: float) -> str: return datetime.fromtimestamp(timestamp, UTC).isoformat().replace("+00:00", "Z") class StatsAPI(Addon): """Opt-in HTTP API backed by FlowStats. StatsAPI starts an HTTP listener only when explicitly added to a Server. When constructed without a FlowStats instance, it owns one and forwards flow hooks into it. When constructed with an existing FlowStats instance, it acts only as a presentation layer so applications can compose both addons without double-counting flows. """ def __init__( self, host: str = "127.0.0.1", port: int = 0, max_closed_flows: int = 100, stats: FlowStats | None = None, audit: FlowAudit | None = None, ) -> None: self.host = host self.port = port self.max_closed_flows = max_closed_flows self.stats = stats or FlowStats(max_closed_flows=max_closed_flows) self._owns_stats = stats is None self.audit = audit self._server: asyncio.AbstractServer | None = None async def on_start(self) -> None: self._server = await asyncio.start_server( self._handle_http, self.host, self.port, ) sock = self._server.sockets[0] if self._server.sockets else None if sock is not None: self.port = sock.getsockname()[1] async def on_stop(self) -> None: if self._server is None: return self._server.close() await self._server.wait_closed() self._server = None async def on_connect(self, flow: Flow) -> None: if self._owns_stats: await self.stats.on_connect(flow) async def on_udp_associate(self, flow: Flow) -> None: if self._owns_stats: await self.stats.on_udp_associate(flow) async def on_flow_close(self, flow: Flow) -> None: if self._owns_stats: await self.stats.on_flow_close(flow) async def on_error(self, error: Exception) -> None: if self._owns_stats: await self.stats.on_error(error) def snapshot(self) -> dict[str, Any]: return self.stats.snapshot() async def _handle_http( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, ) -> None: try: line = await reader.readline() method, target, _ = line.decode("ascii", errors="replace").split(" ", 2) parsed = urlsplit(target) path = parsed.path query = parse_qs(parsed.query) while True: header = await reader.readline() if header in (b"\r\n", b"\n", b""): break if method == "POST" and path == "/audit/refresh": await self._write_audit(writer, query) elif method != "GET": await self._write_json(writer, 405, {"error": "method not allowed"}) elif path == "/health": await self._write_json(writer, 200, {"ok": True}) elif path == "/stats": await self._write_json(writer, 200, self.stats.snapshot()) elif path == "/flows": await self._write_json(writer, 200, self.stats.flows()) elif path == "/errors": await self._write_json(writer, 200, self.stats.errors()) elif path == "/audit": await self._write_audit(writer, query) else: await self._write_json(writer, 404, {"error": "not found"}) except (ConnectionError, OSError, ValueError): pass finally: try: writer.close() await writer.wait_closed() except (ConnectionError, OSError): pass async def _write_json( self, writer: asyncio.StreamWriter, status: int, payload: dict[str, Any], ) -> None: reason = { 200: "OK", 404: "Not Found", 405: "Method Not Allowed", }.get(status, "Error") body = json.dumps(payload, separators=(",", ":")).encode("utf-8") writer.write( f"HTTP/1.1 {status} {reason}\r\n" "Content-Type: application/json\r\n" f"Content-Length: {len(body)}\r\n" "Connection: close\r\n" "\r\n".encode("ascii") + body ) await writer.drain() async def _write_audit( self, writer: asyncio.StreamWriter, query: dict[str, list[str]], ) -> None: if self.audit is None: await self._write_json(writer, 404, {"error": "audit disabled"}) return await self._write_json( writer, 200, self.audit.snapshot( top=self._int_query(query, "top", 25), device=self._str_query(query, "device"), ), ) @staticmethod def _int_query(query: dict[str, list[str]], name: str, default: int) -> int: values = query.get(name) if not values: return default try: return int(values[0]) except ValueError: return default @staticmethod def _str_query(query: dict[str, list[str]], name: str) -> str | None: values = query.get(name) if not values or not values[0]: return None return values[0] class StatsServer(StatsAPI): """Backward-compatible name for StatsAPI.""" ================================================ FILE: src/asyncio_socks_server/addons/traffic.py ================================================ from __future__ import annotations from asyncio_socks_server.addons.base import Addon from asyncio_socks_server.core.types import Flow class TrafficCounter(Addon): """Count bytes flowing through the proxy (TCP and UDP).""" def __init__(self): self.bytes_up: int = 0 self.bytes_down: int = 0 self.connections: int = 0 async def on_connect(self, flow: Flow) -> None: self.connections += 1 async def on_flow_close(self, flow: Flow) -> None: self.bytes_up += flow.bytes_up self.bytes_down += flow.bytes_down ================================================ FILE: src/asyncio_socks_server/addons/udp_over_tcp_entry.py ================================================ from __future__ import annotations import asyncio from asyncio_socks_server.addons.base import Addon from asyncio_socks_server.core.protocol import build_udp_header, parse_udp_header from asyncio_socks_server.core.types import Address, Flow from asyncio_socks_server.server.udp_over_tcp import encode_udp_frame, read_udp_frame from asyncio_socks_server.server.udp_relay import UdpRelayBase class UdpOverTcpEntry(Addon): """Route UDP ASSOCIATE through a downstream SOCKS5 proxy via UDP-over-TCP. The bridge connects to next_hop via SOCKS5 TCP CONNECT, then tunnels SOCKS5 UDP datagrams as length-prefixed TCP frames. """ def __init__( self, next_hop: str, username: str | None = None, password: str | None = None, ): host, _, port_str = next_hop.rpartition(":") self._proxy_addr = Address(host, int(port_str)) self._username = username self._password = password async def on_udp_associate(self, flow: Flow) -> UdpRelayBase | None: return _Bridge(self._proxy_addr, self._username, self._password, flow) class _Bridge(UdpRelayBase): """UDP-over-TCP bridge: client-side UDP ↔ TCP frames ↔ downstream proxy.""" def __init__( self, proxy_addr, username: str | None, password: str | None, flow: Flow, ): self._proxy_addr = proxy_addr self._username = username self._password = password self._tcp_reader: asyncio.StreamReader | None = None self._tcp_writer: asyncio.StreamWriter | None = None self._client_transport: asyncio.DatagramTransport | None = None self._pump_task: asyncio.Task | None = None self._route_map: dict[tuple[str, int], tuple[str, int]] = {} self._flow = flow async def start(self) -> Address: # Open a plain TCP connection to the downstream proxy. # We don't use SOCKS5 handshake here — this is just a raw TCP # connection that the downstream UdpOverTcpExit server accepts. self._tcp_reader, self._tcp_writer = await asyncio.open_connection( self._proxy_addr.host, self._proxy_addr.port ) sock = self._tcp_writer.get_extra_info("socket") sockname = sock.getsockname() if sock else ("::", 0) # Start the TCP→client pump self._pump_task = asyncio.create_task(self._tcp_to_client()) return Address(sockname[0], sockname[1]) def set_client_transport(self, transport: asyncio.DatagramTransport) -> None: self._client_transport = transport async def stop(self) -> None: if self._pump_task: self._pump_task.cancel() try: await self._pump_task except asyncio.CancelledError: pass if self._tcp_writer: try: self._tcp_writer.close() await self._tcp_writer.wait_closed() except (ConnectionError, OSError): pass def handle_client_datagram(self, data: bytes, client_addr: tuple[str, int]) -> None: if not self._tcp_writer: return try: dst, _, payload = parse_udp_header(data) except Exception: return if not payload: return # Record route: remote → client remote_key = (dst.host, dst.port) self._route_map[remote_key] = client_addr self._flow.bytes_up += len(payload) # Send as TCP frame (async but fire-and-forget via task) async def _send(): try: frame = await encode_udp_frame(dst, payload) self._tcp_writer.write(frame) # type: ignore[union-attr] await self._tcp_writer.drain() # type: ignore[union-attr] except (ConnectionError, OSError): pass asyncio.create_task(_send()) async def _tcp_to_client(self) -> None: try: while True: src_addr, payload = await read_udp_frame(self._tcp_reader) # type: ignore[arg-type] self._flow.bytes_down += len(payload) # Find the client that sent to this remote remote_key = (src_addr.host, src_addr.port) client_addr = self._route_map.get(remote_key) if client_addr is None: continue header = build_udp_header(src_addr) packet = header + payload if self._client_transport: self._client_transport.sendto(packet, client_addr) except (asyncio.IncompleteReadError, ConnectionError, OSError): pass ================================================ FILE: src/asyncio_socks_server/cli.py ================================================ from __future__ import annotations import argparse from asyncio_socks_server.server.server import Server def main() -> None: parser = argparse.ArgumentParser( prog="asyncio_socks_server", description="A SOCKS5 proxy server with programmable addons", ) parser.add_argument("--host", default="::", help="bind address") parser.add_argument("--port", type=int, default=1080, help="bind port") parser.add_argument( "--auth", default=None, help="username:password for authentication", ) parser.add_argument( "--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="logging level", ) args = parser.parse_args() auth = None if args.auth: user, _, passwd = args.auth.partition(":") auth = (user, passwd) server = Server( host=args.host, port=args.port, auth=auth, log_level=args.log_level, ) server.run() ================================================ FILE: src/asyncio_socks_server/client/__init__.py ================================================ ================================================ FILE: src/asyncio_socks_server/client/client.py ================================================ from __future__ import annotations import asyncio import socket from itertools import zip_longest from asyncio_socks_server.core.address import decode_address, encode_address from asyncio_socks_server.core.protocol import ProtocolError from asyncio_socks_server.core.types import Address, AuthMethod, Rep from asyncio_socks_server.server.connection import Connection HAPPY_EYEBALLS_DELAY = 0.25 async def connect( proxy_addr: Address, target_addr: Address, username: str | None = None, password: str | None = None, ) -> Connection: """Connect to target through a SOCKS5 proxy using Happy Eyeballs.""" reader, writer = await _happy_eyeballs_connect(proxy_addr) try: await _negotiate(reader, writer, username, password) await _request_connect(reader, writer, target_addr) sock = writer.get_extra_info("socket") sockname = sock.getsockname() if sock else ("0.0.0.0", 0) return Connection( reader=reader, writer=writer, address=Address(sockname[0], sockname[1]), ) except Exception: writer.close() raise async def _happy_eyeballs_connect( addr: Address, ) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: """Happy Eyeballs-style fallback with staggered IPv6/IPv4 candidates.""" loop = asyncio.get_running_loop() ipv4_hosts: list[str] = [] ipv6_hosts: list[str] = [] try: results = await loop.getaddrinfo(addr.host, addr.port, type=socket.SOCK_STREAM) for family, _, _, _, sockaddr in results: if family == socket.AF_INET6: ipv6_hosts.append(sockaddr[0]) elif family == socket.AF_INET: ipv4_hosts.append(sockaddr[0]) except socket.gaierror: ipv4_hosts = [addr.host] candidates: list[tuple[str, int]] = [] for ipv6_host, ipv4_host in zip_longest(ipv6_hosts, ipv4_hosts): if ipv6_host is not None: candidates.append((ipv6_host, addr.port)) if ipv4_host is not None: candidates.append((ipv4_host, addr.port)) if not candidates: raise ConnectionError(f"cannot resolve {addr.host}") if len(candidates) == 1: return await asyncio.open_connection(candidates[0][0], candidates[0][1]) pending: set[asyncio.Task[tuple[asyncio.StreamReader, asyncio.StreamWriter]]] = ( set() ) errors: list[BaseException] = [] next_candidate = 0 def start_next_candidate() -> None: nonlocal next_candidate if next_candidate >= len(candidates): return host, port = candidates[next_candidate] next_candidate += 1 pending.add(loop.create_task(asyncio.open_connection(host, port))) async def cancel_pending() -> None: for task in pending: task.cancel() for task in pending: try: await task except (asyncio.CancelledError, Exception): pass start_next_candidate() while pending: timeout = HAPPY_EYEBALLS_DELAY if next_candidate < len(candidates) else None done_tasks, pending_tasks = await asyncio.wait( pending, timeout=timeout, return_when=asyncio.FIRST_COMPLETED ) pending = set(pending_tasks) if not done_tasks: start_next_candidate() continue for task in done_tasks: try: result = task.result() except Exception as exc: errors.append(exc) else: await cancel_pending() return result if not pending: start_next_candidate() message = f"all connection attempts failed to {addr.host}:{addr.port}" if errors: raise ConnectionError(message) from errors[0] raise ConnectionError(message) async def _negotiate( reader: asyncio.StreamReader, writer: asyncio.StreamWriter, username: str | None, password: str | None, ) -> None: if username and password: writer.write(b"\x05\x01\x02") else: writer.write(b"\x05\x01\x00") await writer.drain() resp = await reader.readexactly(2) if resp[0] != 0x05: raise ProtocolError(f"unsupported SOCKS version: {resp[0]}") if resp[1] == AuthMethod.NO_AUTH: return if resp[1] == AuthMethod.USERNAME_PASSWORD and username and password: uname = username.encode("utf-8") passwd = password.encode("utf-8") writer.write( b"\x01" + len(uname).to_bytes(1, "big") + uname + len(passwd).to_bytes(1, "big") + passwd ) await writer.drain() auth_resp = await reader.readexactly(2) if auth_resp[1] != 0x00: raise ProtocolError("authentication failed") elif resp[1] == AuthMethod.NO_ACCEPTABLE: raise ProtocolError("no acceptable auth method") else: raise ProtocolError(f"unsupported auth method: {resp[1]}") async def _request_connect( reader: asyncio.StreamReader, writer: asyncio.StreamWriter, target: Address, ) -> None: writer.write(b"\x05\x01\x00" + encode_address(target.host, target.port)) await writer.drain() reply = await reader.readexactly(3) if reply[0] != 0x05: raise ProtocolError(f"unsupported SOCKS version: {reply[0]}") if reply[1] != Rep.SUCCEEDED: raise ProtocolError(f"connect failed with rep={reply[1]:#04x}") # Read bound address await decode_address(reader) ================================================ FILE: src/asyncio_socks_server/core/__init__.py ================================================ ================================================ FILE: src/asyncio_socks_server/core/address.py ================================================ from __future__ import annotations import asyncio import ipaddress import struct from ipaddress import IPv4Address, IPv6Address from .types import Address, Atyp, Rep def detect_atyp(host: str) -> Atyp: try: IPv4Address(host) return Atyp.IPV4 except ValueError: pass try: IPv6Address(host) return Atyp.IPV6 except ValueError: pass return Atyp.DOMAIN def encode_address(host: str, port: int) -> bytes: atyp = detect_atyp(host) if atyp == Atyp.IPV4: ADDR = ipaddress.IPv4Address(host).packed elif atyp == Atyp.IPV6: ADDR = ipaddress.IPv6Address(host).packed else: encoded = host.encode("ascii") ADDR = bytes([len(encoded)]) + encoded ATYP = atyp.to_bytes(1, "big") PORT = struct.pack("!H", port) return ATYP + ADDR + PORT async def decode_address(reader: asyncio.StreamReader) -> Address: ATYP = Atyp((await reader.readexactly(1))[0]) if ATYP == Atyp.IPV4: DST_ADDR = ipaddress.IPv4Address(await reader.readexactly(4)).compressed elif ATYP == Atyp.IPV6: DST_ADDR = ipaddress.IPv6Address(await reader.readexactly(16)).compressed elif ATYP == Atyp.DOMAIN: length = (await reader.readexactly(1))[0] DST_ADDR = (await reader.readexactly(length)).decode("ascii") else: raise ValueError(f"unsupported ATYP: {ATYP}") DST_PORT = struct.unpack("!H", await reader.readexactly(2))[0] return Address(DST_ADDR, DST_PORT) def encode_reply( rep: Rep, bind_host: str = "0.0.0.0", bind_port: int = 0, ) -> bytes: VER = b"\x05" REP = rep.to_bytes(1, "big") RSV = b"\x00" return VER + REP + RSV + encode_address(bind_host, bind_port) ================================================ FILE: src/asyncio_socks_server/core/logging.py ================================================ from __future__ import annotations import logging from .types import Address FORMAT = "%(asctime)s | %(levelname)-8s | %(message)s" def setup_logging(level: str = "INFO") -> None: logging.basicConfig( format=FORMAT, level=getattr(logging, level.upper()), force=True, ) def get_logger() -> logging.Logger: return logging.getLogger("asyncio_socks_server") def fmt_addr(addr: Address) -> str: return str(addr) def fmt_connection(src: Address, dst: Address) -> str: return f"{src} → {dst}" def fmt_bytes(n: int) -> str: if n < 1024: return f"{n}B" if n < 1024 * 1024: return f"{n / 1024:.1f}KB" return f"{n / (1024 * 1024):.1f}MB" ================================================ FILE: src/asyncio_socks_server/core/protocol.py ================================================ from __future__ import annotations import asyncio import ipaddress import struct from .types import Address, Cmd class ProtocolError(Exception): pass def parse_method_selection(data: bytes) -> tuple[int, set[int]]: if len(data) < 2: raise ProtocolError("method selection too short") VER = data[0] NMETHODS = data[1] if VER != 0x05: raise ProtocolError(f"unsupported SOCKS version: {VER}") METHODS = set(data[2 : 2 + NMETHODS]) return VER, METHODS def build_method_reply(method: int) -> bytes: VER = b"\x05" METHOD = method.to_bytes(1, "big") return VER + METHOD async def parse_username_password( reader: asyncio.StreamReader, ) -> tuple[str, str]: VER = (await reader.readexactly(1))[0] if VER != 0x01: raise ProtocolError(f"unsupported auth version: {VER}") ULEN = (await reader.readexactly(1))[0] UNAME = (await reader.readexactly(ULEN)).decode("utf-8") PLEN = (await reader.readexactly(1))[0] PASSWD = (await reader.readexactly(PLEN)).decode("utf-8") return UNAME, PASSWD def build_auth_reply(success: bool) -> bytes: VER = b"\x01" STATUS = b"\x00" if success else b"\x01" return VER + STATUS async def parse_request(reader: asyncio.StreamReader) -> tuple[Cmd, Address]: VER, CMD, RSV, ATYP_BYTE = await reader.readexactly(4) if VER != 0x05: raise ProtocolError(f"unsupported SOCKS version: {VER}") try: cmd = Cmd(CMD) except ValueError: raise ProtocolError(f"unsupported command: {CMD}") from None if ATYP_BYTE == 0x01: # IPv4 host = ipaddress.IPv4Address(await reader.readexactly(4)).compressed elif ATYP_BYTE == 0x04: # IPv6 host = ipaddress.IPv6Address(await reader.readexactly(16)).compressed elif ATYP_BYTE == 0x03: # Domain length = (await reader.readexactly(1))[0] host = (await reader.readexactly(length)).decode("ascii") else: raise ProtocolError(f"unsupported ATYP: {ATYP_BYTE}") DST_PORT = struct.unpack("!H", await reader.readexactly(2))[0] return cmd, Address(host, DST_PORT) def parse_udp_header(data: bytes) -> tuple[Address, int, bytes]: """Parse SOCKS5 UDP request header. Returns (dst_address, header_length, payload). """ if len(data) < 4: raise ProtocolError("UDP header too short") # RSV(2) + FRAG(1) skipped — we don't support fragmentation ATYP_BYTE = data[3] if ATYP_BYTE == 0x01: if len(data) < 10: raise ProtocolError("UDP header truncated (IPv4)") host = ipaddress.IPv4Address(data[4:8]).compressed DST_PORT = struct.unpack("!H", data[8:10])[0] header_length = 10 elif ATYP_BYTE == 0x04: if len(data) < 22: raise ProtocolError("UDP header truncated (IPv6)") host = ipaddress.IPv6Address(data[4:20]).compressed DST_PORT = struct.unpack("!H", data[20:22])[0] header_length = 22 elif ATYP_BYTE == 0x03: length = data[4] if len(data) < 5 + length + 2: raise ProtocolError("UDP header truncated (domain)") host = data[5 : 5 + length].decode("ascii") DST_PORT = struct.unpack("!H", data[5 + length : 5 + length + 2])[0] header_length = 5 + length + 2 else: raise ProtocolError(f"unsupported ATYP: {ATYP_BYTE}") return Address(host, DST_PORT), header_length, data[header_length:] def build_udp_header(address: Address) -> bytes: RSV = b"\x00\x00" FRAG = b"\x00" from .address import encode_address return RSV + FRAG + encode_address(address.host, address.port) ================================================ FILE: src/asyncio_socks_server/core/socket.py ================================================ from __future__ import annotations import ipaddress import socket def _is_ipv6(host: str) -> bool: try: ipaddress.IPv6Address(host) return True except ValueError: return False def create_dualstack_tcp_socket(host: str, port: int) -> socket.socket: """Create a TCP server socket with dual-stack (IPv4+IPv6) support.""" if host in ("", "::"): return socket.create_server( ("::", port), family=socket.AF_INET6, dualstack_ipv6=True ) if host == "0.0.0.0": return socket.create_server((host, port), family=socket.AF_INET) if _is_ipv6(host): return socket.create_server((host, port), family=socket.AF_INET6) return socket.create_server((host, port)) def create_dualstack_udp_socket(host: str, port: int = 0) -> socket.socket: """Create a UDP socket with dual-stack support.""" if host in ("0.0.0.0", "", "::"): sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) sock.bind(("::", port)) elif _is_ipv6(host): sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) sock.bind((host, port)) else: sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.bind((host, port)) return sock ================================================ FILE: src/asyncio_socks_server/core/types.py ================================================ from __future__ import annotations import time from dataclasses import dataclass, field from enum import IntEnum, StrEnum from typing import Literal class Rep(IntEnum): """RFC 1928 reply codes.""" SUCCEEDED = 0x00 GENERAL_FAILURE = 0x01 CONNECTION_NOT_ALLOWED = 0x02 NETWORK_UNREACHABLE = 0x03 HOST_UNREACHABLE = 0x04 CONNECTION_REFUSED = 0x05 TTL_EXPIRED = 0x06 COMMAND_NOT_SUPPORTED = 0x07 ADDRESS_TYPE_NOT_SUPPORTED = 0x08 class AuthMethod(IntEnum): """SOCKS5 authentication methods.""" NO_AUTH = 0x00 USERNAME_PASSWORD = 0x02 NO_ACCEPTABLE = 0xFF class Cmd(IntEnum): """SOCKS5 commands.""" CONNECT = 0x01 UDP_ASSOCIATE = 0x03 class Atyp(IntEnum): """SOCKS5 address types.""" IPV4 = 0x01 DOMAIN = 0x03 IPV6 = 0x04 class Direction(StrEnum): """Data flow direction.""" UPSTREAM = "upstream" DOWNSTREAM = "downstream" @dataclass(frozen=True) class Address: host: str port: int def __str__(self) -> str: return f"{self.host}:{self.port}" @dataclass class Flow: """Per-connection context carried through the hook lifecycle.""" id: int src: Address dst: Address protocol: Literal["tcp", "udp"] started_at: float # time.monotonic() started_wall_at: float = field(default_factory=time.time) bytes_up: int = 0 # TCP: post-addon; UDP: raw payload (no addon pipeline) bytes_down: int = 0 ================================================ FILE: src/asyncio_socks_server/py.typed ================================================ ================================================ FILE: src/asyncio_socks_server/server/__init__.py ================================================ ================================================ FILE: src/asyncio_socks_server/server/connection.py ================================================ from __future__ import annotations import asyncio from dataclasses import dataclass from asyncio_socks_server.core.types import Address @dataclass class Connection: reader: asyncio.StreamReader writer: asyncio.StreamWriter address: Address ================================================ FILE: src/asyncio_socks_server/server/server.py ================================================ from __future__ import annotations import asyncio import ipaddress import itertools import socket import time from asyncio_socks_server.addons.base import Addon from asyncio_socks_server.addons.manager import AddonManager from asyncio_socks_server.core.address import encode_reply from asyncio_socks_server.core.logging import fmt_bytes, fmt_connection, get_logger from asyncio_socks_server.core.protocol import ( build_auth_reply, build_method_reply, parse_method_selection, parse_request, parse_username_password, ) from asyncio_socks_server.core.socket import ( create_dualstack_tcp_socket, create_dualstack_udp_socket, ) from asyncio_socks_server.core.types import Address, AuthMethod, Cmd, Flow, Rep from asyncio_socks_server.server.connection import Connection from asyncio_socks_server.server.tcp_relay import handle_tcp_relay from asyncio_socks_server.server.udp_relay import UdpRelay, UdpRelayBase class Server: def __init__( self, host: str = "::", port: int = 1080, addons: list[Addon] | None = None, auth: tuple[str, str] | None = None, log_level: str = "INFO", shutdown_timeout: float | None = 30.0, ): self.host = host self.port = port self.auth = auth self.log_level = log_level self.shutdown_timeout = shutdown_timeout self._addon_manager = AddonManager(addons) self._shutdown_event = asyncio.Event() self._flow_seq = itertools.count(1) self._client_tasks: set[asyncio.Task] = set() def run(self) -> None: asyncio.run(self._run()) def _install_signal_handlers(self) -> None: import signal loop = asyncio.get_running_loop() def _signal_handler(): self.request_shutdown() for sig in (signal.SIGTERM, signal.SIGINT): loop.add_signal_handler(sig, _signal_handler) async def _run(self) -> None: from asyncio_socks_server.core.logging import setup_logging setup_logging(self.log_level) logger = get_logger() await self._addon_manager.dispatch_start() self._install_signal_handlers() sock = create_dualstack_tcp_socket(self.host, self.port) sock.setblocking(False) srv = await asyncio.start_server( self._handle_client, sock=sock, ) addr = srv.sockets[0].getsockname() self.port = addr[1] logger.info(f"server started on {self.host}:{self.port}") try: await self._shutdown_event.wait() finally: srv.close() await srv.wait_closed() await self._wait_for_client_tasks() await self._addon_manager.dispatch_stop() logger.info("server stopped") async def _wait_for_client_tasks(self) -> None: if not self._client_tasks: return tasks = set(self._client_tasks) try: if self.shutdown_timeout is None: await asyncio.gather(*tasks, return_exceptions=True) else: await asyncio.wait_for( asyncio.gather(*tasks, return_exceptions=True), timeout=self.shutdown_timeout, ) except TimeoutError: for task in tasks: if not task.done(): task.cancel() await asyncio.gather(*tasks, return_exceptions=True) async def _handle_client( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: task = asyncio.current_task() if task is not None: self._client_tasks.add(task) try: await self._do_handshake_and_relay(reader, writer) except Exception as e: await self._addon_manager.dispatch_error(e) finally: try: writer.close() await writer.wait_closed() except (ConnectionError, OSError): pass if task is not None: self._client_tasks.discard(task) async def _do_handshake_and_relay( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: header = await reader.readexactly(2) version, method_count = header[0], header[1] method_data = await reader.readexactly(method_count) _, methods = parse_method_selection( bytes([version, method_count]) + method_data ) if self.auth: if AuthMethod.USERNAME_PASSWORD not in methods: writer.write(build_method_reply(AuthMethod.NO_ACCEPTABLE)) await writer.drain() return writer.write(build_method_reply(AuthMethod.USERNAME_PASSWORD)) await writer.drain() username, password = await parse_username_password(reader) auth_result = await self._addon_manager.dispatch_auth(username, password) if auth_result is not None: success = auth_result else: success = username == self.auth[0] and password == self.auth[1] writer.write(build_auth_reply(success)) await writer.drain() if not success: return else: if AuthMethod.NO_AUTH not in methods: writer.write(build_method_reply(AuthMethod.NO_ACCEPTABLE)) await writer.drain() return writer.write(build_method_reply(AuthMethod.NO_AUTH)) await writer.drain() cmd, dst = await parse_request(reader) peername = writer.get_extra_info("peername") src = Address(peername[0], peername[1]) if peername else Address("::", 0) if cmd == Cmd.CONNECT: await self._handle_connect(reader, writer, src, dst) elif cmd == Cmd.UDP_ASSOCIATE: await self._handle_udp_associate(reader, writer, src, dst) else: writer.write(encode_reply(Rep.COMMAND_NOT_SUPPORTED)) await writer.drain() async def _handle_connect( self, client_reader: asyncio.StreamReader, client_writer: asyncio.StreamWriter, src: Address, dst: Address, ) -> None: logger = get_logger() flow = Flow( id=next(self._flow_seq), src=src, dst=dst, protocol="tcp", started_at=time.monotonic(), ) conn: Connection | None = None connected = False try: try: addon_result = await self._addon_manager.dispatch_connect(flow) except Exception as e: logger.error(f"{fmt_connection(src, dst)} | addon error: {e}") client_writer.write(encode_reply(Rep.CONNECTION_NOT_ALLOWED)) await client_writer.drain() return if addon_result is not None and isinstance(addon_result, Connection): conn = addon_result else: try: remote_reader, remote_writer = await asyncio.open_connection( dst.host, dst.port ) sock = remote_writer.get_extra_info("socket") sockname = sock.getsockname() if sock else ("::", 0) conn = Connection( reader=remote_reader, writer=remote_writer, address=Address(sockname[0], sockname[1]), ) except (ConnectionError, OSError) as e: logger.error(f"{fmt_connection(src, dst)} | {e}") rep = self._error_to_rep(e) client_writer.write(encode_reply(rep)) await client_writer.drain() return client_writer.write( encode_reply(Rep.SUCCEEDED, conn.address.host, conn.address.port) ) await client_writer.drain() connected = True logger.info(f"{fmt_connection(src, dst)} | connected") await handle_tcp_relay( client_reader, client_writer, conn.reader, conn.writer, self._addon_manager, flow, ) finally: if connected: elapsed = time.monotonic() - flow.started_at logger.info( f"{fmt_connection(src, dst)} | " f"closed {elapsed:.1f}s " f"↑{fmt_bytes(flow.bytes_up)} ↓{fmt_bytes(flow.bytes_down)}" ) await self._addon_manager.dispatch_flow_close(flow) async def _handle_udp_associate( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, src: Address, dst: Address, ) -> None: logger = get_logger() flow = Flow( id=next(self._flow_seq), src=src, dst=dst, protocol="udp", started_at=time.monotonic(), ) try: relay: UdpRelayBase = ( await self._addon_manager.dispatch_udp_associate(flow) ) or UdpRelay(client_addr=src, flow=flow) except Exception as e: logger.error(f"{fmt_connection(src, dst)} | addon error: {e}") writer.write(encode_reply(Rep.CONNECTION_NOT_ALLOWED)) await writer.drain() return client_transport: asyncio.DatagramTransport | None = None reply_sent = False try: await relay.start() loop = asyncio.get_running_loop() client_udp_sock = _create_client_udp_socket(src.host) client_udp_sock.setblocking(False) client_transport, _ = await loop.create_datagram_endpoint( lambda: _ClientUdpProtocol(relay), sock=client_udp_sock, ) client_sock = client_transport.get_extra_info("socket") fallback = ("::", 0) client_sockname = client_sock.getsockname() if client_sock else fallback client_bind = Address(client_sockname[0], client_sockname[1]) relay.set_client_transport(client_transport) writer.write( encode_reply(Rep.SUCCEEDED, client_bind.host, client_bind.port) ) await writer.drain() reply_sent = True logger.info(f"{fmt_connection(src, dst)} | udp associate started") await reader.read() except Exception as e: logger.error(f"{fmt_connection(src, dst)} | udp associate error: {e}") await self._addon_manager.dispatch_error(e) if not reply_sent: try: writer.write(encode_reply(Rep.GENERAL_FAILURE)) await writer.drain() except (ConnectionError, OSError): pass finally: await relay.stop() if client_transport and not client_transport.is_closing(): client_transport.close() logger.info( f"{fmt_connection(src, dst)} | " f"udp closed ↑{fmt_bytes(flow.bytes_up)} ↓{fmt_bytes(flow.bytes_down)}" ) await self._addon_manager.dispatch_flow_close(flow) @staticmethod def _error_to_rep(exc: Exception) -> Rep: if isinstance(exc, ConnectionRefusedError): return Rep.CONNECTION_REFUSED if isinstance(exc, OSError) and exc.errno == 101: return Rep.NETWORK_UNREACHABLE return Rep.GENERAL_FAILURE def request_shutdown(self) -> None: self._shutdown_event.set() def _create_client_udp_socket(host: str) -> socket.socket: try: ipaddress.IPv6Address(host) except ValueError: sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.bind(("0.0.0.0", 0)) return sock return create_dualstack_udp_socket("::", 0) class _ClientUdpProtocol(asyncio.DatagramProtocol): def __init__(self, relay: UdpRelayBase) -> None: self._relay = relay def connection_made(self, transport: asyncio.DatagramTransport) -> None: pass def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: self._relay.handle_client_datagram(data, addr) def error_received(self, exc: Exception) -> None: pass ================================================ FILE: src/asyncio_socks_server/server/tcp_relay.py ================================================ from __future__ import annotations import asyncio from asyncio_socks_server.addons.manager import AddonManager from asyncio_socks_server.core.logging import get_logger from asyncio_socks_server.core.types import Direction, Flow async def _copy( reader: asyncio.StreamReader, writer: asyncio.StreamWriter, addon_manager: AddonManager, direction: Direction, flow: Flow, ) -> None: try: while True: data = await reader.read(4096) if not data: break result = await addon_manager.dispatch_data(direction, data, flow) if result is None: continue writer.write(result) await writer.drain() n = len(result) if direction == Direction.UPSTREAM: flow.bytes_up += n else: flow.bytes_down += n except (ConnectionError, asyncio.CancelledError): pass finally: try: writer.close() await writer.wait_closed() except (ConnectionError, OSError): pass async def handle_tcp_relay( client_reader: asyncio.StreamReader, client_writer: asyncio.StreamWriter, remote_reader: asyncio.StreamReader, remote_writer: asyncio.StreamWriter, addon_manager: AddonManager, flow: Flow, ) -> None: """Bidirectional TCP relay with addon on_data pipeline.""" try: async with asyncio.TaskGroup() as tg: tg.create_task( _copy( client_reader, remote_writer, addon_manager, Direction.UPSTREAM, flow, ) ) tg.create_task( _copy( remote_reader, client_writer, addon_manager, Direction.DOWNSTREAM, flow, ) ) except ExceptionGroup as eg: get_logger().debug(f"tcp relay task group ended: {eg.exceptions}") ================================================ FILE: src/asyncio_socks_server/server/udp_over_tcp.py ================================================ from __future__ import annotations import asyncio import struct from asyncio_socks_server.core.address import decode_address, encode_address from asyncio_socks_server.core.types import Address async def encode_udp_frame(address: Address, data: bytes) -> bytes: """Encode a UDP datagram as a TCP frame. Frame format: [4-byte length][ATYP+ADDR+PORT][payload] """ addr_bytes = encode_address(address.host, address.port) payload = addr_bytes + data length = struct.pack("!I", len(payload)) return length + payload async def read_udp_frame( reader: asyncio.StreamReader, ) -> tuple[Address, bytes]: """Read a UDP-over-TCP frame from a stream. Returns (target_address, payload). """ length_bytes = await reader.readexactly(4) length = struct.unpack("!I", length_bytes)[0] payload = await reader.readexactly(length) # Parse address from the beginning of payload atyp_byte = payload[0] if atyp_byte == 0x01: # IPv4 addr_len = 1 + 4 + 2 # ATYP + IPv4 + PORT elif atyp_byte == 0x04: # IPv6 addr_len = 1 + 16 + 2 elif atyp_byte == 0x03: # Domain domain_len = payload[1] addr_len = 1 + 1 + domain_len + 2 else: raise ValueError(f"unsupported ATYP: {atyp_byte}") addr_payload = payload[:addr_len] data = payload[addr_len:] # Decode address addr_reader = asyncio.StreamReader() addr_reader.feed_data(addr_payload) addr_reader.feed_eof() address = await decode_address(addr_reader) return address, data ================================================ FILE: src/asyncio_socks_server/server/udp_over_tcp_exit.py ================================================ from __future__ import annotations import asyncio import time from asyncio_socks_server.core.logging import get_logger from asyncio_socks_server.core.socket import ( create_dualstack_tcp_socket, create_dualstack_udp_socket, ) from asyncio_socks_server.core.types import Address from asyncio_socks_server.server.udp_over_tcp import encode_udp_frame, read_udp_frame def _normalize_host(host: str) -> str: if host.startswith("::ffff:"): return host[7:] return host def _map_addr_for_sendto( host: str, port: int ) -> tuple[str, int] | tuple[str, int, int, int]: import ipaddress try: addr = ipaddress.ip_address(host) if isinstance(addr, ipaddress.IPv4Address): return (f"::ffff:{host}", port, 0, 0) except ValueError: pass return (host, port) class UdpOverTcpExitServer: """Accepts TCP connections carrying UDP-over-TCP frames and relays to raw UDP. Used as the exit node in a UDP-over-TCP chain. Not an addon — it's a standalone TCP service that sits at the chain endpoint. """ def __init__(self, host: str = "::", port: int = 0, ttl: float = 300.0): self.host = host self.port = port self._ttl = ttl self._shutdown_event = asyncio.Event() def run(self) -> None: asyncio.run(self._run()) def request_shutdown(self) -> None: self._shutdown_event.set() async def _run(self) -> None: logger = get_logger() sock = create_dualstack_udp_socket("0.0.0.0", 0) sock.setblocking(False) loop = asyncio.get_running_loop() udp_transport: asyncio.DatagramTransport | None = None route_map: dict[tuple[str, int], asyncio.StreamWriter] = {} route_ts: dict[tuple[str, int], float] = {} # Shared outbound UDP socket class UdpProtocol(asyncio.DatagramProtocol): def connection_made(self, transport: asyncio.DatagramTransport) -> None: nonlocal udp_transport udp_transport = transport def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: remote_key = (_normalize_host(addr[0]), addr[1]) writer = route_map.get(remote_key) if writer is None: return route_ts[remote_key] = time.monotonic() src_addr = Address(_normalize_host(addr[0]), addr[1]) task = asyncio.create_task(_send_frame(writer, src_addr, data)) task.add_done_callback( lambda t: t.exception() if not t.cancelled() else None ) def error_received(self, exc: Exception) -> None: pass async def _send_frame( writer: asyncio.StreamWriter, src_addr: Address, data: bytes ) -> None: try: frame = await encode_udp_frame(src_addr, data) writer.write(frame) await writer.drain() except (ConnectionError, OSError): pass _, _ = await loop.create_datagram_endpoint(UdpProtocol, sock=sock) # TTL cleanup task async def _ttl_cleanup(): while True: await asyncio.sleep(60) now = time.monotonic() expired = [k for k, ts in route_ts.items() if now - ts > self._ttl] for k in expired: route_map.pop(k, None) route_ts.pop(k, None) ttl_task = asyncio.create_task(_ttl_cleanup()) # TCP server tcp_sock = create_dualstack_tcp_socket(self.host, self.port) tcp_sock.setblocking(False) tcp_srv = await asyncio.start_server( lambda r, w: _handle_tcp(r, w, udp_transport, route_map, route_ts), sock=tcp_sock, ) tcp_sockname = tcp_srv.sockets[0].getsockname() self.port = tcp_sockname[1] logger.info(f"udp-over-tcp exit started on {self.host}:{self.port}") try: await self._shutdown_event.wait() finally: ttl_task.cancel() try: await ttl_task except asyncio.CancelledError: pass tcp_srv.close() await tcp_srv.wait_closed() if udp_transport: udp_transport.close() logger.info("udp-over-tcp exit stopped") async def _handle_tcp( reader: asyncio.StreamReader, writer: asyncio.StreamWriter, udp_transport: asyncio.DatagramTransport | None, route_map: dict[tuple[str, int], asyncio.StreamWriter], route_ts: dict[tuple[str, int], float], ) -> None: try: while True: dst_addr, payload = await read_udp_frame(reader) remote_key = (dst_addr.host, dst_addr.port) route_map[remote_key] = writer route_ts[remote_key] = time.monotonic() if udp_transport: udp_transport.sendto( payload, _map_addr_for_sendto(dst_addr.host, dst_addr.port) ) except (asyncio.IncompleteReadError, ConnectionError, OSError): pass finally: try: writer.close() await writer.wait_closed() except (ConnectionError, OSError): pass ================================================ FILE: src/asyncio_socks_server/server/udp_relay.py ================================================ from __future__ import annotations import asyncio import time from typing import Callable from asyncio_socks_server.core.protocol import build_udp_header, parse_udp_header from asyncio_socks_server.core.socket import create_dualstack_udp_socket from asyncio_socks_server.core.types import Address, Flow def _normalize_host(host: str) -> str: """Strip IPv4-mapped IPv6 prefix for consistent routing table keys.""" if host.startswith("::ffff:"): return host[7:] return host def _map_addr_for_sendto( host: str, port: int ) -> tuple[str, int] | tuple[str, int, int, int]: """Return an address tuple suitable for the outbound socket's family. AF_INET6 sockets require IPv4-mapped format (::ffff:x.x.x.x) for IPv4 targets. """ import ipaddress try: addr = ipaddress.ip_address(host) if isinstance(addr, ipaddress.IPv4Address): return (f"::ffff:{host}", port, 0, 0) except ValueError: pass return (host, port) class UdpRelayBase: """Interface for UDP relay handlers used by the server and addon system.""" async def start(self) -> Address: raise NotImplementedError def set_client_transport(self, transport: asyncio.DatagramTransport) -> None: raise NotImplementedError async def stop(self) -> None: raise NotImplementedError def handle_client_datagram(self, data: bytes, client_addr: tuple[str, int]) -> None: raise NotImplementedError class UdpRelay(UdpRelayBase): """UDP relay using a shared outbound socket + bidirectional routing table. All clients share one outbound UDP socket. A routing table maps remote addresses back to client addresses for response routing. Entries expire after TTL seconds of inactivity. """ def __init__(self, client_addr: Address, flow: Flow, ttl: float = 300.0): self._client_addr = client_addr self._ttl = ttl self._transport: asyncio.DatagramTransport | None = None self._route_map: dict[tuple[str, int], tuple[str, int]] = {} self._route_timestamps: dict[tuple[str, int], float] = {} self._ttl_task: asyncio.Task | None = None self._client_transport: asyncio.DatagramTransport | None = None self._bind_addr: Address | None = None self._flow = flow async def start(self) -> Address: loop = asyncio.get_running_loop() outbound_sock = create_dualstack_udp_socket("0.0.0.0", 0) outbound_sock.setblocking(False) transport, _ = await loop.create_datagram_endpoint( lambda: _UdpProtocol(self._on_remote_data), sock=outbound_sock, ) self._transport = transport sock = transport.get_extra_info("socket") sockname = sock.getsockname() if sock else ("::", 0) self._bind_addr = Address(sockname[0], sockname[1]) self._ttl_task = asyncio.create_task(self._ttl_cleanup_loop()) return self._bind_addr def set_client_transport(self, transport: asyncio.DatagramTransport) -> None: self._client_transport = transport async def stop(self) -> None: if self._ttl_task: self._ttl_task.cancel() try: await self._ttl_task except asyncio.CancelledError: pass if self._transport: self._transport.close() def handle_client_datagram(self, data: bytes, client_addr: tuple[str, int]) -> None: try: dst, _, payload = parse_udp_header(data) except Exception: return if not payload: return remote_key = (dst.host, dst.port) self._route_map[remote_key] = client_addr self._route_timestamps[remote_key] = time.monotonic() if self._transport: self._transport.sendto(payload, _map_addr_for_sendto(dst.host, dst.port)) self._flow.bytes_up += len(payload) def _on_remote_data(self, data: bytes, remote_addr: tuple[str, int]) -> None: self._flow.bytes_down += len(data) remote_key = (_normalize_host(remote_addr[0]), remote_addr[1]) client_addr = self._route_map.get(remote_key) if client_addr is None: return self._route_timestamps[remote_key] = time.monotonic() # Build SOCKS5 UDP reply header src_addr = Address(_normalize_host(remote_addr[0]), remote_addr[1]) header = build_udp_header(src_addr) packet = header + data if self._client_transport: self._client_transport.sendto(packet, client_addr) async def _ttl_cleanup_loop(self) -> None: while True: await asyncio.sleep(60) now = time.monotonic() expired = [ key for key, ts in self._route_timestamps.items() if now - ts > self._ttl ] for key in expired: self._route_map.pop(key, None) self._route_timestamps.pop(key, None) class _UdpProtocol(asyncio.DatagramProtocol): def __init__(self, on_data: Callable[[bytes, tuple[str, int]], None]) -> None: self._on_data = on_data self._transport: asyncio.DatagramTransport | None = None def connection_made(self, transport: asyncio.DatagramTransport) -> None: self._transport = transport def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: self._on_data(data, addr) def error_received(self, exc: Exception) -> None: pass ================================================ FILE: tests/__init__.py ================================================ ================================================ FILE: tests/conftest.py ================================================ import asyncio import pytest from asyncio_socks_server.core.types import Address from asyncio_socks_server.server.server import Server @pytest.fixture async def echo_server(): """TCP echo server for testing.""" async def handler(reader, writer): try: while True: data = await reader.read(4096) if not data: break writer.write(data) await writer.drain() finally: writer.close() await writer.wait_closed() srv = await asyncio.start_server(handler, "127.0.0.1", 0) addr = srv.sockets[0].getsockname() yield Address(addr[0], addr[1]) srv.close() await srv.wait_closed() @pytest.fixture async def udp_echo_server(): """UDP echo server for testing.""" received = [] class Protocol(asyncio.DatagramProtocol): def connection_made(self, transport): self.transport = transport def datagram_received(self, data, addr): received.append((data, addr)) self.transport.sendto(data, addr) loop = asyncio.get_running_loop() transport, _ = await loop.create_datagram_endpoint( Protocol, local_addr=("127.0.0.1", 0) ) sock = transport.get_extra_info("socket") sockname = sock.getsockname() if sock else ("127.0.0.1", 0) yield Address(sockname[0], sockname[1]), received transport.close() async def _start_server(**kwargs): server = Server(host="127.0.0.1", port=0, **kwargs) task = asyncio.create_task(server._run()) for _ in range(50): if server.port != 0: break await asyncio.sleep(0.01) return server, task async def _stop_server(server, task): server.request_shutdown() await task ================================================ FILE: tests/e2e_helpers.py ================================================ import asyncio import ipaddress import struct from asyncio_socks_server.core.address import encode_address from asyncio_socks_server.core.types import Address async def socks5_connect( proxy: Address, target: Address, auth: tuple[str, str] | None = None, ) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: reader, writer = await asyncio.open_connection(proxy.host, proxy.port) writer.write(b"\x05\x01\x02" if auth else b"\x05\x01\x00") await writer.drain() resp = await reader.readexactly(2) assert resp[0] == 0x05 if auth is None: assert resp[1] == 0x00 else: assert resp[1] == 0x02 username, password = auth uname = username.encode() passwd = password.encode() writer.write( b"\x01" + len(uname).to_bytes(1, "big") + uname + len(passwd).to_bytes(1, "big") + passwd ) await writer.drain() assert await reader.readexactly(2) == b"\x01\x00" writer.write(b"\x05\x01\x00" + encode_address(target.host, target.port)) await writer.drain() return reader, writer async def read_socks_reply(reader: asyncio.StreamReader) -> bytes: reply = await reader.readexactly(3) atyp = (await reader.readexactly(1))[0] if atyp == 0x01: await reader.readexactly(4 + 2) elif atyp == 0x04: await reader.readexactly(16 + 2) elif atyp == 0x03: length = (await reader.readexactly(1))[0] await reader.readexactly(length + 2) return reply async def open_udp_associate( proxy: Address, ) -> tuple[asyncio.StreamReader, asyncio.StreamWriter, Address]: reader, writer = await asyncio.open_connection(proxy.host, proxy.port) writer.write(b"\x05\x01\x00") await writer.drain() assert await reader.readexactly(2) == b"\x05\x00" writer.write(b"\x05\x03\x00" + encode_address("0.0.0.0", 0)) await writer.drain() reply = await reader.readexactly(3) assert reply[1] == 0x00 atyp = (await reader.readexactly(1))[0] if atyp == 0x01: bind_host = ipaddress.IPv4Address(await reader.readexactly(4)).compressed elif atyp == 0x04: bind_host = str(ipaddress.IPv6Address(await reader.readexactly(16))) else: length = (await reader.readexactly(1))[0] bind_host = (await reader.readexactly(length)).decode("ascii") bind_port = struct.unpack("!H", await reader.readexactly(2))[0] return reader, writer, Address(bind_host, bind_port) ================================================ FILE: tests/test_addon_builtins.py ================================================ import json import pytest from asyncio_socks_server.addons.auth import FileAuth from asyncio_socks_server.addons.ip_filter import IPFilter from asyncio_socks_server.addons.logger import Logger from asyncio_socks_server.core.types import Address, Direction, Flow def _make_flow(**kwargs): defaults = dict( id=1, src=Address("127.0.0.1", 0), dst=Address("0.0.0.0", 0), protocol="tcp", started_at=0.0, ) defaults.update(kwargs) return Flow(**defaults) class TestFileAuth: async def test_valid_credentials(self, tmp_path): cred_file = tmp_path / "creds.json" cred_file.write_text(json.dumps({"admin": "secret", "user": "pass"})) auth = FileAuth(cred_file) assert await auth.on_auth("admin", "secret") is True assert await auth.on_auth("admin", "wrong") is False async def test_unknown_user(self, tmp_path): cred_file = tmp_path / "creds.json" cred_file.write_text(json.dumps({"admin": "secret"})) auth = FileAuth(cred_file) assert await auth.on_auth("unknown", "any") is None async def test_missing_file(self, tmp_path): auth = FileAuth(tmp_path / "nonexistent.json") assert await auth.on_auth("any", "any") is None class TestIPFilter: async def test_blocked(self): f = IPFilter(blocked=["10.0.0.0/8", "192.168.1.1"]) # 172.16.0.1 is NOT blocked → returns None result = await f.on_connect(_make_flow(src=Address("172.16.0.1", 0))) assert result is None # 10.0.0.5 IS blocked → raises with pytest.raises(ConnectionRefusedError): await f.on_connect(_make_flow(src=Address("10.0.0.5", 0))) async def test_allowed_list(self): f = IPFilter(allowed=["127.0.0.0/8"]) # 127.0.0.1 should be allowed (returns None) result = await f.on_connect(_make_flow()) assert result is None async def test_not_in_allowed(self): f = IPFilter(allowed=["127.0.0.0/8"]) with pytest.raises(ConnectionRefusedError): await f.on_connect(_make_flow(src=Address("10.0.0.1", 0))) async def test_no_rules(self): f = IPFilter() result = await f.on_connect(_make_flow(src=Address("10.0.0.1", 0))) assert result is None class TestLogger: async def test_on_connect(self): logger = Logger() result = await logger.on_connect( _make_flow(src=Address("127.0.0.1", 1080), dst=Address("example.com", 80)) ) assert result is None async def test_on_data(self): logger = Logger() flow = _make_flow() result = await logger.on_data(Direction.UPSTREAM, b"hello", flow) assert result == b"hello" async def test_on_error(self): logger = Logger() await logger.on_error(ValueError("test")) ================================================ FILE: tests/test_addon_builtins_extended.py ================================================ """Extended tests for built-in addons: FileAuth, IPFilter, Logger.""" import json import pytest from asyncio_socks_server.addons.auth import FileAuth from asyncio_socks_server.addons.ip_filter import IPFilter from asyncio_socks_server.addons.logger import Logger from asyncio_socks_server.core.types import Address, Direction, Flow def _make_flow(**kwargs): defaults = dict( id=1, src=Address("127.0.0.1", 0), dst=Address("0.0.0.0", 0), protocol="tcp", started_at=0.0, ) defaults.update(kwargs) return Flow(**defaults) class TestFileAuthExtended: async def test_corrupted_json_file(self, tmp_path): bad_file = tmp_path / "auth.json" bad_file.write_text("not json at all {{{") auth = FileAuth(str(bad_file)) result = await auth.on_auth("user", "pass") assert result is None async def test_empty_json_file(self, tmp_path): empty_file = tmp_path / "auth.json" empty_file.write_text("{}") auth = FileAuth(str(empty_file)) result = await auth.on_auth("user", "pass") assert result is None async def test_credentials_cached_after_first_load(self, tmp_path): auth_file = tmp_path / "auth.json" auth_file.write_text(json.dumps({"user": "pass"})) auth = FileAuth(str(auth_file)) # First load result = await auth.on_auth("user", "pass") assert result is True # Modify file auth_file.write_text("{}") # Should still use cached credentials result = await auth.on_auth("user", "pass") assert result is True async def test_unicode_credentials(self, tmp_path): auth_file = tmp_path / "auth.json" auth_file.write_text(json.dumps({"用户": "密码"})) auth = FileAuth(str(auth_file)) result = await auth.on_auth("用户", "密码") assert result is True async def test_unknown_user(self, tmp_path): auth_file = tmp_path / "auth.json" auth_file.write_text(json.dumps({"admin": "secret"})) auth = FileAuth(str(auth_file)) result = await auth.on_auth("unknown", "pass") assert result is None class TestIPFilterExtended: async def test_ipv6_blocked(self): filt = IPFilter(blocked=["::1/128"]) with pytest.raises(ConnectionRefusedError, match="IP blocked"): await filt.on_connect( _make_flow(src=Address("::1", 1234), dst=Address("1.2.3.4", 80)) ) async def test_ipv6_allowed(self): filt = IPFilter(allowed=["::1/128", "127.0.0.1/32"]) # 127.0.0.1 should be allowed (returns None) result = await filt.on_connect( _make_flow(src=Address("127.0.0.1", 1234), dst=Address("1.2.3.4", 80)) ) assert result is None async def test_domain_source_falls_back(self): filt = IPFilter(blocked=["10.0.0.0/8"]) # Domain source host — ip_address will raise ValueError try: await filt.on_connect( _make_flow(src=Address("example.com", 1234), dst=Address("1.2.3.4", 80)) ) except (ValueError, ConnectionRefusedError): pass # Either is acceptable async def test_empty_rules(self): filt = IPFilter() # No rules → nothing blocked, should return None result = await filt.on_connect( _make_flow(src=Address("10.0.0.1", 1234), dst=Address("1.2.3.4", 80)) ) assert result is None class TestLoggerExtended: async def test_on_data_returns_data_passthrough(self): logger = Logger() flow = _make_flow() result = await logger.on_data(Direction.UPSTREAM, b"test", flow) assert result == b"test" async def test_on_connect_returns_none(self): logger = Logger() result = await logger.on_connect( _make_flow(src=Address("1.2.3.4", 1234), dst=Address("5.6.7.8", 80)) ) assert result is None async def test_on_error_does_not_raise(self): logger = Logger() await logger.on_error(RuntimeError("test")) await logger.on_error(ConnectionError("test")) await logger.on_error(ValueError("test")) ================================================ FILE: tests/test_addon_chain.py ================================================ import asyncio import pytest from asyncio_socks_server.addons.chain import ChainRouter from asyncio_socks_server.client.client import connect from asyncio_socks_server.core.types import Address from asyncio_socks_server.server.server import Server async def _start_server(**kwargs): server = Server(host="127.0.0.1", port=0, **kwargs) task = asyncio.create_task(server._run()) for _ in range(50): if server.port != 0: break await asyncio.sleep(0.01) return server, task async def _stop_server(server, task): server.request_shutdown() await task @pytest.fixture async def echo_server(): async def handler(reader, writer): try: while True: data = await reader.read(4096) if not data: break writer.write(data) await writer.drain() finally: writer.close() await writer.wait_closed() srv = await asyncio.start_server(handler, "127.0.0.1", 0) addr = srv.sockets[0].getsockname() yield Address(addr[0], addr[1]) srv.close() await srv.wait_closed() class TestChainRouter: async def test_two_hop_chain(self, echo_server): # Exit node: direct to target exit_server, exit_task = await _start_server() # Entry node: routes through exit node chain_addon = ChainRouter(next_hop=f"127.0.0.1:{exit_server.port}") entry_server, entry_task = await _start_server(addons=[chain_addon]) try: conn = await connect( Address(entry_server.host, entry_server.port), echo_server, ) conn.writer.write(b"through the chain") await conn.writer.drain() data = await conn.reader.read(4096) assert data == b"through the chain" conn.writer.close() await conn.writer.wait_closed() finally: await _stop_server(entry_server, entry_task) await _stop_server(exit_server, exit_task) async def test_chain_with_auth(self, echo_server): exit_server, exit_task = await _start_server(auth=("proxy", "secret")) chain_addon = ChainRouter( next_hop=f"127.0.0.1:{exit_server.port}", username="proxy", password="secret", ) entry_server, entry_task = await _start_server(addons=[chain_addon]) try: conn = await connect( Address(entry_server.host, entry_server.port), echo_server, ) conn.writer.write(b"auth chain") await conn.writer.drain() data = await conn.reader.read(4096) assert data == b"auth chain" conn.writer.close() await conn.writer.wait_closed() finally: await _stop_server(entry_server, entry_task) await _stop_server(exit_server, exit_task) ================================================ FILE: tests/test_addon_edge_cases.py ================================================ """Addon dispatch edge cases: competitive, pipeline, exceptions.""" from asyncio_socks_server.addons.base import Addon from asyncio_socks_server.addons.manager import AddonManager from asyncio_socks_server.core.types import Address, Direction, Flow def _make_flow(**kwargs): defaults = dict( id=1, src=Address("127.0.0.1", 0), dst=Address("0.0.0.0", 0), protocol="tcp", started_at=0.0, ) defaults.update(kwargs) return Flow(**defaults) class ConnectReturning(Addon): def __init__(self, value=None): self._value = value async def on_connect(self, flow): return self._value class TrackingAddon(Addon): def __init__(self, name, calls): self._name = name self._calls = calls async def on_start(self): self._calls.append(f"{self._name}:on_start") async def on_stop(self): self._calls.append(f"{self._name}:on_stop") class DataTransform(Addon): def __init__(self, transform_fn): self._fn = transform_fn async def on_data(self, direction, data, flow): return self._fn(data) class ErrorRaiser(Addon): def __init__(self, raise_on_error=False): self._raise_on_error = raise_on_error self.errors = [] async def on_error(self, error): self.errors.append(error) if self._raise_on_error: raise RuntimeError("addon error") class AuthAddon(Addon): def __init__(self, result): self._result = result async def on_auth(self, username, password): return self._result class TestCompetitiveConnect: async def test_first_addon_returns_connection(self): # Use a simple object as Connection proxy sentinel = object() a1 = ConnectReturning(value=sentinel) a2 = ConnectReturning(value=None) manager = AddonManager([a1, a2]) result = await manager.dispatch_connect( _make_flow(src=Address("1.2.3.4", 0), dst=Address("5.6.7.8", 80)) ) assert result is sentinel async def test_second_addon_returns_connection(self): sentinel = object() a1 = ConnectReturning(value=None) a2 = ConnectReturning(value=sentinel) manager = AddonManager([a1, a2]) result = await manager.dispatch_connect( _make_flow(src=Address("1.2.3.4", 0), dst=Address("5.6.7.8", 80)) ) assert result is sentinel async def test_no_addon_returns_connection(self): a1 = ConnectReturning(value=None) a2 = ConnectReturning(value=None) manager = AddonManager([a1, a2]) result = await manager.dispatch_connect( _make_flow(src=Address("1.2.3.4", 0), dst=Address("5.6.7.8", 80)) ) assert result is None class TestPipelineEdgeCases: async def test_pipeline_with_intermediate_none(self): call_log = [] class LogAddon(Addon): def __init__(self, name, ret): self._name = name self._ret = ret async def on_data(self, direction, data, flow): call_log.append(self._name) return self._ret # First transforms, second returns None (drops), third should NOT be called manager = AddonManager( [ LogAddon("upper", b"HELLO"), LogAddon("drop", None), LogAddon("lower", b"hello"), ] ) result = await manager.dispatch_data(Direction.UPSTREAM, b"hello", _make_flow()) assert result is None assert call_log == ["upper", "drop"] async def test_pipeline_empty_bytes(self): received = [] class Capture(Addon): async def on_data(self, direction, data, flow): received.append(data) return data manager = AddonManager([Capture()]) result = await manager.dispatch_data(Direction.UPSTREAM, b"", _make_flow()) assert result == b"" assert received == [b""] class TestAddonExceptions: async def test_auth_addon_raises_exception(self): class FailAuth(Addon): async def on_auth(self, username, password): raise PermissionError("blocked") manager = AddonManager([FailAuth()]) try: await manager.dispatch_auth("user", "pass") assert False, "Should have raised" except PermissionError: pass async def test_data_addon_raises_exception(self): class FailData(Addon): async def on_data(self, direction, data, flow): raise ValueError("bad data") manager = AddonManager([FailData()]) try: await manager.dispatch_data(Direction.UPSTREAM, b"test", _make_flow()) assert False, "Should have raised" except ValueError: pass async def test_error_addon_exception_suppressed(self): a1 = ErrorRaiser(raise_on_error=True) a2 = ErrorRaiser() manager = AddonManager([a1, a2]) # Should not raise even though a1 raises in on_error await manager.dispatch_error(RuntimeError("test")) # a2 should still have been called assert len(a2.errors) == 1 async def test_error_addon_all_called(self): a1 = ErrorRaiser() a2 = ErrorRaiser() a3 = ErrorRaiser() manager = AddonManager([a1, a2, a3]) err = RuntimeError("test") await manager.dispatch_error(err) assert len(a1.errors) == 1 assert len(a2.errors) == 1 assert len(a3.errors) == 1 class TestLifecycleOrder: async def test_multiple_addons_start_stop_order(self): calls = [] a1 = TrackingAddon("a1", calls) a2 = TrackingAddon("a2", calls) a3 = TrackingAddon("a3", calls) manager = AddonManager([a1, a2, a3]) await manager.dispatch_start() assert calls == ["a1:on_start", "a2:on_start", "a3:on_start"] calls.clear() await manager.dispatch_stop() assert calls == ["a1:on_stop", "a2:on_stop", "a3:on_stop"] async def test_addon_with_only_data_override(self): class DataOnly(Addon): async def on_data(self, direction, data, flow): return data manager = AddonManager([DataOnly()]) # auth should return None (not overridden) result = await manager.dispatch_auth("user", "pass") assert result is None # connect should return None result = await manager.dispatch_connect( _make_flow(src=Address("a", 1), dst=Address("b", 2)) ) assert result is None # data should pass through result = await manager.dispatch_data(Direction.UPSTREAM, b"test", _make_flow()) assert result == b"test" class TestCompetitiveAuth: async def test_first_auth_wins_true(self): manager = AddonManager([AuthAddon(True), AuthAddon(False)]) result = await manager.dispatch_auth("user", "pass") assert result is True async def test_first_auth_wins_false(self): manager = AddonManager([AuthAddon(False), AuthAddon(True)]) result = await manager.dispatch_auth("user", "pass") assert result is False async def test_all_none_passes_through(self): manager = AddonManager([AuthAddon(None), AuthAddon(None)]) result = await manager.dispatch_auth("user", "pass") assert result is None ================================================ FILE: tests/test_addon_manager.py ================================================ from asyncio_socks_server.addons.base import Addon from asyncio_socks_server.addons.manager import AddonManager from asyncio_socks_server.core.types import Address, Direction, Flow def _make_flow(**kwargs): defaults = dict( id=1, src=Address("127.0.0.1", 0), dst=Address("0.0.0.0", 0), protocol="tcp", started_at=0.0, ) defaults.update(kwargs) return Flow(**defaults) class LifeCycleAddon(Addon): def __init__(self): self.started = False self.stopped = False async def on_start(self): self.started = True async def on_stop(self): self.stopped = True class TestLifecycle: async def test_start_stop(self): addon = LifeCycleAddon() mgr = AddonManager([addon]) await mgr.dispatch_start() assert addon.started await mgr.dispatch_stop() assert addon.stopped async def test_empty_manager(self): mgr = AddonManager([]) await mgr.dispatch_start() await mgr.dispatch_stop() async def test_base_addon_skipped(self): mgr = AddonManager([Addon()]) await mgr.dispatch_start() # should not raise class AuthAllow(Addon): async def on_auth(self, username, password): return True class AuthDeny(Addon): async def on_auth(self, username, password): return False class AuthPass(Addon): async def on_auth(self, username, password): return None class TestCompetitiveAuth: async def test_first_allow_wins(self): mgr = AddonManager([AuthAllow(), AuthDeny()]) result = await mgr.dispatch_auth("user", "pass") assert result is True async def test_first_deny_wins(self): mgr = AddonManager([AuthDeny(), AuthAllow()]) result = await mgr.dispatch_auth("user", "pass") assert result is False async def test_all_pass(self): mgr = AddonManager([AuthPass(), AuthPass()]) result = await mgr.dispatch_auth("user", "pass") assert result is None async def test_passthrough_then_allow(self): mgr = AddonManager([AuthPass(), AuthAllow()]) result = await mgr.dispatch_auth("user", "pass") assert result is True class UpperAddon(Addon): async def on_data(self, direction, data, flow): return data.upper() class AppendAddon(Addon): async def on_data(self, direction, data, flow): return data + b"!" class DropAddon(Addon): async def on_data(self, direction, data, flow): return None class TestPipelineData: async def test_single_transform(self): mgr = AddonManager([UpperAddon()]) result = await mgr.dispatch_data(Direction.UPSTREAM, b"hello", _make_flow()) assert result == b"HELLO" async def test_chain_transforms(self): mgr = AddonManager([UpperAddon(), AppendAddon()]) result = await mgr.dispatch_data(Direction.UPSTREAM, b"hello", _make_flow()) assert result == b"HELLO!" async def test_drop_stops_pipeline(self): mgr = AddonManager([DropAddon(), UpperAddon()]) result = await mgr.dispatch_data(Direction.UPSTREAM, b"hello", _make_flow()) assert result is None async def test_no_addons(self): mgr = AddonManager([]) result = await mgr.dispatch_data(Direction.UPSTREAM, b"hello", _make_flow()) assert result == b"hello" class ErrorAddon(Addon): def __init__(self): self.errors: list[Exception] = [] async def on_error(self, error): self.errors.append(error) class ErrorRaisingAddon(Addon): async def on_error(self, error): raise RuntimeError("observer crashed") class TestObservationalError: async def test_all_called(self): a1 = ErrorAddon() a2 = ErrorAddon() mgr = AddonManager([a1, a2]) err = ValueError("test") await mgr.dispatch_error(err) assert len(a1.errors) == 1 assert len(a2.errors) == 1 async def test_exception_doesnt_propagate(self): a1 = ErrorRaisingAddon() a2 = ErrorAddon() mgr = AddonManager([a1, a2]) await mgr.dispatch_error(ValueError("test")) assert len(a2.errors) == 1 # second addon still called ================================================ FILE: tests/test_addon_stats.py ================================================ import asyncio import json from asyncio_socks_server import ( Address, FlowAudit, FlowStats, Server, StatsAPI, StatsServer, connect, ) async def _start_server(**kwargs): server = Server(host="127.0.0.1", port=0, **kwargs) task = asyncio.create_task(server._run()) for _ in range(50): if server.port != 0: break await asyncio.sleep(0.01) return server, task async def _stop_server(server, task): server.request_shutdown() await task async def _get_json(port: int, path: str): return await _request_json(port, "GET", path) async def _request_json(port: int, method: str, path: str): reader, writer = await asyncio.open_connection("127.0.0.1", port) writer.write(f"{method} {path} HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n".encode("ascii")) await writer.drain() data = await reader.read() writer.close() await writer.wait_closed() header, body = data.split(b"\r\n\r\n", 1) status = int(header.split(b" ", 2)[1]) return status, json.loads(body) class TestStatsServer: async def test_flow_stats_has_no_network_side_effects(self, echo_server): stats = FlowStats() server, task = await _start_server(addons=[stats]) conn = None try: conn = await connect(Address(server.host, server.port), echo_server) conn.writer.write(b"flowstats") await conn.writer.drain() data = await conn.reader.read(4096) assert data == b"flowstats" payload = stats.snapshot() assert payload["active_flows"] == 1 assert payload["active_bytes_up"] == 9 assert payload["active_bytes_down"] == 9 assert payload["active"][0]["started_at"].endswith("Z") finally: if conn is not None: conn.writer.close() await conn.writer.wait_closed() await _stop_server(server, task) async def test_health_endpoint(self): stats = StatsAPI() server, task = await _start_server(addons=[stats]) try: status, payload = await _get_json(stats.port, "/health") assert status == 200 assert payload == {"ok": True} finally: await _stop_server(server, task) async def test_stats_api_can_present_external_flow_stats_without_double_counting( self, echo_server, ): stats = FlowStats() api = StatsAPI(stats=stats) server, task = await _start_server(addons=[stats, api]) conn = None try: conn = await connect(Address(server.host, server.port), echo_server) conn.writer.write(b"external") await conn.writer.drain() data = await conn.reader.read(4096) assert data == b"external" status, payload = await _get_json(api.port, "/stats") assert status == 200 assert payload["total_flows"] == 1 assert stats.snapshot()["total_flows"] == 1 finally: if conn is not None: conn.writer.close() await conn.writer.wait_closed() await _stop_server(server, task) async def test_errors_endpoint(self): stats = StatsAPI() await stats.on_error(RuntimeError("boom")) server, task = await _start_server(addons=[stats]) try: status, payload = await _get_json(stats.port, "/errors") assert status == 200 assert payload["total"] == 1 assert payload["by_type"] == {"RuntimeError": 1} assert payload["recent"][0]["message"] == "boom" finally: await _stop_server(server, task) async def test_flow_audit_has_no_network_side_effects(self, echo_server): audit = FlowAudit() server, task = await _start_server(addons=[audit]) conn = None try: conn = await connect(Address(server.host, server.port), echo_server) conn.writer.write(b"audit") await conn.writer.drain() data = await conn.reader.read(4096) assert data == b"audit" conn.writer.close() await conn.writer.wait_closed() await asyncio.sleep(0.05) payload = audit.snapshot() assert payload["status"] == "ready" assert payload["records"] == 1 assert payload["total"] == {"upload": 5, "download": 5, "total": 10} assert payload["devices"][0]["total"] == 10 assert payload["traffic"][0]["total"] == 10 assert payload["recent"][0]["started_at"].endswith("Z") finally: if conn is not None and not conn.writer.is_closing(): conn.writer.close() await conn.writer.wait_closed() await _stop_server(server, task) async def test_stats_api_exposes_flow_audit(self, echo_server): audit = FlowAudit() api = StatsAPI(audit=audit) server, task = await _start_server(addons=[audit, api]) conn = None try: conn = await connect(Address(server.host, server.port), echo_server) conn.writer.write(b"audit-api") await conn.writer.drain() data = await conn.reader.read(4096) assert data == b"audit-api" conn.writer.close() await conn.writer.wait_closed() await asyncio.sleep(0.05) status, payload = await _get_json(api.port, "/audit?top=1") assert status == 200 assert payload["records"] == 1 assert len(payload["devices"]) == 1 assert len(payload["traffic"]) == 1 status, payload = await _request_json(api.port, "POST", "/audit/refresh") assert status == 200 assert payload["records"] == 1 finally: if conn is not None and not conn.writer.is_closing(): conn.writer.close() await conn.writer.wait_closed() await _stop_server(server, task) async def test_stats_api_reports_audit_disabled(self): stats = StatsAPI() server, task = await _start_server(addons=[stats]) try: status, payload = await _get_json(stats.port, "/audit") assert status == 404 assert payload == {"error": "audit disabled"} finally: await _stop_server(server, task) async def test_tracks_active_and_closed_tcp_flows(self, echo_server): stats = StatsServer() server, task = await _start_server(addons=[stats]) try: conn = await connect(Address(server.host, server.port), echo_server) conn.writer.write(b"stats") await conn.writer.drain() data = await conn.reader.read(4096) assert data == b"stats" status, payload = await _get_json(stats.port, "/stats") assert status == 200 assert payload["started_at"].endswith("Z") assert payload["active_flows"] == 1 assert payload["closed_flows"] == 0 assert payload["total_closed_flows"] == 0 assert payload["total_flows"] == 1 assert payload["total_tcp_flows"] == 1 assert payload["active_bytes_up"] == 5 assert payload["active_bytes_down"] == 5 assert payload["total_bytes_up"] == 5 assert payload["total_bytes_down"] == 5 assert payload["upload_rate"] >= 0 assert payload["download_rate"] >= 0 assert payload["errors"] == {"total": 0, "by_type": {}, "recent": []} assert payload["active"][0]["started_at"].endswith("Z") assert payload["active"][0]["bytes_up"] == 5 assert payload["active"][0]["bytes_down"] == 5 assert payload["active"][0]["upload_rate"] >= 0 assert payload["active"][0]["download_rate"] >= 0 conn.writer.close() await conn.writer.wait_closed() await asyncio.sleep(0.05) status, flows = await _get_json(stats.port, "/flows") assert status == 200 assert flows["active"] == [] assert len(flows["recent_closed"]) == 1 assert flows["recent_closed"][0]["bytes_up"] == 5 assert flows["recent_closed"][0]["bytes_down"] == 5 snapshot = stats.snapshot() assert snapshot["active_flows"] == 0 assert snapshot["closed_flows"] == 1 assert snapshot["total_closed_flows"] == 1 assert snapshot["closed_bytes_up"] == 5 assert snapshot["closed_bytes_down"] == 5 assert snapshot["total_bytes_up"] == 5 assert snapshot["total_bytes_down"] == 5 finally: await _stop_server(server, task) async def test_tracks_errors(self): stats = StatsServer() await stats.on_error(RuntimeError("boom")) snapshot = stats.snapshot() assert snapshot["errors"]["total"] == 1 assert snapshot["errors"]["by_type"] == {"RuntimeError": 1} assert snapshot["errors"]["recent"][0]["type"] == "RuntimeError" assert snapshot["errors"]["recent"][0]["message"] == "boom" async def test_not_found(self): stats = StatsServer() server, task = await _start_server(addons=[stats]) try: status, payload = await _get_json(stats.port, "/missing") assert status == 404 assert payload == {"error": "not found"} finally: await _stop_server(server, task) ================================================ FILE: tests/test_cli.py ================================================ """Tests for CLI argument parsing.""" from unittest.mock import MagicMock, patch import pytest from asyncio_socks_server.cli import main class TestCliArgs: @patch("asyncio_socks_server.cli.Server") def test_default_values(self, mock_server_cls): with pytest.raises(SystemExit): # argparse exits on --help with patch("sys.argv", ["asyncio_socks_server", "--help"]): main() @patch("asyncio_socks_server.cli.Server") def test_custom_host_port(self, mock_server_cls): mock_instance = MagicMock() mock_server_cls.return_value = mock_instance with patch( "sys.argv", ["asyncio_socks_server", "--host", "127.0.0.1", "--port", "9050"], ): main() mock_server_cls.assert_called_once_with( host="127.0.0.1", port=9050, auth=None, log_level="INFO" ) mock_instance.run.assert_called_once() @patch("asyncio_socks_server.cli.Server") def test_auth_parsing(self, mock_server_cls): mock_instance = MagicMock() mock_server_cls.return_value = mock_instance with patch("sys.argv", ["asyncio_socks_server", "--auth", "user:pass"]): main() mock_server_cls.assert_called_once_with( host="::", port=1080, auth=("user", "pass"), log_level="INFO" ) @patch("asyncio_socks_server.cli.Server") def test_auth_with_colon_in_password(self, mock_server_cls): mock_instance = MagicMock() mock_server_cls.return_value = mock_instance with patch("sys.argv", ["asyncio_socks_server", "--auth", "user:pass:word"]): main() mock_server_cls.assert_called_once_with( host="::", port=1080, auth=("user", "pass:word"), log_level="INFO" ) def test_invalid_log_level(self): with pytest.raises(SystemExit): with patch("sys.argv", ["asyncio_socks_server", "--log-level", "INVALID"]): main() @patch("asyncio_socks_server.cli.Server") def test_debug_log_level(self, mock_server_cls): mock_instance = MagicMock() mock_server_cls.return_value = mock_instance with patch("sys.argv", ["asyncio_socks_server", "--log-level", "DEBUG"]): main() mock_server_cls.assert_called_once_with( host="::", port=1080, auth=None, log_level="DEBUG" ) @patch("asyncio_socks_server.cli.Server") def test_no_auth_flag(self, mock_server_cls): mock_instance = MagicMock() mock_server_cls.return_value = mock_instance with patch("sys.argv", ["asyncio_socks_server"]): main() call_kwargs = mock_server_cls.call_args[1] assert call_kwargs["auth"] is None ================================================ FILE: tests/test_client.py ================================================ import asyncio import pytest from asyncio_socks_server.client.client import connect from asyncio_socks_server.core.types import Address from asyncio_socks_server.server.server import Server async def _start_server(**kwargs): server = Server(host="127.0.0.1", port=0, **kwargs) task = asyncio.create_task(server._run()) for _ in range(50): if server.port != 0: break await asyncio.sleep(0.01) return server, task async def _stop_server(server, task): server.request_shutdown() await task @pytest.fixture async def echo_server(): async def handler(reader, writer): try: while True: data = await reader.read(4096) if not data: break writer.write(data) await writer.drain() finally: writer.close() await writer.wait_closed() srv = await asyncio.start_server(handler, "127.0.0.1", 0) addr = srv.sockets[0].getsockname() yield Address(addr[0], addr[1]) srv.close() await srv.wait_closed() class TestClientConnect: async def test_no_auth(self, echo_server): server, task = await _start_server() try: conn = await connect(Address(server.host, server.port), echo_server) conn.writer.write(b"hello") await conn.writer.drain() data = await conn.reader.read(4096) assert data == b"hello" conn.writer.close() await conn.writer.wait_closed() finally: await _stop_server(server, task) async def test_with_auth(self, echo_server): server, task = await _start_server(auth=("user", "pass")) try: conn = await connect( Address(server.host, server.port), echo_server, username="user", password="pass", ) conn.writer.write(b"secret") await conn.writer.drain() data = await conn.reader.read(4096) assert data == b"secret" conn.writer.close() await conn.writer.wait_closed() finally: await _stop_server(server, task) async def test_auth_failure(self, echo_server): server, task = await _start_server(auth=("user", "pass")) try: from asyncio_socks_server.core.protocol import ProtocolError with pytest.raises(ProtocolError, match="authentication failed"): await connect( Address(server.host, server.port), echo_server, username="user", password="wrong", ) finally: await _stop_server(server, task) async def test_connection_refused(self): server, task = await _start_server() try: with pytest.raises(Exception): await connect( Address(server.host, server.port), Address("127.0.0.1", 1), # port 1 should refuse ) finally: await _stop_server(server, task) ================================================ FILE: tests/test_client_edge_cases.py ================================================ """Client edge cases: negotiation failures, unexpected responses.""" import asyncio import socket import pytest from asyncio_socks_server.client.client import _happy_eyeballs_connect, connect from asyncio_socks_server.core.types import Address async def _fake_socks_server(*responses): """Start a fake SOCKS server that sends predefined responses. Returns (Address, close_event) where Address is the server's listen address. """ close_event = asyncio.Event() received = [] async def handler(reader, writer): try: while not close_event.is_set(): try: data = await asyncio.wait_for(reader.read(4096), timeout=0.5) if not data: break received.append(data) except asyncio.TimeoutError: continue finally: writer.close() await writer.wait_closed() srv = await asyncio.start_server(handler, "127.0.0.1", 0) addr = srv.sockets[0].getsockname() return Address(addr[0], addr[1]), srv, close_event, received async def _fake_socks_server_with_responses(responses): """Start a fake SOCKS server that sends specific byte sequences. Each response is sent after receiving data from the client. Returns (Address, server, close_event). """ close_event = asyncio.Event() resp_idx = [0] async def handler(reader, writer): try: while resp_idx[0] < len(responses): try: data = await asyncio.wait_for(reader.read(4096), timeout=1.0) if not data: break if resp_idx[0] < len(responses): writer.write(responses[resp_idx[0]]) await writer.drain() resp_idx[0] += 1 except asyncio.TimeoutError: break finally: writer.close() await writer.wait_closed() srv = await asyncio.start_server(handler, "127.0.0.1", 0) addr = srv.sockets[0].getsockname() return Address(addr[0], addr[1]), srv, close_event class TestClientNegotiationFailures: async def test_proxy_returns_wrong_version(self): # Reply with version 0x04 instead of 0x05 proxy_addr, srv, close = await _fake_socks_server_with_responses([b"\x04\x00"]) try: with pytest.raises(Exception): await connect(proxy_addr, Address("127.0.0.1", 80)) finally: close.set() srv.close() await srv.wait_closed() async def test_proxy_returns_no_acceptable_method(self): # Reply with 0xFF method (NO_ACCEPTABLE) proxy_addr, srv, close = await _fake_socks_server_with_responses([b"\x05\xff"]) try: with pytest.raises(Exception, match="no acceptable"): await connect(proxy_addr, Address("127.0.0.1", 80)) finally: close.set() srv.close() await srv.wait_closed() async def test_connect_reply_failure(self): # Method selection: accept NO_AUTH, then reply CONNECTION_REFUSED proxy_addr, srv, close = await _fake_socks_server_with_responses( [ b"\x05\x00", # Method reply: NO_AUTH # CONNECT reply: CONNECTION_REFUSED b"\x05\x05\x00\x01\x7f\x00\x00\x01\x00\x50", ] ) try: with pytest.raises(Exception, match="refused|failed|CONNECT"): await connect(proxy_addr, Address("127.0.0.1", 80)) finally: close.set() srv.close() await srv.wait_closed() async def test_connect_reply_wrong_version(self): proxy_addr, srv, close = await _fake_socks_server_with_responses( [ b"\x05\x00", # Method reply OK b"\x04\x00\x00\x01\x7f\x00\x00\x01\x00\x50", # Wrong version in reply ] ) try: with pytest.raises(Exception): await connect(proxy_addr, Address("127.0.0.1", 80)) finally: close.set() srv.close() await srv.wait_closed() class TestClientConnectionFailures: async def test_happy_eyeballs_falls_back_after_fast_first_failure( self, monkeypatch ): loop = asyncio.get_running_loop() attempts = [] async def fake_getaddrinfo(host, port, type): return [ ( socket.AF_INET6, socket.SOCK_STREAM, 0, "", ("2001:db8::1", port, 0, 0), ), ( socket.AF_INET, socket.SOCK_STREAM, 0, "", ("127.0.0.1", port), ), ] async def fake_open_connection(host, port): attempts.append(host) if host == "2001:db8::1": raise OSError("ipv6 unavailable") return "reader", "writer" monkeypatch.setattr(loop, "getaddrinfo", fake_getaddrinfo) monkeypatch.setattr(asyncio, "open_connection", fake_open_connection) result = await _happy_eyeballs_connect(Address("example.test", 1080)) assert result == ("reader", "writer") assert attempts == ["2001:db8::1", "127.0.0.1"] async def test_connection_refused(self): # Connect to a port that nobody is listening on proxy_addr = Address("127.0.0.1", 1) with pytest.raises((ConnectionError, OSError)): await connect(proxy_addr, Address("127.0.0.1", 80)) async def test_auth_failure(self): # Test auth failure via real server from tests.conftest import _start_server, _stop_server server, task = await _start_server(auth=("user", "pass")) try: with pytest.raises(Exception, match="authentication failed"): await connect( Address(server.host, server.port), Address("127.0.0.1", 80), username="user", password="wrong", ) finally: await _stop_server(server, task) ================================================ FILE: tests/test_concurrent.py ================================================ """Concurrency and stress tests.""" import asyncio from asyncio_socks_server import TrafficCounter from asyncio_socks_server.core.types import Address from tests.conftest import _start_server, _stop_server async def _socks5_proxy_connect(proxy: Address, target: Address): """Quick SOCKS5 CONNECT through proxy.""" reader, writer = await asyncio.open_connection(proxy.host, proxy.port) writer.write(b"\x05\x01\x00") await writer.drain() resp = await reader.readexactly(2) assert resp == b"\x05\x00" from asyncio_socks_server.core.address import encode_address writer.write(b"\x05\x01\x00" + encode_address(target.host, target.port)) await writer.drain() reply = await reader.readexactly(3) assert reply[1] == 0x00 atyp = (await reader.readexactly(1))[0] if atyp == 0x01: await reader.readexactly(4 + 2) elif atyp == 0x04: await reader.readexactly(16 + 2) elif atyp == 0x03: length = (await reader.readexactly(1))[0] await reader.readexactly(length + 2) return reader, writer class TestConcurrentConnections: async def test_20_simultaneous_connections(self, echo_server): server, task = await _start_server() try: conns = await asyncio.gather( *[ _socks5_proxy_connect( Address(server.host, server.port), echo_server ) for _ in range(20) ] ) # Send data on all for r, w in conns: w.write(b"ping") await w.drain() # Read all responses for r, w in conns: data = await asyncio.wait_for(r.read(4096), timeout=2.0) assert data == b"ping" w.close() await w.wait_closed() finally: await _stop_server(server, task) async def test_concurrent_with_addon(self, echo_server): counter = TrafficCounter() server, task = await _start_server(addons=[counter]) try: conns = await asyncio.gather( *[ _socks5_proxy_connect( Address(server.host, server.port), echo_server ) for _ in range(10) ] ) for r, w in conns: w.write(b"test") await w.drain() for r, w in conns: await r.read(4096) w.close() await w.wait_closed() await asyncio.sleep(0.3) assert counter.bytes_up == 40 # 10 * 4 bytes assert counter.bytes_down == 40 assert counter.connections == 10 finally: await _stop_server(server, task) class TestLargePayloads: async def test_1mb_payload(self, echo_server): server, task = await _start_server() try: r, w = await _socks5_proxy_connect( Address(server.host, server.port), echo_server ) payload = b"A" * (1024 * 1024) w.write(payload) await w.drain() received = b"" while len(received) < len(payload): chunk = await asyncio.wait_for(r.read(65536), timeout=5.0) if not chunk: break received += chunk assert received == payload w.close() await w.wait_closed() finally: await _stop_server(server, task) async def test_many_small_writes(self, echo_server): server, task = await _start_server() try: r, w = await _socks5_proxy_connect( Address(server.host, server.port), echo_server ) expected = b"".join(f"msg{i:03d}".encode() for i in range(100)) w.write(expected) await w.drain() # Read all echoed data total = b"" while len(total) < len(expected): chunk = await asyncio.wait_for(r.read(65536), timeout=3.0) if not chunk: break total += chunk assert total == expected w.close() await w.wait_closed() finally: await _stop_server(server, task) class TestRapidConnectDisconnect: async def test_rapid_10_cycles(self): server, task = await _start_server() try: for _ in range(10): r, w = await asyncio.open_connection(server.host, server.port) w.write(b"\x05\x01\x00") await w.drain() resp = await r.readexactly(2) assert resp == b"\x05\x00" w.close() await w.wait_closed() # Verify server is still responsive r, w = await asyncio.open_connection(server.host, server.port) w.write(b"\x05\x01\x00") await w.drain() resp = await r.readexactly(2) assert resp == b"\x05\x00" w.close() await w.wait_closed() finally: await _stop_server(server, task) ================================================ FILE: tests/test_connection.py ================================================ """Tests for Connection dataclass.""" import asyncio from asyncio_socks_server.core.types import Address from asyncio_socks_server.server.connection import Connection class TestConnection: async def test_dataclass_fields(self): reader = asyncio.StreamReader() writer = None # Writer not needed for this test addr = Address("127.0.0.1", 1080) conn = Connection(reader=reader, writer=writer, address=addr) assert conn.reader is reader assert conn.writer is writer assert conn.address == addr async def test_address_type(self): reader = asyncio.StreamReader() conn = Connection(reader=reader, writer=None, address=Address("::1", 443)) assert isinstance(conn.address, Address) assert conn.address.host == "::1" assert conn.address.port == 443 ================================================ FILE: tests/test_core_address.py ================================================ import asyncio from asyncio_socks_server.core.address import ( decode_address, detect_atyp, encode_address, encode_reply, ) from asyncio_socks_server.core.types import Atyp, Rep class TestDetectAtyp: def test_ipv4(self): assert detect_atyp("127.0.0.1") == Atyp.IPV4 assert detect_atyp("0.0.0.0") == Atyp.IPV4 def test_ipv6(self): assert detect_atyp("::1") == Atyp.IPV6 assert detect_atyp("2001:db8::1") == Atyp.IPV6 def test_domain(self): assert detect_atyp("example.com") == Atyp.DOMAIN assert detect_atyp("sub.example.com") == Atyp.DOMAIN class TestEncodeDecodeAddress: def _roundtrip(self, host: str, port: int): encoded = encode_address(host, port) reader = asyncio.StreamReader() async def do(): reader.feed_data(encoded) reader.feed_eof() return await decode_address(reader) result = asyncio.get_event_loop().run_until_complete(do()) return result def test_ipv4_roundtrip(self): result = self._roundtrip("127.0.0.1", 1080) assert result.host == "127.0.0.1" assert result.port == 1080 def test_ipv6_roundtrip(self): result = self._roundtrip("::1", 443) assert result.host == "::1" assert result.port == 443 def test_domain_roundtrip(self): result = self._roundtrip("example.com", 80) assert result.host == "example.com" assert result.port == 80 def test_encode_ipv4_binary(self): # ATYP(1) + IPv4(4) + PORT(2) = 7 bytes data = encode_address("0.0.0.0", 0) assert len(data) == 7 assert data[0] == 0x01 def test_encode_ipv6_binary(self): # ATYP(1) + IPv6(16) + PORT(2) = 19 bytes data = encode_address("::1", 0) assert len(data) == 19 assert data[0] == 0x04 def test_encode_domain_binary(self): # ATYP(1) + LEN(1) + "example.com"(11) + PORT(2) = 15 bytes data = encode_address("example.com", 80) assert len(data) == 15 assert data[0] == 0x03 assert data[1] == 11 class TestEncodeReply: def test_success_reply(self): reply = encode_reply(Rep.SUCCEEDED, "0.0.0.0", 0) assert reply[0] == 0x05 # VER assert reply[1] == 0x00 # REP = succeeded assert reply[2] == 0x00 # RSV def test_failure_reply(self): reply = encode_reply(Rep.CONNECTION_REFUSED) assert reply[1] == 0x05 # REP = connection refused ================================================ FILE: tests/test_core_protocol.py ================================================ import asyncio import pytest from asyncio_socks_server.core.protocol import ( ProtocolError, build_auth_reply, build_method_reply, build_udp_header, parse_method_selection, parse_request, parse_udp_header, parse_username_password, ) from asyncio_socks_server.core.types import Address, AuthMethod, Cmd class TestMethodSelection: def test_valid_no_auth(self): data = b"\x05\x01\x00" # VER=5, NMETHODS=1, METHOD=NO_AUTH ver, methods = parse_method_selection(data) assert ver == 0x05 assert AuthMethod.NO_AUTH in methods def test_valid_username_password(self): data = b"\x05\x02\x00\x02" ver, methods = parse_method_selection(data) assert AuthMethod.NO_AUTH in methods assert AuthMethod.USERNAME_PASSWORD in methods def test_wrong_version(self): with pytest.raises(ProtocolError, match="unsupported SOCKS version"): parse_method_selection(b"\x04\x01\x00") def test_too_short(self): with pytest.raises(ProtocolError, match="too short"): parse_method_selection(b"\x05") def test_build_method_reply(self): assert build_method_reply(0x00) == b"\x05\x00" assert build_method_reply(0x02) == b"\x05\x02" assert build_method_reply(0xFF) == b"\x05\xff" class TestUsernamePassword: def test_parse(self): # VER=1, ULEN=4, UNAME="user", PLEN=4, PASSWD="pass" data = b"\x01\x04user\x04pass" reader = asyncio.StreamReader() async def do(): reader.feed_data(data) reader.feed_eof() return await parse_username_password(reader) username, password = asyncio.get_event_loop().run_until_complete(do()) assert username == "user" assert password == "pass" def test_wrong_version(self): data = b"\x02\x04user\x04pass" reader = asyncio.StreamReader() async def do(): reader.feed_data(data) reader.feed_eof() return await parse_username_password(reader) with pytest.raises(ProtocolError, match="unsupported auth version"): asyncio.get_event_loop().run_until_complete(do()) def test_build_auth_reply(self): assert build_auth_reply(True) == b"\x01\x00" assert build_auth_reply(False) == b"\x01\x01" class TestParseRequest: def _make_request(self, cmd: int, host: str, port: int) -> bytes: from asyncio_socks_server.core.address import encode_address VER = b"\x05" CMD = cmd.to_bytes(1, "big") RSV = b"\x00" return VER + CMD + RSV + encode_address(host, port) def test_connect_ipv4(self): data = self._make_request(0x01, "127.0.0.1", 1080) reader = asyncio.StreamReader() async def do(): reader.feed_data(data) reader.feed_eof() return await parse_request(reader) cmd, addr = asyncio.get_event_loop().run_until_complete(do()) assert cmd == Cmd.CONNECT assert addr.host == "127.0.0.1" assert addr.port == 1080 def test_connect_ipv6(self): data = self._make_request(0x01, "::1", 443) reader = asyncio.StreamReader() async def do(): reader.feed_data(data) reader.feed_eof() return await parse_request(reader) cmd, addr = asyncio.get_event_loop().run_until_complete(do()) assert cmd == Cmd.CONNECT assert addr.host == "::1" assert addr.port == 443 def test_connect_domain(self): data = self._make_request(0x01, "example.com", 80) reader = asyncio.StreamReader() async def do(): reader.feed_data(data) reader.feed_eof() return await parse_request(reader) cmd, addr = asyncio.get_event_loop().run_until_complete(do()) assert cmd == Cmd.CONNECT assert addr.host == "example.com" assert addr.port == 80 def test_udp_associate(self): data = self._make_request(0x03, "0.0.0.0", 0) reader = asyncio.StreamReader() async def do(): reader.feed_data(data) reader.feed_eof() return await parse_request(reader) cmd, addr = asyncio.get_event_loop().run_until_complete(do()) assert cmd == Cmd.UDP_ASSOCIATE def test_wrong_version(self): data = b"\x04\x01\x00\x01\x7f\x00\x00\x01\x04\x38" reader = asyncio.StreamReader() async def do(): reader.feed_data(data) reader.feed_eof() return await parse_request(reader) with pytest.raises(ProtocolError, match="unsupported SOCKS version"): asyncio.get_event_loop().run_until_complete(do()) def test_unsupported_command(self): data = self._make_request(0x02, "127.0.0.1", 1080) # BIND reader = asyncio.StreamReader() async def do(): reader.feed_data(data) reader.feed_eof() return await parse_request(reader) with pytest.raises(ProtocolError, match="unsupported command"): asyncio.get_event_loop().run_until_complete(do()) class TestUdpHeader: def test_parse_ipv4(self): from asyncio_socks_server.core.address import encode_address header = b"\x00\x00\x00" + encode_address("127.0.0.1", 1080) payload = b"hello" data = header + payload addr, hdr_len, pl = parse_udp_header(data) assert addr.host == "127.0.0.1" assert addr.port == 1080 assert hdr_len == 3 + 7 # RSV(2)+FRAG(1)+ATYP(1)+IPv4(4)+PORT(2) assert pl == b"hello" def test_parse_domain(self): from asyncio_socks_server.core.address import encode_address header = b"\x00\x00\x00" + encode_address("example.com", 80) payload = b"world" data = header + payload addr, hdr_len, pl = parse_udp_header(data) assert addr.host == "example.com" assert addr.port == 80 assert pl == b"world" def test_build_udp_header(self): header = build_udp_header(Address("127.0.0.1", 1080)) assert header[0:2] == b"\x00\x00" # RSV assert header[2] == 0x00 # FRAG assert header[3] == 0x01 # ATYP IPv4 def test_too_short(self): with pytest.raises(ProtocolError, match="too short"): parse_udp_header(b"\x00\x00") ================================================ FILE: tests/test_core_socket.py ================================================ import socket from asyncio_socks_server.core.socket import ( create_dualstack_tcp_socket, create_dualstack_udp_socket, ) def test_tcp_unspecified_ipv4_uses_ipv4_socket(): sock = create_dualstack_tcp_socket("0.0.0.0", 0) try: assert sock.family == socket.AF_INET finally: sock.close() def test_udp_unspecified_ipv4_uses_ipv4_socket(): sock = create_dualstack_udp_socket("0.0.0.0", 0) try: assert sock.family == socket.AF_INET6 assert sock.getsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY) == 0 finally: sock.close() def test_tcp_unspecified_ipv6_keeps_dualstack_socket(): sock = create_dualstack_tcp_socket("::", 0) try: assert sock.family == socket.AF_INET6 assert sock.getsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY) == 0 finally: sock.close() ================================================ FILE: tests/test_core_types.py ================================================ from asyncio_socks_server.core.types import ( Address, Atyp, AuthMethod, Cmd, Direction, Rep, ) def test_rep_values(): assert Rep.SUCCEEDED == 0x00 assert Rep.GENERAL_FAILURE == 0x01 assert Rep.COMMAND_NOT_SUPPORTED == 0x07 def test_auth_method_values(): assert AuthMethod.NO_AUTH == 0x00 assert AuthMethod.USERNAME_PASSWORD == 0x02 assert AuthMethod.NO_ACCEPTABLE == 0xFF def test_cmd_values(): assert Cmd.CONNECT == 0x01 assert Cmd.UDP_ASSOCIATE == 0x03 def test_atyp_values(): assert Atyp.IPV4 == 0x01 assert Atyp.DOMAIN == 0x03 assert Atyp.IPV6 == 0x04 def test_direction_constants(): assert Direction.UPSTREAM == "upstream" assert Direction.DOWNSTREAM == "downstream" def test_address_frozen(): addr = Address("127.0.0.1", 1080) assert addr.host == "127.0.0.1" assert addr.port == 1080 def test_address_str(): assert str(Address("127.0.0.1", 1080)) == "127.0.0.1:1080" assert str(Address("example.com", 443)) == "example.com:443" ================================================ FILE: tests/test_e2e.py ================================================ """End-to-end tests: full proxy scenarios.""" import asyncio from asyncio_socks_server import Addon, ChainRouter, TrafficCounter, connect from asyncio_socks_server.core.types import Address async def _socks5_proxy_connect(proxy: Address, target: Address, auth=None): """Raw SOCKS5 CONNECT through proxy.""" reader, writer = await asyncio.open_connection(proxy.host, proxy.port) if auth: writer.write(b"\x05\x01\x02") else: writer.write(b"\x05\x01\x00") await writer.drain() resp = await reader.readexactly(2) assert resp[0] == 0x05 if auth: assert resp[1] == 0x02 uname = auth[0].encode() passwd = auth[1].encode() writer.write( b"\x01" + len(uname).to_bytes(1, "big") + uname + len(passwd).to_bytes(1, "big") + passwd ) await writer.drain() auth_resp = await reader.readexactly(2) assert auth_resp == b"\x01\x00" else: assert resp[1] == 0x00 from asyncio_socks_server.core.address import encode_address writer.write(b"\x05\x01\x00" + encode_address(target.host, target.port)) await writer.drain() reply = await reader.readexactly(3) assert reply[1] == 0x00 # succeeded # Skip bound address atyp = (await reader.readexactly(1))[0] if atyp == 0x01: await reader.readexactly(4 + 2) elif atyp == 0x04: await reader.readexactly(16 + 2) elif atyp == 0x03: length = (await reader.readexactly(1))[0] await reader.readexactly(length + 2) return reader, writer class TestE2ETcp: async def test_bidirectional_relay(self, echo_server): from tests.conftest import _start_server, _stop_server server, task = await _start_server() try: r, w = await _socks5_proxy_connect( Address(server.host, server.port), echo_server ) w.write(b"ping") await w.drain() assert await r.read(4096) == b"ping" w.write(b"pong") await w.drain() assert await r.read(4096) == b"pong" w.close() await w.wait_closed() finally: await _stop_server(server, task) async def test_multiple_connections(self, echo_server): from tests.conftest import _start_server, _stop_server server, task = await _start_server() try: conns = [] for _ in range(5): r, w = await _socks5_proxy_connect( Address(server.host, server.port), echo_server ) conns.append((r, w)) for r, w in conns: w.write(b"test") await w.drain() assert await r.read(4096) == b"test" w.close() await w.wait_closed() finally: await _stop_server(server, task) async def test_client_library(self, echo_server): from tests.conftest import _start_server, _stop_server server, task = await _start_server() try: conn = await connect(Address(server.host, server.port), echo_server) conn.writer.write(b"via client library") await conn.writer.drain() data = await conn.reader.read(4096) assert data == b"via client library" conn.writer.close() await conn.writer.wait_closed() finally: await _stop_server(server, task) class TestE2EChain: async def test_three_hop_chain(self, echo_server): from tests.conftest import _start_server, _stop_server exit_server, exit_task = await _start_server() mid_addon = ChainRouter(next_hop=f"127.0.0.1:{exit_server.port}") mid_server, mid_task = await _start_server(addons=[mid_addon]) entry_addon = ChainRouter(next_hop=f"127.0.0.1:{mid_server.port}") entry_server, entry_task = await _start_server(addons=[entry_addon]) try: conn = await connect( Address(entry_server.host, entry_server.port), echo_server ) conn.writer.write(b"three hops!") await conn.writer.drain() data = await conn.reader.read(4096) assert data == b"three hops!" conn.writer.close() await conn.writer.wait_closed() finally: await _stop_server(entry_server, entry_task) await _stop_server(mid_server, mid_task) await _stop_server(exit_server, exit_task) class UpperAddon(Addon): async def on_data(self, direction, data, flow): return data.upper() class TestE2EAddons: async def test_pipeline_transform(self, echo_server): from tests.conftest import _start_server, _stop_server server, task = await _start_server(addons=[UpperAddon()]) try: r, w = await _socks5_proxy_connect( Address(server.host, server.port), echo_server ) w.write(b"hello") await w.drain() # Echo server receives "HELLO" and echoes it back # The upstream addon transforms to uppercase # The downstream addon also transforms, but echo returns it data = await r.read(4096) assert data == b"HELLO" w.close() await w.wait_closed() finally: await _stop_server(server, task) async def test_traffic_counter(self, echo_server): from tests.conftest import _start_server, _stop_server counter = TrafficCounter() server, task = await _start_server(addons=[counter]) try: r, w = await _socks5_proxy_connect( Address(server.host, server.port), echo_server ) w.write(b"count me") await w.drain() await r.read(4096) # TrafficCounter now reads from flow on close, # so assertions must happen after connection teardown. w.close() await w.wait_closed() await asyncio.sleep(0.2) assert counter.bytes_up == 8 assert counter.bytes_down == 8 assert counter.connections == 1 finally: await _stop_server(server, task) ================================================ FILE: tests/test_e2e_auth_chain.py ================================================ import pytest from asyncio_socks_server import ChainRouter, connect from asyncio_socks_server.core.types import Address from tests.conftest import _start_server, _stop_server class TestAuthChain: async def test_chain_with_auth_at_entry(self, echo_server): exit_server, exit_task = await _start_server() chain = ChainRouter(next_hop=f"127.0.0.1:{exit_server.port}") entry_server, entry_task = await _start_server( auth=("admin", "secret"), addons=[chain], ) try: conn = await connect( Address(entry_server.host, entry_server.port), echo_server, username="admin", password="secret", ) conn.writer.write(b"auth-chain") await conn.writer.drain() assert await conn.reader.read(4096) == b"auth-chain" conn.writer.close() await conn.writer.wait_closed() finally: await _stop_server(entry_server, entry_task) await _stop_server(exit_server, exit_task) async def test_chain_with_auth_at_exit(self, echo_server): exit_server, exit_task = await _start_server(auth=("u", "p")) chain = ChainRouter( next_hop=f"127.0.0.1:{exit_server.port}", username="u", password="p", ) entry_server, entry_task = await _start_server(addons=[chain]) try: conn = await connect( Address(entry_server.host, entry_server.port), echo_server ) conn.writer.write(b"chain-auth-exit") await conn.writer.drain() assert await conn.reader.read(4096) == b"chain-auth-exit" conn.writer.close() await conn.writer.wait_closed() finally: await _stop_server(entry_server, entry_task) await _stop_server(exit_server, exit_task) async def test_chain_both_hops_require_auth(self, echo_server): exit_server, exit_task = await _start_server(auth=("exit_u", "exit_p")) chain = ChainRouter( next_hop=f"127.0.0.1:{exit_server.port}", username="exit_u", password="exit_p", ) entry_server, entry_task = await _start_server( auth=("entry_u", "entry_p"), addons=[chain], ) try: conn = await connect( Address(entry_server.host, entry_server.port), echo_server, username="entry_u", password="entry_p", ) conn.writer.write(b"double-auth") await conn.writer.drain() assert await conn.reader.read(4096) == b"double-auth" conn.writer.close() await conn.writer.wait_closed() finally: await _stop_server(entry_server, entry_task) await _stop_server(exit_server, exit_task) async def test_chain_auth_failure_propagates(self, echo_server): from asyncio_socks_server.core.protocol import ProtocolError exit_server, exit_task = await _start_server() chain = ChainRouter(next_hop=f"127.0.0.1:{exit_server.port}") entry_server, entry_task = await _start_server( auth=("good", "creds"), addons=[chain], ) try: with pytest.raises(ProtocolError, match="authentication failed"): await connect( Address(entry_server.host, entry_server.port), echo_server, username="good", password="wrong", ) finally: await _stop_server(entry_server, entry_task) await _stop_server(exit_server, exit_task) ================================================ FILE: tests/test_e2e_data_paths.py ================================================ import asyncio import pytest from asyncio_socks_server import Addon, ChainRouter, IPFilter, TrafficCounter, connect from asyncio_socks_server.core.protocol import build_udp_header, parse_udp_header from asyncio_socks_server.core.types import Address, Direction from tests.conftest import _start_server, _stop_server from tests.e2e_helpers import open_udp_associate, read_socks_reply, socks5_connect class TestBidirectionalData: async def test_simultaneous_bidirectional(self, echo_server): server, task = await _start_server() try: reader, writer = await socks5_connect( Address(server.host, server.port), echo_server ) reply = await read_socks_reply(reader) assert reply[1] == 0x00 writer.write(b"simul") await writer.drain() assert await asyncio.wait_for(reader.read(4096), timeout=2.0) == b"simul" writer.write(b"simul2") await writer.drain() assert await asyncio.wait_for(reader.read(4096), timeout=2.0) == b"simul2" writer.close() await writer.wait_closed() finally: await _stop_server(server, task) class TestLargeDataChain: async def test_512kb_through_chain(self, echo_server): exit_server, exit_task = await _start_server() chain = ChainRouter(next_hop=f"127.0.0.1:{exit_server.port}") entry_server, entry_task = await _start_server(addons=[chain]) try: conn = await connect( Address(entry_server.host, entry_server.port), echo_server ) payload = b"X" * (512 * 1024) conn.writer.write(payload) await conn.writer.drain() received = b"" while len(received) < len(payload): chunk = await asyncio.wait_for(conn.reader.read(65536), timeout=5.0) if not chunk: break received += chunk assert received == payload conn.writer.close() await conn.writer.wait_closed() finally: await _stop_server(entry_server, entry_task) await _stop_server(exit_server, exit_task) class TestMultiAddonComposition: async def test_ipfilter_and_traffic_counter(self, echo_server): counter = TrafficCounter() filter_addon = IPFilter(allowed=["127.0.0.0/8"]) server, task = await _start_server(addons=[filter_addon, counter]) try: conn = await connect(Address(server.host, server.port), echo_server) conn.writer.write(b"filtered") await conn.writer.drain() assert await conn.reader.read(4096) == b"filtered" conn.writer.close() await conn.writer.wait_closed() await asyncio.sleep(0.2) assert counter.connections == 1 assert counter.bytes_up == 8 assert counter.bytes_down == 8 finally: await _stop_server(server, task) async def test_ipfilter_blocks_then_traffic_zero(self, echo_server): counter = TrafficCounter() filter_addon = IPFilter(blocked=["127.0.0.0/8"]) server, task = await _start_server(addons=[filter_addon, counter]) try: reader, writer = await socks5_connect( Address(server.host, server.port), echo_server ) reply = await read_socks_reply(reader) assert reply[1] == 0x02 writer.close() await writer.wait_closed() await asyncio.sleep(0.1) assert counter.connections == 0 finally: await _stop_server(server, task) async def test_pipeline_and_chain_combined(self, echo_server): class UpperAddon(Addon): async def on_data(self, direction, data, flow): return data.upper() exit_server, exit_task = await _start_server(addons=[UpperAddon()]) chain = ChainRouter(next_hop=f"127.0.0.1:{exit_server.port}") entry_server, entry_task = await _start_server(addons=[chain]) try: conn = await connect( Address(entry_server.host, entry_server.port), echo_server ) conn.writer.write(b"transform-me") await conn.writer.drain() assert await conn.reader.read(4096) == b"TRANSFORM-ME" conn.writer.close() await conn.writer.wait_closed() finally: await _stop_server(entry_server, entry_task) await _stop_server(exit_server, exit_task) class TestAddonDataDrop: async def test_drop_addon_silences_upstream(self, echo_server): class DropUpstream(Addon): async def on_data(self, direction, data, flow): if direction == Direction.UPSTREAM: return None return data server, task = await _start_server(addons=[DropUpstream()]) try: reader, writer = await socks5_connect( Address(server.host, server.port), echo_server ) reply = await read_socks_reply(reader) assert reply[1] == 0x00 writer.write(b"dropped") await writer.drain() with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(reader.read(4096), timeout=0.5) writer.close() await writer.wait_closed() finally: await _stop_server(server, task) class TestFlowBytesAccuracy: async def test_traffic_counter_through_chain(self, echo_server): exit_counter = TrafficCounter() exit_server, exit_task = await _start_server(addons=[exit_counter]) entry_counter = TrafficCounter() chain = ChainRouter(next_hop=f"127.0.0.1:{exit_server.port}") entry_server, entry_task = await _start_server(addons=[entry_counter, chain]) try: conn = await connect( Address(entry_server.host, entry_server.port), echo_server ) conn.writer.write(b"12345") await conn.writer.drain() assert await conn.reader.read(4096) == b"12345" conn.writer.close() await conn.writer.wait_closed() await asyncio.sleep(0.3) assert entry_counter.connections == 1 assert entry_counter.bytes_up == 5 assert entry_counter.bytes_down == 5 assert exit_counter.connections == 1 assert exit_counter.bytes_up == 5 assert exit_counter.bytes_down == 5 finally: await _stop_server(entry_server, entry_task) await _stop_server(exit_server, exit_task) class TestMixedProtocol: async def test_tcp_and_udp_concurrent(self, echo_server, udp_echo_server): server, task = await _start_server() try: tcp_reader, tcp_writer = await socks5_connect( Address(server.host, server.port), echo_server ) reply = await read_socks_reply(tcp_reader) assert reply[1] == 0x00 _, udp_writer, udp_bind = await open_udp_associate( Address(server.host, server.port) ) tcp_writer.write(b"tcp-data") await tcp_writer.drain() echo_addr, _ = udp_echo_server loop = asyncio.get_running_loop() udp_received = loop.create_future() class ClientProto(asyncio.DatagramProtocol): def datagram_received(self, data, addr): if not udp_received.done(): udp_received.set_result(data) transport, _ = await loop.create_datagram_endpoint( ClientProto, local_addr=("127.0.0.1", 0), ) try: transport.sendto( build_udp_header(echo_addr) + b"udp-data", (udp_bind.host, udp_bind.port), ) tcp_data = await asyncio.wait_for(tcp_reader.read(4096), timeout=2.0) assert tcp_data == b"tcp-data" udp_data = await asyncio.wait_for(udp_received, timeout=2.0) _, _, payload = parse_udp_header(udp_data) assert payload == b"udp-data" finally: transport.close() tcp_writer.close() await tcp_writer.wait_closed() udp_writer.close() await udp_writer.wait_closed() finally: await _stop_server(server, task) class TestBinaryDataRoundtrip: async def test_null_bytes_and_binary(self, echo_server): server, task = await _start_server() try: conn = await connect(Address(server.host, server.port), echo_server) payload = b"\x00\x01\x02\xff\xfe\xfd" + bytes(range(256)) conn.writer.write(payload) await conn.writer.drain() received = b"" while len(received) < len(payload): chunk = await asyncio.wait_for(conn.reader.read(4096), timeout=3.0) if not chunk: break received += chunk assert received == payload conn.writer.close() await conn.writer.wait_closed() finally: await _stop_server(server, task) async def test_binary_through_chain(self, echo_server): exit_server, exit_task = await _start_server() chain = ChainRouter(next_hop=f"127.0.0.1:{exit_server.port}") entry_server, entry_task = await _start_server(addons=[chain]) try: conn = await connect( Address(entry_server.host, entry_server.port), echo_server ) payload = bytes(range(256)) * 4 conn.writer.write(payload) await conn.writer.drain() received = b"" while len(received) < len(payload): chunk = await asyncio.wait_for(conn.reader.read(4096), timeout=3.0) if not chunk: break received += chunk assert received == payload conn.writer.close() await conn.writer.wait_closed() finally: await _stop_server(entry_server, entry_task) await _stop_server(exit_server, exit_task) ================================================ FILE: tests/test_e2e_lifecycle.py ================================================ import asyncio from asyncio_socks_server import connect from asyncio_socks_server.core.types import Address from tests.conftest import _start_server, _stop_server class TestClientDisconnect: async def test_abrupt_client_disconnect_no_crash(self, echo_server): server, task = await _start_server() try: conn = await connect(Address(server.host, server.port), echo_server) conn.writer.write(b"before-disconnect") await conn.writer.drain() assert await conn.reader.read(4096) == b"before-disconnect" conn.writer.close() await conn.writer.wait_closed() conn2 = await connect(Address(server.host, server.port), echo_server) conn2.writer.write(b"after-disconnect") await conn2.writer.drain() assert await conn2.reader.read(4096) == b"after-disconnect" conn2.writer.close() await conn2.writer.wait_closed() finally: await _stop_server(server, task) async def test_target_disconnect_mid_relay(self): async def disconnect_after_first(reader, writer): try: data = await reader.read(4096) writer.write(data) await writer.drain() writer.close() await writer.wait_closed() except Exception: pass srv = await asyncio.start_server(disconnect_after_first, "127.0.0.1", 0) addr = srv.sockets[0].getsockname() target = Address(addr[0], addr[1]) server, task = await _start_server() try: conn = await connect(Address(server.host, server.port), target) conn.writer.write(b"first") await conn.writer.drain() assert await conn.reader.read(4096) == b"first" assert await conn.reader.read(4096) == b"" conn.writer.close() await conn.writer.wait_closed() finally: await _stop_server(server, task) srv.close() await srv.wait_closed() class TestGracefulShutdown: async def test_active_connections_complete_on_shutdown(self): async def slow_echo(reader, writer): try: data = await reader.read(4096) await asyncio.sleep(0.2) writer.write(data) await writer.drain() finally: writer.close() await writer.wait_closed() srv = await asyncio.start_server(slow_echo, "127.0.0.1", 0) addr = srv.sockets[0].getsockname() target = Address(addr[0], addr[1]) server, task = await _start_server() try: conn = await connect(Address(server.host, server.port), target) conn.writer.write(b"slow") await conn.writer.drain() server.request_shutdown() data = await asyncio.wait_for(conn.reader.read(4096), timeout=3.0) assert data == b"slow" conn.writer.close() await conn.writer.wait_closed() finally: await task srv.close() await srv.wait_closed() class TestRepeatedConnections: async def test_50_sequential_connections(self, echo_server): server, task = await _start_server() try: for i in range(50): conn = await connect(Address(server.host, server.port), echo_server) msg = f"msg-{i:03d}".encode() conn.writer.write(msg) await conn.writer.drain() assert await conn.reader.read(4096) == msg conn.writer.close() await conn.writer.wait_closed() finally: await _stop_server(server, task) async def test_connection_reuse_stability(self, echo_server): server, task = await _start_server() try: for round_num in range(3): conns = [] for i in range(10): conn = await connect(Address(server.host, server.port), echo_server) msg = f"r{round_num}-{i}".encode() conn.writer.write(msg) await conn.writer.drain() conns.append((conn, msg)) for conn, msg in conns: assert await conn.reader.read(4096) == msg conn.writer.close() await conn.writer.wait_closed() await asyncio.sleep(0.1) finally: await _stop_server(server, task) ================================================ FILE: tests/test_e2e_policy_errors.py ================================================ import asyncio from asyncio_socks_server import ChainRouter, IPFilter, connect from asyncio_socks_server.core.address import encode_address from asyncio_socks_server.core.types import Address from tests.conftest import _start_server, _stop_server from tests.e2e_helpers import read_socks_reply, socks5_connect class TestIPFilterE2E: async def test_allowed_ip_connects(self, echo_server): filter_addon = IPFilter(allowed=["127.0.0.0/8"]) server, task = await _start_server(addons=[filter_addon]) try: conn = await connect(Address(server.host, server.port), echo_server) conn.writer.write(b"allowed") await conn.writer.drain() assert await conn.reader.read(4096) == b"allowed" conn.writer.close() await conn.writer.wait_closed() finally: await _stop_server(server, task) async def test_blocked_ip_rejected(self, echo_server): filter_addon = IPFilter(blocked=["127.0.0.0/8"]) server, task = await _start_server(addons=[filter_addon]) try: reader, writer = await socks5_connect( Address(server.host, server.port), echo_server ) reply = await read_socks_reply(reader) assert reply[1] == 0x02 writer.close() await writer.wait_closed() finally: await _stop_server(server, task) class TestConnectionRefusedE2E: async def test_target_refused_returns_error_reply(self): server, task = await _start_server() try: reader, writer = await socks5_connect( Address(server.host, server.port), Address("127.0.0.1", 1), ) reply = await read_socks_reply(reader) assert reply[1] != 0x00 writer.close() await writer.wait_closed() finally: await _stop_server(server, task) async def test_unreachable_target_through_chain(self): exit_server, exit_task = await _start_server() chain = ChainRouter(next_hop=f"127.0.0.1:{exit_server.port}") entry_server, entry_task = await _start_server(addons=[chain]) try: reader, writer = await socks5_connect( Address(entry_server.host, entry_server.port), Address("127.0.0.1", 1), ) reply = await read_socks_reply(reader) assert reply[1] == 0x02 writer.close() await writer.wait_closed() finally: await _stop_server(entry_server, entry_task) await _stop_server(exit_server, exit_task) class TestDomainNameTarget: async def test_domain_target_resolved(self, echo_server): server, task = await _start_server() try: reader, writer = await asyncio.open_connection(server.host, server.port) writer.write(b"\x05\x01\x00") await writer.drain() assert await reader.readexactly(2) == b"\x05\x00" writer.write( b"\x05\x01\x00" + encode_address("127.0.0.1", echo_server.port) ) await writer.drain() reply = await read_socks_reply(reader) assert reply[1] == 0x00 writer.write(b"domain-test") await writer.drain() assert await reader.read(4096) == b"domain-test" writer.close() await writer.wait_closed() finally: await _stop_server(server, task) ================================================ FILE: tests/test_flow.py ================================================ """Flow context tests: on_flow_close, bytes accuracy, dataclass, UdpRelay injection.""" import asyncio import time from asyncio_socks_server.addons.base import Addon from asyncio_socks_server.addons.manager import AddonManager from asyncio_socks_server.core.protocol import build_udp_header from asyncio_socks_server.core.types import Address, Direction, Flow from asyncio_socks_server.server.tcp_relay import _copy, handle_tcp_relay from asyncio_socks_server.server.udp_relay import UdpRelay def _make_flow(**kwargs): defaults = dict( id=1, src=Address("127.0.0.1", 1000), dst=Address("127.0.0.1", 2000), protocol="tcp", started_at=0.0, ) defaults.update(kwargs) return Flow(**defaults) # --- Flow dataclass tests --- class TestFlowDataclass: def test_construction_with_defaults(self): flow = Flow( id=42, src=Address("10.0.0.1", 1234), dst=Address("10.0.0.2", 5678), protocol="tcp", started_at=100.0, ) assert flow.id == 42 assert flow.src.host == "10.0.0.1" assert flow.bytes_up == 0 assert flow.bytes_down == 0 def test_mutable_bytes(self): flow = _make_flow() flow.bytes_up += 100 flow.bytes_down += 200 assert flow.bytes_up == 100 assert flow.bytes_down == 200 def test_protocol_literal(self): flow_tcp = _make_flow(protocol="tcp") flow_udp = _make_flow(protocol="udp") assert flow_tcp.protocol == "tcp" assert flow_udp.protocol == "udp" def test_started_at_monotonic(self): before = time.monotonic() flow = Flow( id=1, src=Address("::", 0), dst=Address("::", 0), protocol="tcp", started_at=time.monotonic(), ) after = time.monotonic() assert before <= flow.started_at <= after # --- on_flow_close hook tests --- class CloseCapture(Addon): def __init__(self): self.closed_flows: list[Flow] = [] async def on_flow_close(self, flow): self.closed_flows.append(flow) class CloseCrasher(Addon): async def on_flow_close(self, flow): raise RuntimeError("crash in on_flow_close") class TestOnFlowClose: async def test_called_for_all_addons(self): a1 = CloseCapture() a2 = CloseCapture() mgr = AddonManager([a1, a2]) flow = _make_flow() await mgr.dispatch_flow_close(flow) assert len(a1.closed_flows) == 1 assert len(a2.closed_flows) == 1 assert a1.closed_flows[0] is flow async def test_exception_does_not_propagate(self): a1 = CloseCrasher() a2 = CloseCapture() mgr = AddonManager([a1, a2]) flow = _make_flow() await mgr.dispatch_flow_close(flow) assert len(a2.closed_flows) == 1 async def test_receives_final_flow_snapshot(self): capture = CloseCapture() mgr = AddonManager([capture]) flow = _make_flow() flow.bytes_up = 1024 flow.bytes_down = 2048 await mgr.dispatch_flow_close(flow) assert capture.closed_flows[0].bytes_up == 1024 assert capture.closed_flows[0].bytes_down == 2048 async def test_base_addon_skipped(self): mgr = AddonManager([Addon()]) await mgr.dispatch_flow_close(_make_flow()) async def test_no_addons(self): mgr = AddonManager([]) await mgr.dispatch_flow_close(_make_flow()) # --- flow.bytes accuracy for TCP path --- class TestTcpFlowBytes: async def _pipe(self): """Create a pipe: write to [0] → read from [3]. Keep all refs.""" import socket sock_a, sock_b = socket.socketpair() sock_a.setblocking(False) sock_b.setblocking(False) reader_a, writer_a = await asyncio.open_connection(sock=sock_a) reader_b, writer_b = await asyncio.open_connection(sock=sock_b) return writer_a, writer_b, reader_a, reader_b async def test_copy_updates_bytes_up(self): in_wa, _in_wb, _in_ra, in_rb = await self._pipe() out_wa, _out_wb, _out_ra, out_rb = await self._pipe() flow = _make_flow() mgr = AddonManager() task = asyncio.create_task(_copy(in_rb, out_wa, mgr, Direction.UPSTREAM, flow)) in_wa.write(b"hello") await in_wa.drain() data = await asyncio.wait_for(out_rb.read(4096), timeout=1.0) assert data == b"hello" in_wa.close() await in_wa.wait_closed() await asyncio.wait_for(task, timeout=1.0) assert flow.bytes_up == 5 assert flow.bytes_down == 0 async def test_copy_updates_bytes_down(self): in_wa, _in_wb, _in_ra, in_rb = await self._pipe() out_wa, _out_wb, _out_ra, out_rb = await self._pipe() flow = _make_flow() mgr = AddonManager() task = asyncio.create_task( _copy(in_rb, out_wa, mgr, Direction.DOWNSTREAM, flow) ) in_wa.write(b"world") await in_wa.drain() data = await asyncio.wait_for(out_rb.read(4096), timeout=1.0) assert data == b"world" in_wa.close() await in_wa.wait_closed() await asyncio.wait_for(task, timeout=1.0) assert flow.bytes_up == 0 assert flow.bytes_down == 5 async def test_bidirectional_relay_bytes(self): import socket # Client pipe c_a, c_b = socket.socketpair() c_a.setblocking(False) c_b.setblocking(False) cr_app, cw_app = await asyncio.open_connection(sock=c_a) cr_relay, cw_relay = await asyncio.open_connection(sock=c_b) # Remote pipe r_a, r_b = socket.socketpair() r_a.setblocking(False) r_b.setblocking(False) rr_app, rw_app = await asyncio.open_connection(sock=r_a) rr_relay, rw_relay = await asyncio.open_connection(sock=r_b) flow = _make_flow() mgr = AddonManager() relay_task = asyncio.create_task( handle_tcp_relay(cr_relay, cw_relay, rr_relay, rw_relay, mgr, flow) ) cw_app.write(b"abc") await cw_app.drain() data = await asyncio.wait_for(rr_app.read(4096), timeout=1.0) assert data == b"abc" rw_app.write(b"xyz") await rw_app.drain() data = await asyncio.wait_for(cr_app.read(4096), timeout=1.0) assert data == b"xyz" cw_app.close() await cw_app.wait_closed() rw_app.close() await rw_app.wait_closed() await asyncio.wait_for(relay_task, timeout=2.0) assert flow.bytes_up == 3 assert flow.bytes_down == 3 # --- UdpRelay constructor injection tests --- class TestUdpRelayFlowInjection: async def test_constructor_stores_flow(self): flow = _make_flow(protocol="udp") relay = UdpRelay(client_addr=Address("127.0.0.1", 12345), flow=flow) assert relay._flow is flow async def test_udp_bytes_single_write(self, udp_echo_server): echo_addr, _ = udp_echo_server flow = _make_flow(protocol="udp") relay = UdpRelay(client_addr=Address("127.0.0.1", 12345), flow=flow) try: await relay.start() datagram = build_udp_header(echo_addr) + b"hello" relay.handle_client_datagram(datagram, ("127.0.0.1", 12345)) await asyncio.sleep(0.1) assert flow.bytes_up == 5 finally: await relay.stop() ================================================ FILE: tests/test_ipv6.py ================================================ """Tests for IPv6 dual-stack support.""" import asyncio import ipaddress import socket import struct import pytest from asyncio_socks_server.core.address import encode_address from asyncio_socks_server.core.protocol import build_udp_header from asyncio_socks_server.core.types import Address from asyncio_socks_server.server.server import Server async def _start_server_ipv6(**kwargs): server = Server(host="::", port=0, **kwargs) task = asyncio.create_task(server._run()) for _ in range(50): if server.port != 0: break await asyncio.sleep(0.01) return server, task async def _stop_server(server, task): server.request_shutdown() await task async def _skip_bind_address(reader): atyp = (await reader.readexactly(1))[0] if atyp == 0x01: await reader.readexactly(4 + 2) elif atyp == 0x04: await reader.readexactly(16 + 2) elif atyp == 0x03: length = (await reader.readexactly(1))[0] await reader.readexactly(length + 2) async def _socks5_connect_ipv6(proxy_addr: Address, target: Address): reader, writer = await asyncio.open_connection(proxy_addr.host, proxy_addr.port) writer.write(b"\x05\x01\x00") await writer.drain() resp = await reader.readexactly(2) assert resp[0] == 0x05 and resp[1] == 0x00 writer.write(b"\x05\x01\x00" + encode_address(target.host, target.port)) await writer.drain() reply = await reader.readexactly(3) assert reply[1] == 0x00 await _skip_bind_address(reader) return reader, writer class TestIPv6TCP: @pytest.fixture async def ipv6_echo_server(self): async def handler(reader, writer): try: while True: data = await reader.read(4096) if not data: break writer.write(data) await writer.drain() finally: writer.close() await writer.wait_closed() srv = await asyncio.start_server(handler, "::1", 0) addr = srv.sockets[0].getsockname() yield Address(addr[0], addr[1]) srv.close() await srv.wait_closed() async def test_tcp_connect_ipv6_loopback(self, ipv6_echo_server): server, task = await _start_server_ipv6() try: tcp_r, tcp_w = await _socks5_connect_ipv6( Address("::1", server.port), ipv6_echo_server ) tcp_w.write(b"hello ipv6") await tcp_w.drain() data = await tcp_r.read(1024) assert data == b"hello ipv6" tcp_w.close() await tcp_w.wait_closed() finally: await _stop_server(server, task) async def test_tcp_ipv4_on_dualstack(self): async def echo_handler(reader, writer): try: while True: data = await reader.read(4096) if not data: break writer.write(data) await writer.drain() finally: writer.close() await writer.wait_closed() echo_srv = await asyncio.start_server(echo_handler, "127.0.0.1", 0) echo_addr = echo_srv.sockets[0].getsockname() echo_target = Address(echo_addr[0], echo_addr[1]) server, task = await _start_server_ipv6() try: r, w = await asyncio.open_connection("127.0.0.1", server.port) w.write(b"\x05\x01\x00") await w.drain() resp = await r.readexactly(2) assert resp[1] == 0x00 target_bytes = encode_address(echo_target.host, echo_target.port) w.write(b"\x05\x01\x00" + target_bytes) await w.drain() reply = await r.readexactly(3) assert reply[1] == 0x00 await _skip_bind_address(r) w.write(b"dualstack works") await w.drain() data = await r.read(1024) assert data == b"dualstack works" w.close() await w.wait_closed() finally: await _stop_server(server, task) echo_srv.close() await echo_srv.wait_closed() class TestIPv6UDP: @pytest.fixture async def ipv6_udp_echo_server(self): received = [] class Protocol(asyncio.DatagramProtocol): def connection_made(self, transport): self.transport = transport def datagram_received(self, data, addr): received.append((data, addr)) self.transport.sendto(data, addr) loop = asyncio.get_running_loop() s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) s.bind(("::1", 0)) s.setblocking(False) transport, _ = await loop.create_datagram_endpoint(Protocol, sock=s) sockname = s.getsockname() yield Address(sockname[0], sockname[1]), received transport.close() async def test_udp_associate_ipv6(self, ipv6_udp_echo_server): echo_addr, _ = ipv6_udp_echo_server server, task = await _start_server_ipv6() try: # SOCKS5 handshake via IPv6 reader, writer = await asyncio.open_connection("::1", server.port) writer.write(b"\x05\x01\x00") await writer.drain() resp = await reader.readexactly(2) assert resp[0] == 0x05 and resp[1] == 0x00 # UDP ASSOCIATE writer.write(b"\x05\x03\x00\x01\x00\x00\x00\x00\x00\x00") await writer.drain() reply = await reader.readexactly(3) assert reply[0] == 0x05 assert reply[1] == 0x00 atyp = (await reader.readexactly(1))[0] if atyp == 0x04: host_bytes = await reader.readexactly(16) host = str(ipaddress.IPv6Address(host_bytes)) elif atyp == 0x01: host_bytes = await reader.readexactly(4) host = ipaddress.IPv4Address(host_bytes).compressed else: length = (await reader.readexactly(1))[0] host = (await reader.readexactly(length)).decode("ascii") port_bytes = await reader.readexactly(2) port = struct.unpack("!H", port_bytes)[0] udp_bind = Address(host, port) # Client UDP socket on IPv6 loop = asyncio.get_running_loop() received = loop.create_future() class ClientProtocol(asyncio.DatagramProtocol): def datagram_received(self, data, addr): if not received.done(): received.set_result(data) client_sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) client_sock.bind(("::1", 0)) client_sock.setblocking(False) transport, _ = await loop.create_datagram_endpoint( ClientProtocol, sock=client_sock ) datagram = build_udp_header(echo_addr) + b"hello ipv6 udp" transport.sendto(datagram, (udp_bind.host, udp_bind.port)) try: resp_data = await asyncio.wait_for(received, timeout=2.0) from asyncio_socks_server.core.protocol import parse_udp_header _, _, payload = parse_udp_header(resp_data) assert payload == b"hello ipv6 udp" finally: transport.close() writer.close() await writer.wait_closed() finally: await _stop_server(server, task) ================================================ FILE: tests/test_logging.py ================================================ """Tests for core logging module.""" import logging from asyncio_socks_server.core.logging import ( fmt_addr, fmt_bytes, fmt_connection, get_logger, setup_logging, ) from asyncio_socks_server.core.types import Address class TestSetupLogging: def test_sets_level_debug(self): setup_logging("DEBUG") logger = get_logger() assert logger.parent.level == logging.DEBUG def test_sets_level_info(self): setup_logging("INFO") logger = get_logger() assert logger.parent.level == logging.INFO class TestGetLogger: def test_returns_named_logger(self): logger = get_logger() assert logger.name == "asyncio_socks_server" class TestFmtAddr: def test_ipv4(self): assert fmt_addr(Address("127.0.0.1", 1080)) == "127.0.0.1:1080" def test_ipv6(self): assert fmt_addr(Address("::1", 443)) == "::1:443" def test_domain(self): assert fmt_addr(Address("example.com", 80)) == "example.com:80" class TestFmtConnection: def test_format(self): src = Address("10.0.0.1", 54321) dst = Address("93.184.216.34", 443) result = fmt_connection(src, dst) assert result == "10.0.0.1:54321 → 93.184.216.34:443" class TestFmtBytes: def test_zero(self): assert fmt_bytes(0) == "0B" def test_bytes(self): assert fmt_bytes(512) == "512B" def test_boundary_1023(self): assert fmt_bytes(1023) == "1023B" def test_boundary_1024(self): assert fmt_bytes(1024) == "1.0KB" def test_kilobytes(self): assert fmt_bytes(2048) == "2.0KB" def test_just_under_mb(self): assert fmt_bytes(1024 * 1024 - 1) == "1024.0KB" def test_exact_mb(self): assert fmt_bytes(1024 * 1024) == "1.0MB" def test_megabytes(self): assert fmt_bytes(5 * 1024 * 1024) == "5.0MB" ================================================ FILE: tests/test_protocol_robustness.py ================================================ """Protocol parser edge cases and boundary conditions.""" import asyncio import pytest from asyncio_socks_server.core.protocol import ( ProtocolError, parse_method_selection, parse_request, parse_udp_header, parse_username_password, ) class TestMethodSelectionEdgeCases: def test_empty_data(self): with pytest.raises(ProtocolError, match="too short"): parse_method_selection(b"") def test_single_byte(self): with pytest.raises(ProtocolError, match="too short"): parse_method_selection(b"\x05") def test_wrong_version(self): with pytest.raises(ProtocolError, match="unsupported SOCKS version"): parse_method_selection(b"\x04\x01\x00") def test_nmethods_zero(self): ver, methods = parse_method_selection(b"\x05\x00") assert ver == 0x05 assert methods == set() def test_extra_bytes_beyond_methods(self): # NMETHODS=1 but data has 3 bytes — extra ignored ver, methods = parse_method_selection(b"\x05\x01\x00\xff\xfe") assert ver == 0x05 assert methods == {0x00} def test_all_methods(self): data = b"\x05\xff" + bytes(range(255)) ver, methods = parse_method_selection(data) assert len(methods) == 255 class TestUsernamePasswordEdgeCases: async def test_empty_username(self): reader = asyncio.StreamReader() reader.feed_data(b"\x01\x00\x04test") reader.feed_eof() username, password = await parse_username_password(reader) assert username == "" assert password == "test" async def test_empty_password(self): reader = asyncio.StreamReader() reader.feed_data(b"\x01\x04test\x00") reader.feed_eof() username, password = await parse_username_password(reader) assert username == "test" assert password == "" async def test_max_length_username(self): reader = asyncio.StreamReader() uname = b"a" * 255 reader.feed_data(b"\x01\xff" + uname + b"\x01x") reader.feed_eof() username, password = await parse_username_password(reader) assert len(username) == 255 assert password == "x" async def test_wrong_auth_version(self): reader = asyncio.StreamReader() reader.feed_data(b"\x02\x04test\x04test") reader.feed_eof() with pytest.raises(ProtocolError, match="unsupported auth version"): await parse_username_password(reader) async def test_truncated_password(self): reader = asyncio.StreamReader() # Claims PLEN=10 but only provides 4 bytes then EOF reader.feed_data(b"\x01\x04test\x0ashor") reader.feed_eof() with pytest.raises(asyncio.IncompleteReadError): await parse_username_password(reader) class TestParseRequestEdgeCases: async def test_unsupported_atyp(self): reader = asyncio.StreamReader() # VER=5, CMD=CONNECT(1), RSV=0, ATYP=0x05 (invalid) reader.feed_data(b"\x05\x01\x00\x05") reader.feed_eof() with pytest.raises(ProtocolError, match="unsupported ATYP"): await parse_request(reader) async def test_unsupported_command(self): reader = asyncio.StreamReader() # VER=5, CMD=0x02 (BIND, not supported), RSV=0, ATYP=1 reader.feed_data(b"\x05\x02\x00\x01") reader.feed_eof() with pytest.raises(ProtocolError, match="unsupported command"): await parse_request(reader) async def test_wrong_version_in_request(self): reader = asyncio.StreamReader() reader.feed_data(b"\x04\x01\x00\x01") reader.feed_eof() with pytest.raises(ProtocolError, match="unsupported SOCKS version"): await parse_request(reader) async def test_ipv4_truncated(self): reader = asyncio.StreamReader() # ATYP=0x01 (IPv4, needs 4 bytes) but only 2 bytes reader.feed_data(b"\x05\x01\x00\x01\x7f\x00") reader.feed_eof() with pytest.raises(asyncio.IncompleteReadError): await parse_request(reader) async def test_domain_empty_label(self): reader = asyncio.StreamReader() # ATYP=0x03, length=0, then 2 port bytes reader.feed_data(b"\x05\x01\x00\x03\x00\x00\x50") cmd, addr = await parse_request(reader) assert addr.host == "" async def test_domain_max_length(self): reader = asyncio.StreamReader() domain = b"a" * 255 reader.feed_data(b"\x05\x01\x00\x03" + bytes([255]) + domain + b"\x00\x50") cmd, addr = await parse_request(reader) assert len(addr.host) == 255 async def test_domain_truncated(self): reader = asyncio.StreamReader() # Claims domain length 20 but only 5 bytes reader.feed_data(b"\x05\x01\x00\x03\x14hello") reader.feed_eof() with pytest.raises(asyncio.IncompleteReadError): await parse_request(reader) async def test_port_truncated(self): reader = asyncio.StreamReader() # IPv4 address present, but only 1 port byte reader.feed_data(b"\x05\x01\x00\x01\x7f\x00\x00\x01\x00") reader.feed_eof() with pytest.raises(asyncio.IncompleteReadError): await parse_request(reader) class TestUdpHeaderEdgeCases: def test_too_short(self): with pytest.raises(ProtocolError, match="too short"): parse_udp_header(b"\x00\x00") def test_ipv4_truncated(self): # ATYP=0x01 but only 2 of 4 IPv4 bytes with pytest.raises(ProtocolError, match="truncated"): parse_udp_header(b"\x00\x00\x00\x01\x7f\x00") def test_ipv6_truncated(self): # ATYP=0x04 but only 10 of 16 IPv6 bytes with pytest.raises(ProtocolError, match="truncated"): parse_udp_header(b"\x00\x00\x00\x04" + b"\x00" * 10) def test_domain_truncated(self): # ATYP=0x03, length=20 but only 5 domain bytes with pytest.raises(ProtocolError, match="truncated"): parse_udp_header(b"\x00\x00\x00\x03\x14hello") def test_unsupported_atyp(self): with pytest.raises(ProtocolError, match="unsupported ATYP"): parse_udp_header(b"\x00\x00\x00\x02" + b"\x00" * 10) def test_header_only_no_payload(self): # Valid IPv4 header with zero payload addr, hdr_len, payload = parse_udp_header( b"\x00\x00\x00\x01\x7f\x00\x00\x01\x00\x50" ) assert payload == b"" assert hdr_len == 10 def test_ipv6_full_roundtrip(self): import ipaddress ipv6 = ipaddress.IPv6Address("::1").compressed header = b"\x00\x00\x00\x04" + ipaddress.IPv6Address("::1").packed + b"\x01\xbb" header += b"payload" addr, hdr_len, payload = parse_udp_header(header) assert addr.host == ipv6 assert addr.port == 443 assert payload == b"payload" assert hdr_len == 22 def test_domain_roundtrip(self): header = b"\x00\x00\x00\x03\x0bexample.com\x00\x50payload" addr, hdr_len, payload = parse_udp_header(header) assert addr.host == "example.com" assert addr.port == 80 assert payload == b"payload" assert hdr_len == 18 ================================================ FILE: tests/test_server.py ================================================ import asyncio import pytest from asyncio_socks_server.addons.base import Addon from asyncio_socks_server.core.types import Address, Direction from asyncio_socks_server.server.server import Server async def _start_server( host: str = "127.0.0.1", auth: tuple[str, str] | None = None, addons: list[Addon] | None = None, ) -> tuple[Server, asyncio.Task]: server = Server(host=host, port=0, auth=auth, addons=addons) task = asyncio.create_task(server._run()) # Wait for the server to be ready for _ in range(50): if server.port != 0: break await asyncio.sleep(0.01) return server, task async def _stop_server(server: Server, task: asyncio.Task): server.request_shutdown() await task @pytest.fixture async def echo_server(): async def handler(reader, writer): try: while True: data = await reader.read(4096) if not data: break writer.write(data) await writer.drain() finally: writer.close() await writer.wait_closed() srv = await asyncio.start_server(handler, "127.0.0.1", 0) addr = srv.sockets[0].getsockname() yield Address(addr[0], addr[1]) srv.close() await srv.wait_closed() async def _socks5_connect( proxy_addr: Address, target_addr: Address, auth: tuple[str, str] | None = None ): reader, writer = await asyncio.open_connection(proxy_addr.host, proxy_addr.port) if auth: writer.write(b"\x05\x01\x02") else: writer.write(b"\x05\x01\x00") await writer.drain() resp = await reader.readexactly(2) assert resp[0] == 0x05 if auth: assert resp[1] == 0x02 uname = auth[0].encode() passwd = auth[1].encode() writer.write( b"\x01" + len(uname).to_bytes(1, "big") + uname + len(passwd).to_bytes(1, "big") + passwd ) await writer.drain() auth_resp = await reader.readexactly(2) assert auth_resp == b"\x01\x00" else: assert resp[1] == 0x00 from asyncio_socks_server.core.address import encode_address writer.write(b"\x05\x01\x00" + encode_address(target_addr.host, target_addr.port)) await writer.drain() reply = await reader.readexactly(3) assert reply[1] == 0x00 atyp = (await reader.readexactly(1))[0] if atyp == 0x01: await reader.readexactly(4 + 2) elif atyp == 0x04: await reader.readexactly(16 + 2) elif atyp == 0x03: length = (await reader.readexactly(1))[0] await reader.readexactly(length + 2) return reader, writer class TestServerConnect: async def test_no_auth_connect(self, echo_server): server, task = await _start_server() try: reader, writer = await _socks5_connect( Address(server.host, server.port), echo_server ) writer.write(b"hello") await writer.drain() data = await reader.read(4096) assert data == b"hello" writer.close() await writer.wait_closed() finally: await _stop_server(server, task) async def test_no_auth_rejected_when_auth_required(self, echo_server): server, task = await _start_server(auth=("user", "pass")) try: reader, writer = await asyncio.open_connection(server.host, server.port) writer.write(b"\x05\x01\x00") await writer.drain() resp = await reader.readexactly(2) assert resp[1] == 0xFF writer.close() await writer.wait_closed() finally: await _stop_server(server, task) async def test_auth_success(self, echo_server): server, task = await _start_server(auth=("user", "pass")) try: reader, writer = await _socks5_connect( Address(server.host, server.port), echo_server, auth=("user", "pass"), ) writer.write(b"secret") await writer.drain() data = await reader.read(4096) assert data == b"secret" writer.close() await writer.wait_closed() finally: await _stop_server(server, task) async def test_auth_failure(self, echo_server): server, task = await _start_server(auth=("user", "pass")) try: reader, writer = await asyncio.open_connection(server.host, server.port) writer.write(b"\x05\x01\x02") await writer.drain() resp = await reader.readexactly(2) assert resp[1] == 0x02 writer.write(b"\x01\x04user\x04xxxx") await writer.drain() auth_resp = await reader.readexactly(2) assert auth_resp == b"\x01\x01" writer.close() await writer.wait_closed() finally: await _stop_server(server, task) class DataCounter(Addon): def __init__(self): self.bytes_up = 0 self.bytes_down = 0 async def on_data(self, direction, data, flow): if direction == Direction.UPSTREAM: self.bytes_up += len(data) else: self.bytes_down += len(data) return data class TestServerWithAddon: async def test_data_counting(self, echo_server): addon = DataCounter() server, task = await _start_server(addons=[addon]) try: reader, writer = await _socks5_connect( Address(server.host, server.port), echo_server ) writer.write(b"hello world") await writer.drain() data = await reader.read(4096) assert data == b"hello world" await asyncio.sleep(0.1) assert addon.bytes_up == 11 assert addon.bytes_down == 11 writer.close() await writer.wait_closed() finally: await _stop_server(server, task) ================================================ FILE: tests/test_server_errors.py ================================================ """Server error handling tests: malformed input, disconnects, error mapping.""" import asyncio from asyncio_socks_server import FlowStats from asyncio_socks_server.core.types import Address, Rep from asyncio_socks_server.server.server import Server from tests.conftest import _start_server, _stop_server async def _raw_connect(proxy): """Open a raw TCP connection to the proxy.""" return await asyncio.open_connection(proxy.host, proxy.port) async def _read_reply(reader): """Read a SOCKS5 CONNECT reply and return (rep_code, bound_addr).""" ver, rep, rsv = await reader.readexactly(3) atyp = (await reader.readexactly(1))[0] if atyp == 0x01: await reader.readexactly(4 + 2) elif atyp == 0x04: await reader.readexactly(16 + 2) elif atyp == 0x03: length = (await reader.readexactly(1))[0] await reader.readexactly(length + 2) return rep class TestHandshakeErrors: async def test_truncated_method_selection(self): server, task = await _start_server() try: reader, writer = await _raw_connect(Address(server.host, server.port)) writer.write(b"\x05") await writer.drain() # Server expects 2 bytes minimum; send 1 then close writer.close() await writer.wait_closed() # Server should handle this without crashing await asyncio.sleep(0.1) finally: await _stop_server(server, task) async def test_wrong_socks_version(self): server, task = await _start_server() try: reader, writer = await _raw_connect(Address(server.host, server.port)) writer.write(b"\x04\x01\x00") await writer.drain() # Server should close or reject — just verify no crash writer.close() await writer.wait_closed() finally: await _stop_server(server, task) async def test_disconnect_after_method_reply(self): server, task = await _start_server() try: reader, writer = await _raw_connect(Address(server.host, server.port)) writer.write(b"\x05\x01\x00") await writer.drain() resp = await reader.readexactly(2) assert resp == b"\x05\x00" # NO_AUTH selected # Now disconnect without sending a request writer.close() await writer.wait_closed() await asyncio.sleep(0.1) finally: await _stop_server(server, task) async def test_disconnect_during_auth(self): server, task = await _start_server(auth=("user", "pass")) try: reader, writer = await _raw_connect(Address(server.host, server.port)) writer.write(b"\x05\x01\x02") await writer.drain() resp = await reader.readexactly(2) assert resp[1] == 0x02 # USERNAME_PASSWORD selected # Disconnect without sending credentials writer.close() await writer.wait_closed() await asyncio.sleep(0.1) finally: await _stop_server(server, task) async def test_nmethods_zero(self): server, task = await _start_server() try: reader, writer = await _raw_connect(Address(server.host, server.port)) writer.write(b"\x05\x00") await writer.drain() resp = await reader.readexactly(2) assert resp[1] == 0xFF # NO_ACCEPTABLE writer.close() await writer.wait_closed() finally: await _stop_server(server, task) class TestRequestErrors: async def test_connect_to_refused_port(self): server, task = await _start_server() try: reader, writer = await _raw_connect(Address(server.host, server.port)) # Method selection writer.write(b"\x05\x01\x00") await writer.drain() resp = await reader.readexactly(2) assert resp[1] == 0x00 # CONNECT to 127.0.0.1:1 (should refuse) writer.write(b"\x05\x01\x00\x01\x7f\x00\x00\x01\x00\x01") await writer.drain() rep = await _read_reply(reader) assert rep == Rep.CONNECTION_REFUSED writer.close() await writer.wait_closed() finally: await _stop_server(server, task) async def test_failed_connect_closes_observed_flow(self): stats = FlowStats() server, task = await _start_server(addons=[stats]) try: reader, writer = await _raw_connect(Address(server.host, server.port)) writer.write(b"\x05\x01\x00") await writer.drain() resp = await reader.readexactly(2) assert resp[1] == 0x00 writer.write(b"\x05\x01\x00\x01\x7f\x00\x00\x01\x00\x01") await writer.drain() rep = await _read_reply(reader) assert rep == Rep.CONNECTION_REFUSED await asyncio.sleep(0.05) snapshot = stats.snapshot() assert snapshot["active_flows"] == 0 assert snapshot["total_flows"] == 1 assert snapshot["total_closed_flows"] == 1 assert snapshot["closed_flows"] == 1 writer.close() await writer.wait_closed() finally: await _stop_server(server, task) class TestConnectionDrop: async def test_drop_during_relay(self, echo_server): server, task = await _start_server() try: reader, writer = await _raw_connect(Address(server.host, server.port)) writer.write(b"\x05\x01\x00") await writer.drain() await reader.readexactly(2) from asyncio_socks_server.core.address import encode_address writer.write( b"\x05\x01\x00" + encode_address(echo_server.host, echo_server.port) ) await writer.drain() await _read_reply(reader) # Abruptly close writer.close() await writer.wait_closed() await asyncio.sleep(0.1) finally: await _stop_server(server, task) async def test_multiple_rapid_connect_disconnect(self): server, task = await _start_server() try: for _ in range(10): reader, writer = await _raw_connect(Address(server.host, server.port)) writer.write(b"\x05\x01\x00") await writer.drain() await reader.readexactly(2) writer.close() await writer.wait_closed() # Verify server is still responsive reader, writer = await _raw_connect(Address(server.host, server.port)) writer.write(b"\x05\x01\x00") await writer.drain() resp = await reader.readexactly(2) assert resp == b"\x05\x00" writer.close() await writer.wait_closed() finally: await _stop_server(server, task) class TestErrorToRep: def test_connection_refused(self): assert Server._error_to_rep(ConnectionRefusedError()) == Rep.CONNECTION_REFUSED def test_network_unreachable(self): exc = OSError(101, "Network is unreachable") assert Server._error_to_rep(exc) == Rep.NETWORK_UNREACHABLE def test_generic_oserror(self): exc = OSError(99, "Some error") assert Server._error_to_rep(exc) == Rep.GENERAL_FAILURE def test_generic_exception(self): assert Server._error_to_rep(RuntimeError("oops")) == Rep.GENERAL_FAILURE ================================================ FILE: tests/test_server_lifecycle.py ================================================ """Server lifecycle tests: startup, shutdown, addon lifecycle during shutdown.""" import asyncio from asyncio_socks_server.addons.base import Addon from tests.conftest import _start_server, _stop_server class StopTracker(Addon): def __init__(self): self.started = False self.stopped = False async def on_start(self): self.started = True async def on_stop(self): self.stopped = True class TestServerStartup: async def test_server_binds_to_port(self): server, task = await _start_server() try: assert server.port > 0 finally: await _stop_server(server, task) async def test_server_with_zero_port_gets_ephemeral(self): server, task = await _start_server() try: assert server.port > 0 # Verify we can connect _, writer = await asyncio.open_connection(server.host, server.port) writer.close() await writer.wait_closed() finally: await _stop_server(server, task) class TestServerShutdown: async def test_request_shutdown_stops_server(self): server, task = await _start_server() await _stop_server(server, task) # Task should be done assert task.done() async def test_shutdown_calls_addon_stop(self): tracker = StopTracker() server, task = await _start_server(addons=[tracker]) assert tracker.started await _stop_server(server, task) assert tracker.stopped async def test_shutdown_closes_listening_socket(self): server, task = await _start_server() port = server.port await _stop_server(server, task) # New connections should be refused for _ in range(5): try: _, writer = await asyncio.open_connection(server.host, port) writer.close() await writer.wait_closed() # Connection succeeded — server might still be closing await asyncio.sleep(0.1) except (ConnectionError, OSError): return # Expected # Should have failed by now assert False, "Server should have closed listening socket" ================================================ FILE: tests/test_tcp_relay.py ================================================ """Unit tests for TCP relay: _copy() and handle_tcp_relay().""" import asyncio import socket import pytest from asyncio_socks_server.addons.base import Addon from asyncio_socks_server.addons.manager import AddonManager from asyncio_socks_server.core.types import Address, Direction, Flow from asyncio_socks_server.server.tcp_relay import _copy, handle_tcp_relay def _make_flow(**kwargs): defaults = dict( id=1, src=Address("127.0.0.1", 1000), dst=Address("127.0.0.1", 2000), protocol="tcp", started_at=0.0, ) defaults.update(kwargs) return Flow(**defaults) async def _pipe(): """Create a pipe with two ends. Returns (write_end, read_end): - Write to write_end → data appears on read_end. """ sock_a, sock_b = socket.socketpair() sock_a.setblocking(False) sock_b.setblocking(False) # writer_a writes to sock_a → reader_b reads from sock_b reader_a, writer_a = await asyncio.open_connection(sock=sock_a) reader_b, writer_b = await asyncio.open_connection(sock=sock_b) # writer_a → reader_b is our pipe direction return (writer_a, writer_b, reader_a, reader_b) class UpperAddon(Addon): async def on_data(self, direction, data, flow): return data.upper() class DropAddon(Addon): async def on_data(self, direction, data, flow): return None class TestCopy: async def test_copies_data(self): # Input: app writes → _copy reads in_wa, in_wb, in_ra, in_rb = await _pipe() # in_wa → in_rb # Output: _copy writes → app reads out_wa, out_wb, out_ra, out_rb = await _pipe() # out_wa → out_rb flow = _make_flow() manager = AddonManager() copy_task = asyncio.create_task( _copy(in_rb, out_wa, manager, Direction.UPSTREAM, flow) ) in_wa.write(b"hello") await in_wa.drain() data = await asyncio.wait_for(out_rb.read(4096), timeout=1.0) assert data == b"hello" in_wa.close() await in_wa.wait_closed() await asyncio.wait_for(copy_task, timeout=2.0) async def test_stops_on_eof(self): in_wa, in_wb, in_ra, in_rb = await _pipe() out_wa, out_wb, out_ra, out_rb = await _pipe() flow = _make_flow() manager = AddonManager() in_wa.close() await in_wa.wait_closed() copy_task = asyncio.create_task( _copy(in_rb, out_wa, manager, Direction.UPSTREAM, flow) ) await asyncio.wait_for(copy_task, timeout=1.0) async def test_addon_pipeline_applied(self): in_wa, in_wb, in_ra, in_rb = await _pipe() out_wa, out_wb, out_ra, out_rb = await _pipe() flow = _make_flow() manager = AddonManager([UpperAddon()]) copy_task = asyncio.create_task( _copy(in_rb, out_wa, manager, Direction.UPSTREAM, flow) ) in_wa.write(b"hello") await in_wa.drain() data = await asyncio.wait_for(out_rb.read(4096), timeout=1.0) assert data == b"HELLO" in_wa.close() await in_wa.wait_closed() await asyncio.wait_for(copy_task, timeout=2.0) async def test_addon_returns_none_skips_write(self): in_wa, in_wb, in_ra, in_rb = await _pipe() out_wa, out_wb, out_ra, out_rb = await _pipe() flow = _make_flow() manager = AddonManager([DropAddon()]) copy_task = asyncio.create_task( _copy(in_rb, out_wa, manager, Direction.UPSTREAM, flow) ) in_wa.write(b"dropped") await in_wa.drain() with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(out_rb.read(4096), timeout=0.3) in_wa.close() await in_wa.wait_closed() await asyncio.wait_for(copy_task, timeout=2.0) async def test_connection_error_handled(self): in_wa, in_wb, in_ra, in_rb = await _pipe() out_wa, out_wb, out_ra, out_rb = await _pipe() flow = _make_flow() manager = AddonManager() out_wa.close() await out_wa.wait_closed() copy_task = asyncio.create_task( _copy(in_rb, out_wa, manager, Direction.UPSTREAM, flow) ) in_wa.write(b"trigger") await in_wa.drain() await asyncio.wait_for(copy_task, timeout=2.0) in_wa.close() await in_wa.wait_closed() async def test_writer_closed_on_finish(self): in_wa, in_wb, in_ra, in_rb = await _pipe() out_wa, out_wb, out_ra, out_rb = await _pipe() flow = _make_flow() manager = AddonManager() in_wa.close() await in_wa.wait_closed() await _copy(in_rb, out_wa, manager, Direction.UPSTREAM, flow) assert out_wa.is_closing() class TestHandleTcpRelay: async def test_bidirectional_relay(self): # Client pipe: cw_app → cr_relay, cw_relay → cr_app c_sock_a, c_sock_b = socket.socketpair() c_sock_a.setblocking(False) c_sock_b.setblocking(False) cr_app, cw_app = await asyncio.open_connection(sock=c_sock_a) cr_relay, cw_relay = await asyncio.open_connection(sock=c_sock_b) # Remote pipe: rw_app → rr_relay, rw_relay → rr_app r_sock_a, r_sock_b = socket.socketpair() r_sock_a.setblocking(False) r_sock_b.setblocking(False) rr_app, rw_app = await asyncio.open_connection(sock=r_sock_a) rr_relay, rw_relay = await asyncio.open_connection(sock=r_sock_b) flow = _make_flow() manager = AddonManager() relay_task = asyncio.create_task( handle_tcp_relay(cr_relay, cw_relay, rr_relay, rw_relay, manager, flow) ) # Client → Remote cw_app.write(b"to-remote") await cw_app.drain() data = await asyncio.wait_for(rr_app.read(4096), timeout=1.0) assert data == b"to-remote" # Remote → Client rw_app.write(b"to-client") await rw_app.drain() data = await asyncio.wait_for(cr_app.read(4096), timeout=1.0) assert data == b"to-client" cw_app.close() await cw_app.wait_closed() rw_app.close() await rw_app.wait_closed() await asyncio.wait_for(relay_task, timeout=2.0) async def test_relay_stops_when_client_closes(self): c_sock_a, c_sock_b = socket.socketpair() c_sock_a.setblocking(False) c_sock_b.setblocking(False) cr_app, cw_app = await asyncio.open_connection(sock=c_sock_a) cr_relay, cw_relay = await asyncio.open_connection(sock=c_sock_b) r_sock_a, r_sock_b = socket.socketpair() r_sock_a.setblocking(False) r_sock_b.setblocking(False) rr_app, rw_app = await asyncio.open_connection(sock=r_sock_a) rr_relay, rw_relay = await asyncio.open_connection(sock=r_sock_b) flow = _make_flow() manager = AddonManager() relay_task = asyncio.create_task( handle_tcp_relay(cr_relay, cw_relay, rr_relay, rw_relay, manager, flow) ) cw_app.close() await cw_app.wait_closed() await asyncio.wait_for(relay_task, timeout=2.0) ================================================ FILE: tests/test_udp_associate_hook.py ================================================ """Tests for the on_udp_associate competitive hook.""" import asyncio from asyncio_socks_server.addons.base import Addon from asyncio_socks_server.core.types import Address from asyncio_socks_server.server.server import Server from asyncio_socks_server.server.udp_relay import UdpRelayBase class _CustomRelay(UdpRelayBase): """Minimal UdpRelayBase that records calls.""" def __init__(self): self.started = False self.stopped = False self.client_transport_set = False self.datagrams: list[bytes] = [] async def start(self) -> Address: self.started = True return Address("127.0.0.1", 12345) def set_client_transport(self, transport: asyncio.DatagramTransport) -> None: self.client_transport_set = True async def stop(self) -> None: self.stopped = True def handle_client_datagram(self, data: bytes, client_addr: tuple[str, int]) -> None: self.datagrams.append(data) class _CustomAddon(Addon): def __init__(self, relay: UdpRelayBase): self._relay = relay async def on_udp_associate(self, flow) -> UdpRelayBase | None: return self._relay class _PassAddon(Addon): async def on_udp_associate(self, flow) -> UdpRelayBase | None: return None class _FailingRelay(UdpRelayBase): def __init__(self): self.stopped = False async def start(self) -> Address: raise RuntimeError("udp relay failed") def set_client_transport(self, transport: asyncio.DatagramTransport) -> None: pass async def stop(self) -> None: self.stopped = True def handle_client_datagram(self, data: bytes, client_addr: tuple[str, int]) -> None: pass class _ErrorTracker(Addon): def __init__(self): self.errors: list[Exception] = [] async def on_error(self, error: Exception) -> None: self.errors.append(error) async def _start_server(**kwargs): server = Server(host="127.0.0.1", port=0, **kwargs) task = asyncio.create_task(server._run()) for _ in range(50): if server.port != 0: break await asyncio.sleep(0.01) return server, task async def _stop_server(server, task): server.request_shutdown() await task class TestUdpAssociateHook: async def test_addon_returns_custom_handler(self): """Addon returning a custom UdpRelayBase replaces the default.""" relay = _CustomRelay() addon = _CustomAddon(relay) server, task = await _start_server(addons=[addon]) try: reader, writer = await asyncio.open_connection("127.0.0.1", server.port) # SOCKS5 handshake writer.write(b"\x05\x01\x00") await writer.drain() resp = await reader.readexactly(2) assert resp == b"\x05\x00" # UDP ASSOCIATE writer.write(b"\x05\x03\x00\x01\x00\x00\x00\x00\x00\x00") await writer.drain() # Read reply reply = await reader.readexactly(3) assert reply[0] == 0x05 # reply[1] == 0x00 means success (custom relay returned its addr) # Read bound address atyp = (await reader.readexactly(1))[0] if atyp == 0x01: await reader.readexactly(4 + 2) elif atyp == 0x04: await reader.readexactly(16 + 2) assert relay.started writer.close() await writer.wait_closed() finally: await _stop_server(server, task) async def test_addon_returns_none_uses_default(self): """Addon returning None falls through to default UdpRelay.""" addon = _PassAddon() server, task = await _start_server(addons=[addon]) try: reader, writer = await asyncio.open_connection("127.0.0.1", server.port) writer.write(b"\x05\x01\x00") await writer.drain() resp = await reader.readexactly(2) assert resp == b"\x05\x00" writer.write(b"\x05\x03\x00\x01\x00\x00\x00\x00\x00\x00") await writer.drain() reply = await reader.readexactly(3) assert reply[0] == 0x05 assert reply[1] == 0x00 writer.close() await writer.wait_closed() finally: await _stop_server(server, task) async def test_competitive_first_wins(self): """Multiple addons: first non-None result wins.""" relay_a = _CustomRelay() relay_b = _CustomRelay() class AddonA(Addon): async def on_udp_associate(self, flow): return relay_a class AddonB(Addon): async def on_udp_associate(self, flow): return relay_b server, task = await _start_server(addons=[AddonA(), AddonB()]) try: reader, writer = await asyncio.open_connection("127.0.0.1", server.port) writer.write(b"\x05\x01\x00") await writer.drain() resp = await reader.readexactly(2) assert resp == b"\x05\x00" writer.write(b"\x05\x03\x00\x01\x00\x00\x00\x00\x00\x00") await writer.drain() reply = await reader.readexactly(3) assert reply[1] == 0x00 await reader.readexactly(1) # atyp await reader.readexactly(4 + 2) # ipv4+port assert relay_a.started assert not relay_b.started writer.close() await writer.wait_closed() finally: await _stop_server(server, task) async def test_relay_start_failure_returns_socks_error_and_dispatches_error(self): relay = _FailingRelay() tracker = _ErrorTracker() server, task = await _start_server( addons=[_CustomAddon(relay), tracker], ) try: reader, writer = await asyncio.open_connection("127.0.0.1", server.port) writer.write(b"\x05\x01\x00") await writer.drain() resp = await reader.readexactly(2) assert resp == b"\x05\x00" writer.write(b"\x05\x03\x00\x01\x00\x00\x00\x00\x00\x00") await writer.drain() reply = await reader.readexactly(3) assert reply == b"\x05\x01\x00" atyp = await reader.readexactly(1) assert atyp == b"\x01" await reader.readexactly(4 + 2) assert relay.stopped assert len(tracker.errors) == 1 assert isinstance(tracker.errors[0], RuntimeError) writer.close() await writer.wait_closed() finally: await _stop_server(server, task) ================================================ FILE: tests/test_udp_over_tcp.py ================================================ import asyncio from asyncio_socks_server.core.types import Address from asyncio_socks_server.server.udp_over_tcp import encode_udp_frame, read_udp_frame class TestUdpOverTcpFrame: async def test_roundtrip_ipv4(self): addr = Address("127.0.0.1", 1080) payload = b"hello world" frame = await encode_udp_frame(addr, payload) reader = asyncio.StreamReader() reader.feed_data(frame) reader.feed_eof() result_addr, result_data = await read_udp_frame(reader) assert result_addr.host == "127.0.0.1" assert result_addr.port == 1080 assert result_data == payload async def test_roundtrip_ipv6(self): addr = Address("::1", 443) payload = b"test data" frame = await encode_udp_frame(addr, payload) reader = asyncio.StreamReader() reader.feed_data(frame) reader.feed_eof() result_addr, result_data = await read_udp_frame(reader) assert result_addr.host == "::1" assert result_addr.port == 443 assert result_data == payload async def test_roundtrip_domain(self): addr = Address("example.com", 80) payload = b"http request" frame = await encode_udp_frame(addr, payload) reader = asyncio.StreamReader() reader.feed_data(frame) reader.feed_eof() result_addr, result_data = await read_udp_frame(reader) assert result_addr.host == "example.com" assert result_addr.port == 80 assert result_data == payload async def test_multiple_frames(self): frames_data = b"" expected = [] for i in range(3): addr = Address("127.0.0.1", 1000 + i) payload = f"packet {i}".encode() frame = await encode_udp_frame(addr, payload) frames_data += frame expected.append((addr, payload)) reader = asyncio.StreamReader() reader.feed_data(frames_data) reader.feed_eof() for exp_addr, exp_data in expected: result_addr, result_data = await read_udp_frame(reader) assert result_addr.host == exp_addr.host assert result_addr.port == exp_addr.port assert result_data == exp_data async def test_empty_payload(self): addr = Address("10.0.0.1", 53) payload = b"" frame = await encode_udp_frame(addr, payload) reader = asyncio.StreamReader() reader.feed_data(frame) reader.feed_eof() result_addr, result_data = await read_udp_frame(reader) assert result_addr.host == "10.0.0.1" assert result_addr.port == 53 assert result_data == b"" ================================================ FILE: tests/test_udp_over_tcp_e2e.py ================================================ """End-to-end test: UDP client → Entry SOCKS5 server → Exit server → UDP echo.""" import asyncio import socket import struct from asyncio_socks_server.addons.udp_over_tcp_entry import UdpOverTcpEntry from asyncio_socks_server.core.protocol import build_udp_header, parse_udp_header from asyncio_socks_server.core.types import Address from asyncio_socks_server.server.server import Server from asyncio_socks_server.server.udp_over_tcp_exit import UdpOverTcpExitServer async def _start_server(**kwargs): server = Server(host="127.0.0.1", port=0, **kwargs) task = asyncio.create_task(server._run()) for _ in range(50): if server.port != 0: break await asyncio.sleep(0.01) return server, task async def _stop_server(server, task): server.request_shutdown() await task async def _start_exit_server(**kwargs): server = UdpOverTcpExitServer(host="127.0.0.1", port=0, **kwargs) task = asyncio.create_task(server._run()) for _ in range(50): if server.port != 0: break await asyncio.sleep(0.01) return server, task async def _stop_exit_server(server, task): server.request_shutdown() await task async def _skip_bind_address(reader): atyp = (await reader.readexactly(1))[0] if atyp == 0x01: await reader.readexactly(4 + 2) elif atyp == 0x04: await reader.readexactly(16 + 2) elif atyp == 0x03: length = (await reader.readexactly(1))[0] await reader.readexactly(length + 2) class TestUdpOverTcpE2E: async def test_full_chain_udp_roundtrip(self): """UDP client → Entry SOCKS5 → Exit server → UDP echo → back.""" # 1. UDP echo server class EchoProtocol(asyncio.DatagramProtocol): def connection_made(self, transport): self.transport = transport def datagram_received(self, data, addr): self.transport.sendto(data, addr) loop = asyncio.get_running_loop() echo_transport, _ = await loop.create_datagram_endpoint( EchoProtocol, local_addr=("127.0.0.1", 0) ) echo_sock = echo_transport.get_extra_info("socket") echo_sockname = echo_sock.getsockname() if echo_sock else ("127.0.0.1", 0) echo_addr = Address(echo_sockname[0], echo_sockname[1]) # 2. Exit server exit_srv, exit_task = await _start_exit_server() # 3. Entry SOCKS5 server with UdpOverTcpEntry addon entry_addon = UdpOverTcpEntry(f"127.0.0.1:{exit_srv.port}") entry_srv, entry_task = await _start_server(addons=[entry_addon]) try: # 4. Client: SOCKS5 handshake reader, writer = await asyncio.open_connection("127.0.0.1", entry_srv.port) writer.write(b"\x05\x01\x00") await writer.drain() resp = await reader.readexactly(2) assert resp == b"\x05\x00" # 5. UDP ASSOCIATE writer.write(b"\x05\x03\x00\x01\x00\x00\x00\x00\x00\x00") await writer.drain() reply = await reader.readexactly(3) assert reply[0] == 0x05 assert reply[1] == 0x00 # Read bind address atyp = (await reader.readexactly(1))[0] if atyp == 0x01: host_bytes = await reader.readexactly(4) import ipaddress host = ipaddress.IPv4Address(host_bytes).compressed elif atyp == 0x04: host_bytes = await reader.readexactly(16) host = str(ipaddress.IPv6Address(host_bytes)) else: length = (await reader.readexactly(1))[0] host = (await reader.readexactly(length)).decode("ascii") port_bytes = await reader.readexactly(2) port = struct.unpack("!H", port_bytes)[0] udp_bind = Address(host, port) # 6. Client sends UDP datagram through the entry server's UDP bind received_future = loop.create_future() class ClientProtocol(asyncio.DatagramProtocol): def datagram_received(self, data, addr): if not received_future.done(): received_future.set_result(data) client_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) client_sock.bind(("127.0.0.1", 0)) client_sock.setblocking(False) client_transport, _ = await loop.create_datagram_endpoint( ClientProtocol, sock=client_sock ) datagram = build_udp_header(echo_addr) + b"hello chain" client_transport.sendto(datagram, (udp_bind.host, udp_bind.port)) resp_data = await asyncio.wait_for(received_future, timeout=3.0) _, _, payload = parse_udp_header(resp_data) assert payload == b"hello chain" client_transport.close() writer.close() await writer.wait_closed() finally: await _stop_server(entry_srv, entry_task) await _stop_exit_server(exit_srv, exit_task) echo_transport.close() async def test_chain_multiple_datagrams(self): """Multiple UDP datagrams through the full chain.""" class EchoProtocol(asyncio.DatagramProtocol): def connection_made(self, transport): self.transport = transport def datagram_received(self, data, addr): self.transport.sendto(data, addr) loop = asyncio.get_running_loop() echo_transport, _ = await loop.create_datagram_endpoint( EchoProtocol, local_addr=("127.0.0.1", 0) ) echo_sock = echo_transport.get_extra_info("socket") echo_sockname = echo_sock.getsockname() if echo_sock else ("127.0.0.1", 0) echo_addr = Address(echo_sockname[0], echo_sockname[1]) exit_srv, exit_task = await _start_exit_server() entry_addon = UdpOverTcpEntry(f"127.0.0.1:{exit_srv.port}") entry_srv, entry_task = await _start_server(addons=[entry_addon]) try: reader, writer = await asyncio.open_connection("127.0.0.1", entry_srv.port) writer.write(b"\x05\x01\x00") await writer.drain() resp = await reader.readexactly(2) assert resp == b"\x05\x00" writer.write(b"\x05\x03\x00\x01\x00\x00\x00\x00\x00\x00") await writer.drain() reply = await reader.readexactly(3) assert reply[1] == 0x00 atyp = (await reader.readexactly(1))[0] if atyp == 0x01: host_bytes = await reader.readexactly(4) import ipaddress host = ipaddress.IPv4Address(host_bytes).compressed elif atyp == 0x04: host_bytes = await reader.readexactly(16) host = str(ipaddress.IPv6Address(host_bytes)) else: length = (await reader.readexactly(1))[0] host = (await reader.readexactly(length)).decode("ascii") port_bytes = await reader.readexactly(2) port = struct.unpack("!H", port_bytes)[0] udp_bind = Address(host, port) received_queue: asyncio.Queue[bytes] = asyncio.Queue() class ClientProtocol(asyncio.DatagramProtocol): def datagram_received(self, data, addr): received_queue.put_nowait(data) client_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) client_sock.bind(("127.0.0.1", 0)) client_sock.setblocking(False) client_transport, _ = await loop.create_datagram_endpoint( ClientProtocol, sock=client_sock ) for i in range(3): datagram = build_udp_header(echo_addr) + f"pkt{i}".encode() client_transport.sendto(datagram, (udp_bind.host, udp_bind.port)) await asyncio.sleep(0.05) for i in range(3): resp_data = await asyncio.wait_for(received_queue.get(), timeout=3.0) _, _, payload = parse_udp_header(resp_data) assert payload == f"pkt{i}".encode() client_transport.close() writer.close() await writer.wait_closed() finally: await _stop_server(entry_srv, entry_task) await _stop_exit_server(exit_srv, exit_task) echo_transport.close() ================================================ FILE: tests/test_udp_over_tcp_exit.py ================================================ """Tests for UdpOverTcpExitServer.""" import asyncio from asyncio_socks_server.core.types import Address from asyncio_socks_server.server.udp_over_tcp import encode_udp_frame, read_udp_frame from asyncio_socks_server.server.udp_over_tcp_exit import UdpOverTcpExitServer async def _start_exit_server(**kwargs): server = UdpOverTcpExitServer(host="127.0.0.1", port=0, **kwargs) task = asyncio.create_task(server._run()) for _ in range(50): if server.port != 0: break await asyncio.sleep(0.01) return server, task async def _stop_exit_server(server, task): server.request_shutdown() await task class TestUdpOverTcpExit: async def test_tcp_to_udp_roundtrip(self): """Send UDP-over-TCP frame → exit server → UDP echo → TCP frame back.""" # UDP echo server received = [] class EchoProtocol(asyncio.DatagramProtocol): def connection_made(self, transport): self.transport = transport def datagram_received(self, data, addr): received.append((data, addr)) self.transport.sendto(data, addr) loop = asyncio.get_running_loop() echo_transport, _ = await loop.create_datagram_endpoint( EchoProtocol, local_addr=("127.0.0.1", 0) ) echo_sock = echo_transport.get_extra_info("socket") echo_sockname = echo_sock.getsockname() if echo_sock else ("127.0.0.1", 0) echo_addr = Address(echo_sockname[0], echo_sockname[1]) exit_srv, exit_task = await _start_exit_server() try: # Connect to exit server via TCP reader, writer = await asyncio.open_connection("127.0.0.1", exit_srv.port) # Send UDP-over-TCP frame targeting the echo server frame = await encode_udp_frame(echo_addr, b"hello exit") writer.write(frame) await writer.drain() # Wait for echo reply to come back via TCP src_addr, payload = await asyncio.wait_for( read_udp_frame(reader), timeout=2.0 ) assert payload == b"hello exit" assert src_addr.host == echo_addr.host assert src_addr.port == echo_addr.port writer.close() await writer.wait_closed() finally: await _stop_exit_server(exit_srv, exit_task) echo_transport.close() async def test_multiple_datagrams(self): """Multiple frames in sequence.""" class EchoProtocol(asyncio.DatagramProtocol): def connection_made(self, transport): self.transport = transport def datagram_received(self, data, addr): self.transport.sendto(data, addr) loop = asyncio.get_running_loop() echo_transport, _ = await loop.create_datagram_endpoint( EchoProtocol, local_addr=("127.0.0.1", 0) ) echo_sock = echo_transport.get_extra_info("socket") echo_sockname = echo_sock.getsockname() if echo_sock else ("127.0.0.1", 0) echo_addr = Address(echo_sockname[0], echo_sockname[1]) exit_srv, exit_task = await _start_exit_server() try: reader, writer = await asyncio.open_connection("127.0.0.1", exit_srv.port) for i in range(5): frame = await encode_udp_frame(echo_addr, f"msg{i}".encode()) writer.write(frame) await writer.drain() for i in range(5): src_addr, payload = await asyncio.wait_for( read_udp_frame(reader), timeout=2.0 ) assert payload == f"msg{i}".encode() writer.close() await writer.wait_closed() finally: await _stop_exit_server(exit_srv, exit_task) echo_transport.close() ================================================ FILE: tests/test_udp_relay.py ================================================ """UDP relay tests: UdpRelay unit + UDP ASSOCIATE end-to-end.""" import asyncio import time from asyncio_socks_server.core.address import encode_address from asyncio_socks_server.core.protocol import build_udp_header from asyncio_socks_server.core.types import Address, Flow from asyncio_socks_server.server.udp_relay import UdpRelay from tests.conftest import _start_server, _stop_server def _udp_flow(): return Flow( id=1, src=Address("127.0.0.1", 12345), dst=Address("0.0.0.0", 0), protocol="udp", started_at=0.0, ) def _build_udp_datagram(dst: Address, payload: bytes) -> bytes: """Build a SOCKS5-encapsulated UDP datagram.""" return build_udp_header(dst) + payload async def _socks5_udp_associate(proxy: Address, auth=None): """Perform SOCKS5 handshake + UDP ASSOCIATE. Returns (tcp_reader, tcp_writer, udp_bind_addr). """ reader, writer = await asyncio.open_connection(proxy.host, proxy.port) # Method selection if auth: writer.write(b"\x05\x01\x02") else: writer.write(b"\x05\x01\x00") await writer.drain() resp = await reader.readexactly(2) assert resp[0] == 0x05 if auth: assert resp[1] == 0x02 uname = auth[0].encode() passwd = auth[1].encode() writer.write( b"\x01" + len(uname).to_bytes(1, "big") + uname + len(passwd).to_bytes(1, "big") + passwd ) await writer.drain() auth_resp = await reader.readexactly(2) assert auth_resp == b"\x01\x00" else: assert resp[1] == 0x00 # UDP ASSOCIATE request (dst = 0.0.0.0:0) writer.write(b"\x05\x03\x00" + encode_address("0.0.0.0", 0)) await writer.drain() reply = await reader.readexactly(3) assert reply[0] == 0x05 assert reply[1] == 0x00 # succeeded # Read bound address atyp = (await reader.readexactly(1))[0] if atyp == 0x01: host_bytes = await reader.readexactly(4) import ipaddress bind_host = ipaddress.IPv4Address(host_bytes).compressed elif atyp == 0x04: host_bytes = await reader.readexactly(16) import ipaddress bind_host = ipaddress.IPv6Address(host_bytes).compressed else: length = (await reader.readexactly(1))[0] bind_host = (await reader.readexactly(length)).decode("ascii") port_bytes = await reader.readexactly(2) import struct bind_port = struct.unpack("!H", port_bytes)[0] return reader, writer, Address(bind_host, bind_port) class TestUdpRelayUnit: async def test_start_returns_bind_address(self): relay = UdpRelay(client_addr=Address("127.0.0.1", 12345), flow=_udp_flow()) try: bind_addr = await relay.start() assert bind_addr.port > 0 assert bind_addr.host != "" finally: await relay.stop() async def test_stop_cancels_ttl_task(self): relay = UdpRelay(client_addr=Address("127.0.0.1", 12345), flow=_udp_flow()) await relay.start() ttl_task = relay._ttl_task await relay.stop() assert ttl_task.cancelled() or ttl_task.done() async def test_handle_client_datagram_routes_outbound(self, udp_echo_server): echo_addr, received = udp_echo_server relay = UdpRelay(client_addr=Address("127.0.0.1", 12345), flow=_udp_flow()) try: await relay.start() datagram = _build_udp_datagram(echo_addr, b"hello") relay.handle_client_datagram(datagram, ("127.0.0.1", 12345)) # Wait for echo server to receive for _ in range(50): if received: break await asyncio.sleep(0.02) assert len(received) == 1 assert received[0][0] == b"hello" finally: await relay.stop() async def test_handle_client_datagram_empty_payload_ignored(self): relay = UdpRelay(client_addr=Address("127.0.0.1", 12345), flow=_udp_flow()) try: await relay.start() # Valid header but empty payload datagram = _build_udp_datagram(Address("127.0.0.1", 80), b"") relay.handle_client_datagram(datagram, ("127.0.0.1", 12345)) # No route should be created (empty payload is ignored) assert len(relay._route_map) == 0 finally: await relay.stop() async def test_handle_client_datagram_malformed_ignored(self): relay = UdpRelay(client_addr=Address("127.0.0.1", 12345), flow=_udp_flow()) try: await relay.start() relay.handle_client_datagram(b"garbage", ("127.0.0.1", 12345)) assert len(relay._route_map) == 0 finally: await relay.stop() async def test_routing_table_entries_created(self): relay = UdpRelay(client_addr=Address("127.0.0.1", 12345), flow=_udp_flow()) try: await relay.start() datagram = _build_udp_datagram(Address("127.0.0.1", 80), b"data") relay.handle_client_datagram(datagram, ("127.0.0.1", 12345)) assert ("127.0.0.1", 80) in relay._route_map finally: await relay.stop() async def test_routing_table_entries_refreshed(self): relay = UdpRelay(client_addr=Address("127.0.0.1", 12345), flow=_udp_flow()) try: await relay.start() datagram = _build_udp_datagram(Address("127.0.0.1", 80), b"data") relay.handle_client_datagram(datagram, ("127.0.0.1", 12345)) ts1 = relay._route_timestamps[("127.0.0.1", 80)] await asyncio.sleep(0.05) relay.handle_client_datagram(datagram, ("127.0.0.1", 12345)) ts2 = relay._route_timestamps[("127.0.0.1", 80)] assert ts2 > ts1 finally: await relay.stop() class TestUdpRelayTTL: async def test_ttl_cleanup_removes_expired(self): relay = UdpRelay( client_addr=Address("127.0.0.1", 12345), flow=_udp_flow(), ttl=0.1 ) try: await relay.start() # Manually inject a route with old timestamp relay._route_map[("10.0.0.1", 80)] = ("127.0.0.1", 12345) relay._route_timestamps[("10.0.0.1", 80)] = time.monotonic() - 1.0 # Wait for TTL cleanup (runs every 60s, but we can trigger manually) # Let's just wait enough time — the cleanup loop runs every 60s, # so we manually trigger it now = time.monotonic() expired = [ key for key, ts in relay._route_timestamps.items() if now - ts > relay._ttl ] for key in expired: relay._route_map.pop(key, None) relay._route_timestamps.pop(key, None) assert ("10.0.0.1", 80) not in relay._route_map finally: await relay.stop() async def test_ttl_cleanup_keeps_active(self): relay = UdpRelay( client_addr=Address("127.0.0.1", 12345), flow=_udp_flow(), ttl=300 ) try: await relay.start() relay._route_map[("10.0.0.1", 80)] = ("127.0.0.1", 12345) relay._route_timestamps[("10.0.0.1", 80)] = time.monotonic() now = time.monotonic() expired = [ key for key, ts in relay._route_timestamps.items() if now - ts > relay._ttl ] assert len(expired) == 0 assert ("10.0.0.1", 80) in relay._route_map finally: await relay.stop() class TestUdpAssociateE2E: async def test_udp_associate_handshake(self): """Test UDP ASSOCIATE returns a valid bind address.""" server, task = await _start_server() try: tcp_r, tcp_w, udp_bind = await _socks5_udp_associate( Address(server.host, server.port) ) assert udp_bind.port > 0 tcp_w.close() await tcp_w.wait_closed() finally: await _stop_server(server, task) async def test_udp_associate_with_auth(self): server, task = await _start_server(auth=("user", "pass")) try: tcp_r, tcp_w, udp_bind = await _socks5_udp_associate( Address(server.host, server.port), auth=("user", "pass") ) assert udp_bind.port > 0 tcp_w.close() await tcp_w.wait_closed() finally: await _stop_server(server, task) async def test_udp_associate_send_and_receive(self, udp_echo_server): """Test sending a UDP datagram through the proxy and receiving the echo.""" echo_addr, _ = udp_echo_server server, task = await _start_server() try: tcp_r, tcp_w, udp_bind = await _socks5_udp_associate( Address(server.host, server.port) ) await asyncio.sleep(0.05) # Let server settle # Set up a UDP client that can send and receive loop = asyncio.get_running_loop() received = asyncio.get_event_loop().create_future() class ClientProtocol(asyncio.DatagramProtocol): def datagram_received(self, data, addr): if not received.done(): received.set_result(data) transport, _ = await loop.create_datagram_endpoint( ClientProtocol, local_addr=("127.0.0.1", 0) ) # Send SOCKS5-encapsulated datagram to proxy's UDP bind datagram = _build_udp_datagram(echo_addr, b"hello") transport.sendto(datagram, (udp_bind.host, udp_bind.port)) # Wait for echo response try: resp_data = await asyncio.wait_for(received, timeout=2.0) from asyncio_socks_server.core.protocol import parse_udp_header resp_addr, _, resp_payload = parse_udp_header(resp_data) assert resp_payload == b"hello" finally: transport.close() tcp_w.close() await tcp_w.wait_closed() finally: await _stop_server(server, task) async def test_tcp_close_ends_relay(self): server, task = await _start_server() try: tcp_r, tcp_w, udp_bind = await _socks5_udp_associate( Address(server.host, server.port) ) assert udp_bind.port > 0 tcp_w.close() await tcp_w.wait_closed() await asyncio.sleep(0.3) # Sending to the closed relay should not crash loop = asyncio.get_running_loop() class SilentProtocol(asyncio.DatagramProtocol): def datagram_received(self, data, addr): pass transport, _ = await loop.create_datagram_endpoint( SilentProtocol, local_addr=("127.0.0.1", 0) ) try: transport.sendto( _build_udp_datagram(Address("127.0.0.1", 1), b"x"), (udp_bind.host, udp_bind.port), ) except OSError: pass transport.close() finally: await _stop_server(server, task)