Repository: genmeta/gm-quic Branch: main Commit: 5e8392136608 Files: 300 Total size: 2.0 MB Directory structure: gitextract_c16ycxax/ ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ ├── feature_request.md │ │ └── r2cn.md │ ├── dependabot.yml │ └── workflows/ │ ├── benchmark.yml │ ├── codecov.yml │ ├── commitlint.yml │ ├── feishu-bot.yml │ ├── rust.yml │ └── traversal.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .rustfmt.toml ├── .rusty-hook.toml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Cargo.toml ├── LICENSE ├── README.md ├── README_CN.md ├── SECURITY.md ├── benchmark/ │ └── launch.py ├── codecov.yml ├── commitlint.config.js ├── dquic/ │ ├── Cargo.toml │ ├── examples/ │ │ ├── echo-client.rs │ │ ├── echo-server.rs │ │ ├── http-client.rs │ │ ├── http-server.rs │ │ ├── traversal-client.rs │ │ └── traversal-server.rs │ ├── src/ │ │ ├── cert.rs │ │ ├── client.rs │ │ ├── common.rs │ │ ├── lib.rs │ │ └── server.rs │ └── tests/ │ ├── auth.rs │ ├── common/ │ │ └── mod.rs │ ├── echo.rs │ ├── echo_common/ │ │ └── mod.rs │ └── traversal.rs ├── h3-shim/ │ ├── Cargo.toml │ ├── examples/ │ │ ├── README.md │ │ ├── h3-client.rs │ │ └── h3-server.rs │ └── src/ │ ├── conn.rs │ ├── error.rs │ ├── ext.rs │ ├── lib.rs │ ├── pool.rs │ └── streams.rs ├── interop/ │ ├── Dockerfile │ └── run_endpoint.sh ├── qbase/ │ ├── Cargo.toml │ └── src/ │ ├── cid/ │ │ ├── connection_id.rs │ │ ├── local_cid.rs │ │ └── remote_cid.rs │ ├── cid.rs │ ├── error.rs │ ├── flow.rs │ ├── frame/ │ │ ├── ack.rs │ │ ├── add_address.rs │ │ ├── connection_close.rs │ │ ├── crypto.rs │ │ ├── data_blocked.rs │ │ ├── datagram.rs │ │ ├── error.rs │ │ ├── handshake_done.rs │ │ ├── io.rs │ │ ├── max_data.rs │ │ ├── max_stream_data.rs │ │ ├── max_streams.rs │ │ ├── new_connection_id.rs │ │ ├── new_token.rs │ │ ├── padding.rs │ │ ├── path_challenge.rs │ │ ├── path_response.rs │ │ ├── ping.rs │ │ ├── punch_done.rs │ │ ├── punch_hello.rs │ │ ├── punch_me_now.rs │ │ ├── remove_address.rs │ │ ├── reset_stream.rs │ │ ├── retire_connection_id.rs │ │ ├── stop_sending.rs │ │ ├── stream.rs │ │ ├── stream_data_blocked.rs │ │ └── streams_blocked.rs │ ├── frame.rs │ ├── handshake.rs │ ├── lib.rs │ ├── metric.rs │ ├── net/ │ │ ├── addr.rs │ │ ├── nat.rs │ │ ├── route.rs │ │ └── tx.rs │ ├── net.rs │ ├── packet/ │ │ ├── decrypt.rs │ │ ├── encrypt.rs │ │ ├── error.rs │ │ ├── header/ │ │ │ ├── long.rs │ │ │ └── short.rs │ │ ├── header.rs │ │ ├── io.rs │ │ ├── keys.rs │ │ ├── number.rs │ │ ├── signal.rs │ │ ├── type/ │ │ │ ├── long/ │ │ │ │ └── v1.rs │ │ │ ├── long.rs │ │ │ └── short.rs │ │ └── type.rs │ ├── packet.rs │ ├── param/ │ │ ├── core.rs │ │ ├── error.rs │ │ ├── handy.rs │ │ ├── io.rs │ │ └── preferred_address.rs │ ├── param.rs │ ├── role.rs │ ├── sid/ │ │ ├── handy.rs │ │ ├── local_sid.rs │ │ └── remote_sid.rs │ ├── sid.rs │ ├── time.rs │ ├── token.rs │ ├── util/ │ │ ├── async_deque.rs │ │ ├── bound_queue.rs │ │ ├── data.rs │ │ ├── index_deque.rs │ │ ├── unique_id.rs │ │ └── wakers.rs │ ├── util.rs │ └── varint.rs ├── qcongestion/ │ ├── Cargo.toml │ └── src/ │ ├── algorithm/ │ │ ├── bbr/ │ │ │ ├── delivery_rate.rs │ │ │ ├── min_max.rs │ │ │ ├── model.rs │ │ │ ├── parameters.rs │ │ │ └── state.rs │ │ ├── bbr.rs │ │ └── new_reno.rs │ ├── algorithm.rs │ ├── congestion.rs │ ├── lib.rs │ ├── pacing.rs │ ├── packets.rs │ ├── rtt.rs │ └── status.rs ├── qconnection/ │ ├── Cargo.toml │ └── src/ │ ├── builder.rs │ ├── events.rs │ ├── handshake.rs │ ├── lib.rs │ ├── path/ │ │ ├── aa.rs │ │ ├── burst.rs │ │ ├── drive.rs │ │ ├── error.rs │ │ ├── paths.rs │ │ ├── util.rs │ │ └── validate.rs │ ├── path.rs │ ├── space/ │ │ ├── data.rs │ │ ├── handshake.rs │ │ └── initial.rs │ ├── space.rs │ ├── state.rs │ ├── termination.rs │ ├── tls/ │ │ ├── agent.rs │ │ └── client_auth.rs │ ├── tls.rs │ ├── traversal.rs │ └── tx.rs ├── qdatagram/ │ ├── Cargo.toml │ └── src/ │ ├── lib.rs │ ├── reader.rs │ └── writer.rs ├── qevent/ │ ├── Cargo.toml │ └── src/ │ ├── legacy/ │ │ ├── exporter.rs │ │ └── quic.rs │ ├── legacy.rs │ ├── lib.rs │ ├── loglevel.rs │ ├── macro_support.rs │ ├── macros.rs │ ├── packet.rs │ ├── quic/ │ │ ├── connectivity.rs │ │ ├── recovery.rs │ │ ├── security.rs │ │ └── transport.rs │ ├── quic.rs │ ├── telemetry/ │ │ ├── filter.rs │ │ ├── handy.rs │ │ ├── macro_support.rs │ │ └── macros.rs │ └── telemetry.rs ├── qinterface/ │ ├── Cargo.toml │ ├── examples/ │ │ └── interface-monitor.rs │ ├── src/ │ │ ├── bind_uri.rs │ │ ├── component/ │ │ │ ├── alive.rs │ │ │ ├── location.rs │ │ │ ├── route/ │ │ │ │ ├── handler.rs │ │ │ │ ├── packet.rs │ │ │ │ └── queue.rs │ │ │ └── route.rs │ │ ├── component.rs │ │ ├── device.rs │ │ ├── iface.rs │ │ ├── io/ │ │ │ ├── factory.rs │ │ │ └── handy.rs │ │ ├── io.rs │ │ ├── lib.rs │ │ └── manager.rs │ └── tests/ │ ├── auto_rebind.rs │ ├── common/ │ │ └── mod.rs │ ├── components.rs │ ├── lifecycle.rs │ ├── locations.rs │ └── rebind.rs ├── qmacro/ │ ├── Cargo.toml │ └── src/ │ ├── derive.rs │ └── lib.rs ├── qprotocol/ │ ├── Cargo.toml │ └── src/ │ ├── dns.rs │ ├── forward.rs │ ├── io.rs │ ├── lib.rs │ ├── quic.rs │ ├── stun/ │ │ └── msg.rs │ └── stun.rs ├── qrecovery/ │ ├── Cargo.toml │ └── src/ │ ├── crypto.rs │ ├── journal/ │ │ ├── rcvd.rs │ │ └── sent.rs │ ├── journal.rs │ ├── lib.rs │ ├── recv/ │ │ ├── incoming.rs │ │ ├── rcvbuf.rs │ │ ├── reader.rs │ │ └── recver.rs │ ├── recv.rs │ ├── reliable.rs │ ├── send/ │ │ ├── outgoing.rs │ │ ├── sender.rs │ │ ├── sndbuf.rs │ │ └── writer.rs │ ├── send.rs │ ├── streams/ │ │ ├── error.rs │ │ ├── io.rs │ │ ├── listener.rs │ │ └── raw.rs │ └── streams.rs ├── qresolve/ │ ├── Cargo.toml │ └── src/ │ └── lib.rs ├── qtraversal/ │ ├── Cargo.toml │ ├── README.md │ ├── examples/ │ │ ├── stun_client.rs │ │ └── stun_server.rs │ ├── src/ │ │ ├── addr.rs │ │ ├── future.rs │ │ ├── lib.rs │ │ ├── nat/ │ │ │ ├── client.rs │ │ │ ├── iface.rs │ │ │ ├── msg.rs │ │ │ ├── router.rs │ │ │ ├── server.rs │ │ │ └── tx.rs │ │ ├── nat.rs │ │ ├── packet.rs │ │ ├── punch/ │ │ │ ├── predictor.rs │ │ │ ├── puncher.rs │ │ │ ├── scheduler.rs │ │ │ └── tx.rs │ │ ├── punch.rs │ │ └── route.rs │ ├── tests/ │ │ └── detect.rs │ └── tools/ │ ├── build_nat.sh │ ├── clear_nat.sh │ ├── dockerfile │ └── run_stun.sh ├── qudp/ │ ├── Cargo.toml │ ├── examples/ │ │ ├── receive.rs │ │ └── send.rs │ └── src/ │ ├── lib.rs │ ├── unix.rs │ └── windows.rs └── tests/ └── keychain/ ├── gen_key.sh ├── localhost/ │ ├── ca.cert │ ├── ca.key │ ├── ca.srl │ ├── client.cert │ ├── client.key │ ├── server.cert │ └── server.key ├── quic.test.net/ │ ├── quic-test-net-ECC.crt │ ├── quic-test-net-ECC.key │ └── quic-test-net.csr ├── root/ │ ├── rootCA-ECC.crt │ ├── rootCA-ECC.key │ └── rootCA-ECC.srl └── start-quic-server.sh ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve title: '' labels: '' assignees: '' --- **Describe the bug** A clear and concise description of what the bug is. **To Reproduce** Steps to reproduce the behavior: 1. Go to '...' 2. Click on '....' 3. Scroll down to '....' 4. See error **Expected behavior** A clear and concise description of what you expected to happen. **Screenshots** If applicable, add screenshots to help explain your problem. **Desktop (please complete the following information):** - OS: [e.g. iOS] - Browser [e.g. chrome, safari] - Version [e.g. 22] **Smartphone (please complete the following information):** - Device: [e.g. iPhone6] - OS: [e.g. iOS8.1] - Browser [e.g. stock browser, safari] - Version [e.g. 22] **Additional context** Add any other context about the problem here. ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest an idea for this project title: '' labels: '' assignees: '' --- **Is your feature request related to a problem? Please describe.** A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] **Describe the solution you'd like** A clear and concise description of what you want to happen. **Describe alternatives you've considered** A clear and concise description of any alternative solutions or features you've considered. **Additional context** Add any other context or screenshots about the feature request here. ================================================ FILE: .github/ISSUE_TEMPLATE/r2cn.md ================================================ --- name: r2cn about: r2cn 任务模板 title: "[r2cn] " labels: r2cn assignees: "" --- [__任务__] [__任务分值__] 4 分 [__背景描述__] [__需求描述__] [__代码标准__] 1. 所有 **PR** 提交必须签署 `Signed-off-by` 和 使用 `GPG` 签名,即提交代码时(使用 `git commit` 命令时)至少使用 `-s -S` 两个参数,参考 [Contributing Guide](https://github.com/genmeta/dquic/blob/main/docs/contributing.md); 2. 所有 **PR** 提交必须通过 `GitHub Actions` 自动化测试,提交 **PR** 后请关注 `GitHub Actions` 结果; 3. 代码注释均需要使用英文; [__PR 提交地址__] 提交到 [dquic](https://github.com/genmeta/dquic) 仓库的 `main` 分支 `` 目录; [__开发指导__] 1. 认领任务参考 [r2cn 开源实习计划 - 任务认领与确认](https://r2cn.dev/docs/student/assign); [__导师及邮箱__] 请申请此题目的同学使用邮件联系导师,或加入到 [R2CN Discord](https://discord.gg/WRp4TKv6rh) 后在 `#p-meta` 频道和导师交流。 1. Peng Zhang [__备注__] 1. **认领实习任务的同学,必须完成测试任务和注册流程,请参考:** [r2cn 开源实习计划 - 测试任务](https://r2cn.dev/docs/student/pre-task) 和 [r2cn 开源实习计划 - 学生注册与审核](https://r2cn.dev/docs/student/signup) ================================================ FILE: .github/dependabot.yml ================================================ # To get started with Dependabot version updates, you'll need to specify which # package ecosystems to update and where the package manifests are located. # Please see the documentation for all configuration options: # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates version: 2 updates: - package-ecosystem: "cargo" # See documentation for possible values directory: "/" # Location of package manifests schedule: interval: "weekly" ================================================ FILE: .github/workflows/benchmark.yml ================================================ name: Benchmarks on: workflow_dispatch: # Allows manual triggering schedule: - cron: '0 2 * * *' # UTC 2:00 AM = Beijing 10:00 AM jobs: prepare-matrix: runs-on: ubuntu-latest outputs: runners: ${{ steps.prepare-runners.outputs.runners }} steps: - uses: actions/checkout@v4 - id: prepare-runners run: | runners="$(python3 benchmark/launch.py runners -q)" echo "runners=$runners" >> $GITHUB_OUTPUT run-benchmarks: strategy: fail-fast: false matrix: runner: ${{ fromJson(needs.prepare-matrix.outputs.runners) }} target: [ubuntu,macos,] runs-on: ${{ matrix.target }}-latest needs: prepare-matrix steps: - uses: actions/checkout@v4 - name: Install latest rust stable toolchain uses: actions-rust-lang/setup-rust-toolchain@v1 with: rustflags: "" # tquic use deprecated function, and this action set rustflags to "-D warnings" by default - name: Install go for macos runner if: matrix.target=='macos' && matrix.runner=='quic-go' run: brew install go - name: Run benchmarks run: | which openssl python3 benchmark/launch.py run ${{ matrix.runner }} --no-plot - name: Rename benchmark results dir run: mv benchmark/output benchmark-output-${{ matrix.target }}-${{ matrix.runner }} - name: Upload benchmark results uses: actions/upload-artifact@v4 with: path: benchmark-output-${{ matrix.target }}-${{ matrix.runner }} name: benchmark-output-${{ matrix.target }}-${{ matrix.runner }} summary-results: runs-on: ubuntu-latest needs: [run-benchmarks] strategy: fail-fast: false matrix: target: [ubuntu, macos] steps: - uses: actions/checkout@v4 - name: Install matplotlib run: | sudo apt update sudo apt install -y python3-matplotlib - name: Download outputs uses: actions/download-artifact@v4 with: pattern: benchmark-output-${{ matrix.target }}-* - name: Summary ${{ matrix.target }} run: | # Collect all results.json paths and create a space-separated list results_files=$(find . -name "results.json" | tr '\n' ' ') echo "Results files: $results_files" # Pass all results files to the plot command python3 benchmark/launch.py plot $results_files # Collect logs cp -r */logs benchmark/output/ mv benchmark/output/ benchmark-output-${{ matrix.target }} - name: Upload benchmark results uses: actions/upload-artifact@v4 with: path: benchmark-output-${{ matrix.target }} name: benchmark-output-${{ matrix.target }} ================================================ FILE: .github/workflows/codecov.yml ================================================ name: Coverage on: push: branches: ["main"] pull_request: branches: ['main'] jobs: coverage: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable - uses: taiki-e/install-action@cargo-llvm-cov # Limit test parallelism to 1 thread to avoid resource contention - run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info -- --test-threads=1 - name: Upload coverage to Codecov uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} files: lcov.info fail_ci_if_error: true ================================================ FILE: .github/workflows/commitlint.yml ================================================ name: Commitlint on: push: branches: ["main"] pull_request: branches: ["main"] jobs: commitlint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: wagoid/commitlint-github-action@v5 ================================================ FILE: .github/workflows/feishu-bot.yml ================================================ name: feishu bot on: branch_protection_rule: types: [created, deleted] check_run: types: [rerequested, completed] check_suite: types: [completed] create: delete: deployment_status: discussion: types: [created, edited, answered] discussion_comment: types: [created, deleted] fork: gollum: issues: types: [opened, edited, milestoned, pinned, reopened] issue_comment: types: [created, deleted] label: types: [created, deleted] merge_group: types: [checks_requested] milestone: types: [opened, deleted] page_build: project: types: [created, deleted, reopened] project_card: types: [created, deleted] project_column: types: [created, deleted] public: pull_request: branches: ["main"] types: [opened, reopened] pull_request_review: types: [edited, dismissed, submitted] pull_request_review_comment: types: [created, edited, deleted] pull_request_target: types: [assigned, opened, synchronize, reopened] push: branches: ["main"] registry_package: types: [published] release: types: [published] status: watch: types: [started] # schedule: # - cron: "30 2 * * *" jobs: send-event: name: Webhook runs-on: ubuntu-latest steps: - uses: KaminariOS/feishu-bot-webhook-action@main with: webhook: ${{ secrets.FEISHU_BOT_WEBHOOK }} signkey: ${{ secrets.FEISHU_BOT_SIGNKEY }} ================================================ FILE: .github/workflows/rust.yml ================================================ name: Rust on: push: branches: ["main"] pull_request: branches: ["main"] env: CARGO_TERM_COLOR: always jobs: build: strategy: matrix: target: [ubuntu, macos, windows] fail-fast: false runs-on: ${{ matrix.target }}-latest steps: - uses: actions/checkout@v4 - name: Install latest rust stable toolchain uses: actions-rust-lang/setup-rust-toolchain@v1 - name: Build run: cargo build --verbose - name: Run tests # Limit test parallelism to 1 thread to avoid resource contention # GitHub runners have limited cores (ubuntu/windows=2, macos=3) run: cargo test --workspace --verbose -- --test-threads=1 format: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Install latest rust nightly toolchain and rustfmt uses: actions-rust-lang/setup-rust-toolchain@v1 with: toolchain: nightly components: rustfmt - name: Run rustfmt run: cargo +nightly fmt --all -- --check clippy: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Install latest rust nightly toolchain with clippy uses: actions-rust-lang/setup-rust-toolchain@v1 with: toolchain: nightly components: clippy - name: Run clippy run: cargo +nightly clippy --all-targets --all-features -- -Dwarnings doc: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Install latest rust nightly toolchain uses: actions-rust-lang/setup-rust-toolchain@v1 with: toolchain: nightly - name: Run doc run: RUSTDOCFLAGS="-D warnings" cargo +nightly doc --no-deps msrv: strategy: matrix: target: [ubuntu, macos, windows] fail-fast: false runs-on: ${{ matrix.target }}-latest steps: - uses: actions/checkout@v4 - name: Install msrv toolchain uses: actions-rust-lang/setup-rust-toolchain@v1 with: toolchain: 1.88.0 - name: Build with msrv run: cargo build --workspace --release ================================================ FILE: .github/workflows/traversal.yml ================================================ name: Traversal on: push: branches: ["main", "build/*"] pull_request: branches: ["main"] workflow_dispatch: env: CARGO_TERM_COLOR: always FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true jobs: nat-detection: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Build Docker image with cache uses: docker/build-push-action@v6 with: context: . file: qtraversal/tools/dockerfile tags: dquic-traversal-test:latest load: true cache-from: type=gha cache-to: type=gha,mode=max - name: Create cargo cache volume run: docker volume create cargo-cache - name: Compile tests and get NAT detection test list run: | docker run --rm --privileged \ -v ${{ github.workspace }}:/dquic \ -v cargo-cache:/usr/local/cargo/registry \ dquic-traversal-test:latest \ /bin/bash -c " set -e cd /dquic cargo build --example stun_server --release cargo test --package qtraversal test_detect --no-run cargo test --package qtraversal test_detect -- --list " | grep ": test$" | awk '{print $1}' | sed 's/:$//' > /tmp/nat_tests.txt cat /tmp/nat_tests.txt - name: Run NAT detection tests serially run: | mapfile -t NAT_TESTS < /tmp/nat_tests.txt for test in "${NAT_TESTS[@]}"; do if [ -z "$test" ]; then continue fi echo "========================================" echo "Running NAT detection: $test" echo "========================================" docker run --rm --privileged \ -v ${{ github.workspace }}:/dquic \ -v cargo-cache:/usr/local/cargo/registry \ dquic-traversal-test:latest \ /bin/bash -c " set -e cd /dquic bash qtraversal/tools/run_stun.sh echo 'DEBUG: Running test [$test]' ip netns exec nsa cargo test --package qtraversal '$test' -- --nocapture --include-ignored " echo "Completed: $test" echo "" done punch: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Build Docker image with cache uses: docker/build-push-action@v6 with: context: . file: qtraversal/tools/dockerfile tags: dquic-traversal-test:latest load: true cache-from: type=gha cache-to: type=gha,mode=max - name: Create cargo cache volume run: docker volume create cargo-cache - name: Compile tests and get hole punching test list run: | docker run --rm --privileged \ -v ${{ github.workspace }}:/dquic \ -v cargo-cache:/usr/local/cargo/registry \ dquic-traversal-test:latest \ /bin/bash -c " set -e cd /dquic cargo build --example stun_server --release cargo test --test traversal --no-run cargo test --test traversal -- --list " | grep ": test$" | awk '{print $1}' | sed 's/:$//' > /tmp/hp_tests.txt cat /tmp/hp_tests.txt - name: Run hole punching tests serially run: | mapfile -t HP_TESTS < /tmp/hp_tests.txt for test in "${HP_TESTS[@]}"; do if [ -z "$test" ]; then continue fi echo "========================================" echo "Running hole punching: $test" echo "========================================" docker run --rm --privileged \ -v ${{ github.workspace }}:/dquic \ -v cargo-cache:/usr/local/cargo/registry \ dquic-traversal-test:latest \ /bin/bash -c " set -e cd /dquic bash qtraversal/tools/run_stun.sh echo 'DEBUG: Running test [$test]' ip netns exec nsa cargo test --test traversal '$test' -- --include-ignored --nocapture " echo "Completed: $test" echo "" done ================================================ FILE: .gitignore ================================================ # Generated by Cargo # will have compiled files and executables debug/ target/ # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html Cargo.lock # These are backup files generated by rustfmt **/*.rs.bk # Exclude benchmark temp files /benchmark/* !/benchmark/launch.py # cago-tarpaulin (coverage tool) generates this tarpaulin-report.html # MSVC Windows builds of rustc generate these, which store debugging information *.pdb .vscode/* .idea/ *.log log .DS_Store *.sqlog .cargo/config.toml # Local agent instructions AGENTS.md ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - hooks: - id: commitizen stages: - commit-msg repo: https://github.com/commitizen-tools/commitizen rev: v2.24.0 - hooks: - id: fmt #- id: cargo-check #- id: clippy repo: https://github.com/doublify/pre-commit-rust rev: v1.0 - repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook rev: v9.5.0 hooks: - id: commitlint stages: [commit-msg] additional_dependencies: ["@commitlint/config-conventional"] ================================================ FILE: .rustfmt.toml ================================================ imports_granularity = "Crate" group_imports = "StdExternalCrate" style_edition = "2024" ================================================ FILE: .rusty-hook.toml ================================================ [hooks] #pre-commit = "cargo check && cargo clippy --all-targets --all -- -D warnings" #pre-push = "cargo check && cargo clippy --all-targets --all -- -D warnings && cargo test -- --test-threads=1" pre-push = "cargo build" #post-commit = "echo yay" [logging] verbose = true ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. ## Our Standards Examples of behavior that contributes to a positive environment for our community include: * Demonstrating empathy and kindness toward other people * Being respectful of differing opinions, viewpoints, and experiences * Giving and gracefully accepting constructive feedback * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience * Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: * The use of sexualized language or imagery, and sexual attention or advances of any kind * Trolling, insulting or derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or email address, without their explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. ## Scope This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at [quic_team@genmeta.net]. All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and security of the reporter of any incident. ## Enforcement Guidelines Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: ### 1. Correction **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. ### 2. Warning **Community Impact**: A violation through a single incident or series of actions. **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. ### 3. Temporary Ban **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. ### 4. Permanent Ban **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. **Consequence**: A permanent ban from any sort of public interaction within the community. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations. ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to dquic Welcome all feedback and PRs, including bug reports, feature requests, documentation improvements, and code refactoring. However, please note that dquic has extremely strict quality requirements for code and documentation. The quality of code and documentation will undergo rigorous review before being merged. Contributors must understand and patiently address all feedback before merging. If you are unsure about the reasonableness of a feature or its implementation, please first create an issue in the [issue list](https://github.com/genmeta/dquic/issues) for discussion to ensure that the feature is reasonable and has a good implementation plan. ================================================ FILE: Cargo.toml ================================================ [workspace] resolver = "2" members = [ "qmacro", "qbase", "qevent", "qrecovery", "qcongestion", "qudp", "qinterface", "qprotocol", "qdatagram", "qconnection", "dquic", "h3-shim", "qtraversal", "qresolve", ] default-members = [ "qmacro", "qbase", "qevent", "qrecovery", "qcongestion", "qinterface", "qprotocol", "qconnection", "dquic", "h3-shim", "qtraversal", ] [workspace.package] version = "0.5.0" edition = "2024" readme = "README.md" repository = "https://github.com/genmeta/dquic" license = "Apache-2.0" keywords = ["async", "quic", "http3"] categories = ["network-programming", "asynchronous"] rust-version = "1.88.0" [workspace.dependencies] arc-swap = "1" async-trait = "0.1.88" bitflags = "2" bon = "3" bytes = "1" cfg-if = "1" dashmap = "6" derive_builder = "0.20" derive_more = "2" enum_dispatch = "0.3" futures = "0.3" getset = "0.1" netdev = "0.42" nom = "8" netwatcher = "0.4" pin-project-lite = "0.2" rand = "0.10" ring = "0.17" rustls = { version = "0.23", default-features = false, features = ["std"] } serde = { version = "1", features = ["derive"] } serde_json = "1" serde_with = "3" smallvec = { version = "1", features = [ "union", "const_generics", "const_new", ] } socket2 = { version = "0.6", features = ["all"] } snafu = "0.8" thiserror = "2" tokio = { version = "1" } tokio-util = { version = "0.7" } tracing = "0.1" x509-parser = "0.18" url = "2.5.7" # h3 for h3-shim only , windows-sys, nix and libc for qudp only # they are not the default members of the workspace # windows-sys = "?" # libc = "0.2" # nix = "?" # dev-dependencies, for examples clap = { version = "4", features = ["derive"] } h3 = "0.0.8" h3-datagram = "0.0.2" http = "1" indicatif = { version = "0.18", features = ["tokio"] } parking_lot = "0.12" postcard = { version = "1", features = ["use-std"] } rustls-native-certs = "0.8" tracing-subscriber = "0.3" tracing-appender = "0.2" # members qmacro = { path = "./qmacro", version = "0.5.0" } qbase = { path = "./qbase", version = "0.5.0" } qevent = { path = "./qevent", version = "0.5.0" } qudp = { path = "./qudp", version = "0.5.0" } qinterface = { path = "./qinterface", version = "0.5.0" } qdatagram = { path = "./qdatagram", version = "0.5.0" } qresolve = { path = "./qresolve", version = "0.5.0" } qrecovery = { path = "./qrecovery", version = "0.5.0" } qtraversal = { path = "./qtraversal", version = "0.5.0" } qcongestion = { path = "./qcongestion", version = "0.5.0" } qconnection = { path = "./qconnection", version = "0.5.0" } dquic = { path = "./dquic", version = "0.5.0" } h3-shim = { path = "./h3-shim", version = "0.5.0" } [profile.bench] debug = true [profile.release] debug = true ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # dquic [![License: Apache-2.0](https://img.shields.io/github/license/genmeta/dquic)](https://www.apache.org/licenses/LICENSE-2.0) [![Build Status](https://img.shields.io/github/actions/workflow/status/genmeta/dquic/rust.yml)](https://github.com/genmeta/dquic/actions/workflows/rust.yml) [![codecov](https://codecov.io/gh/genmeta/dquic/graph/badge.svg)](https://codecov.io/gh/genmeta/dquic) [![crates.io](https://img.shields.io/crates/v/dquic.svg)](https://crates.io/crates/dquic) [![Documentation](https://docs.rs/dquic/badge.svg)](https://docs.rs/dquic/) [![Dependencies](https://img.shields.io/deps-rs/repo/github/genmeta/dquic)](https://github.com/genmeta/dquic/network/dependencies) ![MSRV](https://img.shields.io/crates/msrv/dquic) English | [中文](README_CN.md) The QUIC protocol is an important infrastructure for the next generation Internet, and `dquic` is a native asynchronous Rust implementation of the QUIC protocol, an efficient and scalable [RFC 9000][1] implementation with excellent engineering quality. `dquic` not only implements the standard QUIC protocol but also includes additional extensions such as [RFC 9221 (Unreliable Datagram Extension)][3] and [qlog (QUIC event logging)][2]. As widely recognized, QUIC possesses numerous advanced features and unparalleled security, making it highly suitable for applications in: **High-performance data transmission:** - Achieves 0-RTT connection establishment to minimize latency. - Utilizes multiplexed streams to eliminate head-of-line blocking and improve throughput. - Multi-path transmission to improve transmission capacity. - Efficient transmission control algorithms such as BBR ensure low latency and high bandwidth utilization. **Data privacy and security:** - Integrates TLS 1.3 encryption by default for end-to-end security. - Implements forward-secure keys and authenticated packet headers to resist tampering. **IoT and edge computing:** - Supports connection migration to maintain sessions across network changes (e.g., Wi-Fi to cellular). - Enables lightweight communication with unreliable datagrams (RFC 9221) for real-time IoT scenarios. These characteristics position QUIC as a transformative protocol for modern networks, combining performance optimizations with robust cryptographic guarantees. ## Design The QUIC protocol is a rather complex, IO-intensive protocol, making it extremely fit for asynchronous programming. The basic events in asynchronous IO are read, write, and timers. However, throughout the implementation of the QUIC protocol, the internal events are intricate and dazzling. If you look at the protocol carefully, you will found that certain structures become evident, revealing that the core of the QUIC protocol is driven by layers of underlying IO events progressively influencing the application layer behavior. For example, when the receiving data of a stream is contiguous, it constitutes an event that awakens the corresponding application layer to read; similarly, when the Initial data exchange completes and the Handshake keys are obtained, this is another event that awakens the task processing the Handshake data packet. These events illustrate the classic Reactor pattern. `dquic` refines and encapsulates these various internal Reactors of QUIC, making each module more independent, clarifying the cooperation between the system's modules, and thereby making the overall design more user-friendly. It is noticeable that the QUIC protocol has multiple layers. In the transport layer, there are many functions such as opening new connections, receiving, sending, reading, writing, and accepting new connections, most of which are asynchronous. Here, we call these functions as various functors with each layer having its own functor. With these layers in place, it becomes clear that the `Accept Functor` and the `Read Functor`, or the `Write Functor`, do not belong to the same layer, which is quite interesting. ![image](https://github.com/genmeta/dquic/blob/main/images/arch.png?raw=true) ## Overview - **qbase**: Core structure of the QUIC protocol, including variable integer encoding (VarInt), connection ID management, stream ID, various frame and packet type definitions, and asynchronous keys. - **qrecovery**: The reliable transport part of QUIC, encompassing the state machine evolution of the sender/receiver, and the internal logic interaction between the application layer and the transport layer. - **qcongestion**: Congestion control in QUIC, which abstracts a unified congestion control interface and implements BBRv1. In the future, it will also implement more transport control algorithms such as Cubic and others. - **qinterface**: QUIC's packet routing and definition of the underlying I/O interface (`QuicIO`) enable dquic to run in various environments. Contains an optional qudp-based `QuicIO` implementation - **qdatagram**: The extension for unreliable datagram transmission based on QUIC offers transmission control mechanisms and enhanced security compared to directly sending unreliable datagrams over UDP. See [RFC 9221][3]. - **qconnection**: Encapsulation of QUIC connections, linking the necessary components and tasks within a QUIC connection to ensure smooth operation. - **dquic**: The top-level encapsulation of the QUIC protocol, including interfaces for both the QUIC client and server. - **qudp**: High-performance UDP encapsulation for QUIC. Ordinary UDP incurs a system call for each packet sent or received, resulting in poor performance. - **qevent**: The implementation of [qlog][2] supports logging internal activities of individual QUIC connections in JSON format, maintains compatibility with qlog 3, and enables visualization analysis through [qvis][4]. However, it is important to note that enabling qlog can significantly impact performance despite its utility in troubleshooting. ![image](https://github.com/genmeta/dquic/blob/main/images/qvis.png?raw=true) ## Usage #### Demos Run h3 server: ```shell cargo run --example h3-server --package h3-shim -- --dir ./h3-shim ``` Send a h3 request: ```shell cargo run --example h3-client --package h3-shim -- https://localhost:4433/examples/h3-server.rs ``` For more complete examples, please refer to the `examples` folders under the `h3-shim` and `dquic` folders. #### API `dquic` provides user-friendly interfaces for creating client and server connections, while also supporting additional features that meet modern network requirements. In addition to bind an IP address + port, `dquic` can also bind a network interface, dynamically adapting to actual address changes, which provides good mobility for dquic. The QUIC client not only provides configuration options specified by the QUIC protocol's Parameters and optional 0-RTT functionality, but also includes some additional advanced options. For example, the QUIC client can set its own certificate for server verification, and can also set its own Token manager to manage Tokens issued by various servers for future connections with these servers. The QUIC client supports multipath handshaking, it can simultaneously connect to server's IPv4 and IPv6 addresses. Even if some paths are unreachable, as long as one path is reachable, the connection can be established. The following is a simple example, please refer to the documentation for more details. ```rust use std::path::PathBuf; use std::sync::Arc; use dquic::prelude::{handy::ToCertificate, *}; async fn client() -> Result<(), Box> { // Set up root certificate store let mut roots = rustls::RootCertStore::empty(); // Load system certificates roots.add_parsable_certificates(rustls_native_certs::load_native_certs().certs); // Load custom certificates (can be used independently of system certificates) roots.add_parsable_certificates(PathBuf::from("/path/to/ca.cert").to_certificate()); // Load at runtime // roots.add_parsable_certificates(include_bytes!("/path/to/ca.cert").to_certificate()); // Embed at compile time // Build the QUIC client let quic_client = Arc::new(QuicClient::builder() .with_root_certificates(roots) .without_cert() // Client certificate verification is typically not required // .with_parameters(your_parameters) // Custom transport parameters // .bind(["iface://v4.eth0:0", "iface://v6.eth0:0"]) // Bind to specific network interfaces // .enable_0rtt() // Enable 0-RTT // .enable_sslkeylog() // Enable SSL key logging // .with_qlog(Arc::new(handy::LegacySeqLogger::new( // PathBuf::from("/path/to/qlog_dir"), // ))) // Enable qlog for visualization with qvis tool .build()); // Connect to the server let connection = quic_client.connect("localhost").await?; // Start using the QUIC connection! // For more usage examples, see dquic/examples and h3-shim/examples Ok(()) } ``` The QUIC server is represented as `QuicListeners`, supporting SNI (Server Name Indication), allowing multiple Servers to be started in one process, each with their own certificates and keys. Each server can also bind to multiple addresses, and multiple Servers can bind to the same address. Clients must correctly connect to the corresponding interface of the corresponding Server, otherwise the connection will be rejected. QuicListeners supports verifying client identity through various methods, including through `client_name` transport parameters, verifying client certificate content, etc. QuicListeners also supports anti-port scanning functionality, only responding after preliminary verification of client identity. ```rust use std::path::PathBuf; use dquic::prelude::*; async fn server() -> Result<(), Box> { let quic_listeners = QuicListeners::builder() .without_client_cert_verifier() // Client certificate verification is typically not required // .with_parameters(your_parameters) // Custom transport parameters // .enable_0rtt() // Enable 0-RTT for servers // .enable_anti_port_scan() // Anti-port scanning protection .listen(8192)?; // Start listening with backlog (similar to Unix listen) // Add a server that can be connected quic_listeners.add_server( "localhost", // Certificate and key files as byte arrays or paths PathBuf::from("/path/to/server.cert").as_path(), PathBuf::from("/path/to/server.key").as_path(), [ "192.168.1.108:4433", // Bind to the IPv4 address "iface://v6.eth0:4433", // Bind to the eth0's IPv6 address ], None, // ocsp ).await?; // Continue calling `quic_listeners.add_server()` to add more servers // Call `quic_listeners.remove_server()` to remove a server // Accept trusted new connections while let Ok((connection, server_name, pathway, link)) = quic_listeners.accept().await { // Handle the incoming QUIC connection! // You can refer to examples in dquic/examples and h3-shim/examples } Ok(()) } ``` There is an asynchronous interface for creating unidirectional or bidirectional QUIC streams from a QUIC Connection, or for listening to incoming streams from the other side of a QUIC Connection. This interface is almost identical to the one in [`hyperium/h3`](https://github.com/hyperium/h3/blob/master/docs/PROPOSAL.md#5-quic-transport). For reading and writing data from QUIC streams, the standard **`AsyncRead`** and **`AsyncWrite`** interfaces are implemented for QUIC streams, making them very convenient to use. ## Performance GitHub Actions periodically runs [benchmark tests][5]. The results show that dquic, quiche, tquic and quinn all deliver excellent performance, with each excelling in different benchmark testing scenarios. It should be noted that transmission performance is also greatly related to congestion control algorithms. dquic's performance will continue to be optimized in the coming period. If you want higher performance, dquic provides abstract interfaces that can use DPDK or XDP to replace UdpSocket! ## Contribution All feedback and PRs are welcome, including bug reports, feature requests, documentation improvements, code refactoring, and more. If you are unsure whether a feature or its implementation is reasonable, please first create an issue in the [issue list](https://github.com/genmeta/dquic/issues) for discussion. This ensures the feature is reasonable and has a solid implementation plan. ## Community - [Official Community](https://github.com/genmeta/dquic/discussions) - chat group:[send email](mailto:quic_team@genmeta.net) to introduce your contribution, and we will reply to your email with an invitation link and QR code to join the group. [1]: https://www.rfc-editor.org/rfc/rfc9000.html [2]: https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-quic-events/ [3]: https://datatracker.ietf.org/doc/html/rfc9221 [4]: https://qvis.quictools.info/#/files [5]: https://github.com/genmeta/dquic/actions ================================================ FILE: README_CN.md ================================================ # dquic [![License: Apache-2.0](https://img.shields.io/github/license/genmeta/dquic)](https://www.apache.org/licenses/LICENSE-2.0) [![Build Status](https://img.shields.io/github/actions/workflow/status/genmeta/dquic/rust.yml)](https://github.com/genmeta/dquic/actions/workflows/rust.yml) [![codecov](https://codecov.io/gh/genmeta/dquic/graph/badge.svg)](https://codecov.io/gh/genmeta/dquic) [![crates.io](https://img.shields.io/crates/v/dquic.svg)](https://crates.io/crates/dquic) [![Documentation](https://docs.rs/dquic/badge.svg)](https://docs.rs/dquic/) [![Dependencies](https://img.shields.io/deps-rs/repo/github/genmeta/dquic)](https://github.com/genmeta/dquic/network/dependencies) ![MSRV](https://img.shields.io/crates/msrv/dquic) [English](README.md) | 中文 QUIC协议是下一代互联网重要的基础设施,而`dquic`则是一个原生异步Rust的QUIC协议实现,一个高效的、可扩展的[RFC 9000][1]实现,同时工程质量优良。 `dquic`不仅实现了标准QUIC协议,还额外实现了[RFC 9221 (Unreliable Datagram Extension)][3]、[qlog (QUIC event logging)][2]等扩展。 众所周知,QUIC拥有许多优良特性,以及极致的安全性,十分适合在高性能传输、数据隐私安全、物联网领域推广使用: **高性能数据传输:** - 0-RTT握手,最小化建连时延 - 流的多路复用,消除了头端阻塞,提升吞吐率 - 多路径传输,提升传输能力 - BBR等高效的传输控制算法,保证低时延、高带宽利用率 **数据隐私安全:** - 默认集成TLS 1.3端到端加密 - 实现前向安全密钥和经过身份验证的数据包头,以抵御篡改。 **IoT和边缘计算:** - 支持连接迁移,以便在网络变化(例如从Wi-Fi切换到蜂窝网络)时保持会话。 - 实现轻量级通信,支持不可靠数据报(RFC 9221),适用于实时物联网场景。 ## 设计原则 QUIC协议可谓一个相当复杂的、IO密集型的协议,因此正是适合异步大显身手的地方。异步IO中最基本的事件有数据可读、可写,以及定时器,但纵观整个QUIC协议实现,内部的事件错综复杂、令人眼花缭乱。然而,仔细探查之下还是能发现一些结构,会发现QUIC协议核心是由一层层底层IO事件逐步向上驱动应用层行为的。比如当一个流接收数据至连续,这也是一个事件,将唤醒对应的应用层来读;再比如,当Initial数据交互完毕获得Handshake密钥之后,这也是一个事件,将唤醒Handshake数据包任务的处理。以上这些事件就是经典的Reactor模式,`dquic`正是对这些QUIC内部形形色色的Reactor的拆分细化和封装,让各个模块更加独立,让整个系统各模块配合的更加清晰,进而整体设计也更加人性化。 注意到QUIC协议内部,还能分出很多层。在传输层,有很多功能比如打开新连接、接收、发送、读取、写入、Accept新连接,它们大都是异步的,在这里称之为各种“算子”,且每层都有自己的算子,有了这些分层之后,就会发现,其实Accept算子和Read算子、Write算子根本不在同一层,很有意思。 ![image](https://github.com/genmeta/dquic/blob/main/images/arch.png) ## 概览 - **qbase**: QUIC协议的基础结构,包括可变整型编码VarInt、连接ID管理、流ID、各种帧以及包类型定义、异步密钥等 - **qrecovery**: QUIC的可靠传输部分,包括发送端/接收端的状态机演变、应用层与传输层的内部逻辑交互等 - **qcongestion**: QUIC的拥塞控制,抽象了统一的拥塞控制接口,并实现了BBRv1,未来还会实现Cubic、ETC等更多的传输控制算法 - **qinterface**: QUIC的数据包路由和对底层I/O接口(`QuicIO`)的定义,令dquic可以运行在各种环境。内含一个可选的基于qudp的`QuicIO`实现 - **qconnection**: QUIC连接封装,将QUIC连接内部所需的各组件、任务串联起来,最终能够完美运行 - **dquic**: QUIC协议的顶层封装,包括QUIC客户端和服务端2部分的接口 - **qudp**: QUIC的高性能UDP封装,使用GSO、GRO等手段极致优化UDP的性能 - **qdatagram**: 基于QUIC的不可靠数据报传输的扩展,相比于直接用UDP发送不可靠数据报,该扩展拥有QUIC的传输控制和极致安全性。详情参考[RFC 9221][3] - **qevent**: [qlog][2]的实现,支持以json形式记录单个quic连接内部活动,兼容qlog 3,支持[qvis][4]可视化分析。请注意,开启qlog虽有助于分析问题,但相当影响性能 ![image](https://github.com/genmeta/dquic/blob/main/images/qvis.png?raw=true) ## 使用方式 #### 样例演示 本仓库提供了三组样例: - `echo-client`和`echo-server`: 位于`dquic/examples/`文件夹下,展示了dquic的基本使用方法。 - `http-client`和`http-server`: 位于`dquic/examples/`文件夹下,展示了在dquic上运行HTTP/0.9协议。 - `h3-client`和`h3-server`: 位于`h3-shim/examples/`文件夹下,展示了在dquic上运行HTTP/3协议。 以H3为例,运行一个H3服务器: ```shell cargo run --example h3-server --package h3-shim -- --dir ./h3-shim ``` 发起一个H3请求: ```shell cargo run --example h3-client --package h3-shim -- https://localhost:4433/examples/h3-server.rs ``` #### API简介 `dquic`提供了人性化的接口创建客户端和服务端的连接,同时还支持一些符合现代网络需求的附加功能设置。 除了可以绑定到ip地址+端口,`dquic`还支持绑定到网络接口上,以动态地适应实际地址变化,这使得`dquic`拥有了良好的移动性。 QUIC客户端不仅提供了QUIC协议所规定的Parameters选项配置,可选的0RTT功能,还有一些额外的高级选项,比如QUIC客户端可设置自己的证书以供服务端验证,也可设置自己的Token管理器,管理着各服务器颁发的Token,以便未来和这些服务器再次连接时用的上。 QUIC客户端支持多路径握手,即同时尝试连接到服务器的IPv4和IPv6地址,即使某些路径不可达,但只要有一条路径能够联通,连接就可以建立。如果对端的实现同样是dquic,则还支持多路径传输。 以下为简单示例,更多细节请参阅文档。 ```rust use std::path::PathBuf; use std::sync::Arc; use dquic::prelude::{handy::ToCertificate, *}; async fn client() -> Result<(), Box> { // 设置根证书存储 let mut roots = rustls::RootCertStore::empty(); // 加载系统证书 roots.add_parsable_certificates(rustls_native_certs::load_native_certs().certs); // 加载自定义证书(可与系统证书独立使用) roots.add_parsable_certificates(PathBuf::from("path/to/your/cert.pem").to_certificate()); // 运行时加载 // roots.add_parsable_certificates(include_bytes!("path/to/your/cert.pem").to_certificate()); // 编译时嵌入 // 构建QUIC客户端 let quic_client = Arc::new(QuicClient::builder() .with_root_certificates(roots) .without_cert() // 通常不需要客户端证书验证 // .with_parameters(your_parameters) // 自定义传输参数 // .bind(["iface://v4.eth0:0", "iface://v6.eth0:0"]) // 绑定到指定网络接口eth0的IPv4和IPv6地址 // .enable_0rtt() // 启用0-RTT // .enable_sslkeylog() // 启用SSL密钥日志 // .with_qlog(Arc::new(handy::LegacySeqLogger::new( // PathBuf::from("/path/to/qlog_dir"), // ))) // 启用qlog,可用qvis工具可视化 .build()); // 连接到服务器 let connection = quic_client.connect("localhost").await?; // 开始使用QUIC连接! // 更多使用示例请参考 dquic/examples 和 h3-shim/examples Ok(()) } ``` QUIC服务端表现为`QuicListeners`,支持SNI(Server Name Indication),在一个进程启动多个Server,分别有自己的证书和密钥,每个服务端又可以绑定到多个地址上,支持多个Server绑定同一个地址。Client必须正确连接到对应的Server的对应接口上,否则连接会被自动拒绝。 QuicListeners支持通过多种方法验证客客户端的身份,包括通过`client_name`传输参数,验证客户端证书的内容等。QuicListeners还支持抗端口扫描功能,只有在初步验证客户端的身份后才会做出响应。 ```rust // 创建QUIC监听器(每个程序只能有一个实例) use std::path::PathBuf; use dquic::prelude::*; async fn server() -> Result<(), Box> { let quic_listeners = QuicListeners::builder() .without_client_cert_verifier() // 通常不需要客户端证书验证 // .with_parameters(your_parameters) // 自定义传输参数 // .enable_0rtt() // 为服务器启用0-RTT // .enable_anti_port_scan() // 抗端口扫描保护 .listen(8192)?; // 开始监听,设置积压队列(类似Unix listen) // 添加可连接的服务器 quic_listeners.add_server( "localhost", // 证书和密钥文件的字节数组或路径 PathBuf::from("/path/to/server.crt").as_path(), PathBuf::from("/path/to/server.key").as_path(), [ "192.168.1.106:4433", // 绑定到此IPv4地址 "iface://v6.eth0:4433", // 绑定到eth0的IPv6地址 ], None, // ocsp ).await?; // 继续调用 `quic_listeners.add_server()` 来添加更Server // 调用 `quic_listeners.remove_server()` 来移除一个Serer // 接受可信的新连接 while let Ok((connection, server_name, pathway, link)) = quic_listeners.accept().await { // 处理传入的QUIC连接! // 可以参考 dquic/examples 和 h3-shim/examples 中的示例 } Ok(()) } ``` 关于如何从QUIC Connection中创建单向QUIC流,或者双向QUIC流,抑或是从QUIC Connection监听来自对方的流,都有一套异步的接口,这套接口几乎与[`hyperium/h3`](https://github.com/hyperium/h3/blob/master/docs/PROPOSAL.md#5-quic-transport)的接口相同。 至于如何从QUIC流中读写数据,则为QUIC流实现了标准的 **`AsyncRead`** 、 **`AsyncWrite`** 接口,可以很方便地使用。 ## 性能 github action会定期运行[基准测试][5],效果如下。go-quic和quiche、tquic、quinn都具备优良性能,在三种基准测试场景下互有千秋。须知传输性能跟传输控制算法也有很大关系,dquic的性能在未来一段时间还会持续优化,如果想获得更高性能,dquic提供了抽象接口,可使用DPDK或者XDP代替UdpSocket! ## 贡献 欢迎所有反馈和PR,包括bug反馈、功能请求、文档修缮、代码重构等。 如果不确定一个功能或者其实现是否合理,请首先在[issue列表](https://github.com/genmeta/dquic/issues)中创建一个issue,大家一起讨论,以确保功能是合理的,并有一个良好的实现方案。 ## 社区交流 - [用户论坛](https://github.com/genmeta/dquic/discussions) - 聊天群:[发送邮件](mailto:quic_team@genmeta.net)介绍一下您的贡献,我们将邮件回复您加群链接及群二维码。 [1]: https://www.rfc-editor.org/rfc/rfc9000.html [2]: https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-quic-events/ [3]: https://datatracker.ietf.org/doc/html/rfc9221 [4]: https://qvis.quictools.info/#/files [5]: https://github.com/genmeta/dquic/actions ================================================ FILE: SECURITY.md ================================================ # Security Policy ## Supported Versions Use this section to tell people about which versions of your project are currently being supported with security updates. | Version | Supported | | ------- | ------------------ | | 0.5.x | :white_check_mark: | ## Reporting a Vulnerability Use this section to tell people how to report a vulnerability. Tell them where to go, how often they can expect to get an update on a reported vulnerability, what to expect if the vulnerability is accepted or declined, etc. ================================================ FILE: benchmark/launch.py ================================================ #!/usr/bin/env python3 import os import subprocess import re import json import logging import shutil import argparse class ServerRunner: name: str launch_server: list[str] listen_port: int def __init__(self, impl_name: str, launch_server: list[str], listen_port: int): self.name = impl_name self.listen_port = listen_port self.launch_server = launch_server def run(self, log) -> subprocess.Popen: # 在后台运行server return subprocess.Popen( self.launch_server, cwd=rand_files.path, stdout=log, stderr=log, env={**os.environ, "RUST_LOG": "off"} ) class Result: success: int duration: float qps: float def __init__(self, success: int, duration: float): self.success, self.duration = success, duration self.qps = success / duration if duration > 0 else 0 def __str__(self): return f"success={self.success}, duration={self.duration}, qps={self.qps}" def __repr__(self): return self.__str__() @staticmethod def average(results: list['Result']) -> 'Result': total_success, total_duration = 0, 0 total_success = sum(result.success for result in results) total_duration = sum(result.duration for result in results) return Result(total_success, total_duration) root = os.path.join(os.path.dirname(__file__)) class RandomFiles: path = os.path.join(root, "rand_files") def __init__(self): if not os.path.exists(self.path): os.makedirs(self.path) def gen(self, file_size: int) -> str: file_name = f"rand_file_{file_size}.bin" logging.info(f"Generating {file_name}...") file_path = os.path.join(self.path, file_name) if not os.path.exists(file_path): with open(file_path, "wb") as f: f.write(os.urandom(int(file_size) * 1024)) return file_name class Certs: path = os.path.join(root, "certs") root_cert = os.path.join(path, "root_cert.pem") root_key = os.path.join(path, "root_key.pem") server_cert = os.path.join(path, "server_cert.pem") server_key = os.path.join(path, "server_key.pem") server_csr = os.path.join(path, "server_csr.pem") server_ext = os.path.join(path, "server.ext") server_cert_der = os.path.join(path, "server_cert.der") server_key_der = os.path.join(path, "server_key.der") def __init__(self): pass def gen(self): if not os.path.exists(self.path): logging.info("Generating certs...") os.makedirs(self.path) # CA subprocess.run( ["openssl", "ecparam", "-name", "prime256v1", "-genkey", "-out", self.root_key], check=True) subprocess.run( ["openssl", "req", "-new", "-x509", "-key", self.root_key, "-out", self.root_cert, "-days", "3650", "-subj", "/CN=localhost", "-addext", "subjectAltName=DNS:localhost"], check=True) # Server subprocess.run( ["openssl", "ecparam", "-name", "prime256v1", "-genkey", "-out", self.server_key], check=True) subprocess.run( ["openssl", "req", "-new", "-key", self.server_key, "-out", self.server_csr, "-subj", "/CN=localhost", "-addext", "subjectAltName=DNS:localhost"], check=True) # use server ext to add subjectAltName, openssl binary on macos CI doesnot support `-copy-extensions copy` parameter with open(self.server_ext, "w") as f: f.write("subjectAltName=DNS:localhost\n") f.flush() subprocess.run( ["openssl", "x509", "-req", "-in", self.server_csr, "-CA", self.root_cert, "-CAkey", self.root_key, "-CAcreateserial", "-out", self.server_cert, "-days", "365", "-extfile", self.server_ext], check=True) # Convert pem to der subprocess.run( ["openssl", "x509", "-in", self.server_cert, "-outform", "der", "-out", self.server_cert_der], check=True) subprocess.run( ["openssl", "ec", "-in", self.server_key, "-outform", "der", "-out", self.server_key_der], check=True) rand_files = RandomFiles() ecc_certs = Certs() quic_go_dir = os.path.join(root, "go-quic-demo") dquic_dir = os.path.join(root, "..") tquic_dir = os.path.join(root, "tquic") quinn_dir = os.path.join(root, "h3") quiche_dir = os.path.join(root, "quiche") def git_clone(owner: str, repo: str, branch: str) -> None: if not os.path.exists(os.path.join(root, repo)): logging.info(f"Cloning {owner}/{repo}...") subprocess.run( ["git", "clone", "--recursive", "--branch", branch, f"https://github.com/{owner}/{repo}"], cwd=root, ) def go_quic_runner() -> ServerRunner: logging.info("Building quic-go server...") git_clone("eareimu", "go-quic-demo", "main") # 编译 subprocess.run( ["go", "get", "example/quic-server",], cwd=quic_go_dir, check=True ) subprocess.run( ["go", "build", "-ldflags=-s -w", "-trimpath", "-o", "quic_server"], cwd=quic_go_dir, check=True ) binary = os.path.join(quic_go_dir, "quic_server") launch = [ binary, "-c", ecc_certs.server_cert, "-k", ecc_certs.server_key, "-a", "[::1]:4430", ] return ServerRunner('quic-go', launch, 4430) def dquic_runner() -> ServerRunner: logging.info("Building dquic server...") # git_clone("genmeta", "dquic", "main") # 编译 subprocess.run( ["cargo", "build", "--release", "--package", "h3-shim", "--example", "h3-server"], cwd=dquic_dir, check=True ) launch = [ os.path.join(dquic_dir, "target", "release", "examples", "h3-server"), "-c", ecc_certs.server_cert, "-k", ecc_certs.server_key, "-b", "4096", # 设置backlog "-l", "[::1]:4431" ] return ServerRunner('dquic', launch, 4431) def dquic_multi_path_runner() -> ServerRunner: logging.info("Building dquic server...") # git_clone("genmeta", "dquic", "main") # 编译 subprocess.run( ["cargo", "build", "--release", "--package", "h3-shim", "--example", "h3-server"], cwd=dquic_dir, check=True ) launch = [ os.path.join(dquic_dir, "target", "release", "examples", "h3-server"), "-c", ecc_certs.server_cert, "-k", ecc_certs.server_key, "-b", "4096", # 设置backlog "-l", "[::1]:4435", "-l", "127.0.0.1:4435" ] return ServerRunner('dquic-multi-path', launch, 4435) def tquic_runner() -> ServerRunner: logging.info("Building tquic server...") git_clone("Tencent", "tquic", "v1.6.0") subprocess.run( ["cargo", "build", "--release", "--package", "tquic_tools", "--bin", "tquic_server"], cwd=tquic_dir, check=True ) launch = [ os.path.join(tquic_dir, "target", "release", "tquic_server"), "-c", ecc_certs.server_cert, "-k", ecc_certs.server_key, "-l", "[::1]:4432", "--log-level", "OFF", ] return ServerRunner('tquic', launch, 4432) def quinn_runner() -> ServerRunner: logging.info("Building quinn server...") git_clone("hyperium", "h3", "h3-quinn-v0.0.9") subprocess.run( ["cargo", "build", "--release", "--example", "server"], cwd=quinn_dir, check=True ) launch = [ os.path.join(quinn_dir, "target", "release", "examples", "server"), "-c", ecc_certs.server_cert_der, "-k", ecc_certs.server_key_der, "-l", "[::1]:4433", "-d", "." # 实际上是rand-files ] return ServerRunner('quinn', launch, 4433) def cf_quiche_runner() -> ServerRunner: logging.info("Building cloudflare-quiche server...") git_clone("cloudflare", "quiche", "0.23.4") subprocess.run( ["cargo", "build", "--release", "--bin", "quiche-server"], cwd=quiche_dir, check=True ) launch = [ os.path.join(quiche_dir, "target", "release", "quiche-server"), "--key", ecc_certs.server_key, "--cert", ecc_certs.server_cert, "--listen", "[::1]:4434", "--root", ".", "--no-retry" ] return ServerRunner('cloudflare quiche', launch, 4434) class H3Client: stress: int requests: int progress: bool def __init__(self, stress: int = 1024*30, requests: int = 8, progress: bool = False): logging.info("Building dquic client") subprocess.run( [ "cargo", "build", "--package", "h3-shim", "--release", "--example", "h3-client", ], check=True ) self.stress = stress self.requests = requests self.progress = progress def run_once(self, server_runner: ServerRunner, file_size: int, seq: int = 0) -> Result: logging.info(f"Launch {server_runner.name} server and client") # 在后台启动server log_dir = os.path.join(output_dir, "logs") if not os.path.exists(log_dir): os.makedirs(log_dir) client_log = f"client_{server_runner.name}_{file_size}KB_{seq}.log" client_log = open(os.path.join(log_dir, client_log), "w+") server_log = f"server_{server_runner.name}_{file_size}KB_{seq}.log" server_log = open(os.path.join(log_dir, server_log), "w+") server = server_runner.run(server_log) launch_client = [ "cargo", "run", "--package", "h3-shim", "--release", "--example", "h3-client", "--", "--conns", str(int(self.stress / file_size)), "--reqs", str(self.requests), "--roots", ecc_certs.root_cert, "--progress", "true" if self.progress else "false", "--ansi", "false", f'https://localhost:{server_runner.listen_port}/{rand_files.gen(file_size)}' ] try: subprocess.run( launch_client, cwd=dquic_dir, env={**os.environ, "RUST_LOG": "counting"}, stdout=client_log, text=True, timeout=15 ) except subprocess.TimeoutExpired: server.kill() server.wait() logging.warning( f"Timeout expired for running {server_runner.name} {file_size}KB") client_log.close() server_log.close() return Result(success=0, duration=0) server.kill() server.wait() client_log.seek(0) output = client_log.read() client_log.close() server_log.close() # Extract total_time and success_queries using regex match = re.search( r"success_queries=(\d+).*?total_time=(\d+\.?\d*)", output) if match: success_queries = int(match.group(1)) total_time = float(match.group(2)) return Result(success=int(success_queries), duration=total_time) else: logging.error(f"Failed to parse benchmark output: {output}") return Result(success=0, duration=0) def run_many(self, server_runner: ServerRunner, file_size: int, times: int = 3) -> list[Result]: results = [] for seq in range(0, times): once = self.run_once(server_runner, file_size, seq) logging.info( f"Run {server_runner.name} {file_size}KB complete: {once}") results.append(once) return results def run(*runners: ServerRunner) -> dict[dict[str, list[Result]]]: ecc_certs.gen() client = H3Client(stress=2048*15, requests=8, progress=True) return { runner.name: { file_size: client.run_many(runner, file_size, times=10) for file_size in [15, 30, 2048] } for runner in runners } output_dir = os.path.join(root, "output") def plot_results(results: dict[str, dict[str, list[Result]]]): import matplotlib.pyplot as plt # [实现名, [文件大小, 多次运行的结果]] plot_out_dir = os.path.join(output_dir, "plots") if not os.path.exists(plot_out_dir): os.makedirs(plot_out_dir) implementations = sorted(results.keys()) file_sizes = sorted(results[implementations[0]].keys()) for file_size in file_sizes: plt.figure(figsize=(10, 6)) # 平均图 qps_values = [ Result.average(results[impl][file_size]).qps for impl in implementations ] bars = plt.bar(implementations, qps_values) plt.title(f"file size {file_size}KB") plt.xlabel("Implementations") plt.ylabel("QPS") plt.xticks(rotation=45) for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2, height, round(height, 2), ha='center', va='bottom') plt.tight_layout() plt.savefig(os.path.join(plot_out_dir, f"benchmark_{file_size}KB.png")) plt.close() # 每个实现的多次运行结果图 for impl in implementations: plt.figure(figsize=(10, 6)) qps_values = [result.qps for result in results[impl][file_size]] bars = plt.bar([i for i in range(len(qps_values))], qps_values) plt.title(f"{impl} file size {file_size}KB") plt.xlabel("Runs") plt.ylabel("QPS") plt.xticks(rotation=45) for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2, height, round(height, 2), ha='center', va='bottom') plt.tight_layout() plt.savefig( os.path.join(plot_out_dir, f"{impl}_{file_size}KB.png")) plt.close() def save_results(results: dict[str, dict[str, list[Result]]]): """save results to json file""" if not os.path.exists(output_dir): os.makedirs(output_dir) with open(os.path.join(output_dir, "results.json"), "w") as f: json.dump({ impl: { size: [r.__dict__ for r in results_list] for size, results_list in sizes_results.items() } for impl, sizes_results in results.items() }, f, indent=2) def load_results(*paths: str) -> dict[str, dict[str, list[Result]]]: """load and merge results from multiple json files""" merged = {} for path in paths: with open(path, "r") as f: results = { impl: { size: [Result(r["success"], r["duration"]) for r in runs] for size, runs in sizes_results.items() } for impl, sizes_results in json.load(f).items() } for impl, sizes in results.items(): if impl not in merged: merged[impl] = {} for size, runs in sizes.items(): if size not in merged[impl]: merged[impl][size] = [] merged[impl][size].extend(runs) return merged if __name__ == "__main__": logging.root.setLevel(logging.INFO) parser = argparse.ArgumentParser( description='QUIC implementation benchmark') subparsers = parser.add_subparsers(dest='command', required=True) runners = { 'quic-go': go_quic_runner, 'dquic': dquic_runner, 'tquic': tquic_runner, 'quinn': quinn_runner, 'cf-quiche': cf_quiche_runner, 'dquic-multi-path': dquic_multi_path_runner, } # Diaplay runners runners_parser = subparsers.add_parser( 'runners', help='List available implementations') runners_parser.add_argument('-q', '--quiet', action='store_true', help='Only display implementation names') # Run command run_parser = subparsers.add_parser( 'run', help='Run benchmark and save results') run_parser.add_argument('implementations', nargs='*', choices=list(runners.keys()), help='Implementations to benchmark') run_parser.add_argument('--no-plot', action='store_true', help='Skip plotting results') # plot command plot_parser = subparsers.add_parser( 'plot', help='Load and plot results from files') plot_parser.add_argument('files', nargs='+', default=[os.path.join(output_dir, "results.json")], help='Results JSON file paths') # Clean command clean_parser = subparsers.add_parser('clean', help='Clean generated files') clean_parser.add_argument('--all', action='store_true', help='Also remove git cloned implementations') args = parser.parse_args() if args.command == 'runners': if not args.quiet: print("Available implementations:") for impl in runners.keys(): print(f"- {impl}") else: print( '[' + ', '.join(f'"{impl}"' for impl in runners.keys()) + ']' ) exit(0) elif args.command == 'run': selected_runners = [ runners[impl]() for impl in args.implementations] if args.implementations else [r() for r in runners.values()] results = run(*selected_runners) save_results(results) if args.no_plot: exit(0) elif args.command == 'plot': results = load_results(*args.files) elif args.command == 'clean': paths = [rand_files.path, ecc_certs.path, output_dir] if args.all: paths.extend([quic_go_dir, tquic_dir, quinn_dir, quiche_dir]) for path in paths: if os.path.exists(path): shutil.rmtree(path) exit(0) plot_results(results) print(results) ================================================ FILE: codecov.yml ================================================ coverage: status: patch: off project: off range: "70..100" ================================================ FILE: commitlint.config.js ================================================ module.exports = { extends: ['@commitlint/config-conventional'], rules: { 'header-max-length': [2, 'always', 160], 'body-max-line-length': [2, 'always', 160], 'footer-max-line-length': [2, 'always', 160], }, } ================================================ FILE: dquic/Cargo.toml ================================================ [package] name = "dquic" version = "0.5.0" edition.workspace = true description = "An IETF quic transport protocol implemented natively using async Rust" readme = "README.md" repository.workspace = true license.workspace = true keywords.workspace = true categories.workspace = true rust-version.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] arc-swap = { workspace = true } bytes = { workspace = true } dashmap = { workspace = true } derive_more = { workspace = true, features = ["deref"] } futures = { workspace = true } qconnection = { workspace = true } qresolve = { workspace = true } rustls = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true } tokio-util = { workspace = true, features = ["rt"] } tracing = { workspace = true } [dev-dependencies] clap = { workspace = true } http = { workspace = true } indicatif = { workspace = true } postcard = { workspace = true } qevent = { workspace = true, features = ["telemetry"] } qtraversal = { workspace = true, features = ["test-ttl"] } rustls = { workspace = true, features = ["ring"] } rustls-native-certs = { workspace = true } serde = { workspace = true } tokio = { workspace = true, features = ["fs", "io-std", "rt-multi-thread"] } tokio-util = { workspace = true, features = ["rt"] } tracing-appender = { workspace = true } x509-parser = { workspace = true } # console-subscriber = "0.4" [features] default = ["datagram"] telemetry = ["qconnection/telemetry"] datagram = ["qconnection/datagram"] [dev-dependencies.tracing-subscriber] workspace = true features = ["env-filter", "time"] ================================================ FILE: dquic/examples/echo-client.rs ================================================ use std::{ borrow::Cow, path::{Path, PathBuf}, sync::Arc, time::Duration, }; use clap::Parser; use dquic::prelude::{handy::ToCertificate, *}; use http::uri::Authority; use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; use qevent::telemetry::handy::{LegacySeqLogger, NoopLogger}; use rustls::RootCertStore; use tokio::{ fs, io::{self, AsyncBufReadExt, AsyncWrite, AsyncWriteExt}, task::JoinSet, }; use tracing_subscriber::prelude::*; #[derive(Parser, Debug)] #[command(name = "server")] struct Options { #[arg(long, help = "Save the qlog to a dir", value_name = "PATH")] qlog: Option, #[arg( long, short, value_delimiter = ',', default_value = "tests/keychain/localhost/ca.cert", help = "Certificates of CA who issues the server certificate" )] roots: Vec, #[arg( long, short, value_delimiter = ',', help = "files that will be sent to server, if not present, stdin will be used" )] files: Vec, #[arg( long, short = 'p', action = clap::ArgAction::Set, help = "enable progress bar", default_value = "false", value_enum )] progress: bool, #[arg( long, default_value = "true", action = clap::ArgAction::Set, help = "Enable ANSI color output in logs" )] ansi: bool, #[arg(default_value = "localhost:4433", help = "Host and port to connect to")] auth: Authority, } #[tokio::main] async fn main() { let options = Options::parse(); let (non_blocking, _guard) = tracing_appender::non_blocking(std::io::stdout()); tracing_subscriber::registry() // .with( // console_subscriber::ConsoleLayer::builder() // .server_addr("127.0.0.1:6670".parse::().unwrap()) // .spawn(), // ) .with( tracing_subscriber::fmt::layer() .with_writer(non_blocking) .with_ansi(options.ansi) .with_filter( tracing_subscriber::EnvFilter::builder() .with_default_directive(match options.progress { true => tracing::level_filters::LevelFilter::OFF.into(), false => tracing::level_filters::LevelFilter::INFO.into(), }) .from_env_lossy(), ), ) .init(); if let Err(error) = run(options).await { tracing::error!(?error); std::process::exit(1); }; } type Error = Box; async fn run(options: Options) -> Result<(), Error> { let qlogger: Arc = match options.qlog { Some(dir) => Arc::new(LegacySeqLogger::new(dir)), None => Arc::new(NoopLogger), }; let mut roots = RootCertStore::empty(); roots.add_parsable_certificates(rustls_native_certs::load_native_certs().certs); roots.add_parsable_certificates(options.roots.iter().flat_map(|path| path.to_certificate())); let client = Arc::new( QuicClient::builder() .with_root_certificates(roots) .without_cert() .with_parameters(handy::client_parameters()) .with_qlog(qlogger) .defer_idle_timeout(Duration::from_secs(60)) .enable_sslkeylog() .enable_0rtt() .build(), ); match options.files { files if files.is_empty() => process(&client, &options.auth, options.progress).await, files => { let files = files.iter().map(|p| p.as_path()); send_and_verify_files(&client, options.auth, files, options.progress).await } } } async fn send_and_verify_files( client: &Arc, auth: Authority, files: impl Iterator, progress: bool, ) -> Result<(), Error> { let pbs = MultiProgress::new(); if !progress { pbs.set_draw_target(ProgressDrawTarget::hidden()); } let total_tx = pbs.add(new_pb("总↑", 0)); let total_rx = pbs.add(new_pb("总↓️", 0)); let mut echos = JoinSet::new(); for path in files { let data = fs::read(path).await?; let (total_tx, total_rx) = (total_tx.clone(), total_rx.clone()); total_tx.inc_length(data.len() as u64); total_rx.inc_length(data.len() as u64); let client = client.clone(); let auth = auth.clone(); let tx_pb = pbs.insert_before(&total_tx, new_pb("↑", data.len() as u64)); let rx_pb = pbs.insert_before(&total_rx, new_pb("↓", data.len() as u64)); echos.spawn(async move { let mut back = vec![]; send_and_verify_echo(&client, &auth, &data, tx_pb, rx_pb, &mut back).await?; assert_eq!(back, data); total_tx.inc(data.len() as u64); total_rx.inc(data.len() as u64); Result::<(), Error>::Ok(()) }); } echos .join_all() .await .into_iter() .collect::>()?; total_tx.finish(); total_rx.finish(); Ok(()) } async fn process(client: &Arc, auth: &Authority, progress: bool) -> Result<(), Error> { eprintln!( "Enter interactive mode. Input anything, enter, then server will echo it back. Input `exit` or `quit` to quit." ); let mut stdin = io::BufReader::new(io::stdin()); let mut stdout = io::stdout(); loop { stdout.write_all(b"\n>").await?; stdout.flush().await?; let mut line = String::new(); stdin.read_line(&mut line).await?; let line = line.trim(); if line == "exit" || line == "quit" { break Ok(()); } let tx_pb = new_pb("↑", line.len() as u64); let rx_pb = new_pb("↓️", line.len() as u64); if !progress { tx_pb.set_draw_target(ProgressDrawTarget::hidden()); rx_pb.set_draw_target(ProgressDrawTarget::hidden()); } send_and_verify_echo(client, auth, line.as_bytes(), tx_pb, rx_pb, &mut stdout).await?; } } fn new_pb(prefix: impl Into>, len: u64) -> ProgressBar { let style = ProgressStyle::default_bar() .template("{prefix} {wide_bar} {percent_precise}% {decimal_bytes_per_sec} ETA: {eta} {msg}") .unwrap(); ProgressBar::new(len).with_style(style).with_prefix(prefix) } async fn send_and_verify_echo( client: &Arc, auth: &Authority, data: &[u8], tx_pb: ProgressBar, rx_pb: ProgressBar, dst: &mut (impl AsyncWrite + Unpin), ) -> Result<(), Error> { let connection = client.connect(auth.host()).await?; let (sid, (reader, writer)) = connection.open_bi_stream().await?.unwrap(); tracing::debug!(%sid, "opened stream"); let mut reader = rx_pb.wrap_async_read(reader); let mut writer = tx_pb.wrap_async_write(writer); tokio::try_join!( async { writer.write_all(data).await?; writer.shutdown().await?; tx_pb.finish(); Result::<(), Error>::Ok(()) }, async { io::copy(&mut reader, dst).await?; dst.flush().await?; rx_pb.finish(); Result::<(), Error>::Ok(()) } ) .map(|_| ()) } ================================================ FILE: dquic/examples/echo-server.rs ================================================ use std::{path::PathBuf, sync::Arc, time::Duration}; use clap::Parser; use dquic::{prelude::*, qinterface::io::IO}; use qevent::telemetry::handy::{LegacySeqLogger, NoopLogger}; use tokio::io::{self, AsyncWriteExt}; use tracing::info; use tracing_subscriber::prelude::*; #[derive(Parser, Debug)] #[command(name = "server")] struct Options { #[arg(long, help = "Save the qlog to a dir", value_name = "PATH")] qlog: Option, #[arg( short, long, value_delimiter = ',', default_values = ["127.0.0.1:4433", "[::1]:4433"], help = "What BindUris to listen for new connections", )] listen: Vec, #[arg( long, short, default_value = "4096", help = "Maximum number of requests in the backlog. \ If the backlog is full, new connections will be refused." )] backlog: usize, #[arg( long, default_value = "true", action = clap::ArgAction::Set, help = "Enable ANSI color output in logs" )] ansi: bool, #[command(flatten)] certs: Certs, } #[derive(Parser, Debug)] struct Certs { #[arg(long, short, default_value = "localhost", help = "Server name.")] server_name: String, #[arg( long, short, default_value = "tests/keychain/localhost/server.cert", help = "Certificate for TLS. If present, `--key` is mandatory." )] cert: PathBuf, #[arg( long, short, default_value = "tests/keychain/localhost/server.key", help = "Private key for the certificate." )] key: PathBuf, } #[tokio::main] async fn main() { let options = Options::parse(); let (non_blocking, _guard) = tracing_appender::non_blocking(std::io::stdout()); tracing_subscriber::registry() // .with(console_subscriber::spawn()) .with( tracing_subscriber::fmt::layer() .with_writer(non_blocking) .with_ansi(options.ansi) .with_filter( tracing_subscriber::EnvFilter::builder() .with_default_directive(tracing::level_filters::LevelFilter::INFO.into()) .from_env_lossy(), ), ) .init(); if let Err(error) = run(options).await { tracing::info!(?error); std::process::exit(1); } } async fn run(options: Options) -> Result<(), Box> { let qlogger: Arc = match options.qlog { Some(dir) => Arc::new(LegacySeqLogger::new(dir)), None => Arc::new(NoopLogger), }; let listeners = QuicListeners::builder() .without_client_cert_verifier() .with_parameters(handy::server_parameters()) .with_qlog(qlogger) .defer_idle_timeout(Duration::from_secs(0)) .enable_0rtt() .listen(options.backlog)?; listeners .add_server( options.certs.server_name.as_str(), options.certs.cert.as_path(), options.certs.key.as_path(), options.listen, None, ) .await?; tracing::info!( "Listening on {}", listeners .get_server(options.certs.server_name.as_str()) .unwrap() .bind_interfaces() .iter() .next() .unwrap() .1 .borrow() .bound_addr()? ); serve_echo(listeners).await?; Ok(()) } async fn serve_echo(listeners: Arc) -> Result<(), ListenersShutdown> { async fn handle_stream(mut reader: StreamReader, mut writer: StreamWriter) -> io::Result<()> { io::copy(&mut reader, &mut writer).await?; writer.shutdown().await?; tracing::debug!("stream copy done"); io::Result::Ok(()) } loop { let (connection, _server, pathway, ..) = listeners.accept().await?; info!(source = ?pathway.remote(), "accepted new connection"); tokio::spawn(async move { while let Ok((_sid, (reader, writer))) = connection.accept_bi_stream().await { tokio::spawn(handle_stream(reader, writer)); } }); } } ================================================ FILE: dquic/examples/http-client.rs ================================================ use std::{path::PathBuf, sync::Arc}; use clap::Parser; use dquic::prelude::{handy::ToCertificate, *}; use http::{Uri, uri::Parts}; use qevent::telemetry::handy::{LegacySeqLogger, NoopLogger}; use tokio::{ fs, io::{self, AsyncBufReadExt, AsyncWriteExt, BufReader}, }; use tracing_subscriber::prelude::*; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] struct Options { #[arg(long, help = "Save the qlog to a dir", value_name = "PATH")] qlog: Option, #[arg( long, short, value_delimiter = ',', default_value = "tests/keychain/localhost/ca.cert", help = "Certificates of CA who issues the server certificate" )] roots: Vec, #[arg(long, help = "Skip verification of server certificate")] skip_verify: bool, #[arg( long, short, value_delimiter = ',', default_value = "quic", help = "ALPNs to use for the connection" )] alpns: Vec>, #[arg( long, default_value = "true", action = clap::ArgAction::Set, help = "Enable ANSI color output in logs" )] ansi: bool, #[arg(long, help = "Save the response to a dir", value_name = "PATH")] save: Option, #[arg( value_delimiter = ',', default_value = "http://localhost:4433/", help = "Uri to request. If only one uri is present and path is not specified, enter process mode" )] uris: Vec, } #[tokio::main] async fn main() { let options = Options::parse(); let (non_blocking, _guard) = tracing_appender::non_blocking(std::io::stdout()); tracing_subscriber::registry() // .with( // console_subscriber::ConsoleLayer::builder() // .server_addr("127.0.0.1:6670".parse::().unwrap()) // .spawn(), // ) .with( tracing_subscriber::fmt::layer() .with_writer(non_blocking) .with_ansi(options.ansi) .with_filter( tracing_subscriber::EnvFilter::builder() .with_default_directive(tracing::level_filters::LevelFilter::INFO.into()) .from_env_lossy(), ), ) .init(); if let Err(error) = run(options).await { tracing::error!(?error); std::process::exit(1); } } type Error = Box; async fn run(options: Options) -> Result<(), Error> { if options.uris.is_empty() { return Err("no uri specified".into()); } let qlogger: Arc = match options.qlog { Some(dir) => Arc::new(LegacySeqLogger::new(dir)), None => Arc::new(NoopLogger), }; let client_builder = if options.skip_verify { tracing::warn!("skip server verify"); QuicClient::builder().without_verifier() } else { tracing::info!("load ca certs"); let mut roots = rustls::RootCertStore::empty(); roots.add_parsable_certificates(rustls_native_certs::load_native_certs().certs); roots .add_parsable_certificates(options.roots.iter().flat_map(|path| path.to_certificate())); QuicClient::builder().with_root_certificates(roots) }; let client = Arc::new( client_builder .with_qlog(qlogger) .without_cert() .with_parameters(handy::client_parameters()) .with_alpns(options.alpns) .enable_sslkeylog() .build(), ); if options.uris.len() == 1 && options.uris[0].path() == "/" { return process(&client, &options.uris[0], options.save).await; } else { for uri in options.uris { download(&client, uri, options.save.as_ref()).await?; } } Ok(()) } async fn process( client: &Arc, base_uri: &Uri, save: Option, ) -> Result<(), Error> { let mut stdin = BufReader::new(io::stdin()); eprintln!( "Enter interactive mode. Input content to request (e.g: Cargo.toml), input `exit` or `quit` to quit." ); loop { let mut input = String::new(); _ = stdin.read_line(&mut input).await?; let content = input.trim(); if content.is_empty() { continue; } if content == "exit" || content == "quit" { return Ok(()); } let mut uri_parts = Parts::default(); uri_parts.scheme = base_uri.scheme().cloned(); uri_parts.authority = base_uri.authority().cloned(); uri_parts.path_and_query = Some(format!("/{content}").parse()?); download(client, Uri::from_parts(uri_parts)?, save.as_ref()).await?; } } async fn download(client: &Arc, uri: Uri, save: Option<&PathBuf>) -> Result<(), Error> { let authority = uri.authority().ok_or("authority must be present in uri")?; let file_path = uri.path().strip_prefix('/'); let file_path = file_path.ok_or_else(|| format!("invalid path `{}`", uri.path()))?; let connection = client.connect(authority.host()).await?; let (_sid, (mut response, mut request)) = connection .open_bi_stream() .await? .expect("very very hard to exhaust the available stream ids"); request .write_all(format!("GET /{file_path}").as_bytes()) .await?; request.shutdown().await?; match save.map(|dir| dir.join(file_path)) { Some(path) => io::copy(&mut response, &mut fs::File::create(path).await?).await?, None => io::copy(&mut response, &mut io::stdout()).await?, }; _ = connection.close("Bye bye", 0); tracing::info!("Saved to file {file_path}"); Ok(()) } ================================================ FILE: dquic/examples/http-server.rs ================================================ use std::{path::PathBuf, sync::Arc}; use clap::Parser; use dquic::{prelude::*, qinterface::io::IO}; use tokio::{ fs, io::{self, AsyncReadExt, AsyncWriteExt}, }; use tracing_subscriber::prelude::*; #[derive(Parser, Debug)] #[command(name = "server")] struct Options { #[arg( name = "dir", short, long, help = "Root directory of the files to serve. \ If omitted, server will respond OK.", default_value = "./" )] root: PathBuf, #[arg(long, help = "Save the qlog to a dir", value_name = "PATH")] qlog: Option, #[arg( short, long, value_delimiter = ',', default_values = ["127.0.0.1:4433", "[::1]:4433"], help = "What BindUris to listen for new connections" )] listen: Vec, #[arg( long, short, value_delimiter = ',', default_value = "quic", help = "ALPNs to use for the connection" )] alpns: Vec>, #[arg( long, short, default_value = "4096", help = "Maximum number of requests in the backlog. \ If the backlog is full, new connections will be refused." )] backlog: usize, #[arg( long, default_value = "true", action = clap::ArgAction::Set, help = "Enable ANSI color output in logs" )] ansi: bool, #[command(flatten)] certs: Certs, } #[derive(Parser, Debug)] struct Certs { #[arg(long, short, default_value = "localhost", help = "Server name.")] server_name: String, #[arg( long, short, default_value = "tests/keychain/localhost/server.cert", help = "Certificate for TLS. If present, `--key` is mandatory." )] cert: PathBuf, #[arg( long, short, default_value = "tests/keychain/localhost/server.key", help = "Private key for the certificate." )] key: PathBuf, } type Error = Box; fn main() { let options = Options::parse(); let (non_blocking, _guard) = tracing_appender::non_blocking(std::io::stdout()); tracing_subscriber::registry() // .with(console_subscriber::spawn()) .with( tracing_subscriber::fmt::layer() .with_writer(non_blocking) .with_ansi(options.ansi) .with_filter( tracing_subscriber::EnvFilter::builder() .with_default_directive(tracing::level_filters::LevelFilter::INFO.into()) .from_env_lossy(), ), ) .init(); let rt = tokio::runtime::Builder::new_current_thread() .enable_all() // default value 512 out of macos ulimit .max_blocking_threads(256) .build() .expect("failed to build tokio runtime"); if let Err(error) = rt.block_on(run(options)) { tracing::info!(?error); std::process::exit(1); } } async fn run(options: Options) -> Result<(), Error> { let qlogger: Arc = match options.qlog { Some(dir) => Arc::new(handy::LegacySeqLogger::new(dir)), None => Arc::new(handy::NoopLogger), }; let listeners = QuicListeners::builder() .with_qlog(qlogger) .without_client_cert_verifier() .with_parameters(handy::server_parameters()) .with_alpns(options.alpns) .listen(options.backlog)?; listeners .add_server( options.certs.server_name.as_str(), options.certs.cert.as_path(), options.certs.key.as_path(), options.listen, None, ) .await?; tracing::info!( "Listening on {}", listeners .get_server(options.certs.server_name.as_str()) .unwrap() .bind_interfaces() .iter() .next() .unwrap() .1 .borrow() .bound_addr()? ); loop { let (connection, _server, _pathway, _link) = listeners.accept().await?; tokio::spawn(serve_files(connection)); } } async fn serve_files(connection: Connection) -> Result<(), Error> { async fn serve_file(mut reader: StreamReader, mut writer: StreamWriter) -> Result<(), Error> { let mut request = String::new(); reader.read_to_string(&mut request).await?; tracing::info!("received request: {request}"); // HTTP/0.9 is very simple - just a GET request with a path let serve = async { match request.trim().strip_prefix("GET /") { Some(path) => { tracing::debug!(?path, "Received HTTP/0.9 request"); let mut file = fs::File::open(PathBuf::from_iter(["./", path])).await?; io::copy(&mut file, &mut writer).await.map(|_| ()) } None => Err(io::Error::other(format!( "Invalid HTTP/0.9 request: {request}", ))), } }; if let Err(error) = serve.await { tracing::warn!("failed to serve request: {}", error); } _ = writer.shutdown().await; Ok(()) } loop { let (_sid, (reader, writer)) = connection.accept_bi_stream().await?; tokio::spawn(serve_file(reader, writer)); } } ================================================ FILE: dquic/examples/traversal-client.rs ================================================ // use std::{io, net::SocketAddr}; // use clap::Parser; // use dquic::{ // prelude::{ // Connection, EndpointAddr, ParameterId, QuicClient, EndpointAddr, handy::ToCertificate, // }, // qbase::param::ClientParameters, // qtraversal::iface::TraversalFactory, // }; // use rustls::RootCertStore; // use tokio::{ // io::{AsyncReadExt, AsyncWriteExt}, // task::JoinSet, // }; // use tracing::{info, warn}; // use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; // #[derive(Parser)] // struct Options { // #[arg(long)] // bind1: SocketAddr, // #[arg(long)] // bind2: SocketAddr, // #[arg(long)] // server_outer: SocketAddr, // #[arg(long)] // server_agent: SocketAddr, // #[arg(long, default_value = "nat.genmeta.net:20004")] // stun_server: String, // } // pub type Error = Box; // #[tokio::main] // pub async fn main() -> io::Result<()> { // init_logger()?; // let default_panic = std::panic::take_hook(); // std::panic::set_hook(Box::new(move |info| { // default_panic(info); // info!("panic: {}", info); // std::process::exit(1); // })); // let ops = Options::parse(); // let server_ep = EndpointAddr::Agent { // agent: ops.server_agent, // outer: ops.server_outer, // }; // let mut roots = RootCertStore::empty(); // roots.add_parsable_certificates( // include_bytes!("../../../tests/keychain/localhost/ca.cert").to_certificate(), // ); // let stun_servers: Vec = tokio::net::lookup_host(&ops.stun_server).await?.collect(); // if stun_servers.is_empty() { // return Err(io::Error::other("failed to resolve stun server")); // } // let factory = TraversalFactory::initialize_global(stun_servers).unwrap(); // let client = QuicClient::builder() // .with_root_certificates(roots) // .without_cert() // .enable_sslkeylog() // // .with_qlog(Arc::new(DefaultSeqLogger::new(PathBuf::from("qlog")))) // .with_iface_factory(factory.as_ref().clone()) // .with_parameters(client_stream_unlimited_parameters()) // .bind(&[ops.bind1, ops.bind2][..]) // .await // .build(); // let mut handle_set = JoinSet::new(); // for _ in 0..1 { // info!( // "server ep {:?}, bind {} {}", // server_ep, ops.bind1, ops.bind2 // ); // let connection = client // .connected_to("localhost", [server_ep]) // .await // .map_err(io::Error::other)?; // const DATA: &[u8] = include_bytes!("./client.rs"); // handle_set.spawn(async move { // send_and_verify_echo(&connection, DATA).await.unwrap(); // // 等待打洞结束 // tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; // warn!("finish one connection"); // }); // } // let _et = handle_set.join_all().await; // Ok(()) // } // async fn send_and_verify_echo(connection: &Connection, data: &[u8]) -> Result<(), Error> { // let (_sid, (mut reader, mut writer)) = connection.open_bi_stream().await?.unwrap(); // tracing::debug!("stream opened"); // let mut back = Vec::new(); // tokio::try_join!( // async { // writer.write_all(data).await?; // writer.shutdown().await?; // tracing::info!("xxxxx write done"); // Result::<(), Error>::Ok(()) // }, // async { // reader.read_to_end(&mut back).await?; // assert_eq!(back, data); // tracing::info!("xxxx read done"); // Result::<(), Error>::Ok(()) // } // ) // .map(|_| ()) // } // fn client_stream_unlimited_parameters() -> ClientParameters { // let mut params = ClientParameters::default(); // _ = params.set(ParameterId::ActiveConnectionIdLimit, 10u32); // _ = params.set(ParameterId::InitialMaxData, 1u32 << 20); // _ = params.set(ParameterId::InitialMaxStreamDataBidiLocal, 1u32 << 20); // _ = params.set(ParameterId::InitialMaxStreamDataBidiRemote, 1u32 << 20); // _ = params.set(ParameterId::InitialMaxStreamDataUni, 1u32 << 20); // _ = params.set(ParameterId::InitialMaxStreamsBidi, 100u32); // _ = params.set(ParameterId::InitialMaxStreamsUni, 100u32); // params // } // pub fn init_logger() -> std::io::Result<()> { // let filter = tracing_subscriber::filter::filter_fn(|metadata| { // !metadata.target().contains("netlink_packet_route") // }); // let _ = tracing_subscriber::registry() // .with(tracing_subscriber::Layer::with_filter( // tracing_subscriber::fmt::layer() // .with_target(true) // .with_ansi(false) // .with_file(true) // .with_line_number(true), // filter, // )) // .try_init(); // Ok(()) // } fn main() {} ================================================ FILE: dquic/examples/traversal-server.rs ================================================ // use std::{io, net::SocketAddr, sync::Arc}; // use clap::Parser; // use dquic::{ // prelude::{Connection, ParameterId, QuicListeners, StreamReader, StreamWriter}, // qbase::param::ServerParameters, // qtraversal, // }; // use qtraversal::iface::TraversalFactory; // use tokio::io::AsyncWriteExt; // use tracing::{Instrument, info, info_span}; // use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; // #[derive(clap::Parser)] // struct Options { // #[arg(long, default_value = "192.168.1.4:6000")] // bind1: SocketAddr, // #[arg(long, default_value = "[2409:8a00:1850:be40:1037:3cbd:ec40:11c6]:6000")] // bind2: SocketAddr, // #[arg(long, default_value = "nat.genmeta.net:20004")] // stun_server: String, // } // #[tokio::main] // pub async fn main() -> io::Result<()> { // init_logger()?; // let default_panic = std::panic::take_hook(); // std::panic::set_hook(Box::new(move |info| { // default_panic(info); // info!("panic: {}", info); // std::process::exit(1); // })); // let ops = Options::parse(); // let stun_servers: Vec = tokio::net::lookup_host(&ops.stun_server).await?.collect(); // if stun_servers.is_empty() { // return Err(io::Error::other("failed to resolve stun server")); // } // let factory = TraversalFactory::initialize_global(stun_servers).unwrap(); // let server = QuicListeners::builder()? // // .with_single_cert( // // include_bytes!("../../../tests/keychain/localhost/server.cert"), // // include_bytes!("../../../tests/keychain/localhost/server.key"), // // ) // .with_iface_factory(factory.as_ref().clone()) // .with_parameters(server_stream_unlimited_parameters()) // .without_client_cert_verifier() // .listen(1000); // server // .add_server( // "localhost", // include_bytes!("../../../tests/keychain/localhost/server.cert"), // include_bytes!("../../../tests/keychain/localhost/server.key"), // [ops.bind1], // None, // ) // .await?; // launch(server).await?; // Ok(()) // } // pub fn server_stream_unlimited_parameters() -> ServerParameters { // let mut params = ServerParameters::default(); // _ = params.set(ParameterId::ActiveConnectionIdLimit, 10u32); // _ = params.set(ParameterId::InitialMaxData, 1u32 << 20); // _ = params.set(ParameterId::InitialMaxStreamDataBidiLocal, 1u32 << 20); // _ = params.set(ParameterId::InitialMaxStreamDataBidiRemote, 1u32 << 20); // _ = params.set(ParameterId::InitialMaxStreamDataUni, 1u32 << 20); // _ = params.set(ParameterId::InitialMaxStreamsBidi, 100u32); // _ = params.set(ParameterId::InitialMaxStreamsUni, 100u32); // params // } // pub async fn launch(server: Arc) -> io::Result<()> { // async fn handle_connection(conn: Arc) -> io::Result<()> { // loop { // let (sid, (reader, writer)) = conn.accept_bi_stream().await?; // tokio::spawn( // handle_stream(reader, writer).instrument(info_span!("handle_stream",%sid)), // ); // } // } // async fn handle_stream(mut reader: StreamReader, mut writer: StreamWriter) -> io::Result<()> { // tokio::io::copy(&mut reader, &mut writer).await?; // writer.shutdown().await?; // tracing::info!("stream copy done"); // io::Result::Ok(()) // } // loop { // let (connection, _name, pathway, _link) = server // .accept() // .await // .map_err(|_e| io::Error::other("accept error"))?; // info!(source = ?pathway.remote(), "accepted new connection"); // tokio::spawn(handle_connection(Arc::new(connection))); // } // } // pub fn init_logger() -> std::io::Result<()> { // let filter = tracing_subscriber::filter::filter_fn(|metadata| { // !metadata.target().contains("netlink_packet_route") // }); // let _ = tracing_subscriber::registry() // .with(tracing_subscriber::Layer::with_filter( // tracing_subscriber::fmt::layer() // .with_target(true) // .with_ansi(false) // .with_file(true) // .with_line_number(true), // filter, // )) // .try_init(); // Ok(()) // } fn main() {} ================================================ FILE: dquic/src/cert.rs ================================================ use std::path::Path; use rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject}; pub trait ToCertificate { fn to_certificate(self) -> Vec>; } impl ToCertificate for Vec> { fn to_certificate(self) -> Vec> { self } } impl ToCertificate for &[CertificateDer<'static>] { fn to_certificate(self) -> Vec> { self.to_vec().to_certificate() } } impl ToCertificate for [CertificateDer<'static>; N] { fn to_certificate(self) -> Vec> { self.to_vec().to_certificate() } } impl ToCertificate for CertificateDer<'static> { fn to_certificate(self) -> Vec> { vec![self] } } impl ToCertificate for &Path { fn to_certificate(self) -> Vec> { let data = std::fs::read(self).expect("Failed to read certificate file"); if let Ok(certs) = CertificateDer::pem_slice_iter(&data).collect::, _>>() && !certs.is_empty() { return certs; } vec![CertificateDer::from(data)] } } impl ToCertificate for &[u8] { fn to_certificate(self) -> Vec> { if let Ok(certs) = CertificateDer::pem_slice_iter(self).collect::, _>>() && !certs.is_empty() { return certs; } vec![CertificateDer::from(self.to_vec())] } } impl ToCertificate for &[u8; N] { fn to_certificate(self) -> Vec> { <&[u8]>::to_certificate(self) } } pub trait ToPrivateKey { fn to_private_key(self) -> PrivateKeyDer<'static>; } impl ToPrivateKey for PrivateKeyDer<'static> { fn to_private_key(self) -> PrivateKeyDer<'static> { self } } impl ToPrivateKey for &PrivateKeyDer<'static> { fn to_private_key(self) -> PrivateKeyDer<'static> { self.clone_key() } } impl ToPrivateKey for &Path { fn to_private_key(self) -> PrivateKeyDer<'static> { let data = std::fs::read(self).expect("failed to read private key file"); if let Ok(key) = PrivateKeyDer::from_pem_slice(&data) { return key; } PrivateKeyDer::try_from(data) .expect("failed to parse private key file as pem or der format") } } impl ToPrivateKey for &[u8] { fn to_private_key(self) -> PrivateKeyDer<'static> { if let Ok(key) = PrivateKeyDer::from_pem_slice(self) { return key; } PrivateKeyDer::try_from(self.to_vec()) .expect("failed to parse private key file as pem or der format") } } impl ToPrivateKey for &[u8; N] { fn to_private_key(self) -> PrivateKeyDer<'static> { <&[u8]>::to_private_key(self) } } ================================================ FILE: dquic/src/client.rs ================================================ use std::{ collections::HashMap, io, net::SocketAddr, str::FromStr, sync::{ Arc, atomic::{AtomicBool, Ordering}, }, time::Duration, }; use dashmap::DashMap; use futures::StreamExt; use qbase::{net::Family, param::ClientParameters, token::TokenSink}; use qconnection::{ self, qbase::net::AddrFamily, qinterface::{component::location::Locations, io::IO}, }; use qevent::telemetry::QLog; use qinterface::{ BindInterface, Interface, bind_uri::BindUri, component::route::QuicRouter, device::Devices, io::ProductIO, manager::InterfaceManager, }; use qresolve::Source; use rustls::{ ConfigBuilder, WantsVerifier, client::{ResolvesClientCert, WantsClientCert}, }; use thiserror::Error; use crate::{prelude::*, *}; type TlsClientConfig = rustls::ClientConfig; type TlsClientConfigBuilder = ConfigBuilder; /// A QUIC client for initiating connections to servers. /// /// ## Creating Clients /// /// Use [`QuicClient::builder`] to configure and create a client instance. /// Configure interfaces, TLS settings, and connection behavior before building. /// /// ## Interface Management /// /// - **Automatic binding**: If no interfaces are bound, the client automatically binds to system-assigned addresses /// - **Manual binding**: Use [`QuicClientBuilder::bind`] to bind specific interfaces /// /// ## Connection Handling /// /// Call [`QuicClient::connect`] to establish connections. The client supports: /// - **Automatic interface selection**: Matches interface with server endpoint address #[derive(Clone)] pub struct QuicClient { network: common::Network, bind_ifaces: DashMap, manual_bind: Arc, // quic config(in initialize order) _prefer_versions: Vec, token_sink: Arc, parameters: ClientParameters, tls_config: TlsClientConfig, stream_strategy_factory: Arc, defer_idle_timeout: Duration, qlogger: Arc, } #[derive(Debug, Error)] pub enum ConnectServerError { #[error("DNS lookup failed")] Dns { #[from] source: io::Error, }, #[error("Failed to bind interface for client connection")] BindInterface { #[from] source: BindInterfaceError, }, } #[derive(Debug, Error)] #[error( "Failed to bind interface `{}` for client connection", bind_uri.as_ref().map_or(String::from(""), |bind_uri| bind_uri.to_string()) )] pub struct BindInterfaceError { bind_uri: Option, #[source] bind_error: io::Error, } impl QuicClient { #[inline] pub fn bind_ifaces(&self) -> HashMap { self.bind_ifaces .iter() .map(|entry| (entry.key().clone(), entry.value().clone())) .collect() } pub async fn bind(&self, bind_uri: impl Into) -> BindInterface { let bind_interface = self.network.bind(bind_uri.into()).await; self.bind_ifaces .insert(bind_interface.bind_uri(), bind_interface.clone()); self.manual_bind.store(true, Ordering::Relaxed); bind_interface } #[inline] pub fn unbind(&self, bind_uri: &BindUri) -> Option { self.bind_ifaces.remove(bind_uri).map(|(_, iface)| iface) } /// Creates a new QUIC connection to the specified server without any initial paths. /// /// This method initializes the connection state but does not start the handshake /// because no network paths are established yet. You must manually add paths /// using [`Connection::add_path`] to initiate communication. /// /// This is useful for advanced scenarios where you need fine-grained control /// over which interfaces and paths are used for the connection. pub fn new_connection(&self, server_name: impl Into) -> Connection { Connection::new_client(server_name.into(), self.token_sink.clone()) .with_parameters(self.parameters.clone()) .with_tls_config(self.tls_config.clone()) .with_streams_concurrency_strategy(self.stream_strategy_factory.as_ref()) .with_zero_rtt(self.tls_config.enable_early_data) .with_iface_factory(self.network.iface_factory.clone()) .with_iface_manager(self.network.iface_manager.clone()) .with_quic_router(self.network.quic_router.clone()) .with_locations(self.network.locations.clone()) // todo // .with_stun_servers() .with_defer_idle_timeout(self.defer_idle_timeout) .with_cids(ConnectionId::random_gen(8)) .with_qlog(self.qlogger.clone()) .run() } /// Builds a [`BindUri`] from the DNS [`Source`] and endpoint address. /// /// - For [`Source::Mdns`]: binds to the discovering NIC (e.g., `iface://v4.en0:0`). /// - For other sources: binds to a wildcard address matching the endpoint family. fn bind_uri_for(source: &Source, ep: &EndpointAddr) -> BindUri { match source { Source::Mdns { nic, family } => { let f = match family { Family::V4 => "v4", Family::V6 => "v6", }; BindUri::from_str(&format!("iface://{f}.{nic}:0")) .expect("iface URI should be valid") .alloc_port() } _ => match ep.family() { Family::V4 => BindUri::from_str("inet://0.0.0.0:0") .expect("URL should be valid") .alloc_port(), Family::V6 => BindUri::from_str("inet://[::]:0") .expect("URL should be valid") .alloc_port(), }, } } /// Ensures at least one interface exists for the given endpoint. async fn ensure_iface_for(&self, source: &Source, ep: &EndpointAddr) { if self.manual_bind.load(Ordering::Relaxed) { return; } if self.bind_ifaces.is_empty() { let bind_uri = Self::bind_uri_for(source, ep); let iface = self.network.bind(bind_uri).await; self.bind_ifaces.insert(iface.bind_uri(), iface); } } /// Returns matching bound interfaces or auto-binds a new one. async fn select_or_bind_ifaces( &self, source: &Source, ep: &EndpointAddr, ) -> Result, BindInterfaceError> { let iface_matches_source = |iface: &Interface| match source { Source::Mdns { nic, family } => iface.bind_uri().as_iface_bind_uri().is_some_and( |(iface_family, iface_name, _)| { iface_family == *family && iface_name == nic.as_ref() }, ), _ => true, }; if self.manual_bind.load(Ordering::Relaxed) { let ifaces = self .bind_ifaces .iter() .map(|entry| entry.value().borrow()) .filter(|iface| iface_matches_source(iface)) .filter_map(|iface| Some((iface.bound_addr().ok()?, iface))) .filter(|(addr, _)| addr.family() == ep.family()) .collect::>(); Ok(ifaces) } else { let ifaces = self .bind_ifaces .iter() .map(|entry| entry.value().borrow()) .filter(|iface| iface_matches_source(iface)) .filter_map(|iface| Some((iface.bound_addr().ok()?, iface))) .filter(|(addr, _)| addr.family() == ep.family()) .collect::>(); if !ifaces.is_empty() { return Ok(ifaces); } let bind_uri = Self::bind_uri_for(source, ep); let iface = self.network.bind(bind_uri.clone()).await.borrow(); let bound_addr = iface.bound_addr().map_err(|source| BindInterfaceError { bind_uri: Some(bind_uri), bind_error: source, })?; Ok(vec![(bound_addr, iface)]) } } /// Probes and generates potential network paths to the given server endpoints. /// /// Each endpoint is paired with its DNS [`Source`] so that the correct network /// interface can be selected: /// /// - **Direct endpoints**: selects matching bound interfaces or auto-binds a new one, /// then constructs [`Link`] and [`Pathway`] for each. /// - **Agent endpoints**: ensures an interface exists but does **not** build a path — /// the puncher system handles Agent paths after STUN discovery. /// /// Returns a list of `(Interface, Link, Pathway)` tuples for Direct endpoints only. /// /// ### Example /// /// ```no_run /// # use dquic::prelude::*; /// # use dquic::qresolve::Source; /// # async fn example(quic_client: &QuicClient) -> Result<(), Box> { /// let server_addresses: Vec<_> = tokio::net::lookup_host("genmeta.net:443") /// .await? /// .map(|addr| (Source::System, addr.into())) /// .collect(); /// let paths = quic_client.probe(server_addresses).await?; /// let connection = quic_client.new_connection("genmeta.net"); /// for (iface, link, pathway) in paths { /// connection.add_path(iface.bind_uri(), link, pathway)?; /// } /// # Ok(()) /// # } /// ``` pub async fn probe( &self, server_eps: impl IntoIterator, ) -> Result, BindInterfaceError> { let server_eps = server_eps.into_iter().collect::>(); let mut paths = vec![]; for (source, server_ep) in server_eps { if matches!(server_ep, EndpointAddr::Agent { .. }) { self.ensure_iface_for(&source, &server_ep).await; } else { let ifaces = self.select_or_bind_ifaces(&source, &server_ep).await?; paths.extend(ifaces.into_iter().map(move |(bound_addr, iface)| { let dst = *server_ep; let link = Link::new(bound_addr, dst); let pathway = Pathway::new(bound_addr.into(), server_ep); (iface, link, pathway) })); } } Ok(paths) } /// Processes a single server endpoint for the given connection: /// 1. Registers the peer endpoint (with its DNS source) in the connection's address book. /// 2. Probes for immediate paths (Direct endpoints) or ensures an interface /// is bound (Agent endpoints). See [`Self::probe`] for details. /// 3. Adds any resulting Direct paths to the connection. /// /// Returns `true` if at least one Direct path was added. async fn setup_server_endpoint( &self, connection: &Connection, source: Source, server_ep: EndpointAddr, ) -> Result { // Register the peer endpoint with its DNS source — the puncher will // only auto-create paths with local endpoints matching the source constraint // (e.g. mDNS endpoints are restricted to the discovering NIC). _ = connection.add_peer_endpoint(server_ep, source.clone()); // probe() handles both Direct and Agent uniformly: // Direct → select/bind interface, construct Link & Pathway, return paths. // Agent → ensure an interface is bound, return empty paths. let paths = self.probe([(source, server_ep)]).await?; let has_direct_path = !paths.is_empty(); for (iface, link, pathway) in paths { _ = connection.add_path(iface.bind_uri(), link, pathway); } Ok(has_direct_path) } /// Connects to a server using specific endpoint addresses. /// /// This method combines [`QuicClient::probe`] and [`QuicClient::new_connection`]. /// It creates a connection and automatically adds paths for all the provided /// server endpoints. /// /// The returned [`Connection`] may not have completed the handshake yet. /// However, any asynchronous operations on the connection (like opening streams) /// will automatically wait for the handshake to complete. /// /// If `server_eps` is empty, this is equivalent to calling [`QuicClient::new_connection`] /// and the connection will remain idle until paths are added. /// /// This variant preserves the DNS [`Source`] so that the correct network interface /// is selected for each endpoint (e.g., mDNS endpoints bind to the discovering NIC). pub async fn connected_to_with_source( &self, server_name: impl Into, server_eps: impl IntoIterator, ) -> Result { let connection = self.new_connection(server_name); _ = connection.subscribe_local_address(); for (source, server_ep) in server_eps { self.setup_server_endpoint(&connection, source, server_ep) .await .map_err(|source| ConnectServerError::BindInterface { source })?; } Ok(connection) } /// Connects to a server by its hostname and optional port. /// /// This is the most convenient way to establish a connection. It performs the following steps: /// 1. Parses the server string (e.g., "example.com" or "example.com:443"). /// Defaults to port 443 if not specified. /// 2. Performs an asynchronous DNS lookup to resolve the hostname to IP addresses. /// 3. Calls [`QuicClient::connected_to_with_source`] with the resolved addresses. /// /// The returned [`Connection`] may not have completed the handshake yet. /// Asynchronous operations on the connection will wait for the handshake. pub async fn connect(self: &Arc, server: &str) -> Result { let mut server_eps = self .network .resolver .lookup(server) .await .map_err(|source| ConnectServerError::Dns { source })?; let connection = self.new_connection(server); if connection.subscribe_local_address().is_err() { // connection already closed, return immediately (not connect error) return Ok(connection); } let mut last_error: Option = None; // Consume the DNS stream until we get at least one Direct path, // or exhaust all endpoints (Agent-only is acceptable). // // `last_error` doubles as a "no viable endpoint yet" sentinel: // - On `Ok(false)` (Agent registered): clear it — we have a viable fallback. // - On `Err`: set/keep it — probe failure, keep looking. // - On stream exhaustion: if still `Some`, nothing viable → propagate error. while let Some((source, server_ep)) = server_eps.next().await { match self .setup_server_endpoint(&connection, source, server_ep) .await { Ok(true) => { last_error = None; // Got a Direct path, proceed. break; } Ok(false) => { // Agent endpoint registered — even if later Direct probes fail, // the puncher can still establish paths asynchronously. last_error = None; } Err(error) => { last_error.get_or_insert(error.into()); } } } if let Some(error) = last_error { return Err(error); } // Background task: keep consuming the DNS stream for late-arriving endpoints. tokio::spawn({ let connection = connection.clone(); let client = self.clone(); async move { while let Some((source, server_ep)) = server_eps.next().await { _ = client .setup_server_endpoint(&connection, source, server_ep) .await; } } }); Ok(connection) } } /// Builder for [`QuicClient`]. #[derive(Clone)] pub struct QuicClientBuilder { network: common::Network, // client bind_ifaces: DashMap, manual_bind: bool, // client: quic config(in initialize order) prefer_versions: Vec, token_sink: Arc, parameters: ClientParameters, tls_config: T, stream_strategy_factory: Arc, defer_idle_timeout: Duration, qlogger: Arc, } impl QuicClient { /// Create a new [`QuicClient`] builder. pub fn builder() -> QuicClientBuilder> { Self::builder_with_tls(TlsClientConfig::builder_with_protocol_versions(&[ &rustls::version::TLS13, ])) } /// Create a [`QuicClient`] builder with custom crypto provider. pub fn builder_with_crypto_provider( provider: Arc, ) -> QuicClientBuilder> { Self::builder_with_tls( TlsClientConfig::builder_with_provider(provider) .with_protocol_versions(&[&rustls::version::TLS13]) .unwrap(), ) } /// Start to build a QuicClient with the given TLS configuration. /// /// This is useful when you want to customize the TLS configuration, or integrate qm-quic with other crates. pub fn builder_with_tls(tls_config: T) -> QuicClientBuilder { QuicClientBuilder { // network network: common::Network::default(), // client bind_ifaces: DashMap::new(), manual_bind: false, // client: quic config(in initialize order) prefer_versions: vec![1], token_sink: Arc::new(handy::NoopTokenRegistry), parameters: handy::client_parameters(), tls_config, stream_strategy_factory: Arc::new(handy::ConsistentConcurrency::new), defer_idle_timeout: Duration::ZERO, qlogger: Arc::new(handy::NoopLogger), } } } impl QuicClientBuilder { pub fn with_resolver(mut self, resolver: Arc) -> Self { self.network.resolver = resolver; self } pub fn physical_ifaces(mut self, physical_ifaces: &'static Devices) -> Self { self.network.devices = physical_ifaces; self } /// Specify how client bind interfaces. /// /// The given factory will be used by [`Self::bind`], /// and/or [`QuicClient::connect`] if no interface bound when client built. /// /// The default quic interface is provided by [`handy::DEFAULT_IO_FACTORY`]. /// For Unix and Windows targets, this is a high performance UDP library supporting GSO and GRO /// provided by `qudp` crate. For other platforms, please specify you own factory. pub fn with_iface_factory(mut self, iface_factory: Arc) -> Self { self.network.iface_factory = iface_factory; self } /// Specify the interfaces manager for the client. pub fn with_iface_manager(mut self, iface_manager: Arc) -> Self { self.network.iface_manager = iface_manager; self } pub fn with_router(mut self, router: Arc) -> Self { self.network.quic_router = router; self } pub fn with_stun(mut self, server: impl Into>) -> Self { self.network.stun_server = Some(server.into()); self } /// Specify the locations for interface sharing. /// /// The given locations is shared by all connections created by this client. pub fn with_locations(mut self, locations: Arc) -> Self { self.network.locations = locations; self } /// Create quic interfaces bound on given address. /// /// If the bind failed, the error will be returned immediately. /// /// The default quic interface is provided by [`handy::DEFAULT_IO_FACTORY`]. /// For Unix and Windows targets, this is a high performance UDP library supporting GSO and GRO /// provided by `qudp` crate. For other platforms, please specify you own factory with /// [`QuicClientBuilder::with_iface_factory`]. /// /// If you dont bind any address, each time the client initiates a new connection, /// the client will use bind a new interface on address and port that dynamic assigned by the system. /// /// To know more about how the client selects the interface when initiates a new connection, /// read [`QuicClient::connect`]. /// /// If you call this multiple times, only the last set of interface will be used, /// previous bound interface will be freed immediately. /// /// If all interfaces are closed, clients will no longer be able to initiate new connections. pub async fn bind(mut self, bind_uris: impl IntoIterator>) -> Self { self.bind_ifaces = self .network .bind_many(bind_uris) .await .map(|bind_iface| (bind_iface.bind_uri(), bind_iface)) .collect() .await; self.manual_bind = true; self } /// (WIP)Specify the quic versions that the client prefers. /// /// If you call this multiple times, only the last call will take effect. pub fn prefer_versions(mut self, versions: impl IntoIterator) -> Self { self.prefer_versions.clear(); self.prefer_versions.extend(versions); self } /// Specify the token sink for the client. /// /// The token sink is used to storage the tokens that the client received from the server. The client will use the /// tokens to prove it self to the server when it reconnects to the server. read [address verification] in quic rfc /// for more information. /// /// [address verification](https://www.rfc-editor.org/rfc/rfc9000.html#name-address-validation) pub fn with_token_sink(self, token_sink: Arc) -> Self { Self { token_sink, ..self } } /// Specify the [transport parameters] for the client. /// /// If you call this multiple times, only the last `parameters` will be used. /// /// Usually, you don't need to call this method, because the client will use a set of default parameters. /// /// [transport parameters](https://www.rfc-editor.org/rfc/rfc9000.html#name-transport-parameter-definit) pub fn with_parameters(self, parameters: ClientParameters) -> Self { Self { parameters, ..self } } fn map_tls(self, f: impl FnOnce(T) -> T1) -> QuicClientBuilder { QuicClientBuilder { network: self.network, bind_ifaces: self.bind_ifaces, manual_bind: self.manual_bind, prefer_versions: self.prefer_versions, token_sink: self.token_sink, parameters: self.parameters, tls_config: f(self.tls_config), stream_strategy_factory: self.stream_strategy_factory, defer_idle_timeout: self.defer_idle_timeout, qlogger: self.qlogger, } } pub fn with_name(mut self, name: impl Into) -> Self { self.parameters .set(ParameterId::ClientName, name.into()) .expect("parameter 0xffee belong_to client and has type String"); self } /// Provide an option to defer an idle timeout. /// /// This facility could be used when the application wishes to avoid losing /// state that has been associated with an open connection but does not expect /// to exchange application data for some time. /// /// See [Deferring Idle Timeout](https://datatracker.ietf.org/doc/html/rfc9000#name-deferring-idle-timeout) /// of [RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000) /// for more information. pub fn defer_idle_timeout(mut self, duration: Duration) -> Self { self.defer_idle_timeout = duration; self } /// Specify the streams concurrency strategy controller for the client. /// /// The streams controller is used to control the concurrency of data streams. `controller` is a closure that accept /// (initial maximum number of bidirectional streams, initial maximum number of unidirectional streams) configured in /// [transport parameters] and return a `ControlConcurrency` object. /// /// If you call this multiple times, only the last `controller` will be used. /// /// [transport parameters](https://www.rfc-editor.org/rfc/rfc9000.html#name-transport-parameter-definit) pub fn with_streams_concurrency_strategy( self, stream_strategy_factory: Arc, ) -> Self { Self { stream_strategy_factory, ..self } } /// Specify qlog collector for server connections. /// /// If you call this multiple times, only the last `logger` will be used. /// /// Pre-implemented loggers: /// - [`LegacySeqLogger`]: Generates qlog files compatible with [qvis] visualization. /// - `LegacySeqLogger::new(PathBuf::from("/dir"))`: Write to files `{connection_id}_{role}.sqlog` in `dir` /// - `LegacySeqLogger::new(tokio::io::stdout())`: Stream to stdout /// - `LegacySeqLogger::new(tokio::io::stderr())`: Stream to stderr /// /// Output format: JSON-SEQ ([RFC7464]), one JSON event per line. /// /// - [`handy::NoopLogger`] (default): Ignores all qlog events (default, recommended for production). /// /// [qvis]: https://qvis.quictools.info/ /// [RFC7464]: https://www.rfc-editor.org/rfc/rfc7464 /// [`LegacySeqLogger`]: qevent::telemetry::handy::LegacySeqLogger pub fn with_qlog(self, qlogger: Arc) -> Self { Self { qlogger, ..self } } } impl QuicClientBuilder> { /// Choose how to verify server certificates. /// /// Read [TlsClientConfigBuilder::with_root_certificates] for more information. pub fn with_root_certificates( self, root_store: impl Into>, ) -> QuicClientBuilder> { self.map_tls(|tls_config_builder| tls_config_builder.with_root_certificates(root_store)) } /// Choose how to verify server certificates using a webpki verifier. /// /// Read [TlsClientConfigBuilder::with_webpki_verifier] for more information. pub fn with_webpki_verifier( self, verifier: Arc, ) -> QuicClientBuilder> { self.map_tls(|tls_config_builder| tls_config_builder.with_webpki_verifier(verifier)) } /// Replace the default server certificate verifier with a custom one. /// /// This exposes rustls' low-level custom verifier hook. The provided /// verifier becomes fully responsible for server certificate validation, /// including any WebPKI, OCSP, pinning, or private PKI checks you require. pub fn with_custom_server_cert_verifier( self, verifier: Arc, ) -> QuicClientBuilder> { self.map_tls(|tls_config_builder| { tls_config_builder .dangerous() .with_custom_certificate_verifier(verifier) }) } /// Dangerously disable server certificate verification. pub fn without_verifier(self) -> QuicClientBuilder> { #[derive(Debug)] struct DangerousServerCertVerifier; impl rustls::client::danger::ServerCertVerifier for DangerousServerCertVerifier { fn verify_server_cert( &self, _: &rustls::pki_types::CertificateDer<'_>, _: &[rustls::pki_types::CertificateDer<'_>], _: &rustls::pki_types::ServerName<'_>, _: &[u8], _: rustls::pki_types::UnixTime, ) -> Result { Ok(rustls::client::danger::ServerCertVerified::assertion()) } fn verify_tls12_signature( &self, _: &[u8], _: &rustls::pki_types::CertificateDer<'_>, _: &rustls::DigitallySignedStruct, ) -> Result { Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) } fn verify_tls13_signature( &self, _: &[u8], _: &rustls::pki_types::CertificateDer<'_>, _: &rustls::DigitallySignedStruct, ) -> Result { Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) } fn supported_verify_schemes(&self) -> Vec { vec![ rustls::SignatureScheme::RSA_PKCS1_SHA1, rustls::SignatureScheme::ECDSA_SHA1_Legacy, rustls::SignatureScheme::RSA_PKCS1_SHA256, rustls::SignatureScheme::ECDSA_NISTP256_SHA256, rustls::SignatureScheme::RSA_PKCS1_SHA384, rustls::SignatureScheme::ECDSA_NISTP384_SHA384, rustls::SignatureScheme::RSA_PKCS1_SHA512, rustls::SignatureScheme::ECDSA_NISTP521_SHA512, rustls::SignatureScheme::RSA_PSS_SHA256, rustls::SignatureScheme::RSA_PSS_SHA384, rustls::SignatureScheme::RSA_PSS_SHA512, rustls::SignatureScheme::ED25519, rustls::SignatureScheme::ED448, ] } } self.map_tls(|tls_config_builder| { tls_config_builder .dangerous() .with_custom_certificate_verifier(Arc::new(DangerousServerCertVerifier)) }) } } impl QuicClientBuilder> { /// Sets a single certificate chain and matching private key for use /// in client authentication. /// /// Read [TlsClientConfigBuilder::with_single_cert] for more information. pub fn with_cert( self, cert: impl handy::ToCertificate, key: impl handy::ToPrivateKey, ) -> QuicClientBuilder { self.map_tls(|tls_config_builder| { tls_config_builder .with_client_auth_cert(cert.to_certificate(), key.to_private_key()) .expect("The private key was wrong encoded or failed validation") }) } /// Do not support client auth. pub fn without_cert(self) -> QuicClientBuilder { self.map_tls(|tls_config_builder| tls_config_builder.with_no_client_auth()) } /// Sets a custom [`ResolvesClientCert`]. pub fn with_cert_resolver( self, cert_resolver: Arc, ) -> QuicClientBuilder { self.map_tls(|tls_config_builder| { tls_config_builder.with_client_cert_resolver(cert_resolver) }) } } impl QuicClientBuilder { /// Specify the [alpn-protocol-ids] that will be sent in `ClientHello`. /// /// By default, its empty and the APLN extension wont be sent. /// /// If you call this multiple times, all the `alpn_protocol` will be used. /// /// [alpn-protocol-ids](https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids) pub fn with_alpns(mut self, alpns: impl IntoIterator>>) -> Self { self.tls_config .alpn_protocols .extend(alpns.into_iter().map(Into::into)); self } /// Enable the `keylog` feature. /// /// This is useful when you want to debug the TLS connection. /// /// The keylog file will be in the file that environment veriable `SSLKEYLOGFILE` pointed to. /// /// Read [`rustls::KeyLogFile`] for more information. pub fn enable_sslkeylog(mut self) -> Self { self.tls_config.key_log = Arc::new(rustls::KeyLogFile::new()); self } pub fn enable_0rtt(mut self) -> Self { self.tls_config.enable_early_data = true; self } /// Build the QuicClient, ready to initiates connect to the servers. pub fn build(self) -> QuicClient { QuicClient { network: self.network, bind_ifaces: self.bind_ifaces, manual_bind: Arc::new(AtomicBool::new(self.manual_bind)), _prefer_versions: self.prefer_versions, token_sink: self.token_sink, parameters: self.parameters, tls_config: self.tls_config, stream_strategy_factory: self.stream_strategy_factory, defer_idle_timeout: self.defer_idle_timeout, qlogger: self.qlogger, } } } ================================================ FILE: dquic/src/common.rs ================================================ use std::{net::SocketAddr, sync::Arc}; use futures::{Stream, StreamExt, stream}; use qconnection::{ prelude::{EndpointAddr, handy}, qinterface::{ BindInterface, Interface, bind_uri::BindUri, component::{ Components, alive::RebindOnNetworkChangedComponent, location::{Locations, LocationsComponent}, route::{QuicRouter, QuicRouterComponent}, }, device::Devices, io::ProductIO, manager::InterfaceManager, }, qtraversal::{ nat::{client::StunClientsComponent, router::StunRouterComponent}, route::{ForwardersComponent, ReceiveAndDeliverPacketComponent}, }, }; use qresolve::{Family, Resolve, SystemResolver}; #[derive(Clone)] pub struct Network { pub resolver: Arc, pub devices: &'static Devices, pub iface_factory: Arc, pub iface_manager: Arc, pub quic_router: Arc, pub stun_server: Option>, pub locations: Arc, } impl Default for Network { fn default() -> Self { Self { resolver: Arc::new(SystemResolver), devices: Devices::global(), iface_factory: Arc::new(handy::DEFAULT_IO_FACTORY), iface_manager: InterfaceManager::global().clone(), quic_router: QuicRouter::global().clone(), stun_server: None, locations: Arc::new(Locations::new()), } } } impl Network { /// 只取第一个可用的 STUN agent 即返回,后续由 StunClientsComponent 自动补充到 MIN_AGENTS async fn lookup_first_agent( &self, stun_server: &str, family: Family, ) -> Option> { let stream = self.resolver.lookup(stun_server).await.ok()?; let mut stream = std::pin::pin!(stream); while let Some((_source, ep)) = stream.next().await { let EndpointAddr::Direct { addr } = ep else { continue; }; if match family { Family::V4 => addr.is_ipv4(), Family::V6 => addr.is_ipv6(), } { tracing::trace!("resolved first stun agent for {stun_server}: {addr}"); return Some(vec![addr]); } } None } fn init_iface_components( &self, bind_iface: &BindInterface, stun_agent: Option<(Arc, Vec)>, ) { bind_iface.with_components_mut(move |components: &mut Components, iface: &Interface| { // rebind interface on network changed components.init_with(|| RebindOnNetworkChangedComponent::new(iface, self.devices)); // quic packet router let quic_router = components .init_with(|| QuicRouterComponent::new(self.quic_router.clone())) .router(); let locations = components .init_with(|| LocationsComponent::new(iface.downgrade(), self.locations.clone())) .clone(); match stun_agent { // stun enabled: Some((stun_server, stun_agents)) => { // initial stun router let stun_router = components .init_with(|| StunRouterComponent::new(iface.downgrade())) .router(); // initial stun clients (后续会自动补充到 MIN_AGENTS) let clients = components .init_with(|| { StunClientsComponent::new( iface.downgrade(), stun_router.clone(), self.resolver.clone(), stun_server, stun_agents, Some(locations.clone()), ) }) .clone(); // initial forwarder let relay = bind_iface .bind_uri() .relay() .and_then(|r| r.parse::().ok()); let forwarder = if let Some(relay) = relay { components .init_with(|| ForwardersComponent::new_server(relay)) .forwarder() } else { components .init_with(|| ForwardersComponent::new_client(clients)) .forwarder() }; // initial receive and deliver packet component(quic, stun and forwarder) components.init_with(|| { ReceiveAndDeliverPacketComponent::builder(iface.downgrade()) .quic_router(quic_router) .stun_router(stun_router) .forwarder(forwarder) .init() }); } // no stun: receive and deliver quic only None => { components.init_with(|| { ReceiveAndDeliverPacketComponent::builder(iface.downgrade()) .quic_router(quic_router) .init() }); } }; }); } pub async fn bind(&self, bind_uri: BindUri) -> BindInterface { let stun_server = if let Some(server) = bind_uri.stun_server() { Some(Arc::from(server)) } else if let Some("false") = bind_uri.prop(BindUri::STUN_PROP).as_deref() { None } else { self.stun_server.clone() }; let family = bind_uri.family(); let stun_agents = match &stun_server { Some(server) => self .lookup_first_agent(server.as_ref(), family) .await .unwrap_or_default(), None => vec![], }; let factory = self.iface_factory.clone(); let bind_iface = self.iface_manager.bind(bind_uri, factory).await; self.init_iface_components(&bind_iface, stun_server.map(|s| (s, stun_agents))); bind_iface } pub async fn bind_many( &self, bind_uris: impl IntoIterator>, ) -> impl Stream { stream::iter(bind_uris).then(async |bind_uri| self.bind(bind_uri.into()).await) } } ================================================ FILE: dquic/src/lib.rs ================================================ #![doc=include_str!("../README.md")] pub mod prelude { pub use ::qconnection; pub use qconnection::prelude::*; pub use qresolve::Resolve; pub use crate::{ client::{BindInterfaceError, ConnectServerError, QuicClient}, server::{ListenError, ListenersShutdown, QuicListeners, Server, ServerError}, }; pub mod handy { pub use qconnection::prelude::handy::*; pub use qresolve::SystemResolver; pub use crate::cert::{ToCertificate, ToPrivateKey}; } } pub mod builder { pub use qconnection::builder::*; pub use crate::{client::QuicClientBuilder, server::QuicListenersBuilder}; } // Hidden modules used to integrate the code examples from the README into the cargo test mod doc { #[doc=include_str!("../README_CN.md")] mod zh {} // Omitted: Duplicate with crate documentation // #[doc=include_str!("../../README.md")] // mod en {} } pub use ::qconnection::{self, qbase, qdatagram, qevent, qinterface, qrecovery, qtraversal}; pub use ::qresolve; mod cert; mod client; mod common; mod server; ================================================ FILE: dquic/src/server.rs ================================================ use std::{ collections::HashMap, fmt::Debug, io, ops::{Deref, DerefMut}, pin::pin, sync::Arc, time::Duration, }; use arc_swap::ArcSwap; use dashmap::DashMap; use futures::StreamExt; use qbase::{ packet::{DataHeader, GetDcid, Packet, long::DataHeader as LongHeader}, param::ServerParameters, token::TokenProvider, util::BoundQueue, }; use qconnection::{ self, qinterface::{self, bind_uri::BindUri, component::location::Locations, device::Devices}, tls::AcceptAllClientAuther, }; use qevent::telemetry::QLog; use qinterface::{ BindInterface, component::route::{QuicRouter, Way}, io::ProductIO, manager::InterfaceManager, }; use rustls::{ ConfigBuilder, ServerConfig as TlsServerConfig, WantsVerifier, server::{NoClientAuth, ResolvesServerCert, danger::ClientCertVerifier}, sign::CertifiedKey, }; use thiserror::Error; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tracing::Instrument; use crate::{prelude::*, *}; /// Errors that can occur during server management operations. #[derive(Debug, thiserror::Error)] pub enum ServerError { /// The server with the specified name already exists. #[error("Server '{server}' already exists")] ServerAlreadyExists { server: String }, /// The server with the specified name was not found. #[error("Server '{server}' not found")] ServerNotFound { server: String }, /// Failed to load the private key for the server. #[error("Failed to load private key for server '{server}': {source}")] InvalidCertOrKey { server: String, #[source] source: rustls::Error, }, } impl From for io::Error { fn from(error: ServerError) -> Self { let kind = match &error { ServerError::ServerAlreadyExists { .. } => io::ErrorKind::AlreadyExists, ServerError::ServerNotFound { .. } => io::ErrorKind::NotFound, ServerError::InvalidCertOrKey { .. } => io::ErrorKind::InvalidInput, }; io::Error::new(kind, error) } } /// Errors that can occur during QuicListeners builder creation. #[derive(Debug, thiserror::Error)] pub enum ListenError { /// A QuicListeners instance is already running globally. #[error("A QuicListeners is already running on the router")] AlreadyRunning, } impl From for io::Error { fn from(error: ListenError) -> Self { match error { ListenError::AlreadyRunning => io::Error::new(io::ErrorKind::AlreadyExists, error), } } } type TlsServerConfigBuilder = ConfigBuilder; #[derive(Debug, Default)] pub struct VirtualHosts(Arc>); impl ResolvesServerCert for VirtualHosts { fn resolve(&self, client_hello: rustls::server::ClientHello) -> Option> { self.0 .get(client_hello.server_name()?) .map(|server| server.certified_key()) } } pub struct Server { network: common::Network, bind_ifaces: DashMap, // todo: [update] change to LocalAgent certified_key: ArcSwap, } impl std::fmt::Debug for Server { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Server") .field("bind_ifaces", &self.bind_ifaces) .field("certified_key", &self.certified_key()) .finish() } } impl Server { pub fn bind_interfaces(&self) -> HashMap { self.bind_ifaces .iter() .map(|entry| (entry.key().clone(), entry.value().clone())) .collect() } pub async fn bind(&self, bind_uris: impl IntoIterator>) { let mut bind_ifaces = pin!(self.network.bind_many(bind_uris).await); while let Some(bind_iface) = bind_ifaces.next().await { self.bind_ifaces.insert(bind_iface.bind_uri(), bind_iface); } } pub fn get_iface(&self, bind_uri: &BindUri) -> Option { self.bind_ifaces .get(bind_uri) .map(|iface| iface.value().clone()) } pub fn remove_iface(&self, bind_uri: &BindUri) -> Option { self.bind_ifaces.remove(bind_uri).map(|entry| entry.1) } pub fn certified_key(&self) -> Arc { self.certified_key.load_full() } pub fn update_ocsp(&self, ocsp: Option>) { self.certified_key.rcu(|current| CertifiedKey { cert: current.cert.clone(), key: current.key.clone(), ocsp: ocsp.clone(), }); } } type Incomings = BoundQueue<((Connection, String, Pathway, Link), OwnedSemaphorePermit)>; /// A QUIC listener that can serve multiple virtual servers, accepting incoming connections. /// /// ## Creating Listeners /// /// Use [`QuicListenersBuilder`] to configure the listener, then call [`QuicListenersBuilder::listen`] /// to start accepting connections. /// /// **Note**: Only one [`QuicListeners`] instance can run at a time globally. /// To stop the listeners, call [`QuicListeners::shutdown`] or drop all references to the [`Arc`]. /// /// ## Managing Servers /// /// Add multiple virtual servers by calling [`QuicListeners::add_server`] multiple times. /// Each server is identified by its server name (SNI) and handles connections independently. /// /// - Servers can share the same network interfaces /// - Servers can be added without initially binding to any interface /// /// ## Connection Handling /// /// Call [`QuicListeners::accept`] to receive incoming connections. The listener automatically: /// - Routes connections to the appropriate server based on SNI (Server Name Indication) /// - Rejects connections if the target server isn't listening on the receiving interface /// - Returns connections that may still be completing their QUIC handshake #[derive(Clone)] pub struct QuicListeners { network: common::Network, // server servers: Arc>, // must be empty while building incomings: Arc, // identify the building QuicListeners backlog: Arc, // limit the number of concurrent connections // server: quic config(in initialize order) _supported_versions: Vec, token_provider: Arc, parameters: ServerParameters, anti_port_scan: bool, client_auther: Arc, tls_config: TlsServerConfig, stream_strategy_factory: Arc, defer_idle_timeout: Duration, qlogger: Arc, } impl QuicListeners { /// Add a virtual server with its certificate chain and private key. /// /// Creates a new virtual host identified by its server name (SNI). The server will use the /// certificate chain and private key that matches the SNI in the client's `ClientHello` message. /// If no matching server is found, the connection will be rejected. /// /// A server can be added without binding to any interface initially, but will not accept /// connections until interfaces are added via [`bind`]. This allows flexible /// server configuration and hot-swapping of network bindings. /// /// [`bind`]: Server::bind pub async fn add_server( &self, server_name: impl Into, cert_chain: impl handy::ToCertificate, private_key: impl handy::ToPrivateKey, bind_uris: impl IntoIterator>, ocsp: impl Into>>, ) -> Result<(), ServerError> { let server = server_name.into(); let server_entry = match self.servers.entry(server.clone()) { dashmap::Entry::Vacant(entry) => entry, dashmap::Entry::Occupied(..) => { return Err(ServerError::ServerAlreadyExists { server }); } }; let cert = cert_chain.to_certificate(); let key = self .tls_config .crypto_provider() .key_provider .load_private_key(private_key.to_private_key()) .map_err(|e| ServerError::InvalidCertOrKey { server: server.clone(), source: e, })?; let ocsp = ocsp.into(); let certified_key = CertifiedKey { cert, key, ocsp }; certified_key .keys_match() .map_err(|source| ServerError::InvalidCertOrKey { server: server.clone(), source, })?; let certified_key = Arc::new(certified_key); let bind_uris = bind_uris.into_iter(); let server = Server { network: self.network.clone(), bind_ifaces: DashMap::with_capacity(bind_uris.size_hint().0), certified_key: ArcSwap::new(certified_key), }; server.bind(bind_uris).await; server_entry.insert(server); Ok(()) } /// Remove a virtual server and all its associated interfaces. /// /// Completely removes a server from the listeners, including all network interfaces /// it was bound to (if the interface is not used by other servers). /// This is the inverse operation of [`add_server`] and provides a clean /// way to decommission a virtual host. /// /// Returns `true` if the server existed and was removed, `false` if no server with the /// specified name was found. You must remove an existing server before adding a new /// one with the same name. /// /// [`add_server`]: QuicListeners::add_server pub fn remove_server(&self, server_name: &str) -> bool { self.servers.remove(server_name).is_some() } /// Get the server by its name. pub fn get_server<'l>(&'l self, server_name: &str) -> Option + 'l> { self.servers.get(server_name) } /// Get a mutable reference to the server by its name. pub fn get_server_mut<'l>( &'l self, server_name: &str, ) -> Option + 'l> { self.servers.get_mut(server_name) } pub fn servers(&self) -> Vec { self.servers .iter() .map(|entry| entry.key().clone()) .collect() } } #[derive(Debug, Error, Clone, Copy)] #[error("Listeners shutdown")] pub struct ListenersShutdown; impl QuicListeners { /// Accept an incoming QUIC connection from the queue. /// /// Returns the connection, connected server name, and network path information. /// Connections are automatically routed based on SNI (Server Name Indication). /// /// The connection queue size is limited by the `backlog` parameter in [`QuicListenersBuilder::listen`]. /// When the queue is full, new incoming packets may be dropped at the network level. pub async fn accept(&self) -> Result<(Connection, String, Pathway, Link), ListenersShutdown> { self.incomings .recv() .await .ok_or(ListenersShutdown) .map(|(i, ..)| i) } /// Close the QuicListeners, stops accepting new connections. /// /// Unaccepted connections will be closed pub fn shutdown(&self) { self.incomings.close(); self.backlog.close(); } } impl Drop for QuicListeners { fn drop(&mut self) { self.shutdown(); } } struct ServerAuther { anti_port_scan: bool, iface: BindUri, servers: Arc>, } impl AuthClient for ServerAuther { fn verify_client_name( &self, server_agent: &LocalAgent, _: Option<&str>, ) -> ClientNameVerifyResult { match self .servers .get(server_agent.name()) .is_some_and(|server| server.bind_ifaces.contains_key(&self.iface)) { true => ClientNameVerifyResult::Accept, false if self.anti_port_scan => ClientNameVerifyResult::SilentRefuse("".to_owned()), false => ClientNameVerifyResult::Refuse("".to_owned()), } } fn verify_client_agent(&self, _: &LocalAgent, _: &RemoteAgent) -> ClientAgentVerifyResult { ClientAgentVerifyResult::Accept } } // internal methods impl QuicListeners { #[tracing::instrument( target = "quic_listeners", level = "debug", skip_all, fields(%bind_uri, %pathway, %link, odcid=tracing::field::Empty, server_name=tracing::field::Empty) )] pub(crate) fn try_accept_connection(&self, packet: Packet, (bind_uri, pathway, link): Way) { let origin_dcid = match &packet { Packet::Data(data_packet) => match &data_packet.header { DataHeader::Long(LongHeader::Initial(hdr)) => *hdr.dcid(), DataHeader::Long(LongHeader::ZeroRtt(hdr)) => *hdr.dcid(), _ => return, }, _ => return, }; tracing::Span::current().record("odcid", origin_dcid.to_string()); if origin_dcid.is_empty() { tracing::debug!(target: "quic_listeners", "Received an initial/0rtt packet with empty destination CID, ignoring it"); return; } // Acquire a permit from the backlog semaphore to limit the number of concurrent connections. let Ok(premit) = self.backlog.clone().try_acquire_owned() else { tracing::debug!(target: "quic_listeners", "Backlog full, dropping incoming packet"); return; }; let server_auther = ServerAuther { anti_port_scan: self.anti_port_scan, iface: bind_uri.clone(), servers: self.servers.clone(), }; let connection = Connection::new_server(self.token_provider.clone()) .with_parameters(self.parameters.clone()) .with_client_auther(Box::new((server_auther, self.client_auther.clone()))) .with_tls_config(self.tls_config.clone()) .with_streams_concurrency_strategy(self.stream_strategy_factory.as_ref()) .with_zero_rtt(self.tls_config.max_early_data_size == 0xffffffff) .with_defer_idle_timeout(self.defer_idle_timeout) .with_iface_factory(self.network.iface_factory.clone()) .with_iface_manager(self.network.iface_manager.clone()) .with_quic_router(self.network.quic_router.clone()) .with_locations(self.network.locations.clone()) // todo // .with_stun_servers() .with_cids(origin_dcid) .with_qlog(self.qlogger.clone()) .run(); let incomings = self.incomings.clone(); let quic_router = self.network.quic_router.clone(); let try_accept_connection = async move { quic_router.deliver(packet, (bind_uri, pathway, link)).await; match connection.server_name().await { Ok(server_name) => { tracing::Span::current().record("server_name", &server_name); _ = connection.subscribe_local_address(); let incoming = (connection, server_name, pathway, link); match incomings.send((incoming, premit)).await { Ok(..) => { tracing::debug!(target: "quic_listeners", "Accepted incoming connection") } Err(..) => { tracing::debug!(target: "quic_listeners", "Listeners is shutdown, closing incoming connection") } } } Err(error) => { tracing::debug!( target: "quic_listeners", "Failed to accept connection: {error}", ); } } }; // Task completes after a single accept-notify cycle; no explicit join needed. tokio::spawn(try_accept_connection.in_current_span()); } } /// The builder for the quic listeners. #[derive(Clone)] pub struct QuicListenersBuilder { // network network: common::Network, // server servers: Arc>, // must be empty while building incomings: Arc, // identify the building QuicListeners // server: quic config(in initialize order) supported_versions: Vec, token_provider: Arc, parameters: ServerParameters, anti_port_scan: bool, client_auther: Arc, tls_config: T, stream_strategy_factory: Arc, defer_idle_timeout: Duration, qlogger: Arc, } impl QuicListeners { /// Start to build a [`QuicListeners`]. pub fn builder() -> QuicListenersBuilder> { Self::builder_with_tls(TlsServerConfig::builder_with_protocol_versions(&[ &rustls::version::TLS13, ])) } /// Start to build a QuicServer with the given tls crypto provider. pub fn builder_with_crypto_provider( provider: Arc, ) -> Result>, rustls::Error> { Ok(Self::builder_with_tls( TlsServerConfig::builder_with_provider(provider) .with_protocol_versions(&[&rustls::version::TLS13])?, )) } /// Start to build a [`QuicListeners`] with the given TLS configuration. /// /// This is useful when you want to customize the TLS configuration, or integrate qm-quic with other crates. pub fn builder_with_tls(tls_config: T) -> QuicListenersBuilder { QuicListenersBuilder { // network network: common::Network::default(), // server servers: Arc::new(DashMap::new()), // must be empty while building incomings: Arc::new(BoundQueue::new(8)), // identify the building QuicListeners // server: quic config(in initialize order) supported_versions: vec![1], token_provider: Arc::new(handy::NoopTokenRegistry), parameters: handy::server_parameters(), anti_port_scan: false, client_auther: Arc::new(AcceptAllClientAuther), tls_config, stream_strategy_factory: Arc::new(handy::ConsistentConcurrency::new), defer_idle_timeout: Duration::ZERO, qlogger: Arc::new(handy::NoopLogger), } } } impl QuicListenersBuilder { pub fn with_resolver(mut self, resolver: Arc) -> Self { self.network.resolver = resolver; self } pub fn with_physical_ifaces(mut self, physical_ifaces: &'static Devices) -> Self { self.network.devices = physical_ifaces; self } /// Specify how hosts bind to the interface. /// /// If you call this multiple times, only the last `factory` will be used. /// /// The default quic interface is provided by [`handy::DEFAULT_IO_FACTORY`]. /// For Unix and Windows targets, this is a high performance UDP library supporting GSO and GRO /// provided by `qudp` crate. For other platforms, please specify you own factory. pub fn with_iface_factory(mut self, iface_factory: Arc) -> Self { self.network.iface_factory = iface_factory; self } pub fn with_iface_manager(mut self, iface_manager: Arc) -> Self { self.network.iface_manager = iface_manager; self } /// Specify the router to use for the listeners. /// /// Packets received from the interface bound to the server will be deliver this router, /// connectless packets (maybe incoming client connection) will be delivered to QuicListeners. /// /// A router can only be listened to by one QuicListener, /// or the [`QuicListenersBuilder::listen`] will fail. pub fn with_router(mut self, router: Arc) -> Self { self.network.quic_router = router; self } pub fn with_stun(mut self, stun_server: impl Into>) -> Self { self.network.stun_server = Some(stun_server.into()); self } /// Specify the locations for interface sharing. /// /// The given locations is shared by all connections created by this listeners. pub fn with_locations(mut self, locations: Arc) -> Self { self.network.locations = locations; self } /// (WIP)Specify the supported quic versions. /// /// If you call this multiple times, only the last call will take effect. pub fn with_supported_versions(mut self, versions: impl IntoIterator) -> Self { self.supported_versions.clear(); self.supported_versions.extend(versions); self } /// Specify how server to create and verify the client's Token in [address verification]. /// /// If you call this multiple times, only the last `token_provider` will be used. /// /// [address verification](https://www.rfc-editor.org/rfc/rfc9000.html#name-address-validation) pub fn with_token_provider(self, token_provider: Arc) -> Self { Self { token_provider, ..self } } /// Specify the [transport parameters] for the server connections. /// /// If you call this multiple times, only the last `parameters` will be used. /// /// Usually, you don't need to call this method, because the server will use a set of default parameters. /// /// [transport parameters](https://www.rfc-editor.org/rfc/rfc9000.html#name-transport-parameter-definit) pub fn with_parameters(mut self, parameters: ServerParameters) -> Self { self.parameters = parameters; self } /// Enable anti-port scanning protection. /// /// When anti-port scanning protection is enabled, the server will silently drop connections /// that fail validation (e.g., invalid ClientHello, authentication failures) /// without sending any response packets. /// /// This security feature provides the following benefits: /// - Prevents attackers from detecting server presence through port scanning /// - Reduces the attack surface by not revealing server configuration details /// - Protects against network reconnaissance and probing attacks /// - Makes the server appear "offline" to unauthorized connection attempts /// /// **Security Note:** This feature should be used carefully as it may make /// debugging connection issues more difficult. Consider using it in production /// environments where security is prioritized over observability. /// /// **Tip:** For enhanced security, combine this with [`with_client_auther`] to implement /// custom authentication logic while maintaining stealth behavior for failed connections. /// /// Default: disabled /// /// [`with_client_auther`]: QuicListenersBuilder::with_client_auther pub fn enable_anti_port_scan(mut self) -> Self { self.anti_port_scan = true; self } /// Specify custom client authentication handlers for the server. /// /// Client authers are used to perform additional validation beyond standard TLS /// certificate verification. They can verify server names, client parameters, /// and client certificates according to custom business logic. /// /// Each [`AuthClient`] implementation provides three verification methods: /// - `verify_server_name()`: Validates the requested server name (SNI) /// - `verify_client_params()`: Validates client QUIC transport parameters /// - `verify_client_certs()`: Validates client certificate chains /// /// All provided authers must approve the connection for it to be accepted. /// If any auther rejects the connection, it will be dropped. /// /// If you call this multiple times, only the last `client_auther` will be used. /// /// **Security Enhancement:** When combined with [`enable_anti_port_scan`], /// failed authentication attempts will be silently dropped without any response, /// providing enhanced security against reconnaissance attacks. /// /// **TLS Protocol Note:** Certificate verification failures during the TLS handshake /// will still send error responses to clients, as the server has already sent /// its `ServerHello` message at that point. The stealth behavior only applies to /// earlier validation failures that occur before the TLS handshake begins. /// /// **Built-in Validation:** The server automatically verifies that the interface /// receiving the client connection is configured to listen for the requested /// server name (SNI). This built-in validation ensures proper routing of /// connections to their intended hosts. /// /// Default: empty (only built-in host and interface validation) /// /// [`AuthClient`]: qconnection::tls::AuthClient /// [`enable_anti_port_scan`]: QuicListenersBuilder::enable_anti_port_scan pub fn with_client_auther(mut self, client_auther: impl AuthClient + 'static) -> Self { self.client_auther = Arc::new(client_auther); self } fn map_tls(self, f: impl FnOnce(T) -> T1) -> QuicListenersBuilder { QuicListenersBuilder { network: self.network, servers: self.servers, incomings: self.incomings, supported_versions: self.supported_versions, token_provider: self.token_provider, parameters: self.parameters, anti_port_scan: self.anti_port_scan, client_auther: self.client_auther, tls_config: f(self.tls_config), stream_strategy_factory: self.stream_strategy_factory, defer_idle_timeout: self.defer_idle_timeout, qlogger: self.qlogger, } } /// Specify the factory which product the streams concurrency strategy controller for the server. /// /// The streams controller is used to control the concurrency of data streams. /// Take a look of [`ControlStreamsConcurrency`] for more information. /// /// If you call this multiple times, only the last `controller` will be used. pub fn with_streams_concurrency_strategy( self, stream_strategy_factory: Arc, ) -> Self { Self { stream_strategy_factory, ..self } } /// Provide an option to defer an idle timeout. /// /// This facility could be used when the application wishes to avoid losing /// state that has been associated with an open connection but does not expect /// to exchange application data for some time. /// /// See [Deferring Idle Timeout](https://datatracker.ietf.org/doc/html/rfc9000#name-deferring-idle-timeout) /// of [RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000) /// for more information. pub fn defer_idle_timeout(mut self, duration: Duration) -> Self { self.defer_idle_timeout = duration; self } /// Specify qlog collector for server connections. /// /// If you call this multiple times, only the last `logger` will be used. /// /// Pre-implemented loggers: /// - [`LegacySeqLogger`]: Generates qlog files compatible with [qvis] visualization. /// - `LegacySeqLogger::new(PathBuf::from("/dir"))`: Write to files `{connection_id}_{role}.sqlog` in `dir` /// - `LegacySeqLogger::new(tokio::io::stdout())`: Stream to stdout /// - `LegacySeqLogger::new(tokio::io::stderr())`: Stream to stderr /// /// Output format: JSON-SEQ ([RFC7464]), one JSON event per line. /// /// - [`handy::NoopLogger`] (default): Ignores all qlog events (default, recommended for production). /// /// [qvis]: https://qvis.quictools.info/ /// [RFC7464]: https://www.rfc-editor.org/rfc/rfc7464 /// [`LegacySeqLogger`]: qevent::telemetry::handy::LegacySeqLogger pub fn with_qlog(self, qlogger: Arc) -> Self { Self { qlogger, ..self } } } impl QuicListenersBuilder> { /// Choose how to verify client certificates. pub fn with_client_cert_verifier( self, client_cert_verifier: Arc, ) -> QuicListenersBuilder { let virtual_servers = Arc::new(VirtualHosts(self.servers.clone())); self.map_tls(|tls_config_builder| { tls_config_builder .with_client_cert_verifier(client_cert_verifier) .with_cert_resolver(virtual_servers) }) } /// Disable client authentication. pub fn without_client_cert_verifier(self) -> QuicListenersBuilder { let virtual_servers = Arc::new(VirtualHosts(self.servers.clone())); self.map_tls(|tls_config_builder| { tls_config_builder .with_client_cert_verifier(Arc::new(NoClientAuth)) .with_cert_resolver(virtual_servers) }) } } impl QuicListenersBuilder { /// Specify the [alpn-protocol-ids] that the server supports. /// /// If you call this multiple times, all the `alpn_protocol` will be used. /// /// If you never call this method, we will not do ALPN with the client. /// /// [alpn-protocol-ids](https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids) pub fn with_alpns(mut self, alpn: impl IntoIterator>>) -> Self { self.tls_config .alpn_protocols .extend(alpn.into_iter().map(Into::into)); self } pub fn enable_0rtt(mut self) -> Self { // The TLS early_data extension in the NewSessionTicket message is defined to convey (in the // max_early_data_size parameter) the amount of TLS 0-RTT data the server is willing to accept. QUIC does not // use TLS early data. QUIC uses 0-RTT packets to carry early data. Accordingly, the max_early_data_size // parameter is repurposed to hold a sentinel value 0xffffffff to indicate that the server is willing to accept QUIC // 0-RTT data. To indicate that the server does not accept 0-RTT data, the early_data extension is omitted from // the NewSessionTicket. The amount of data that the client can send in QUIC 0-RTT is controlled by the // initial_max_data transport parameter supplied by the server. self.tls_config.max_early_data_size = 0xffffffff; self } /// Start listening for incoming connections. /// /// The `backlog` parameter has the same meaning as the backlog parameter of the UNIX listen function, /// which is the maximum number of pending connections that can be queued. /// If the queue is full, new initial packets may be dropped. /// /// Panic if `backlog` is 0. pub fn listen(self, backlog: usize) -> Result, ListenError> { assert!(backlog > 0, "backlog must be greater than 0"); debug_assert!(self.servers.is_empty()); let quic_router = self.network.quic_router.clone(); let quic_listeners = Arc::new(QuicListeners { network: self.network, servers: self.servers, incomings: self.incomings, backlog: Arc::new(Semaphore::new(backlog)), _supported_versions: self.supported_versions, token_provider: self.token_provider, parameters: self.parameters, anti_port_scan: self.anti_port_scan, client_auther: self.client_auther, tls_config: self.tls_config, stream_strategy_factory: self.stream_strategy_factory, defer_idle_timeout: self.defer_idle_timeout, qlogger: self.qlogger, }); // TODO: optimize init order let listeners = quic_listeners.clone(); if !quic_router.on_connectless_packets(move |packet, way| { listeners.try_accept_connection(packet, way); }) { return Err(ListenError::AlreadyRunning); } Ok(quic_listeners) } } ================================================ FILE: dquic/tests/auth.rs ================================================ use std::{future::Future, sync::Arc, time::Duration}; use dquic::{ prelude::{handy::*, *}, qbase, qresolve::Source, }; use qbase::param::ServerParameters; use qconnection::qinterface::{bind_uri::BindUri, component::route::QuicRouter}; use rustls::{ pki_types::{CertificateDer, pem::PemObject}, server::WebPkiClientVerifier, }; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, time, }; use tokio_util::task::AbortOnDropHandle; mod common; use common::*; mod echo_common; use echo_common::*; #[test] fn client_without_verify() -> Result<(), BoxError> { run(async { let router = Arc::new(QuicRouter::default()); let (listeners, server_task) = launch_echo_server(router.clone(), server_parameters()).await?; let _server_task = AbortOnDropHandle::new(tokio::spawn(server_task)); let server_addr = get_server_addr(&listeners); let client = { let parameters = client_parameters(); let client = QuicClient::builder() .with_router(router) .without_verifier() .with_parameters(parameters) .without_cert() .with_qlog(qlogger()) .enable_sslkeylog() .build(); Arc::new(client) }; let connection = client .connected_to_with_source("localhost", [(Source::System, server_addr.into())]) .await?; send_and_verify_echo(&connection, TEST_DATA).await?; listeners.shutdown(); Ok(()) }) } struct ClientNameAuther; impl AuthClient for ClientNameAuther { fn verify_client_name( &self, _: &LocalAgent, client_name: Option<&str>, ) -> ClientNameVerifyResult { match matches!(client_name, Some("client")) { true => ClientNameVerifyResult::Accept, false if !SILENT => ClientNameVerifyResult::Refuse("".to_owned()), false => ClientNameVerifyResult::SilentRefuse("Client name ".to_owned()), } } fn verify_client_agent(&self, _: &LocalAgent, _: &RemoteAgent) -> ClientAgentVerifyResult { ClientAgentVerifyResult::Accept } } async fn launch_client_auth_test_server( quic_router: Arc, server_parameters: ServerParameters, ) -> Result<(Arc, impl Future), BoxError> { let mut roots = rustls::RootCertStore::empty(); roots.add_parsable_certificates(CertificateDer::pem_slice_iter(CA_CERT).map(Result::unwrap)); let listeners = QuicListeners::builder() .with_router(quic_router) .with_client_cert_verifier( WebPkiClientVerifier::builder(Arc::new(roots)) .build() .unwrap(), ) .with_client_auther(ClientNameAuther::) .with_parameters(server_parameters) .with_qlog(qlogger()) .listen(128)?; listeners .add_server( "localhost", SERVER_CERT, SERVER_KEY, [BindUri::from("inet://127.0.0.1:0").alloc_port()], None, ) .await?; Ok((listeners.clone(), serve_echo(listeners))) } #[test] fn auth_client_name() -> Result<(), BoxError> { run(async { const SILENT_REFUSE: bool = false; let router = Arc::new(QuicRouter::default()); let (listeners, server_task) = launch_client_auth_test_server::(router.clone(), server_parameters()) .await?; let _server_task = AbortOnDropHandle::new(tokio::spawn(server_task)); let server_addr = get_server_addr(&listeners); let client = { let mut roots = rustls::RootCertStore::empty(); roots.add_parsable_certificates( CertificateDer::pem_slice_iter(CA_CERT).map(Result::unwrap), ); let client = QuicClient::builder() .with_router(router) .with_root_certificates(roots) .with_parameters(client_parameters()) .with_cert(CLIENT_CERT, CLIENT_KEY) .with_name("client") .with_qlog(qlogger()) .enable_sslkeylog() .build(); Arc::new(client) }; let connection = client .connected_to_with_source("localhost", [(Source::System, server_addr.into())]) .await?; send_and_verify_echo(&connection, TEST_DATA).await?; listeners.shutdown(); Ok(()) }) } #[test] fn auth_client_name_incorrect_name() -> Result<(), BoxError> { run(async { const SILENT_REFUSE: bool = false; let router = Arc::new(QuicRouter::default()); let (listeners, server_task) = launch_client_auth_test_server::(router.clone(), server_parameters()) .await?; let _server_task = AbortOnDropHandle::new(tokio::spawn(server_task)); let server_addr = get_server_addr(&listeners); let client = { let mut roots = rustls::RootCertStore::empty(); roots.add_parsable_certificates( CertificateDer::pem_slice_iter(CA_CERT).map(Result::unwrap), ); let client = QuicClient::builder() .with_router(router) .with_root_certificates(roots) .with_parameters(client_parameters()) .with_cert(CLIENT_CERT, CLIENT_KEY) .with_name("wrong_name") .with_qlog(qlogger()) .enable_sslkeylog() .build(); Arc::new(client) }; let connection = client .connected_to_with_source("localhost", [(Source::System, server_addr.into())]) .await?; let error = connection.terminated().await; // TODO: 偶尔以NoViablePath结束,需要调查原因 assert_eq!(error.kind(), ErrorKind::ConnectionRefused); listeners.shutdown(); Ok(()) }) } #[test] fn auth_client_refuse() -> Result<(), BoxError> { run(async { const SILENT_REFUSE: bool = false; let router = Arc::new(QuicRouter::default()); let (listeners, server_task) = launch_client_auth_test_server::(router.clone(), server_parameters()) .await?; let _server_task = AbortOnDropHandle::new(tokio::spawn(server_task)); let server_addr = get_server_addr(&listeners); let client = { let parameters = client_parameters(); // no CLIENT_NAME let mut roots = rustls::RootCertStore::empty(); roots.add_parsable_certificates( CertificateDer::pem_slice_iter(CA_CERT).map(Result::unwrap), ); let client = QuicClient::builder() .with_router(router) .with_root_certificates(roots) .with_parameters(parameters) .with_cert(CLIENT_CERT, CLIENT_KEY) .with_qlog(qlogger()) .enable_sslkeylog() .build(); Arc::new(client) }; let connection = client .connected_to_with_source("localhost", [(Source::System, server_addr.into())]) .await?; let error = connection.terminated().await; // TODO: 偶尔以NoViablePath结束,需要调查原因 assert_eq!(error.kind(), ErrorKind::ConnectionRefused); listeners.shutdown(); Ok(()) }) } #[test] fn auth_client_refuse_silently() -> Result<(), BoxError> { run(async { const SILENT_REFUSE: bool = true; let router = Arc::new(QuicRouter::default()); let (listeners, server_task) = launch_client_auth_test_server::(router.clone(), server_parameters()) .await?; let _server_task = AbortOnDropHandle::new(tokio::spawn(server_task)); let server_addr = get_server_addr(&listeners); let client = { let parameters = client_parameters(); // no CLIENT_NAME let mut roots = rustls::RootCertStore::empty(); roots.add_parsable_certificates( CertificateDer::pem_slice_iter(CA_CERT).map(Result::unwrap), ); let client = QuicClient::builder() .with_router(router) .with_root_certificates(roots) .with_parameters(parameters) .with_cert(CLIENT_CERT, CLIENT_KEY) .with_qlog(qlogger()) .enable_sslkeylog() .build(); Arc::new(client) }; let connection = client .connected_to_with_source("localhost", [(Source::System, server_addr.into())]) .await?; // Silent refuse means server doesn't send CCF, so client should either: // 1. Timeout waiting for handshake // 2. Fail with NoViablePath when path times out let result = time::timeout(Duration::from_secs(1), connection.handshaked()).await; match result { Err(_timeout) => {} // Expected: timeout Ok(Err(e)) if e.kind() == ErrorKind::NoViablePath => {} // Also acceptable: path timeout Ok(other) => panic!("Expected timeout or NoViablePath, got {:?}", other), } listeners.shutdown(); Ok(()) }) } #[derive(serde::Serialize, serde::Deserialize)] struct Message { data: Vec, sign: Vec, } const SIGNATURE_SCHEME: rustls::SignatureScheme = rustls::SignatureScheme::ECDSA_NISTP256_SHA256; async fn send_and_verify_echo_with_sign_verify( connection: &Connection, data: &[u8], ) -> Result<(), BoxError> { let local_agent = connection.local_agent().await.unwrap().unwrap(); let remote_agent = connection.remote_agent().await.unwrap().unwrap(); let (_sid, (mut reader, mut writer)) = connection.open_bi_stream().await?.unwrap(); tracing::debug!("stream opened"); let write = async { let data = data.to_vec(); let sign = local_agent.sign(SIGNATURE_SCHEME, &data).unwrap(); let message = postcard::to_stdvec(&Message { data, sign }).unwrap(); writer.write_all(&message).await?; writer.shutdown().await?; tracing::info!("write done"); Result::<(), BoxError>::Ok(()) }; let read = async { let mut message = Vec::new(); reader.read_to_end(&mut message).await?; let message: Message = postcard::from_bytes(&message).unwrap(); remote_agent .verify(SIGNATURE_SCHEME, &message.data, &message.sign) .unwrap(); assert_eq!(message.data, data); tracing::info!("read done"); Result::<(), BoxError>::Ok(()) }; tokio::try_join!(read, write).map(|_| ()) } async fn echo_stream_with_sign_verify( local_agent: LocalAgent, remote_agent: RemoteAgent, mut reader: StreamReader, mut writer: StreamWriter, ) { let mut message = Vec::new(); reader.read_to_end(&mut message).await.unwrap(); let Message { data, sign } = postcard::from_bytes(&message).unwrap(); remote_agent.verify(SIGNATURE_SCHEME, &data, &sign).unwrap(); tracing::debug!("message received and verified"); let sign = local_agent.sign(SIGNATURE_SCHEME, &data).unwrap(); let message = postcard::to_stdvec(&Message { data, sign }).unwrap(); writer.write_all(&message).await.unwrap(); writer.shutdown().await.unwrap(); tracing::debug!("signed echo sent"); } pub async fn serve_echo_with_sign_verify(listeners: Arc) { while let Ok((connection, server, pathway, _link)) = listeners.accept().await { assert_eq!(server, "localhost"); let local_agent = connection.local_agent().await.unwrap().unwrap(); let remote_agent = connection.remote_agent().await.unwrap().unwrap(); tracing::info!(source = ?pathway.remote(),"accepted new connection"); tokio::spawn(async move { while let Ok((_sid, (reader, writer))) = connection.accept_bi_stream().await { tokio::spawn(echo_stream_with_sign_verify( local_agent.clone(), remote_agent.clone(), reader, writer, )); } }); } } async fn launch_echo_with_sign_verify_server( quic_router: Arc, parameters: ServerParameters, ) -> Result<(Arc, impl Future), BoxError> { let mut roots = rustls::RootCertStore::empty(); roots.add_parsable_certificates(CertificateDer::pem_slice_iter(CA_CERT).map(Result::unwrap)); let listeners = QuicListeners::builder() .with_router(quic_router) .with_client_cert_verifier( WebPkiClientVerifier::builder(Arc::new(roots)) .build() .unwrap(), ) .with_parameters(parameters) .with_qlog(qlogger()) .listen(128)?; listeners .add_server( "localhost", SERVER_CERT, SERVER_KEY, [BindUri::from("inet://127.0.0.1:0").alloc_port()], None, ) .await?; Ok((listeners.clone(), serve_echo_with_sign_verify(listeners))) } #[test] fn sign_and_verify() -> Result<(), BoxError> { run(async { let router = Arc::new(QuicRouter::default()); let (listeners, server_task) = launch_echo_with_sign_verify_server(router.clone(), server_parameters()).await?; let _server_task = AbortOnDropHandle::new(tokio::spawn(server_task)); let server_addr = get_server_addr(&listeners); let client = { let mut roots = rustls::RootCertStore::empty(); roots.add_parsable_certificates( CertificateDer::pem_slice_iter(CA_CERT).map(Result::unwrap), ); let client = QuicClient::builder() .with_router(router) .with_root_certificates(roots) .with_parameters(client_parameters()) .with_cert(CLIENT_CERT, CLIENT_KEY) .with_name("client") .with_qlog(qlogger()) .enable_sslkeylog() .build(); Arc::new(client) }; let connection = client .connected_to_with_source("localhost", [(Source::System, server_addr.into())]) .await?; send_and_verify_echo_with_sign_verify(&connection, TEST_DATA).await?; listeners.shutdown(); Ok(()) }) } ================================================ FILE: dquic/tests/common/mod.rs ================================================ // common is submod for both echo and auth tests #![allow(unused)] use std::{ future::Future, net::SocketAddr, sync::{Arc, LazyLock, OnceLock}, time::Duration, }; use dquic::{ prelude::{handy::*, *}, qbase::{self, param::ClientParameters}, qinterface::{component::route::QuicRouter, io::IO}, }; use qevent::telemetry::QLog; use rustls::pki_types::{CertificateDer, pem::PemObject}; use tokio::time; use tracing::level_filters::LevelFilter; use tracing_appender::non_blocking::WorkerGuard; use tracing_subscriber::{ Layer, prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt, }; pub fn qlogger() -> Arc { static QLOGGER: OnceLock> = OnceLock::new(); QLOGGER.get_or_init(|| Arc::new(NoopLogger)).clone() } pub type BoxError = Box; pub fn run(future: F) -> F::Output { static RT: LazyLock = LazyLock::new(|| { tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .unwrap() }); static TRACING: LazyLock = LazyLock::new(|| { let (non_blocking, guard) = tracing_appender::non_blocking(std::io::stdout()); tracing_subscriber::registry() // .with(console_subscriber::spawn()) .with( tracing_subscriber::fmt::layer() .with_writer(non_blocking) .with_file(true) .with_line_number(true) .with_filter(LevelFilter::DEBUG), ) .with(tracing_subscriber::filter::filter_fn(|metadata| { !metadata.target().contains("netlink_packet_route") })) .init(); guard }); RT.block_on(async move { LazyLock::force(&TRACING); match time::timeout(Duration::from_secs(60), future).await { Ok(output) => output, Err(_timedout) => panic!("test timed out"), } }) } pub fn launch_test_client( quic_router: Arc, parameters: ClientParameters, ) -> Arc { let mut roots = rustls::RootCertStore::empty(); roots.add_parsable_certificates(CertificateDer::pem_slice_iter(CA_CERT).map(Result::unwrap)); let client = QuicClient::builder() .with_router(quic_router) .with_root_certificates(roots) .with_parameters(parameters) .without_cert() .with_qlog(qlogger()) .enable_sslkeylog() .build(); Arc::new(client) } pub fn get_server_addr(listeners: &QuicListeners) -> SocketAddr { let localhost = listeners .get_server("localhost") .expect("Server localhost must be registered"); let localhost_bind_interface = localhost .bind_interfaces() .into_iter() .next() .map(|(_bind_uri, interface)| interface) .expect("Server should bind at least one address"); localhost_bind_interface .borrow() .bound_addr() .expect("failed to get real addr") } pub const CA_CERT: &[u8] = include_bytes!("../../../tests/keychain/localhost/ca.cert"); pub const SERVER_CERT: &[u8] = include_bytes!("../../../tests/keychain/localhost/server.cert"); pub const SERVER_KEY: &[u8] = include_bytes!("../../../tests/keychain/localhost/server.key"); pub const CLIENT_CERT: &[u8] = include_bytes!("../../../tests/keychain/localhost/client.cert"); pub const CLIENT_KEY: &[u8] = include_bytes!("../../../tests/keychain/localhost/client.key"); pub const TEST_DATA: &[u8] = include_bytes!("mod.rs"); ================================================ FILE: dquic/tests/echo.rs ================================================ use std::{sync::Arc, time::Duration}; use dquic::{ prelude::{handy::*, *}, qbase::param::{ClientParameters, ServerParameters}, qinterface::{bind_uri::BindUri, component::route::QuicRouter}, qresolve::Source, }; use tokio::task::JoinSet; use tokio_util::task::AbortOnDropHandle; use tracing::Instrument; mod common; use common::*; mod echo_common; use echo_common::*; #[test] fn single_stream() -> Result<(), BoxError> { run(async { let router = Arc::new(QuicRouter::default()); let (listeners, server_task) = launch_echo_server(router.clone(), server_parameters()).await?; let _server_task = AbortOnDropHandle::new(tokio::spawn(server_task)); let server_addr = get_server_addr(&listeners); let client = launch_test_client(router, client_parameters()); let connection = client .connected_to_with_source("localhost", [(Source::System, server_addr.into())]) .await?; send_and_verify_echo(&connection, TEST_DATA).await?; listeners.shutdown(); Ok(()) }) } #[test] fn signal_big_stream() -> Result<(), BoxError> { run(async { let router = Arc::new(QuicRouter::default()); let (listeners, server_task) = launch_echo_server(router.clone(), server_parameters()).await?; let _server_task = AbortOnDropHandle::new(tokio::spawn(server_task)); let server_addr = get_server_addr(&listeners); let client = launch_test_client(router, client_parameters()); let connection = client .connected_to_with_source("localhost", [(Source::System, server_addr.into())]) .await?; // Use 16x repeat (~58KB) instead of 1024x (~3.7MB) for CI stability send_and_verify_echo(&connection, &TEST_DATA.to_vec().repeat(16)).await?; listeners.shutdown(); Ok(()) }) } #[test] fn empty_stream() -> Result<(), BoxError> { run(async { let router = Arc::new(QuicRouter::default()); let (listeners, server_task) = launch_echo_server(router.clone(), server_parameters()).await?; let _server_task = AbortOnDropHandle::new(tokio::spawn(server_task)); let server_addr = get_server_addr(&listeners); let client = launch_test_client(router, client_parameters()); let connection = client .connected_to_with_source("localhost", [(Source::System, server_addr.into())]) .await?; send_and_verify_echo(&connection, b"").await?; listeners.shutdown(); Ok(()) }) } #[test] fn shutdown() -> Result<(), BoxError> { run(async { async fn serve_only_one_stream(listeners: Arc) { while let Ok((connection, server, pathway, _link)) = listeners.accept().await { assert_eq!(server, "localhost"); tracing::info!(source = ?pathway.remote(), "accepted new connection"); tokio::spawn(async move { let (_sid, (reader, writer)) = connection.accept_bi_stream().await?; echo_stream(reader, writer).await; _ = connection.close("Bye bye", 0); Result::<(), BoxError>::Ok(()) }); } } let router = Arc::new(QuicRouter::default()); let listeners = QuicListeners::builder() .with_router(router.clone()) .without_client_cert_verifier() .with_parameters(server_parameters()) .with_qlog(qlogger()) .listen(128)?; listeners .add_server( "localhost", SERVER_CERT, SERVER_KEY, [BindUri::from("inet://127.0.0.1:0").alloc_port()], None, ) .await?; let server_task = serve_only_one_stream(listeners.clone()); let _server_task = AbortOnDropHandle::new(tokio::spawn(server_task)); let server_addr = get_server_addr(&listeners); let client = launch_test_client(router, client_parameters()); let connection = client .connected_to_with_source("localhost", [(Source::System, server_addr.into())]) .await?; _ = connection.handshaked().await; // 可有可无 assert!( send_and_verify_echo(&connection, b"").await.is_err() || send_and_verify_echo(&connection, b"").await.is_err() ); connection.terminated().await; listeners.shutdown(); Ok(()) }) } #[test] fn idle_timeout() -> Result<(), BoxError> { run(async { fn server_parameters() -> ServerParameters { let mut params = handy::server_parameters(); params .set(ParameterId::MaxIdleTimeout, Duration::from_secs(1)) .expect("unreachable"); params } let router = Arc::new(QuicRouter::default()); let (listeners, server_task) = launch_echo_server(router.clone(), server_parameters()).await?; let _server_task = AbortOnDropHandle::new(tokio::spawn(server_task)); let server_addr = get_server_addr(&listeners); let client = launch_test_client(router, client_parameters()); let connection = client .connected_to_with_source("localhost", [(Source::System, server_addr.into())]) .await?; connection.terminated().await; listeners.shutdown(); Ok(()) }) } #[test] fn double_connections() -> Result<(), BoxError> { run(async { // Use extended timeouts for parallel connection tests on slower CI fn client_parameters() -> ClientParameters { let mut params = handy::client_parameters(); params .set(ParameterId::MaxIdleTimeout, Duration::from_secs(60)) .expect("unreachable"); params } fn server_parameters() -> ServerParameters { let mut params = handy::server_parameters(); params .set(ParameterId::MaxIdleTimeout, Duration::from_secs(60)) .expect("unreachable"); params } let router = Arc::new(QuicRouter::default()); let (listeners, server_task) = launch_echo_server(router.clone(), server_parameters()).await?; let _server_task = AbortOnDropHandle::new(tokio::spawn(server_task)); let server_addr = get_server_addr(&listeners); let client = launch_test_client(router, client_parameters()); let mut connections = JoinSet::new(); for conn_idx in 0..2 { let connection = client .connected_to_with_source("localhost", [(Source::System, server_addr.into())]) .await?; connections.spawn( async move { send_and_verify_echo(&connection, TEST_DATA).await } .instrument(tracing::info_span!("stream", conn_idx)), ); } connections .join_all() .await .into_iter() .collect::>()?; listeners.shutdown(); Ok(()) }) } const PARALLEL_ECHO_CONNS: usize = 3; const PARALLEL_ECHO_STREAMS: usize = 2; #[test] fn parallel_stream() -> Result<(), BoxError> { run(async { fn client_parameters() -> ClientParameters { let mut params = handy::client_parameters(); params .set(ParameterId::MaxIdleTimeout, Duration::from_secs(60)) .expect("unreachable"); params } fn server_parameters() -> ServerParameters { let mut params = handy::server_parameters(); params .set(ParameterId::MaxIdleTimeout, Duration::from_secs(60)) .expect("unreachable"); params } let router = Arc::new(QuicRouter::default()); let (listeners, server_task) = launch_echo_server(router.clone(), server_parameters()).await?; let _server_task = AbortOnDropHandle::new(tokio::spawn(server_task)); let server_addr = get_server_addr(&listeners); let client = launch_test_client(router, client_parameters()); let mut streams = JoinSet::new(); for conn_idx in 0..PARALLEL_ECHO_CONNS { tracing::info!(conn_idx, "Starting connection"); let connection = Arc::new( client .connected_to_with_source("localhost", [(Source::System, server_addr.into())]) .await?, ); tracing::info!(conn_idx, "Connected"); for stream_idx in 0..PARALLEL_ECHO_STREAMS { let connection = connection.clone(); streams.spawn( async move { send_and_verify_echo(&connection, TEST_DATA).await } .instrument(tracing::info_span!("stream", conn_idx, stream_idx)), ); } } streams .join_all() .await .into_iter() .collect::>()?; listeners.shutdown(); Ok(()) }) } #[test] fn parallel_big_stream() -> Result<(), BoxError> { run(async { fn client_parameters() -> ClientParameters { let mut params = handy::client_parameters(); params .set(ParameterId::MaxIdleTimeout, Duration::from_secs(60)) .expect("unreachable"); params } fn server_parameters() -> ServerParameters { let mut params = handy::server_parameters(); params .set(ParameterId::MaxIdleTimeout, Duration::from_secs(60)) .expect("unreachable"); params } let router = Arc::new(QuicRouter::default()); let (listeners, server_task) = launch_echo_server(router.clone(), server_parameters()).await?; let _server_task = AbortOnDropHandle::new(tokio::spawn(server_task)); let server_addr = get_server_addr(&listeners); let client = launch_test_client(router, client_parameters()); let mut big_streams = JoinSet::new(); // Use 4x repeat (~14KB per connection) instead of 32x (~117KB) for CI stability let test_data = Arc::new(TEST_DATA.to_vec().repeat(4)); for conn_idx in 0..PARALLEL_ECHO_CONNS { let connection = client .connected_to_with_source("localhost", [(Source::System, server_addr.into())]) .await?; let test_data = test_data.clone(); big_streams.spawn( async move { send_and_verify_echo(&connection, &test_data).await } .instrument(tracing::info_span!("stream", conn_idx)), ); } big_streams .join_all() .await .into_iter() .collect::>()?; listeners.shutdown(); Ok(()) }) } #[test] fn limited_streams() -> Result<(), BoxError> { run(async { pub fn client_parameters() -> ClientParameters { let mut params = ClientParameters::default(); for (id, value) in [ (ParameterId::InitialMaxStreamsBidi, 2u32), (ParameterId::InitialMaxStreamsUni, 0u32), (ParameterId::InitialMaxData, 1u32 << 10), (ParameterId::InitialMaxStreamDataBidiLocal, 1u32 << 10), (ParameterId::InitialMaxStreamDataBidiRemote, 1u32 << 10), (ParameterId::InitialMaxStreamDataUni, 1u32 << 10), ] { params.set(id, value).expect("unreachable"); } params } pub fn server_parameters() -> ServerParameters { let mut params = ServerParameters::default(); for (id, value) in [ (ParameterId::InitialMaxStreamsBidi, 2u32), (ParameterId::InitialMaxStreamsUni, 2u32), (ParameterId::InitialMaxData, 1u32 << 20), (ParameterId::InitialMaxStreamDataBidiLocal, 1u32 << 10), (ParameterId::InitialMaxStreamDataBidiRemote, 1u32 << 10), (ParameterId::InitialMaxStreamDataUni, 1u32 << 10), ] { params.set(id, value).expect("unreachable"); } params .set(ParameterId::MaxIdleTimeout, Duration::from_secs(30)) .expect("unreachable"); params } let router = Arc::new(QuicRouter::default()); let (listeners, server_task) = launch_echo_server(router.clone(), server_parameters()).await?; let _server_task = AbortOnDropHandle::new(tokio::spawn(server_task)); let server_addr = get_server_addr(&listeners); let client = launch_test_client(router, client_parameters()); let mut streams = JoinSet::new(); for conn_idx in 0..PARALLEL_ECHO_CONNS / 2 { let connection = Arc::new( client .connected_to_with_source("localhost", [(Source::System, server_addr.into())]) .await?, ); for stream_idx in 0..PARALLEL_ECHO_STREAMS / 2 { let connection = connection.clone(); streams.spawn( async move { send_and_verify_echo(&connection, TEST_DATA).await } .instrument(tracing::info_span!("stream", conn_idx, stream_idx)), ); } } streams .join_all() .await .into_iter() .collect::>()?; listeners.shutdown(); Ok(()) }) } ================================================ FILE: dquic/tests/echo_common/mod.rs ================================================ // common is submod for echo, auth and traversal #![allow(unused)] use std::sync::Arc; use dquic::{prelude::*, qbase::param::ServerParameters, qinterface::component::route::QuicRouter}; use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; use crate::common::{BoxError, SERVER_CERT, SERVER_KEY, qlogger}; pub async fn echo_stream(mut reader: StreamReader, mut writer: StreamWriter) { io::copy(&mut reader, &mut writer).await.unwrap(); _ = writer.shutdown().await; tracing::debug!("stream copy done"); } pub async fn serve_echo(listeners: Arc) { while let Ok((connection, server, pathway, _link)) = listeners.accept().await { assert_eq!(server, "localhost"); tracing::info!(source = ?pathway.remote(), "accepted new connection"); tokio::spawn(async move { while let Ok((_sid, (reader, writer))) = connection.accept_bi_stream().await { tokio::spawn(echo_stream(reader, writer)); } }); } } pub async fn send_and_verify_echo(connection: &Connection, data: &[u8]) -> Result<(), BoxError> { let (_sid, (mut reader, mut writer)) = connection.open_bi_stream().await?.unwrap(); tracing::debug!("stream opened"); let mut back = Vec::new(); tokio::try_join!( async { writer.write_all(data).await?; writer.shutdown().await?; tracing::info!("write done"); Result::<(), BoxError>::Ok(()) }, async { reader.read_to_end(&mut back).await?; assert_eq!(back, data); tracing::info!("read done"); Result::<(), BoxError>::Ok(()) } ) .map(|_| ()) } pub async fn launch_echo_server( quic_router: Arc, parameters: ServerParameters, ) -> Result<(Arc, impl Future), BoxError> { let listeners = QuicListeners::builder() .with_router(quic_router) .without_client_cert_verifier() .with_parameters(parameters) .with_qlog(qlogger()) .listen(128) .unwrap(); listeners .add_server( "localhost", SERVER_CERT, SERVER_KEY, [BindUri::from("inet://127.0.0.1:0").alloc_port()], None, ) .await?; Ok((listeners.clone(), serve_echo(listeners))) } ================================================ FILE: dquic/tests/traversal.rs ================================================ use std::{ collections::HashMap, io, net::SocketAddr, sync::{Arc, LazyLock}, time::Duration, }; use dquic::{ prelude::{handy::*, *}, qinterface::{component::location::Locations, manager::InterfaceManager}, qresolve::Source, qtraversal::nat::client::{NatType, StunClientsComponent}, }; use futures::{ FutureExt, future::{BoxFuture, Shared}, }; use rustls::RootCertStore; use tokio::task::JoinSet; use tracing::{info, warn}; mod common; use common::*; mod echo_common; use echo_common::*; #[derive(Debug, Clone, Copy)] pub struct TestCase { pub bind_addr: &'static str, pub outer_addr: &'static str, pub nat_type: NatType, } pub const STUN_SERVERS: &str = "10.10.0.64:20002"; pub const CASES: [TestCase; 10] = [ TestCase { bind_addr: "192.168.0.98:6001", outer_addr: "10.10.0.98:6001", nat_type: NatType::FullCone, }, TestCase { bind_addr: "192.168.0.96:6002", outer_addr: "10.10.0.96:6002", nat_type: NatType::RestrictedCone, }, TestCase { bind_addr: "192.168.0.88:6003", outer_addr: "10.10.0.88:6003", nat_type: NatType::RestrictedPort, }, TestCase { bind_addr: "192.168.0.86:6004", outer_addr: "10.10.0.86:6004", nat_type: NatType::Dynamic, }, TestCase { bind_addr: "192.168.0.84:6005", outer_addr: "10.10.0.84:6005", nat_type: NatType::Symmetric, }, // server TestCase { bind_addr: "172.16.0.48:6006", outer_addr: "10.10.0.48:6006", nat_type: NatType::FullCone, }, TestCase { bind_addr: "172.16.0.46:6007", outer_addr: "10.10.0.46:6007", nat_type: NatType::RestrictedCone, }, TestCase { bind_addr: "172.16.0.38:6008", outer_addr: "10.10.0.38:6008", nat_type: NatType::RestrictedPort, }, TestCase { bind_addr: "172.16.0.36:6009", outer_addr: "10.10.0.36:6009", nat_type: NatType::Dynamic, }, TestCase { bind_addr: "172.16.0.34:6010", outer_addr: "10.10.0.34:6010", nat_type: NatType::Symmetric, }, ]; static CLIENT_CASES: LazyLock> = LazyLock::new(|| { CASES[0..5] .iter() .map(|case| (case.nat_type, *case)) .collect() }); static SERVER_CASES: LazyLock> = LazyLock::new(|| { CASES[5..10] .iter() .map(|case| (case.nat_type, *case)) .collect() }); macro_rules! test_punch_matrix { (async fn $test_name:ident = test_punch_case($client:expr, $server:expr) $($tt:tt)*) => { #[test] #[ignore] fn $test_name() { run(async move { let span = tracing::info_span!( stringify!($test_name), client = stringify!($client), server = stringify!($server) ); let _enter = span.enter(); test_punch_case($client, $server).await }); } test_punch_matrix!($($tt)*); }; () => {} } /* // in host: sudo docker buildx build -f qtraversal/tools/dockerfile -t dquic-traversal-test:latest . sudo docker run -it --rm --privileged -v .:/dquic dquic-traversal-test:latest // in contrainer: cd /dquic && ./qtraversal/tools/run_stun.sh ip netns exec nsa cargo test --test traversal -- --include-ignored --nocapture */ test_punch_matrix! { async fn test_punch_full_cone_to_full_cone = test_punch_case(NatType::FullCone, NatType::FullCone) async fn test_punch_full_cone_to_restricted_cone = test_punch_case(NatType::FullCone, NatType::RestrictedCone) async fn test_punch_full_cone_to_port_restricted = test_punch_case(NatType::FullCone, NatType::RestrictedPort) async fn test_punch_full_cone_to_dynamic = test_punch_case(NatType::FullCone, NatType::Dynamic) async fn test_punch_full_cone_to_symmetric = test_punch_case(NatType::FullCone, NatType::Symmetric) async fn test_punch_restricted_cone_to_full_cone = test_punch_case(NatType::RestrictedCone, NatType::FullCone) async fn test_punch_restricted_cone_to_restricted_cone = test_punch_case(NatType::RestrictedCone, NatType::RestrictedCone) async fn test_punch_restricted_cone_to_port_restricted = test_punch_case(NatType::RestrictedCone, NatType::RestrictedPort) async fn test_punch_restricted_cone_to_dynamic = test_punch_case(NatType::RestrictedCone, NatType::Dynamic) async fn test_punch_restricted_cone_to_symmetric = test_punch_case(NatType::RestrictedCone, NatType::Symmetric) async fn test_punch_port_restricted_to_full_cone = test_punch_case(NatType::RestrictedPort, NatType::FullCone) async fn test_punch_port_restricted_to_restricted_cone = test_punch_case(NatType::RestrictedPort, NatType::RestrictedCone) async fn test_punch_port_restricted_to_port_restricted = test_punch_case(NatType::RestrictedPort, NatType::RestrictedPort) async fn test_punch_port_restricted_to_dynamic = test_punch_case(NatType::RestrictedPort, NatType::Dynamic) async fn test_punch_port_restricted_to_symmetric = test_punch_case(NatType::RestrictedPort, NatType::Symmetric) async fn test_punch_dynamic_to_full_cone = test_punch_case(NatType::Dynamic, NatType::FullCone) async fn test_punch_dynamic_to_restricted_cone = test_punch_case(NatType::Dynamic, NatType::RestrictedCone) async fn test_punch_dynamic_to_port_restricted = test_punch_case(NatType::Dynamic, NatType::RestrictedPort) async fn test_punch_dynamic_to_dynamic = test_punch_case(NatType::Dynamic, NatType::Dynamic) async fn test_punch_dynamic_to_symmetric = test_punch_case(NatType::Dynamic, NatType::Symmetric) async fn test_punch_symmetric_to_full_cone = test_punch_case(NatType::Symmetric, NatType::FullCone) async fn test_punch_symmetric_to_restricted_cone = test_punch_case(NatType::Symmetric, NatType::RestrictedCone) async fn test_punch_symmetric_to_port_restricted = test_punch_case(NatType::Symmetric, NatType::RestrictedPort) async fn test_punch_symmetric_to_dynamic = test_punch_case(NatType::Symmetric, NatType::Dynamic) async fn test_punch_symmetric_to_symmetric = test_punch_case(NatType::Symmetric, NatType::Symmetric) } async fn launch_stun_test_server(server_case: TestCase) -> Arc { let server_addr: SocketAddr = server_case.bind_addr.parse().unwrap(); let locations = Arc::new(Locations::new()); let listeners = QuicListeners::builder() .with_parameters(server_parameters()) .without_client_cert_verifier() .with_stun(STUN_SERVERS) .with_router(Arc::default()) .with_locations(locations) .with_qlog(qlogger()) .listen(1000) .unwrap(); listeners .add_server("localhost", SERVER_CERT, SERVER_KEY, [server_addr], None) .await .unwrap(); info!("Server listening on {server_addr}"); tokio::spawn(serve_echo(listeners.clone())); listeners } static SERVERS: LazyLock>>>> = LazyLock::new(|| { SERVER_CASES .values() .map(|case| { let server = launch_stun_test_server(*case).boxed().shared(); (case.nat_type, server) }) .collect() }); async fn launch_stun_test_client(client_case: TestCase) -> Arc { let client_addr: SocketAddr = client_case.bind_addr.parse().unwrap(); let mut roots = RootCertStore::empty(); roots.add_parsable_certificates(CA_CERT.to_certificate()); let locations = Arc::new(Locations::new()); let client = QuicClient::builder() .with_root_certificates(roots) .without_cert() .enable_sslkeylog() .with_parameters(client_parameters()) .with_stun(STUN_SERVERS) .with_locations(locations) .bind([client_addr]) .await .with_qlog(qlogger()) .build(); info!("Client bound on {client_addr}"); Arc::new(client) } static CLIENTS: LazyLock>>>> = LazyLock::new(|| { CLIENT_CASES .values() .map(|case| { let client = launch_stun_test_client(*case).boxed().shared(); (case.nat_type, client) }) .collect() }); async fn test_punch_case(client_nat: NatType, server_nat: NatType) { let client_case = CLIENT_CASES[&client_nat]; let server_case = SERVER_CASES[&server_nat]; info!("Testing punch case: client {client_nat:?} <-> server {server_nat:?}",); if client_nat == NatType::Dynamic || server_nat == NatType::Dynamic { warn!("Skipping Dynamic NAT test case"); // TODO: Dynamic NAT 模拟有问题 return; } if client_nat == NatType::Symmetric && server_nat == NatType::Symmetric { warn!("Skipping Symmetric NAT to Symmetric NAT test case"); // Symmetric NAT 互穿不通 return; } let _server = SERVERS[&server_nat].clone().await; let server_iface = InterfaceManager::global() .borrow(&(server_case.bind_addr.parse::().unwrap().into())) .unwrap(); let server_ep = get_stun_data(server_iface).await[0].0; launch_client(client_case, server_ep).await; } async fn get_stun_data(server_iface: dquic::qinterface::Interface) -> Vec<(EndpointAddr, NatType)> { let mut outer_addresses = server_iface .with_component(|clients: &StunClientsComponent| { clients.with_clients(|clients| { // workaround. clippy issue: https://github.com/rust-lang/rust-clippy/issues/16428 #[allow(clippy::redundant_iter_cloned)] clients .values() .cloned() .map(|client| async move { let agent = client.agent_addr(); let outer = client.outer_addr().await?; let ep = EndpointAddr::with_agent(agent, outer); let nat_type = client.nat_type().await?; io::Result::Ok((ep, nat_type)) }) .collect::>() }) }) .expect("interface rebinded too quickly") .expect("traversal components missing"); let mut datas = vec![]; while let Some(join_result) = outer_addresses.join_next().await { let result = join_result.expect("detect panic"); let data = result.expect("detect outer addr or nat type failed"); datas.push(data); } datas } async fn launch_client(client_case: TestCase, server_ep: EndpointAddr) { let client = CLIENTS[&client_case.nat_type].clone().await; get_stun_data( InterfaceManager::global() .borrow(&client_case.bind_addr.parse::().unwrap().into()) .unwrap(), ) .await; // 不会进行绑定,不会出错 let connection = client .connected_to_with_source("localhost", [(Source::System, server_ep)]) .await .unwrap(); let odcid = connection.origin_dcid().expect("connection failed"); tracing::info!(%odcid, "connected to server"); let test_data = Arc::new(TEST_DATA.to_vec()); // 循环检查直连路径,每秒检查一次 // 如果没有直连路径,执行 echo 测试确保连接正常 // 总超时由 run() 函数的 60s 超时控制 loop { // 检查是否有直连路径 let paths = connection .path_context() .expect("connection failed") .paths::>() .into_iter() .map(|(p, _)| p) .collect::>(); let has_direct = paths .iter() .any(|pathway| matches!(pathway.local(), EndpointAddr::Direct { .. })); if has_direct { tracing::info!("Direct path established: {:?}", paths); return; } // 没有直连路径,执行 echo 测试确保连接正常 tracing::debug!("no direct path yet, verifying connection with echo test"); send_and_verify_echo(&connection, &test_data) .await .expect("echo test failed"); // 等待 1 秒后再次检查 tokio::time::sleep(Duration::from_secs(1)).await; } } pub type Error = Box; #[test] fn test_knock_ttl_is_1_in_tests() { assert_eq!(dquic::qtraversal::punch::puncher::KNOCK_TTL, 1); } ================================================ FILE: h3-shim/Cargo.toml ================================================ [package] name = "h3-shim" version = "0.5.0" edition.workspace = true description = "Shim libray between dquic and h3" readme.workspace = true repository.workspace = true license.workspace = true keywords.workspace = true categories.workspace = true rust-version.workspace = true autoexamples = false [dependencies] h3 = { workspace = true } h3-datagram = { workspace = true, optional = true } bytes = { workspace = true } dashmap = { workspace = true } futures = { workspace = true } dquic = { workspace = true } tokio = { workspace = true } [features] datagram = ["dep:h3-datagram", "dquic/datagram"] telemetry = ["dquic/telemetry"] [dev-dependencies] base64 = "0.22" clap = { workspace = true, features = ["derive"] } crossterm = { version = "0.29", features = ["events", "event-stream"] } http = { workspace = true } indicatif = { workspace = true } libc = "0.2" qevent = { workspace = true, features = ["telemetry"] } rustls = { workspace = true, features = ["logging", "ring"] } rustls-native-certs = { workspace = true } rpassword = "7.3" serde = { workspace = true } serde_json = { workspace = true } tokio = { workspace = true, features = ["io-std", "fs", "rt-multi-thread"] } tracing = { workspace = true } tracing-appender = { workspace = true } # console-subscriber = "0.4" [dev-dependencies.tracing-subscriber] workspace = true features = ["env-filter", "time"] [[example]] name = "h3-server" [[example]] name = "h3-client" ================================================ FILE: h3-shim/examples/README.md ================================================ # h3-shim测试 本测试所使用的密钥来自,`h3-server.rs`和`h3-client.rs`的源代码亦是在其基础上修改而来 你也可以自己签名密钥,并在运行server/client时通过命令行参数指定自己的密钥 > 我们还有一个对reqwest的[fork](https://github.com/genmeta/reqwest/tree/dquic),其quic实现被替换为为dquic。基于reqwest的client用例可以参考[此gist](https://gist.github.com/ealinmen/ed79f3bf95fa91e9475484560fb2744e) 运行之前,推荐设置环境变量`RUST_LOG=info`,以便查看更多的日志信息 ```shell # 非必需,但是建议 export RUST_LOG=info ``` ## 运行 所需命令行参数均已预设,你也可以通过`--help`查看帮助,自己指定参数 cd到`dquic`目录下,运行以下命令即可 ```shell cd path/to/dquic # 启动Server,默认会加载localhost的自签名证书,因此必须通过localhost来请求 # server会默认监听[127.0.0.1:4433, [::1]:4433]两个地址,请确保您的机器支持IPv6 # 如果不支持,请使用-b参数手动绑定监听地址 cargo run --example=h3-server --package=h3-shim -- --dir=./h3-shim # 启动Client cargo run --example=h3-client --package=h3-shim -- https://localhost:4433/examples/h3-server.rs --keylog ``` client默认会向`https://localhost:4433/Cargo.toml`发送一个Get请求,你可以通过命令行参数改变请求的url 如下,client会向`https://localhost:4433/examples/server.rs`发送一个Get请求 ```shell cargo run --example=h3-client --package=h3-shim -- https://localhost:4433/examples/server.rs ``` 你也可以指定服务的根目录,或者更改绑定端口 ```shell # 设置服务根目录 cargo run --example=h3-server --package=h3-shim -- --dir=/path/to/www # 更改绑定端口 cargo run --example=h3-server --package=h3-shim -- -l=127.0.0.1:123456 ``` ## 问题排查 ### 找不到文件 如果你遇到类似这样的错误 ``` failed to read CA certificate: Os { code: 2, kind: NotFound, message: "No such file or directory" } failed to read certificate file: Os { code: 2, kind: NotFound, message: "No such file or directory" } ``` 说明你并没有移动到`h3-shim`目录下,你可以移动到`h3-shim`目录下,再次运行;或者通过命令行参数指定证书文件,密钥文件的路径 ### 无法连接 首先检查你设置的ip和端口是否正确 client和server默认使用ipv6。如果在你的设备上localhost被解析为ipv4,你需要通过`-b`参数指定客户端和服务端使用ipv4地址 ```shell cargo run --example=h3-server --package=h3-shim -- -b=127.0.0.1:0 cargo run --example=h3-client --package=h3-shim -- -b=127.0.0.1 ``` ## 抓包 如果你想使用Wireshark抓包,你需要设置环境变量`SSLKEYLOGFILE`,且在启动client时加上`--keylog`参数,以获得keylog文件 ```shell export SSLKEYLOGFILE= <指定一个地方> cargo run --example=h3-client --package=h3-shim -- --keylog ``` 然后,打开wireshark,Preferences -> Protocols-> TLS -> (Pre)-Master-Secret log filename 的地方填入上述keylog文件的路径,即可享受wireshark抓包并解密的便利。 ================================================ FILE: h3-shim/examples/h3-client.rs ================================================ use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Instant}; use clap::Parser; use dquic::prelude::{ handy::{ToCertificate, client_parameters}, *, }; use http::{ Uri, uri::{Authority, Parts, Scheme}, }; use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use tokio::{ fs, io::{AsyncWrite, AsyncWriteExt}, task::JoinSet, }; use tracing::Instrument; use tracing_subscriber::prelude::*; #[derive(Parser, Clone)] struct Options { #[arg(long, help = "Save the qlog to a dir", value_name = "PATH")] qlog: Option, #[arg( long, help = "Certificate of CA who issues the server certificate", value_delimiter = ',', default_value = "tests/keychain/localhost/ca.cert" )] roots: Vec, #[arg( long, default_value = "false", help = "Skip verification of server certificate" )] skip_verify: bool, #[arg( long, short, value_delimiter = ',', default_value = "h3", help = "ALPNs to use for the connection" )] alpns: Vec>, #[arg( long, short = 'p', action = clap::ArgAction::Set, help = "enable progress bar", default_value = "true", value_enum )] progress: bool, #[arg( long, action = clap::ArgAction::Set, help = "enable ansi", default_value = "true", value_enum )] ansi: bool, #[arg( long, short = 'r', help = "number of requests per connection", default_value = "1" )] reqs: usize, #[arg( long, short = 'c', help = "number of connections client initiates", default_value = "1" )] conns: usize, #[arg(long, help = "Save the response to a dir", value_name = "PATH")] save: Option, #[arg( help = "URI to request", value_delimiter = ',', default_value = "https://localhost:4433/Cargo.lock" )] uris: Vec, } #[tokio::main] async fn main() { let options = Options::parse(); let (non_blocking, _guard) = tracing_appender::non_blocking(std::io::stdout()); tracing_subscriber::registry() // .with( // console_subscriber::ConsoleLayer::builder() // .server_addr("127.0.0.1:6670".parse::().unwrap()) // .spawn(), // ) .with( tracing_subscriber::fmt::layer() .with_writer(non_blocking) .with_ansi(options.ansi) .with_filter( tracing_subscriber::EnvFilter::builder() .with_default_directive(match options.progress { true => tracing::level_filters::LevelFilter::OFF.into(), false => tracing::level_filters::LevelFilter::INFO.into(), }) .from_env_lossy(), ), ) .init(); if let Err(error) = run(options).await { tracing::error!(?error); std::process::exit(1); }; } type Error = Box; async fn run(options: Options) -> Result<(), Error> { let qlogger: Arc = match options.qlog { Some(dir) => Arc::new(handy::LegacySeqLogger::new(dir)), None => Arc::new(handy::NoopLogger), }; let client_builder = if options.skip_verify { tracing::warn!("skip server verify"); QuicClient::builder().without_verifier() } else { tracing::info!("load ca certs"); let mut roots = rustls::RootCertStore::empty(); roots.add_parsable_certificates(rustls_native_certs::load_native_certs().certs); roots .add_parsable_certificates(options.roots.iter().flat_map(|path| path.to_certificate())); QuicClient::builder().with_root_certificates(roots) }; let client = Arc::new( client_builder .with_qlog(qlogger) .without_cert() .with_parameters(client_parameters()) .with_alpns(options.alpns) .enable_sslkeylog() .build(), ); let pbs = MultiProgress::new(); if !options.progress { pbs.set_draw_target(indicatif::ProgressDrawTarget::hidden()); } let conns_pb = pbs.add(ProgressBar::new(0).with_prefix("connections").with_style( ProgressStyle::with_template("{prefix} {wide_bar} {pos}/{len}")?, )); let total_pb = pbs.add(ProgressBar::new(0).with_prefix("requests").with_style( ProgressStyle::with_template("{prefix} {wide_bar} {pos}/{len} {per_sec} {eta}")?, )); let queries = options .uris .into_iter() // 根据 authority 分组 .fold(HashMap::<_, Vec<_>>::new(), |mut uris, uri| { let auth = uri.authority().expect("uri must have authority"); uris.entry(auth.to_string()) .or_default() .push(uri.path().to_owned()); uris }) .into_iter() // 压力测试:让uri变多 .map(|(auth, uris)| { let authority = auth.parse::().unwrap(); let totoal_reqs = uris.len() * options.reqs; let total_reqs = uris.into_iter().cycle().take(totoal_reqs); (authority, total_reqs) }); let start_time = Instant::now(); let mut connections = JoinSet::new(); for (authority, paths) in queries { for _conn_idx in 0..options.conns { conns_pb.inc_length(1); connections.spawn(download_files_with_progress( client.clone(), authority.clone(), paths.clone(), total_pb.clone(), options.save.clone(), )); } } let mut success_queries = 0; while let Some(res) = connections.join_next().await { match res { Ok(Ok(queries)) => { tracing::info!(target: "counting", queries, "connection finished"); success_queries += queries; conns_pb.inc(1); } Ok(Err(err)) => { tracing::error!(target: "counting", error=?err, "conenction failed"); conns_pb.dec_length(1); } Err(err) if err.is_panic() => std::panic::resume_unwind(err.into_panic()), Err(err) => panic!("{err}"), } } conns_pb.finish(); total_pb.finish(); let total_time = start_time.elapsed().as_secs_f64(); let qps = success_queries as f64 / total_time; tracing::info!(target: "counting", success_queries, total_time, qps, "done!"); Ok(()) } async fn download_files_with_progress( client: Arc, authority: Authority, paths: impl Iterator, total_pb: ProgressBar, save: Option, ) -> Result { let quic_connection = Arc::new(client.connect(authority.host()).await?); let odcid = quic_connection.origin_dcid()?; let span = tracing::info_span!("requests", %odcid, host = authority.host()); let (mut connection, send_request) = h3::client::new(h3_shim::QuicConnection::new(quic_connection.clone())) .instrument(span.clone()) .await?; tokio::spawn(async move { connection.wait_idle().await }.instrument(span.clone())); let mut requests = JoinSet::new(); for path in paths { total_pb.inc_length(1); let uri = { let mut parts = Parts::default(); parts.scheme = Some(Scheme::HTTPS); parts.authority = Some(authority.clone()); parts.path_and_query = Some(path.parse()?); Uri::from_parts(parts)? }; let save_to = save .as_ref() .map(|dir| dir.join(uri.path().strip_prefix('/').unwrap())); let request = http::Request::builder().uri(uri).body(())?; let mut send_request = send_request.clone(); requests.spawn( async move { let mut request_stream = send_request.send_request(request).await?; request_stream.finish().await?; let resp = request_stream.recv_response().await?; if resp.status() != http::StatusCode::OK { return Err(format!("response status: {}", resp.status()).into()); } let mut save_to: Box = match save_to { Some(path) => Box::new(fs::File::create(path).await?), None => Box::new(tokio::io::sink()), }; while let Some(mut data) = request_stream.recv_data().await? { save_to.write_all_buf(&mut data).await?; } Result::<(), Error>::Ok(()) } .instrument(span.clone()), ); } let mut error = None; let mut success_queries = 0; tracing::info!(target: "counting", "Waiting for {} requests to finish", requests.len()); while let Some(res) = requests.join_next().await { match res { Ok(Ok(())) => { tracing::warn!(target: "counting", "Request success"); success_queries += 1; total_pb.inc(1); } Ok(Err(err)) => { tracing::warn!(target: "counting", ?err, "Request failed"); total_pb.dec_length(1); error = Some(err); } Err(err) if err.is_panic() => std::panic::resume_unwind(err.into_panic()), Err(err) => panic!("{err}"), } } tracing::info!(target: "counting", success_queries, "Requests completed"); if success_queries != 0 { Ok(success_queries) } else { Err(error.unwrap()) } } ================================================ FILE: h3-shim/examples/h3-server.rs ================================================ use std::{ops::Deref, path::PathBuf, sync::Arc}; use bytes::{Bytes, BytesMut}; use clap::Parser; use dquic::{ prelude::*, qinterface::{bind_uri::BindUri, io::IO}, }; use h3::{quic::BidiStream, server::RequestStream}; use http::{Request, StatusCode}; use tokio::{fs::File, io::AsyncReadExt}; use tracing::level_filters::LevelFilter; use tracing_subscriber::{EnvFilter, prelude::*}; #[derive(Parser, Debug)] #[command(name = "server")] struct Options { #[arg( name = "dir", short, long, help = "Root directory of the files to serve. \ If omitted, server will respond OK.", default_value = "./" )] root: PathBuf, #[arg(long, help = "Save the qlog to a dir", value_name = "PATH")] qlog: Option, #[arg( short, long, value_delimiter = ',', default_values = ["127.0.0.1:4433", "[::1]:4433"], help = "What BindUris to listen for new connections" )] listen: Vec, #[arg( long, short, value_delimiter = ',', default_value = "h3", help = "ALPNs to use for the connection" )] alpns: Vec>, #[arg( long, short, default_value = "4096", help = "Maximum number of requests in the backlog. \ If the backlog is full, new connections will be refused." )] backlog: usize, #[arg( long, action = clap::ArgAction::Set, default_value = "true", help = "Enable ANSI color output in logs" )] ansi: bool, #[command(flatten)] certs: Certs, } #[derive(Parser, Debug)] struct Certs { #[arg(long, short, default_value = "localhost", help = "Server name.")] server_name: String, #[arg( long, short, default_value = "tests/keychain/localhost/server.cert", help = "Certificate for TLS. If present, `--key` is mandatory." )] cert: PathBuf, #[arg( long, short, default_value = "tests/keychain/localhost/server.key", help = "Private key for the certificate." )] key: PathBuf, } fn main() { let options = Options::parse(); let (non_blocking, _guard) = tracing_appender::non_blocking(std::io::stdout()); tracing_subscriber::registry() // .with(console_subscriber::spawn()) .with( tracing_subscriber::fmt::layer() .with_writer(non_blocking) .with_ansi(options.ansi) .with_filter( EnvFilter::builder() .with_default_directive(LevelFilter::INFO.into()) .from_env_lossy(), ), ) .init(); // 测试日志是否工作 tracing::info!("tracing initialized successfully"); let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() // default value 512 out of macos ulimit .max_blocking_threads(256) .build() .expect("failed to build tokio runtime"); if let Err(error) = rt.block_on(run(options)) { tracing::info!(?error); std::process::exit(1); } } async fn run(options: Options) -> Result<(), Box> { tracing::info!("Serving {}", options.root.display()); let root = Arc::new(options.root); if !root.is_dir() { return Err(format!("{}: is not a readable directory", root.display()).into()); } let qlogger: Arc = match options.qlog { Some(dir) => Arc::new(handy::LegacySeqLogger::new(dir)), None => Arc::new(handy::NoopLogger), }; let Certs { server_name, cert, key, } = options.certs; let listeners = QuicListeners::builder() .with_qlog(qlogger) .without_client_cert_verifier() .with_parameters(handy::server_parameters()) .with_alpns(options.alpns) .listen(options.backlog)?; listeners .add_server( server_name.as_str(), cert.as_path(), key.as_path(), options.listen, None, ) .await?; tracing::info!( "Listening on {}", listeners .get_server(server_name.as_str()) .unwrap() .bind_interfaces() .iter() .next() .unwrap() .1 .borrow() .bound_addr()? ); // handle incoming connections and requests while let Ok((new_conn, _server, _pathway, _link)) = listeners.accept().await { let h3_conn = match h3::server::Connection::new(h3_shim::QuicConnection::new(Arc::new(new_conn))) .await { Ok(h3_conn) => { tracing::info!("accept a new quic connection"); h3_conn } Err(error) => { tracing::error!("failed to establish h3 connection: {}", error); continue; } }; let root = root.clone(); tokio::spawn(handle_connection(root, h3_conn)); } Ok(()) } async fn handle_connection( serve_root: Arc, mut connection: h3::server::Connection, ) where T: h3::quic::Connection + 'static, >::BidiStream: h3::quic::BidiStream + Send + 'static, { loop { match connection.accept().await { Ok(Some(request_resolver)) => { let serve_root = serve_root.clone(); let handle_request = async move { let (request, stream) = request_resolver.resolve_request().await?; handle_request(request, stream, serve_root).await }; tokio::spawn(async move { if let Err(e) = handle_request.await { tracing::error!("handling request failed: {}", e); } }); } Ok(None) => break, Err(..) => break, } } } #[tracing::instrument(skip_all)] async fn handle_request( request: Request<()>, mut stream: RequestStream, serve_root: Arc, ) -> Result<(), Box> where T: BidiStream, { let (status, to_serve) = match serve_root.deref() { _ if request.uri().path().contains("..") => (StatusCode::NOT_FOUND, None), root => { let to_serve = root.join(request.uri().path().strip_prefix('/').unwrap_or("")); match File::open(&to_serve).await { Ok(file) => (StatusCode::OK, Some(file)), Err(e) => { tracing::error!("failed to open: \"{}\": {}", to_serve.to_string_lossy(), e); (StatusCode::NOT_FOUND, None) } } } }; let resp = http::Response::builder().status(status).body(())?; stream.send_response(resp).await?; if let Some(mut file) = to_serve { loop { let mut buf = BytesMut::with_capacity(4096 * 10); if file.read_buf(&mut buf).await? == 0 { break; } stream.send_data(buf.freeze()).await?; } } stream.finish().await?; Ok(()) } #[cfg(test)] mod tests { use super::*; #[test] fn test_name() {} } ================================================ FILE: h3-shim/src/conn.rs ================================================ use std::{ ops::Deref, pin::Pin, sync::Arc, task::{Context, Poll}, }; use dquic::prelude::{Connection, StreamId, StreamReader, StreamWriter}; use futures::Stream; use h3::quic::{ConnectionErrorIncoming, StreamErrorIncoming}; use crate::{ error::{self, convert_quic_error}, streams::{BidiStream, RecvStream, SendStream}, }; // 由于数据报的特性,接收流的特征,QuicConnection不允许被Clone pub struct QuicConnection { connection: Arc, accept_bi: AcceptBiStreams, accept_uni: AcceptUniStreams, open_bi: OpenBiStreams, open_uni: OpenUniStreams, } impl Deref for QuicConnection { type Target = Arc; fn deref(&self) -> &Self::Target { &self.connection } } impl QuicConnection { pub fn new(conn: Arc) -> Self { Self { accept_bi: AcceptBiStreams::new(conn.clone()), accept_uni: AcceptUniStreams::new(conn.clone()), open_bi: OpenBiStreams::new(conn.clone()), open_uni: OpenUniStreams::new(conn.clone()), connection: conn, } } } /// 首先,QuicConnection需能主动创建双向流和发送流,以及关闭连接. impl h3::quic::OpenStreams for QuicConnection { type BidiStream = BidiStream; type SendStream = SendStream; #[inline] fn poll_open_bidi( &mut self, cx: &mut Context<'_>, ) -> Poll> { // 以下代码的代价是,每次调用open_bi_stream()都是一个新的实现了Future的闭包 // 实际上应该是同一个,否则每次poll都会造成open_bi_stream()中的每个await点 // 都得重新执行一遍,这是有问题的。 // let mut fut = self.connection.open_bi_stream(); // let mut task = pin!(fut); // let result = ready!(task.as_mut().poll_unpin(cx)); // let bi_stream = result // .and_then(|o| o.ok_or_else(sid_exceed_limit_error)) // .map(|s| BidiStream::new(s)) // .map_err(Into::into); // Poll::Ready(bi_stream) // 以下代码的问题是:不可重入,切忌上个流未成功打开返回前,任何地方不可尝试打开流 self.open_bi.poll_open(cx) // 应该的做法是,与这个poll_open_bidi关联的一个open_bi_stream()返回的固定Future来poll } #[inline] fn poll_open_send( &mut self, cx: &mut Context<'_>, ) -> Poll> { self.open_uni.poll_open(cx) } #[inline] fn close(&mut self, code: h3::error::Code, reason: &[u8]) { let reason = unsafe { String::from_utf8_unchecked(reason.to_vec()) }; _ = self.connection.close(reason, code.into()); } } /// 其次,QuicConnection需能接收双向流和发送流. /// 欲实现`h3::quic::Connection`,必须先实现`h3::quic::OpenStreams` impl h3::quic::Connection for QuicConnection { type RecvStream = RecvStream; type OpenStreams = OpenStreams; #[inline] fn poll_accept_recv( &mut self, cx: &mut Context<'_>, ) -> Poll> { self.accept_uni.poll_accept(cx) } #[inline] fn poll_accept_bidi( &mut self, cx: &mut Context<'_>, ) -> Poll> { self.accept_bi.poll_accept(cx) } /// 为何要再来个这玩意?多次一举 /// 如果这个opener()的返回值只负责打开一条流,不可重用; /// 再打开流,要再次调用opener()来open,那还有点意思 #[inline] fn opener(&self) -> Self::OpenStreams { OpenStreams::new(self.connection.clone()) } } /// 多此一举,实在是多此一举 pub struct OpenStreams { connection: Arc, open_bi: OpenBiStreams, open_uni: OpenUniStreams, } impl OpenStreams { fn new(conn: Arc) -> Self { Self { open_bi: OpenBiStreams::new(conn.clone()), open_uni: OpenUniStreams::new(conn.clone()), connection: conn, } } } impl Clone for OpenStreams { fn clone(&self) -> Self { Self { open_bi: OpenBiStreams::new(self.connection.clone()), open_uni: OpenUniStreams::new(self.connection.clone()), connection: self.connection.clone(), } } } /// 跟QuicConnection::poll_open_bidi()的实现一样,重复 impl h3::quic::OpenStreams for OpenStreams { type BidiStream = BidiStream; type SendStream = SendStream; #[inline] fn poll_open_bidi( &mut self, cx: &mut Context<'_>, ) -> Poll> { self.open_bi.poll_open(cx) } #[inline] fn poll_open_send( &mut self, cx: &mut Context<'_>, ) -> Poll> { self.open_uni.poll_open(cx) } #[inline] fn close(&mut self, code: h3::error::Code, reason: &[u8]) { let reason = unsafe { String::from_utf8_unchecked(reason.to_vec()) }; _ = self.connection.close(reason, code.into()); } } type BoxStream = Pin + Send + Sync>>; fn sid_exceed_limit_error() -> ConnectionErrorIncoming { ConnectionErrorIncoming::Undefined(Arc::from(Box::from( "the stream IDs in the `dir` direction exceed 2^60, this is very very hard to happen.", )) as _) } #[allow(clippy::type_complexity)] struct OpenBiStreams( BoxStream>, ); impl OpenBiStreams { fn new(conn: Arc) -> Self { let stream = futures::stream::unfold(conn, |conn| async { let bidi = conn .open_bi_stream() .await .map_err(convert_quic_error) .and_then(|o| o.ok_or_else(sid_exceed_limit_error)); Some((bidi, conn)) }); Self(Box::pin(stream)) } /// TODO: 以此法实现的`poll_open`方法,不可重入,即A、B同时要打开一个流, /// 实际上只有一个能成功,后一个的waker会取代前一个的waker注册在stream中,导致前一个waker无法被唤醒 /// 以下同 fn poll_open( &mut self, cx: &mut Context<'_>, ) -> Poll, StreamErrorIncoming>> { self.0 .as_mut() .poll_next(cx) .map(Option::unwrap) .map_ok(|(sid, stream)| BidiStream::new(sid, stream)) .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming { connection_error: e, }) } } struct OpenUniStreams(BoxStream>); impl OpenUniStreams { fn new(conn: Arc) -> Self { let stream = futures::stream::unfold(conn, |conn| async { let send = conn .open_uni_stream() .await .map_err(convert_quic_error) .and_then(|o| o.ok_or_else(sid_exceed_limit_error)); Some((send, conn)) }); Self(Box::pin(stream)) } fn poll_open( &mut self, cx: &mut Context<'_>, ) -> Poll, StreamErrorIncoming>> { self.0 .as_mut() .poll_next(cx) .map(Option::unwrap) .map_ok(|(sid, writer)| SendStream::new(sid, writer)) .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming { connection_error: e, }) } } #[allow(clippy::type_complexity)] struct AcceptBiStreams( BoxStream>, ); impl AcceptBiStreams { fn new(conn: Arc) -> Self { let stream = futures::stream::unfold(conn, |conn| async { Some(( conn.accept_bi_stream() .await .map_err(error::convert_quic_error), conn, )) }); Self(Box::pin(stream)) } fn poll_accept( &mut self, cx: &mut Context<'_>, ) -> Poll, ConnectionErrorIncoming>> { self.0 .as_mut() .poll_next(cx) .map(Option::unwrap) .map_ok(|(sid, stream)| BidiStream::new(sid, stream)) } } struct AcceptUniStreams(BoxStream>); impl AcceptUniStreams { fn new(conn: Arc) -> Self { let stream = futures::stream::unfold(conn, |conn| async { let uni = conn .accept_uni_stream() .await .map_err(error::convert_quic_error); Some((uni, conn)) }); Self(Box::pin(stream)) } fn poll_accept( &mut self, cx: &mut Context<'_>, ) -> Poll> { self.0 .as_mut() .poll_next(cx) .map(Option::unwrap) .map_ok(|(sid, reader)| RecvStream::new(sid, reader)) } } ================================================ FILE: h3-shim/src/error.rs ================================================ use std::{error::Error, sync::Arc}; use dquic::qbase; use h3::quic::{ConnectionErrorIncoming, StreamErrorIncoming}; use qbase::frame::ResetStreamError; pub fn convert_quic_error(e: qbase::error::Error) -> ConnectionErrorIncoming { match e { qbase::error::Error::Quic(quic_error) => { ConnectionErrorIncoming::Undefined(Arc::new(quic_error)) } qbase::error::Error::App(app_error) => ConnectionErrorIncoming::ApplicationClose { error_code: app_error.error_code(), }, } } pub fn convert_stream_io_error(e: std::io::Error) -> StreamErrorIncoming { if let Some(reset_stream_error) = e .source() .and_then(|e| e.downcast_ref::()) { return StreamErrorIncoming::StreamTerminated { error_code: reset_stream_error.error_code(), }; } if let Some(quic_error) = e .source() .and_then(|e| e.downcast_ref::()) { return StreamErrorIncoming::ConnectionErrorIncoming { connection_error: convert_quic_error(quic_error.clone()), }; } StreamErrorIncoming::Unknown(e.into()) } ================================================ FILE: h3-shim/src/ext.rs ================================================ // See https://github.com/hyperium/h3/issues/307" // use std::{ // io, // ops::Deref, // task::{Context, Poll}, // }; // use bytes::{Buf, Bytes}; // use futures::future::BoxFuture; // use dquic::{DatagramReader, DatagramWriter}; // use h3_datagram::{ // ConnectionErrorIncoming, // datagram::EncodedDatagram, // quic_traits::{DatagramConnectionExt, RecvDatagram, SendDatagram, SendDatagramErrorIncoming}, // }; // use crate::{conn::QuicConnection, error::convert_connection_io_error}; // impl DatagramConnectionExt for QuicConnection { // type SendDatagramHandler = DatagramSender; // type RecvDatagramHandler = DatagramReceiver; // fn send_datagram_handler(&self) -> Self::SendDatagramHandler { // let conn = self.deref().clone(); // DatagramSender::Pending(Box::pin(async move { conn.datagram_writer().await })) // } // fn recv_datagram_handler(&self) -> Self::RecvDatagramHandler { // let conn = self.deref().clone(); // DatagramReceiver::Pending(Box::pin(async move { conn.datagram_reader() })) // } // } // pub enum DatagramSender { // Pending(BoxFuture<'static, io::Result>), // Ready(Result), // } // impl SendDatagram for DatagramSender { // fn send_datagram>>( // &mut self, // data: T, // ) -> Result<(), SendDatagramErrorIncoming> { // // let mut buf = bytes::BytesMut::new(); // // buf // // data.encode(&mut buf); // let mut datagram = >>::into(data); // self.0 // .send_bytes(datagram.copy_to_bytes(datagram.remaining())) // .map_err(|e| match e { // e if e.kind() == io::ErrorKind::InvalidInput => SendDatagramErrorIncoming::TooLarge, // e => SendDatagramErrorIncoming::ConnectionError(convert_connection_io_error(e)), // }) // } // } // pub enum DatagramReceiver { // Pending(BoxFuture<'static, io::Result>), // Ready(Result), // } // impl RecvDatagram for DatagramReceiver { // /// The buffer type // type Buffer = Bytes; // /// Poll the connection for incoming datagrams. // fn poll_incoming_datagram( // &mut self, // cx: &mut Context<'_>, // ) -> Poll> { // self.0.poll_recv(cx).map_err(convert_connection_io_error) // } // } ================================================ FILE: h3-shim/src/lib.rs ================================================ pub mod conn; mod error; pub mod pool; pub use conn::{OpenStreams, QuicConnection}; #[cfg(feature = "datagram")] pub mod ext; #[cfg(feature = "datagram")] #[allow(unused_imports)] pub use ext::*; pub mod streams; pub use dquic; pub use streams::{BidiStream, RecvStream, SendStream}; ================================================ FILE: h3-shim/src/pool.rs ================================================ //! TODO: unimplemented ================================================ FILE: h3-shim/src/streams.rs ================================================ use std::{ mem::MaybeUninit, pin::Pin, task::{Context, Poll, ready}, }; use bytes::Buf; use dquic::{ prelude::{CancelStream, StopSending, StreamReader, StreamWriter}, qbase, }; use h3::quic::StreamErrorIncoming; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::error::convert_stream_io_error; pub struct SendStream { writer: StreamWriter, data: Option>, send_id: h3::quic::StreamId, } impl SendStream { pub fn new(sid: qbase::sid::StreamId, writer: StreamWriter) -> Self { let sid = u64::from(sid); Self { writer, data: None, send_id: h3::quic::StreamId::try_from(sid).expect("unreachable"), } } } impl h3::quic::SendStream for SendStream { #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { let Some(buf) = self.data.as_mut() else { return Poll::Ready(Ok(())); }; loop { match ready!(Pin::new(&mut self.writer).poll_write(cx, buf.chunk())) { Ok(written) => { buf.advance(written); if buf.remaining() == 0 { self.data = None; return Poll::Ready(Ok(())); } } Err(e) => { self.data = None; return Poll::Ready(Err(convert_stream_io_error(e))); } } } } #[inline] fn send_data>>( &mut self, data: T, ) -> Result<(), StreamErrorIncoming> { assert!(self.data.is_none()); self.data = Some(data.into()); Ok(()) } #[inline] fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll> { assert!(self.data.is_none()); Pin::new(&mut self.writer) .poll_shutdown(cx) .map(|r| r.map_err(convert_stream_io_error)) } #[inline] fn reset(&mut self, reset_code: u64) { assert!(self.data.is_none()); self.writer.cancel(reset_code); } #[inline] fn send_id(&self) -> h3::quic::StreamId { self.send_id } } impl h3::quic::SendStreamUnframed for SendStream { #[inline] fn poll_send( &mut self, cx: &mut Context<'_>, buf: &mut D, ) -> Poll> { assert!(self.data.is_none()); Pin::new(&mut self.writer) .poll_write(cx, buf.chunk()) .map(|r| r.map_err(convert_stream_io_error)) } } pub struct RecvStream { reader: StreamReader, recv_id: h3::quic::StreamId, } impl RecvStream { pub(crate) fn new(sid: qbase::sid::StreamId, reader: StreamReader) -> Self { let sid = u64::from(sid); Self { reader, recv_id: h3::quic::StreamId::try_from(sid).expect("unreachable"), } } } impl h3::quic::RecvStream for RecvStream { type Buf = bytes::Bytes; #[inline] fn poll_data( &mut self, cx: &mut Context<'_>, ) -> Poll, StreamErrorIncoming>> { let mut uninit_buf = [MaybeUninit::uninit(); 4096]; let mut read_buf = ReadBuf::uninit(&mut uninit_buf); match ready!(Pin::new(&mut self.reader).poll_read(cx, &mut read_buf)) { Ok(()) => { if read_buf.filled().is_empty() { return Poll::Ready(Ok(None)); } let bytes = bytes::Bytes::copy_from_slice(read_buf.filled()); Poll::Ready(Ok(Some(bytes))) } Err(e) => Poll::Ready(Err(convert_stream_io_error(e))), } } #[inline] fn stop_sending(&mut self, error_code: u64) { self.reader.stop(error_code); } #[inline] fn recv_id(&self) -> h3::quic::StreamId { self.recv_id } } pub struct BidiStream { send: SendStream, recv: RecvStream, } impl BidiStream { pub(crate) fn new( sid: qbase::sid::StreamId, (reader, writer): (StreamReader, StreamWriter), ) -> Self { Self { send: SendStream::new(sid, writer), recv: RecvStream::new(sid, reader), } } } impl h3::quic::RecvStream for BidiStream { type Buf = bytes::Bytes; #[inline] fn poll_data( &mut self, cx: &mut Context<'_>, ) -> Poll, StreamErrorIncoming>> { self.recv.poll_data(cx) } #[inline] fn stop_sending(&mut self, error_code: u64) { self.recv.stop_sending(error_code); } #[inline] fn recv_id(&self) -> h3::quic::StreamId { self.recv.recv_id() } } impl h3::quic::SendStream for BidiStream { #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.send.poll_ready(cx) } #[inline] fn send_data>>( &mut self, data: T, ) -> Result<(), StreamErrorIncoming> { self.send.send_data(data) } #[inline] fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll> { self.send.poll_finish(cx) } #[inline] fn reset(&mut self, reset_code: u64) { self.send.reset(reset_code); } #[inline] fn send_id(&self) -> h3::quic::StreamId { self.send.send_id() } } impl h3::quic::SendStreamUnframed for BidiStream { #[inline] fn poll_send( &mut self, cx: &mut Context<'_>, buf: &mut D, ) -> Poll> { self.send.poll_send(cx, buf) } } impl h3::quic::BidiStream for BidiStream { type SendStream = SendStream; type RecvStream = RecvStream; #[inline] fn split(self) -> (Self::SendStream, Self::RecvStream) { (self.send, self.recv) } } ================================================ FILE: interop/Dockerfile ================================================ FROM docker.io/martenseemann/quic-network-simulator-endpoint:latest RUN env # download and build your QUIC implementation COPY . /dquic # setup rust RUN apt-get update && apt-get install -y curl gcc \ && curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \ && . "$HOME/.cargo/env" \ # build the QUIC implementation && cd dquic \ && cargo build --release --example http-server \ && cargo build --release --example http-client \ && cargo build --release --example h3-client \ && cargo build --release --example h3-server \ # copy the binary && mv target/release/examples/http-server / \ && mv target/release/examples/http-client / \ && mv target/release/examples/h3-client / \ && mv target/release/examples/h3-server / \ # cleanup && cd / && rm -rf /dquic \ && rm -rf $HOME/.cargo/registry \ && rm -rf $HOME/.cargo/git \ && rustup self uninstall -y \ && apt-get remove -y curl gcc \ && apt-get autoremove -y \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* # copy run script and run it COPY interop/run_endpoint.sh . RUN chmod +x run_endpoint.sh ENTRYPOINT [ "./run_endpoint.sh" ] ================================================ FILE: interop/run_endpoint.sh ================================================ #!/bin/bash # Set up the routing needed for the simulation /setup.sh # The following variables are available for use: # - ROLE contains the role of this execution context, client or server # - SERVER_PARAMS contains user-supplied command line parameters # - CLIENT_PARAMS contains user-supplied command line parameters run_client() { binary="/http-client" case "$TESTCASE" in "handshake" | "transfer" | "rebind-port" | "rebind-addr" ) # do nothing ;; "multiconnect" ) CLIENT_PARAMS="$CLIENT_PARAMS" ;; "http3" ) binary="/h3-client" ;; *) echo "Unupported test case: $TESTCASE" exit 127 ;; esac # Start the client echo "Starting client with parameters: $CLIENT_PARAMS" RUST_LOG=debug $binary --alpns hq-interop --qlog $QLOGDIR \ --skip-verify --save /downloads $CLIENT_PARAMS $REQUESTS } run_server() { binary="/http-server" case "$TESTCASE" in "handshake" | "transfer" | "multiconnect" | "rebind-port" | "rebind-addr" ) # do nothing ;; "http3" ) binary="/h3-server" ;; *) echo "Unupported test case: $TESTCASE" exit 127 ;; esac # Start the server echo "Starting server with parameters: $SERVER_PARAMS" RUST_LOG=debug $binary --alpns hq-interop --qlog $QLOGDIR \ -c /certs/cert.pem -k /certs/server.key -d /www $SERVER_PARAMS } if [ "$ROLE" == "client" ]; then # Wait for the simulator to start up. /wait-for-it.sh sim:57832 -s -t 30 run_client elif [ "$ROLE" == "server" ]; then run_server fi ================================================ FILE: qbase/Cargo.toml ================================================ [package] name = "qbase" version = "0.5.0" edition.workspace = true description = "Core structure of the QUIC protocol, a part of dquic" readme.workspace = true repository.workspace = true license.workspace = true keywords.workspace = true categories.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] bitflags = { workspace = true } bytes = { workspace = true } derive_more = { workspace = true, features = [ "as_ref", "deref", "deref_mut", "display", "from", "into", "try_into", ] } enum_dispatch = { workspace = true } futures = { workspace = true } getset = { workspace = true } http = { workspace = true } netdev = { workspace = true } nom = { workspace = true } qmacro = { workspace = true } tracing = { workspace = true } rand = { workspace = true } rustls = { workspace = true } serde = { workspace = true, features = ["derive"] } smallvec = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["rt", "sync", "time"] } [dev-dependencies] tokio = { workspace = true, features = ["test-util", "macros"] } rustls = { workspace = true, features = ["ring"] } ================================================ FILE: qbase/src/cid/connection_id.rs ================================================ use std::{ hash::{Hash, Hasher}, ops::Deref, }; use nom::{IResult, bytes::streaming::take, number::streaming::be_u8}; use rand::RngExt; /// The connection id length must not exceed 20 bytes. See [`ConnectionId`]. pub const MAX_CID_SIZE: usize = 20; /// Connection ID in [QUIC RFC 9000](https://tools.ietf.org/html/rfc9000). /// /// In QUIC version 1, this value MUST NOT exceed 20 bytes. /// Endpoints that receive a version 1 long header with a value larger than /// 20 MUST drop the packet. /// See [Connection Id Length](https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.11). /// /// See [connection id](https://tools.ietf.org/html/rfc9000#name-connection-id) /// of [QUIC RFC 9000](https://www.rfc-editor.org/rfc/rfc9000.html) /// for more details. #[derive(Clone, Copy, Eq, Default)] pub struct ConnectionId { pub(crate) len: u8, pub(crate) bytes: [u8; MAX_CID_SIZE], } impl core::fmt::LowerHex for ConnectionId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { for &b in self.as_ref() { write!(f, "{b:02x}")?; } Ok(()) } } impl core::fmt::Display for ConnectionId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { core::fmt::LowerHex::fmt(self, f) } } impl core::fmt::Debug for ConnectionId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { core::fmt::LowerHex::fmt(self, f) } } /// Parse a connection ID from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. /// /// ## Note: /// /// The connection ID length is limited to 20 bytes, or it will return an error. /// See [`ConnectionId`]. pub fn be_connection_id(input: &[u8]) -> IResult<&[u8], ConnectionId> { let (remain, len) = be_u8(input)?; be_connection_id_with_len(remain, len as usize) } /// Parse a given `len` connection ID from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. /// /// ## Note: /// /// The connection ID length is limited to 20 bytes, or it will return an error. pub fn be_connection_id_with_len(input: &[u8], len: usize) -> IResult<&[u8], ConnectionId> { if len > MAX_CID_SIZE { return Err(nom::Err::Error(nom::error::make_error( input, nom::error::ErrorKind::TooLarge, ))); } let (remain, bytes) = take(len)(input)?; Ok((remain, ConnectionId::from_slice(bytes))) } /// A BufMut extension trait, makes buffer more friendly to write connection ID. pub trait WriteConnectionId: bytes::BufMut { /// Write a connection ID to the buffer. fn put_connection_id(&mut self, cid: &ConnectionId); } impl WriteConnectionId for T { fn put_connection_id(&mut self, cid: &ConnectionId) { self.put_u8(cid.len); self.put_slice(cid); } } impl ConnectionId { /// Create a new connection ID from a given bytes slice. pub fn from_slice(bytes: &[u8]) -> Self { debug_assert!(bytes.len() <= MAX_CID_SIZE); let mut res = Self { len: bytes.len() as u8, bytes: [0; MAX_CID_SIZE], }; res.bytes[..bytes.len()].copy_from_slice(bytes); res } /// Random generate a connection ID of the given length. /// The connection ID maybe not unique, so it should be checked before use. pub fn random_gen(len: usize) -> Self { debug_assert!(len <= MAX_CID_SIZE); let mut bytes = [0; MAX_CID_SIZE]; rand::rng().fill(&mut bytes[..len]); Self { len: len as u8, bytes, } } /// Generates a random connection ID like [`Self::random_gen`]. /// Additionally, allows specific bits of the connection ID to be set to the given mark. pub fn random_gen_with_mark(len: usize, mark: u8, mask: u8) -> Self { debug_assert!(len > 0 && len <= MAX_CID_SIZE); let mut bytes = [0; MAX_CID_SIZE]; rand::rng().fill(&mut bytes[..len]); bytes[0] = (bytes[0] & mask) | mark; Self { len: len as u8, bytes, } } /// Get the encoding size of the connection ID. /// /// Includes 1-byte length encoding and connection ID bytes. pub fn encoding_size(&self) -> usize { 1 + self.len as usize } } impl Deref for ConnectionId { type Target = [u8]; fn deref(&self) -> &Self::Target { &self.bytes[0..self.len as usize] } } impl PartialEq for ConnectionId { fn eq(&self, other: &ConnectionId) -> bool { self.len == other.len && self.bytes[..self.len as usize] == other.bytes[..self.len as usize] } } impl Hash for ConnectionId { fn hash(&self, state: &mut H) { self.len.hash(state); self.bytes[..self.len as usize].hash(state); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_read_connection_id() { let buf = vec![0x04, 0x01, 0x02, 0x03, 0x04]; let (remain, cid) = be_connection_id(&buf).unwrap(); assert!(remain.is_empty()); assert_eq!(*cid, [0x01, 0x02, 0x03, 0x04],); let buf = vec![21, 0x01, 0x02, 0x03, 0x04]; assert_eq!( be_connection_id(&buf), Err(nom::Err::Error(nom::error::make_error( &buf[1..], nom::error::ErrorKind::TooLarge ))) ); } #[test] #[should_panic] fn test_cid_from_large_slice() { ConnectionId::from_slice(&[0; MAX_CID_SIZE + 1]); } #[test] fn test_write_connection_id() { use bytes::{Bytes, BytesMut}; let mut buf = BytesMut::new(); let cid = ConnectionId::from_slice(&[0x01, 0x02, 0x03, 0x04]); buf.put_connection_id(&cid); assert_eq!( buf.freeze(), Bytes::from_static(&[0x04, 0x01, 0x02, 0x03, 0x04]) ); } } ================================================ FILE: qbase/src/cid/local_cid.rs ================================================ use std::sync::{Arc, Mutex}; use super::{ConnectionId, GenUniqueCid, RetireCid}; use crate::{ error::{Error, ErrorKind, QuicError}, frame::{ FrameType, GetFrameType, NewConnectionIdFrame, RetireConnectionIdFrame, io::{ReceiveFrame, SendFrame}, }, token::ResetToken, util::IndexDeque, varint::{VARINT_MAX, VarInt}, }; /// Local connection ID management. #[derive(Debug)] struct LocalCids where ISSUED: GenUniqueCid + RetireCid + SendFrame, { // If the item in cid_deque is None, it means the connection ID has been retired. cid_deque: IndexDeque, VARINT_MAX>, // Each issued connection ID will be written into this issued_cids. // The frames in issued_cids should be able to enter the QUIC sending channel // and be reliably sent to the peer finally. issued_cids: ISSUED, // This is an integer value specifying the maximum number of active connection // IDs limited by peer. // While the client does not know the server's parameters at the beginning, // it can be set to None and will be reset. // If this transport parameter is absent, a default of 2 is assumed. active_cid_limit: Option, } impl LocalCids where ISSUED: GenUniqueCid + RetireCid + SendFrame, { /// Create a new local connection ID manager. fn new(scid: ConnectionId, issued_cids: ISSUED) -> Self { let mut cid_deque = IndexDeque::default(); cid_deque .push_back(Some((scid, ResetToken::default()))) .unwrap(); let new_cid = issued_cids.gen_unique_cid(); let new_cid_frame = NewConnectionIdFrame::new(new_cid, VarInt::from_u32(1), VarInt::from_u32(0)); issued_cids.send_frame([new_cid_frame]); cid_deque .push_back(Some(( *new_cid_frame.connection_id(), *new_cid_frame.reset_token(), ))) .unwrap(); Self { cid_deque, issued_cids, active_cid_limit: None, } } fn initial_scid(&self) -> Option { self.cid_deque.get(0)?.map(|(cid, _)| cid) } /// Set the maximum number of active connection IDs. /// /// The value of the active_connection_id_limit parameter MUST be at least 2. /// An endpoint that receives a value less than 2 MUST close the connection /// with an error of type TRANSPORT_PARAMETER_ERROR. fn set_limit(&mut self, active_cid_limit: u64) -> Result<(), Error> { debug_assert!(self.active_cid_limit.is_none()); if active_cid_limit < 2 { return Err(QuicError::new( ErrorKind::TransportParameter, FrameType::Crypto.into(), format!("active connection id limit {active_cid_limit} < 2"), ) .into()); } for _ in self.cid_deque.largest()..active_cid_limit { self.issue_new_cid(); } self.active_cid_limit = Some(active_cid_limit); Ok(()) } /// Issue a new connection ID, for internal used only. fn issue_new_cid(&mut self) { let seq = VarInt::from_u64(self.cid_deque.largest()).unwrap(); let retire_prior_to = VarInt::from_u64(self.cid_deque.offset()).unwrap(); let new_cid = self.issued_cids.gen_unique_cid(); let new_cid_frame = NewConnectionIdFrame::new(new_cid, seq, retire_prior_to); self.issued_cids.send_frame([new_cid_frame]); self.cid_deque.push_back(Some((*new_cid_frame.connection_id(), *new_cid_frame.reset_token()))) .expect("it's very very hard to issue a new connection ID whose sequence excceeds VARINT_MAX"); } /// Receive a [`RetireConnectionIdFrame`] from the peer, /// retire the connection IDs of the sequence in [`RetireConnectionIdFrame`]. fn recv_retire_cid_frame(&mut self, frame: RetireConnectionIdFrame) -> Result<(), Error> { let seq = frame.sequence(); if seq >= self.cid_deque.largest() { return Err(QuicError::new( ErrorKind::ConnectionIdLimit, frame.frame_type().into(), format!( "Sequence({seq}) in RetireConnectionIdFrame exceeds the largest one({}) issued by us", self.cid_deque.largest().saturating_sub(1) ), ).into()); } if let Some(value) = self.cid_deque.get_mut(seq) { if let Some((cid, _)) = value.take() { let n = self.cid_deque.iter().take_while(|v| v.is_none()).count(); self.cid_deque.advance(n); // generates a new connection ID while retiring an old one. self.issue_new_cid(); self.issued_cids.retire_cid(cid); } } Ok(()) } fn clear(&mut self) { for (cid, _reset_token) in self.cid_deque.drain_to(self.cid_deque.largest()).flatten() { self.issued_cids.retire_cid(cid); } } } impl Drop for LocalCids where ISSUED: GenUniqueCid + RetireCid + SendFrame, { fn drop(&mut self) { self.clear(); } } /// Shared local connection ID manager. Most times, you should use this struct. /// /// Responsible for generating and issuing connection IDs to the peer. /// The number of active connection IDs is limited by the peer's active_cid_limit. /// /// - `ISSUED`: is a struct that can generate unique connection id and finally send the new /// issued connection ID frame to the peer. /// It can be a channel, a queue, or a buffer. Whatever, it must be able to send the /// [`NewConnectionIdFrame`] to the peer. /// /// ## Note /// /// The generated connection ID will be added to the packet reception routing table, /// which is shared with other QUIC connections. /// Therefore, the generated connection ID must not duplicate other local connection IDs, /// including connection IDs of other connections, /// and those issued to the peer and have not been retired, /// otherwise routing conflicts will occur. #[derive(Debug, Clone)] pub struct ArcLocalCids(Arc>>) where ISSUED: GenUniqueCid + RetireCid + SendFrame; impl ArcLocalCids where ISSUED: GenUniqueCid + RetireCid + SendFrame, { /// Create a new share local connection ID manager. /// /// - `scid` is set initially, whether it is a client or a server, /// they both get their early `scid` externally. /// - `issued_cids` is responsible for generating CIDs that do not conflict /// in the packet reception routing table and will also be responsible for /// eventually sending the [`NewConnectionIdFrame`] to the peer. pub fn new(scid: ConnectionId, issued_cids: ISSUED) -> Self { let raw_local_cids = LocalCids::new(scid, issued_cids); Self(Arc::new(Mutex::new(raw_local_cids))) } /// Get the initial source connection ID. /// /// 0-RTT packets in the first flight use the same Destination Connection ID /// and Source Connection ID values as the client's first Initial packet. /// see [Section 7.2.6](https://datatracker.ietf.org/doc/html/rfc9000#section-7.2-6) /// of [RFC9000](https://datatracker.ietf.org/doc/html/rfc9000). /// /// Once a client has received a valid Initial packet from the server, /// it MUST discard any subsequent packet it receives on that connection /// with a different Source Connection ID, /// see [Section 7.2.7](https://datatracker.ietf.org/doc/html/rfc9000#section-7.2-7) /// of [RFC9000](https://datatracker.ietf.org/doc/html/rfc9000). /// /// Any further changes to the Destination Connection ID are only permitted /// if the values are taken from NEW_CONNECTION_ID frames; /// if subsequent Initial packets include a different Source Connection ID, /// they MUST be discarded, /// see [Section 7.2.8](https://datatracker.ietf.org/doc/html/rfc9000#section-7.2-8) /// of [RFC9000](https://datatracker.ietf.org/doc/html/rfc9000) for more details. /// /// It means that the initial source connection ID is the only one that can be used /// to send the Initial, 0Rtt and Handshake packets. /// Changing the scid is like issuing a new connection ID to the peer, /// without specifying a sequence number or Stateless Reset Token. /// Changing the scid during the Handshake phase is meaningless and harmful. /// /// For the server, even though the server provides the preferred address /// as the first connection ID, and even though the server can use this /// connection ID as the scid in the Handshake packet, it is not necessary. /// The client could not eliminate the zero connection ID before entering 1RTT. /// When the client actually eliminates the zero connection ID, /// it means that 1RTT packets have already started to be transmitted, /// and all subsequent transmissions should be through 1RTT packets. /// /// Return None if the initial source connection ID has been retired, /// which indicates that the connection has been established, /// and only the short header packet should be used. pub fn initial_scid(&self) -> Option { self.0.lock().unwrap().initial_scid() } /// Unilaterally no longer use all local connection IDs. /// /// No longer used means that packets sent by the peer to that connection ID are no /// longer accepted. This method is called when the Termination event occurs and `LocalCids` /// dropped, to clean up the state of the connection after the connection ends. /// /// In some rare cases, there are still connection IDs issued after the Termination event occurs, /// resulting in incomplete cleaning of the connection status. /// After externally receiving the Termination event, the connection instance should be dropped /// as early as possible to trigger another cleanup in the [`Drop`] implementation to /// completely clean up connection's state. pub fn clear(&self) { self.0.lock().unwrap().clear(); } /// Set the maximum number of active connection IDs. /// /// After fully obtaining the peer's connection parameters, extract the peer's /// active_cid_limit parameter and set it through this method. pub fn set_limit(&self, active_cid_limit: u64) -> Result<(), Error> { self.0.lock().unwrap().set_limit(active_cid_limit) } } impl ReceiveFrame for ArcLocalCids where ISSUED: GenUniqueCid + RetireCid + SendFrame, { type Output = (); /// Receive a [`RetireConnectionIdFrame`] from the peer, /// retire the connection IDs of the sequence in [`RetireConnectionIdFrame`]. fn recv_frame( &self, frame: RetireConnectionIdFrame, ) -> Result { self.0.lock().unwrap().recv_retire_cid_frame(frame) } } #[cfg(test)] mod tests { use std::{collections::HashMap, sync::MutexGuard}; use super::*; #[derive(Default)] struct IssuedCids { frames: Mutex>, active_cids: Mutex>, } impl IssuedCids { fn frames(&self) -> MutexGuard<'_, Vec> { self.frames.lock().unwrap() } fn active_cids(&self) -> MutexGuard<'_, HashMap> { self.active_cids.lock().unwrap() } } impl GenUniqueCid for IssuedCids { fn gen_unique_cid(&self) -> ConnectionId { let mut local_cids = self.active_cids.lock().unwrap(); let unique_cid = core::iter::from_fn(|| Some(ConnectionId::random_gen_with_mark(8, 0x80, 0x7F))) .find(|cid| !local_cids.contains_key(cid)) .unwrap(); local_cids.insert(unique_cid, ResetToken::default()); unique_cid } } impl RetireCid for IssuedCids { fn retire_cid(&self, cid: ConnectionId) { self.active_cids.lock().unwrap().remove(&cid); } } impl SendFrame for IssuedCids { fn send_frame>(&self, iter: I) { self.frames.lock().unwrap().extend(iter); } } #[test] fn test_issue_cid() { let initial_scid = ConnectionId::random_gen(8); let local_cids = ArcLocalCids::new(initial_scid, IssuedCids::default()); let mut local_cids = local_cids.0.lock().unwrap(); assert_eq!(local_cids.cid_deque.len(), 2); local_cids.set_limit(3).unwrap(); assert_eq!(local_cids.cid_deque.len(), 3); } #[test] fn test_recv_retire_cid_frame() { let initial_scid = ConnectionId::random_gen(8); let mut local_cids = LocalCids::new(initial_scid, IssuedCids::default()); assert_eq!(local_cids.cid_deque.len(), 2); assert_eq!(local_cids.issued_cids.frames().len(), 1); let issued_cid2 = *local_cids.issued_cids.frames()[0].connection_id(); let retire_frame = RetireConnectionIdFrame::new(VarInt::from_u32(1)); let cid2 = local_cids.recv_retire_cid_frame(retire_frame); assert!(cid2.is_ok()); assert!( !local_cids .issued_cids .active_cids() .contains_key(&issued_cid2) ); assert_eq!(local_cids.cid_deque.get(1), Some(&None)); // issued new cid while retiring an old one assert_eq!(local_cids.cid_deque.len(), 3); assert_eq!(local_cids.issued_cids.frames().len(), 2); let retire_frame = RetireConnectionIdFrame::new(VarInt::from_u32(0)); let cid1 = local_cids.recv_retire_cid_frame(retire_frame); assert!(cid1.is_ok()); assert!( !local_cids .issued_cids .active_cids() .contains_key(&initial_scid) ); assert_eq!(local_cids.cid_deque.get(0), None); // have been slided out assert_eq!(local_cids.cid_deque.len(), 2); assert_eq!(local_cids.issued_cids.frames().len(), 3); let retire_frame = RetireConnectionIdFrame::new(VarInt::from_u32(2)); let cid3 = local_cids.recv_retire_cid_frame(retire_frame); assert!(cid3.is_ok()); } } ================================================ FILE: qbase/src/cid/remote_cid.rs ================================================ use std::{ collections::VecDeque, ops::Deref, sync::{Arc, Mutex}, }; use super::ConnectionId; use crate::{ error::{Error, ErrorKind, QuicError}, frame::{ GetFrameType, NewConnectionIdFrame, RetireConnectionIdFrame, io::{ReceiveFrame, SendFrame}, }, net::tx::{ArcSendWaker, Signals}, token::ResetToken, util::IndexDeque, varint::{VARINT_MAX, VarInt}, }; /// RemoteCids is used to manage the connection IDs issued by the peer, /// and to send [`RetireConnectionIdFrame`] to the peer. // TODO: support 0RTT? #[derive(Debug)] struct RemoteCids where RETIRED: SendFrame + Clone, { // The cid issued by the peer, the sequence number maybe not continuous // since the disordered [`NewConnectionIdFrame`] cid_deque: IndexDeque, VARINT_MAX>, // The cell of the connection ID, which is ready in use ready_cells: IndexDeque, VARINT_MAX>, // The cell of the connection ID, which needs to be assigned or reassigned // They can be retired before being assigned or reassigned. pending_cells: VecDeque>, // The maximum number of connection IDs which is used to check if the // maximum number of connection IDs has been exceeded // when receiving a [`NewConnectionIdFrame`] active_cid_limit: u64, // The position of the cid to be used, and the position of the cell to be assigned. cursor: u64, // The retired cids, each needs send a [`RetireConnectionIdFrame`] to peer retired_cids: RETIRED, } impl RemoteCids where RETIRED: SendFrame + Clone, { /// Create a new RemoteCids with the maximum number of active cids, /// and the retired cids. /// /// As mentioned above, the retired cids can be a deque, a channel, or any buffer, /// as long as it can send those [`RetireConnectionIdFrame`] to the peer finally. /// See [`RemoteCids`] fn new(active_cid_limit: u64, retired_cids: RETIRED) -> Self { let cid_deque = IndexDeque::default(); Self { active_cid_limit, cid_deque, ready_cells: Default::default(), pending_cells: Default::default(), cursor: 0, retired_cids, } } fn apply_initial_dcid(&mut self, initial_dcid: ConnectionId, dcid_cell: &ArcCidCell) { assert!( self.cid_deque.is_empty() && self.cid_deque.offset() == 0 && self.cursor == 0, "NewConnectionIdFrame received before the first initial packet processed" ); self.cid_deque .push_back(Some((0, initial_dcid, ResetToken::default()))) .expect("Initial connection ID should be inserted at the offset 0"); let handshake_path = self .pending_cells .iter() .enumerate() .find_map(|(idx, cell)| Arc::ptr_eq(&cell.0, &dcid_cell.0).then_some(idx)) .expect("Initial path should be in pending_cells"); // Move the initial path to the front of the pending cells let handshake_path = self.pending_cells.remove(handshake_path).unwrap(); self.pending_cells.insert(0, handshake_path); self.arrange_idle_cid(); } /// Receive a [`NewConnectionIdFrame`] from peer. /// /// Add the new connection id to the deque, and retire the old cids before /// the retire_prior_to in the [`NewConnectionIdFrame`]. /// Try to arrange the idle cids to the hungry cid applys if exist. /// /// Return the reset token of this [`NewConnectionIdFrame`] if it is valid. fn recv_new_cid_frame( &mut self, frame: NewConnectionIdFrame, ) -> Result, Error> { let seq = frame.sequence(); let retire_prior_to = frame.retire_prior_to(); let active_len = seq.saturating_sub(retire_prior_to); if active_len > self.active_cid_limit { return Err(QuicError::new( ErrorKind::ConnectionIdLimit, frame.frame_type().into(), format!( "{active_len} exceed active_cid_limit {}", self.active_cid_limit ), ) .into()); } // Discard the frame if the sequence number is less than the current offset. if seq < self.cid_deque.offset() { return Ok(None); } let id = *frame.connection_id(); let token = *frame.reset_token(); self.cid_deque.insert(seq, Some((seq, id, token))).unwrap(); self.retire_prior_to(retire_prior_to); self.arrange_idle_cid(); Ok(Some(token)) } /// Arrange the idle cids to the front of the cid applys #[doc(hidden)] fn arrange_idle_cid(&mut self) { loop { let next_unalloced_cell = self.pending_cells.front(); if next_unalloced_cell.is_none() { break; } let next_unalloced_cell = next_unalloced_cell.unwrap(); let mut guard = next_unalloced_cell.0.lock().unwrap(); if guard.is_retired { drop(guard); self.pending_cells.pop_front(); continue; } let next_unused_cid = self.cid_deque.get(self.cursor); if let Some(Some((seq, cid, _))) = next_unused_cid { guard.assign(*seq, *cid); // Until an unused CID is allocated, the guard cannot be released early. drop(guard); let apply = self.pending_cells.pop_front().unwrap(); self.ready_cells .push_back(apply) .expect("Sequence of new connection ID should never exceed the limit"); self.cursor += 1; } else { break; } } } /// Eliminate the old cids and inform the peer with a /// [`RetireConnectionIdFrame`] for each retired connection ID. #[doc(hidden)] fn retire_prior_to(&mut self, tomb_seq: u64) { if tomb_seq <= self.ready_cells.offset() { return; } _ = self.cid_deque.drain_to(tomb_seq); // it is possible that the connection id that has not been used is directly retired, // and there is no chance to assign it, this phenomenon is called "jumping retire cid" self.cursor = self.cursor.max(tomb_seq); // reassign the cid that has been assigned to the Path but is facing retirement if self.ready_cells.is_empty() { // it is not necessary to resize the deque, because all elements will be drained // // self.cid_cells.resize(seq, ArcCidCell::default()).expect(""); self.retired_cids .send_frame((self.ready_cells.offset()..tomb_seq).map(|seq| { RetireConnectionIdFrame::new( VarInt::from_u64(seq) .expect("Sequence of connection id is very hard to exceed VARINT_MAX"), ) })); self.ready_cells.reset_offset(tomb_seq); } else { let actual_applied = self.ready_cells.largest(); let need_reassigned = actual_applied.min(tomb_seq); // retire the cids before seq, including the applied and unapplied for _ in self.ready_cells.offset()..need_reassigned { let (_, cell) = self.ready_cells.pop_front().unwrap(); if cell.is_retired() { continue; } self.pending_cells.push_back(cell); } if actual_applied < tomb_seq { self.ready_cells.reset_offset(tomb_seq); // even the cid that has not been applied is retired right now self.retired_cids .send_frame((actual_applied..tomb_seq).map(|seq| { RetireConnectionIdFrame::new( VarInt::from_u64(seq).expect( "Sequence of connection id is very hard to exceed VARINT_MAX", ), ) })); } } } /// Apply for a new connection ID, and return an [`ArcCidCell`], which may be not ready state. fn apply_dcid(&mut self) -> ArcCidCell { let cell = ArcCidCell::new(self.retired_cids.clone()); self.pending_cells.push_back(cell.clone()); self.arrange_idle_cid(); cell } } /// Shared remote connection ID manager. Most of the time, you should use this struct. /// /// These connection IDs will be assigned to the Path. /// Every new path needs to apply for a new connection ID from the RemoteCids. /// Each path may retire the old connection ID proactively, and apply for a new one. /// /// `RETIRED` stores the [`RetireConnectionIdFrame`], which need to be sent to the peer. /// It can be a deque, a channel, or any buffer, /// as long as it can send those [`RetireConnectionIdFrame`] to the peer finally. #[derive(Debug, Clone)] pub struct ArcRemoteCids(Arc>>) where RETIRED: SendFrame + Clone; impl ArcRemoteCids where RETIRED: SendFrame + Clone, { /// Create a new RemoteCids with the maximum number of active cids, /// and the retired cids. /// /// As mentioned above, the `retired_cids` can be a deque, a channel, or any buffer, /// as long as it can send those [`RetireConnectionIdFrame`] to the peer finally. pub fn new(active_cid_limit: u64, retired_cids: RETIRED) -> Self { Self(Arc::new(Mutex::new(RemoteCids::new( active_cid_limit, retired_cids, )))) } /// Apply initial dcid to handshake path. /// /// dquic implements multi-path handshake feature, the client creates many paths and sends initial packets. /// /// The client and server must negotiate a handshake path and assign the initial dcid to this path /// to prevent the unique connection ID from being obtained by an invalid path, causing the connection to fail. /// /// The client and server choose the path where they receive the first initial packet as the handshake path. /// The server will only return the initial packet on the handshake path to negotiate the handshake path. /// /// This method should only be called when the connection receives the first initial packet, or panic. /// The parameters are the Source Connection Id of the first initial packet received by the connection, /// and the [`ArcCidCell`] of the path that passed this packet. pub fn apply_initial_dcid(&self, initial_dcid: ConnectionId, dcid_cell: &ArcCidCell) { self.0 .lock() .unwrap() .apply_initial_dcid(initial_dcid, dcid_cell); } /// Apply for a new connection ID, which is used when the Path is created. /// /// Return an [`ArcCidCell`], which may be not ready state. pub fn apply_dcid(&self) -> ArcCidCell { self.0.lock().unwrap().apply_dcid() } /// Return the latest connection ID issued by the peer. /// /// The cid is used to assemble the packet that contains a connection close frame. When the /// connection is closed, the connection close frame will be sent to the peer. pub fn latest_dcid(&self) -> Option { self.0 .lock() .unwrap() .cid_deque .iter() .rev() .flatten() .next() .map(|(_, cid, _)| *cid) } } impl ReceiveFrame for ArcRemoteCids where RETIRED: SendFrame + Clone, { type Output = Option; fn recv_frame(&self, frame: NewConnectionIdFrame) -> Result { self.0.lock().unwrap().recv_new_cid_frame(frame) } } #[derive(Debug)] struct CidCell where RETIRED: SendFrame, { retired_cids: RETIRED, allocated_cids: VecDeque<(u64, ConnectionId)>, waker: Option, is_retired: bool, is_using: bool, } impl CidCell where RETIRED: SendFrame + Clone, { fn assign(&mut self, seq: u64, cid: ConnectionId) { assert!(!self.is_retired); self.allocated_cids.push_front((seq, cid)); if !self.is_using { while self.allocated_cids.len() > 1 { let (seq, _) = self.allocated_cids.pop_back().unwrap(); let sequence = VarInt::try_from(seq) .expect("Sequence of connection id is very hard to exceed VARINT_MAX"); self.retired_cids .send_frame([RetireConnectionIdFrame::new(sequence)]); } } if let Some(waker) = self.waker.take() { waker.wake_by(Signals::CONNECTION_ID); } } fn borrow_cid(&mut self, tx_waker: ArcSendWaker) -> Result, Signals> { if self.is_retired { return Ok(None); } if self.allocated_cids.is_empty() { self.waker = Some(tx_waker); Err(Signals::CONNECTION_ID) } else { let cid = self.allocated_cids[0].1; self.is_using = true; Ok(Some(cid)) } } fn renew(&mut self) { assert!(self.is_using); self.is_using = false; while self.allocated_cids.len() > 1 { let (seq, _) = self.allocated_cids.pop_back().unwrap(); let sequence = VarInt::try_from(seq) .expect("Sequence of connection id is very hard to exceed VARINT_MAX"); self.retired_cids .send_frame([RetireConnectionIdFrame::new(sequence)]); } } fn retire(&mut self) { if !self.is_retired { self.is_retired = true; while let Some((seq, _)) = self.allocated_cids.pop_front() { let sequence = VarInt::try_from(seq) .expect("Sequence of connection id is very hard to exceed VARINT_MAX"); self.retired_cids .send_frame([RetireConnectionIdFrame::new(sequence)]); } if let Some(waker) = self.waker.take() { waker.wake_by(Signals::CONNECTION_ID); } } } } /// Shared connection ID cell. Most of the time, you should use this struct. #[derive(Debug, Clone)] pub struct ArcCidCell(Arc>>) where RETIRED: SendFrame + Clone; impl ArcCidCell where RETIRED: SendFrame + Clone, { /// Create a new CidCell with the retired cids, the sequence number of the connection ID, /// and the state of the connection ID. /// /// It can be created only by the [`ArcRemoteCids::apply_dcid`] method. #[doc(hidden)] fn new(retired_cids: RETIRED) -> Self { Self(Arc::new(Mutex::new(CidCell { retired_cids, allocated_cids: VecDeque::with_capacity(2), waker: None, is_retired: false, is_using: false, }))) } fn is_retired(&self) -> bool { self.0.lock().unwrap().is_retired } /// Asynchronously get the connection ID, if it is not ready, return Pending. /// /// If the corresponding path which applied this cid is inactive, /// then this cid apply is retired. /// In this case, None will be returned. pub fn borrow_cid( &'_ self, tx_waker: ArcSendWaker, ) -> Result>, Signals> { self.0.lock().unwrap().borrow_cid(tx_waker).map(|cid| { cid.map(|cid| BorrowedCid { cid_cell: &self.0, cid, }) }) } /// When the Path is invalid, the connection id needs to be retired, and this Cell /// is marked as no longer in use, with a [`RetireConnectionIdFrame`] being sent to peer. pub fn retire(&self) { self.0.lock().unwrap().retire(); } } /// A borrowed connection ID, which will be returned back when it is dropped. /// /// While the connection ID is borrowed, the retired cids will not be truly retired. The retire will be delayed until /// the [`BorrowedCid`] is dropped, a [`RetireConnectionIdFrame`] will be sent to the peer. pub struct BorrowedCid<'a, RETIRED> where RETIRED: SendFrame + Clone, { cid: ConnectionId, cid_cell: &'a Mutex>, } impl Deref for BorrowedCid<'_, RETIRED> where RETIRED: SendFrame + Clone, { type Target = ConnectionId; fn deref(&self) -> &Self::Target { &self.cid } } impl Drop for BorrowedCid<'_, RETIRED> where RETIRED: SendFrame + Clone, { fn drop(&mut self) { self.cid_cell.lock().unwrap().renew(); } } #[cfg(test)] mod tests { use derive_more::Deref; use super::*; #[derive(Debug, Clone, Default, Deref)] struct RetiredCids(Arc>>); impl SendFrame for RetiredCids { fn send_frame>(&self, iter: I) { self.0.lock().unwrap().extend(iter); } } #[test] fn test_remote_cids() { let retired_cids = RetiredCids::default(); let mut remote_cids = RemoteCids::new(8, retired_cids); let initial_dcid = ConnectionId::random_gen(8); let cid_apply0 = remote_cids.apply_dcid(); remote_cids.apply_initial_dcid(initial_dcid, &cid_apply0); let waker = ArcSendWaker::new(); assert!(matches!( cid_apply0.borrow_cid(waker.clone()), Ok(Some(cid)) if *cid == initial_dcid )); // Will return Pending, because the peer hasn't issue any connection id let cid_apply1 = remote_cids.apply_dcid(); assert!(matches!( cid_apply1.borrow_cid(waker.clone()), Err(Signals::CONNECTION_ID) )); let new_dcid = ConnectionId::random_gen(8); let frame = NewConnectionIdFrame::new(new_dcid, VarInt::from_u32(1), VarInt::from_u32(0)); assert!(remote_cids.recv_new_cid_frame(frame).is_ok()); assert_eq!(remote_cids.cid_deque.len(), 2); assert!(matches!( cid_apply0.borrow_cid(waker.clone()), Ok(Some(cid)) if *cid == initial_dcid )); assert!(matches!( cid_apply1.borrow_cid(waker.clone()), Ok(Some(cid)) if *cid == new_dcid )); // Additionally, a new request will be made because if the peer-issued CID is // insufficient, it will still return Pending. remote_cids.retire_prior_to(1); let cid_apply2 = remote_cids.apply_dcid(); assert!(cid_apply2.borrow_cid(waker.clone()).is_err()); assert!(matches!( cid_apply0.borrow_cid(waker.clone()), Ok(Some(cid)) if *cid == initial_dcid )); } #[test] fn test_retire_in_remote_cids() { let retired_cids = RetiredCids::default(); let remote_cids = ArcRemoteCids::new(8, retired_cids); let initial_dcid = ConnectionId::random_gen(8); let cid_apply0 = remote_cids.apply_dcid(); remote_cids.apply_initial_dcid(initial_dcid, &cid_apply0); let mut guard = remote_cids.0.lock().unwrap(); let mut cids = vec![initial_dcid]; for seq in 1..8 { let cid = ConnectionId::random_gen(8); cids.push(cid); let frame = NewConnectionIdFrame::new(cid, VarInt::from_u32(seq), VarInt::from_u32(0)); _ = guard.recv_new_cid_frame(frame); } let cid_apply1 = guard.apply_dcid(); let waker = ArcSendWaker::new(); assert_eq!(cid_apply0.0.lock().unwrap().allocated_cids[0].0, 0); assert!(matches!( cid_apply0.borrow_cid(waker.clone()), Ok(Some(cid)) if *cid == cids[0] )); assert_eq!(cid_apply1.0.lock().unwrap().allocated_cids[0].0, 1); assert!(matches!( cid_apply1.borrow_cid(waker.clone()), Ok(Some(cid)) if *cid == cids[1] )); guard.retire_prior_to(4); assert_eq!(guard.cid_deque.offset(), 4); assert_eq!(guard.ready_cells.offset(), 4); // delay retire cid assert_eq!(guard.retired_cids.0.lock().unwrap().len(), 2); assert_eq!(cid_apply0.0.lock().unwrap().allocated_cids[0].0, 0); assert_eq!(cid_apply1.0.lock().unwrap().allocated_cids[0].0, 1); assert!(matches!( cid_apply0.borrow_cid(waker.clone()), Ok(Some(cid)) if *cid == cids[0] )); assert!(matches!( cid_apply1.borrow_cid(waker.clone()), Ok(Some(cid)) if *cid == cids[1] )); guard.arrange_idle_cid(); assert_eq!(guard.retired_cids.0.lock().unwrap().len(), 4); let retired_cids = [1, 0, 3, 2]; for seq in retired_cids { assert_eq!( // like a stack, the last in the first out guard.retired_cids.0.lock().unwrap().pop(), Some(RetireConnectionIdFrame::new(VarInt::from_u32(seq))) ); } assert!(matches!( cid_apply0.borrow_cid(waker.clone()), Ok(Some(entry)) if *entry == cids[4] )); assert!(matches!( cid_apply1.borrow_cid(waker.clone()), Ok(Some(entry)) if *entry == cids[5] )); cid_apply1.retire(); assert_eq!(guard.retired_cids.lock().unwrap().len(), 1); assert_eq!( guard.retired_cids.0.lock().unwrap().pop(), Some(RetireConnectionIdFrame::new(VarInt::from_u32(5))) ); } #[test] fn test_retire_without_apply() { let retired_cids = RetiredCids::default(); let remote_cids = ArcRemoteCids::new(8, retired_cids); let initial_dcid = ConnectionId::random_gen(8); let cid_apply0 = remote_cids.apply_dcid(); remote_cids.apply_initial_dcid(initial_dcid, &cid_apply0); let mut guard = remote_cids.0.lock().unwrap(); let mut cids = vec![initial_dcid]; for seq in 1..8 { let cid = ConnectionId::random_gen(8); cids.push(cid); let frame = NewConnectionIdFrame::new(cid, VarInt::from_u32(seq), VarInt::from_u32(0)); _ = guard.recv_new_cid_frame(frame); } guard.retire_prior_to(4); assert_eq!(guard.cid_deque.offset(), 4); assert_eq!(guard.ready_cells.offset(), 4); assert_eq!(guard.retired_cids.0.lock().unwrap().len(), 3); let cid_apply1 = guard.apply_dcid(); assert_eq!(cid_apply0.0.lock().unwrap().allocated_cids[0].0, 4); assert_eq!(cid_apply1.0.lock().unwrap().allocated_cids[0].0, 5); let waker = ArcSendWaker::new(); assert!(matches!( cid_apply0.borrow_cid(waker.clone()), Ok(Some(entry)) if *entry == cids[4] )); assert!(matches!( cid_apply1.borrow_cid(waker.clone()), Ok(Some(entry)) if *entry == cids[5] )); } } ================================================ FILE: qbase/src/cid.rs ================================================ mod connection_id; pub use connection_id::*; mod local_cid; pub use local_cid::*; mod remote_cid; pub use remote_cid::*; use crate::role::Role; /// When issuing a CID to the peer, be careful not to duplicate /// other local connection IDs, as this will cause routing conflicts. pub trait GenUniqueCid { /// Generate a unique connection ID. #[must_use] fn gen_unique_cid(&self) -> ConnectionId; } pub trait RetireCid { /// Retire a connection ID. fn retire_cid(&self, cid: ConnectionId); } /// Connection ID registry. /// /// - `local` represents the management of connection IDs issued by me to peer, /// - `remote` represents the reception of connection IDs issued by peer, /// which will be used by the path. #[derive(Debug, Clone)] pub struct Registry { pub local: LOCAL, pub remote: REMOTE, role: Role, origin_dcid: ConnectionId, } impl Registry { /// Create a new connection ID registry. pub fn new(role: Role, origin_dcid: ConnectionId, local: LOCAL, remote: REMOTE) -> Self { Self { role, origin_dcid, local, remote, } } pub fn role(&self) -> Role { self.role } pub fn origin_dcid(&self) -> ConnectionId { self.origin_dcid } } ================================================ FILE: qbase/src/error.rs ================================================ use std::{borrow::Cow, fmt::Display}; use derive_more::From; use thiserror::Error; use crate::{ frame::{ConnectionCloseFrame, FrameType}, varint::VarInt, }; /// QUIC transport error codes and application error codes. /// /// See [table-7](https://www.rfc-editor.org/rfc/rfc9000.html#table-7) /// and [error codes](https://www.rfc-editor.org/rfc/rfc9000.html#name-error-codes) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum ErrorKind { /// An endpoint uses this with CONNECTION_CLOSE to signal that /// the connection is being closed abruptly in the absence of any error. None, /// The endpoint encountered an internal error and cannot continue with the connection. Internal, /// The server refused to accept a new connection. ConnectionRefused, /// An endpoint received more data than it permitted in its advertised data limits. FlowControl, /// An endpoint received a frame for a stream identifier that /// exceeded its advertised stream limit for the corresponding stream type. StreamLimit, /// An endpoint received a frame for a stream that was not in a state that permitted that frame. StreamState, /// - An endpoint received a STREAM frame containing data that /// exceeded the previously established final size, /// - an endpoint received a STREAM frame or a RESET_STREAM frame containing a final size /// that was lower than the size of stream data that was already received, or /// - an endpoint received a STREAM frame or a RESET_STREAM frame containing a different /// final size to the one already established. FinalSize, /// An endpoint received a frame that was badly formatted. FrameEncoding, /// An endpoint received transport parameters that were badly formatted, included: /// - an invalid value, omitted a mandatory transport parameter /// - a forbidden transport parameter /// - otherwise in error. TransportParameter, /// The number of connection IDs provided by the peer exceeds /// the advertised active_connection_id_limit. ConnectionIdLimit, /// An endpoint detected an error with protocol compliance /// that was not covered by more specific error codes. ProtocolViolation, /// A server received a client Initial that contained an invalid Token field. InvalidToken, /// The application or application protocol caused the connection to be closed. Application, /// An endpoint has received more data in CRYPTO frames than it can buffer. CryptoBufferExceeded, /// An endpoint detected errors in performing key updates; see /// [Section 6](https://www.rfc-editor.org/rfc/rfc9001#section-6) /// of [QUIC-TLS](https://www.rfc-editor.org/rfc/rfc9000.html#QUIC-TLS). KeyUpdate, /// An endpoint has reached the confidentiality or integrity limit /// for the AEAD algorithm used by the given connection. AeadLimitReached, /// An endpoint has determined that the network path is incapable of supporting QUIC. /// An endpoint is unlikely to receive a CONNECTION_CLOSE frame carrying this code /// except when the path does not support a large enough MTU. NoViablePath, /// The cryptographic handshake failed. /// A range of 256 values is reserved for carrying error codes specific /// to the cryptographic handshake that is used. /// Codes for errors occurring when TLS is used for the cryptographic handshake are described /// in [Section 4.8](https://www.rfc-editor.org/rfc/rfc9001#section-4.8) /// of [QUIC-TLS](https://www.rfc-editor.org/rfc/rfc9000.html#QUIC-TLS). Crypto(u8), } impl Display for ErrorKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let description = match self { ErrorKind::None => "No error", ErrorKind::Internal => "Implementation error", ErrorKind::ConnectionRefused => "Server refuses a connection", ErrorKind::FlowControl => "Flow control error", ErrorKind::StreamLimit => "Too many streams opened", ErrorKind::StreamState => "Frame received in invalid stream state", ErrorKind::FinalSize => "Change to final size", ErrorKind::FrameEncoding => "Frame encoding error", ErrorKind::TransportParameter => "Error in transport parameters", ErrorKind::ConnectionIdLimit => "Too many connection IDs received", ErrorKind::ProtocolViolation => "Generic protocol violation", ErrorKind::InvalidToken => "Invalid Token received", ErrorKind::Application => "Application error", ErrorKind::CryptoBufferExceeded => "CRYPTO data buffer overflowed", ErrorKind::KeyUpdate => "Invalid packet protection update", ErrorKind::AeadLimitReached => "Excessive use of packet protection keys", ErrorKind::NoViablePath => "No viable network path exists", ErrorKind::Crypto(x) => return write!(f, "TLS alert code: {x}"), }; write!(f, "{description}",) } } /// Invalid error code while parsing. /// The parsed [`VarInt`] error code exceeds the normal range of error codes. /// /// See [Initial QUIC Transport Error Codes Registry Entries](https://www.rfc-editor.org/rfc/rfc9000.html#table-7) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, PartialEq, Eq, Clone, Copy, Error)] #[error("Invalid error code {0}")] pub struct InvalidErrorKind(u64); impl TryFrom for ErrorKind { type Error = InvalidErrorKind; fn try_from(value: VarInt) -> Result { Ok(match value.into_u64() { 0x00 => ErrorKind::None, 0x01 => ErrorKind::Internal, 0x02 => ErrorKind::ConnectionRefused, 0x03 => ErrorKind::FlowControl, 0x04 => ErrorKind::StreamLimit, 0x05 => ErrorKind::StreamState, 0x06 => ErrorKind::FinalSize, 0x07 => ErrorKind::FrameEncoding, 0x08 => ErrorKind::TransportParameter, 0x09 => ErrorKind::ConnectionIdLimit, 0x0a => ErrorKind::ProtocolViolation, 0x0b => ErrorKind::InvalidToken, 0x0c => ErrorKind::Application, 0x0d => ErrorKind::CryptoBufferExceeded, 0x0e => ErrorKind::KeyUpdate, 0x0f => ErrorKind::AeadLimitReached, 0x10 => ErrorKind::NoViablePath, 0x0100..=0x01ff => ErrorKind::Crypto((value.into_u64() & 0xff) as u8), other => return Err(InvalidErrorKind(other)), }) } } impl From for VarInt { fn from(value: ErrorKind) -> Self { match value { ErrorKind::None => VarInt::from(0x00u8), ErrorKind::Internal => VarInt::from(0x01u8), ErrorKind::ConnectionRefused => VarInt::from(0x02u8), ErrorKind::FlowControl => VarInt::from(0x03u8), ErrorKind::StreamLimit => VarInt::from(0x04u8), ErrorKind::StreamState => VarInt::from(0x05u8), ErrorKind::FinalSize => VarInt::from(0x06u8), ErrorKind::FrameEncoding => VarInt::from(0x07u8), ErrorKind::TransportParameter => VarInt::from(0x08u8), ErrorKind::ConnectionIdLimit => VarInt::from(0x09u8), ErrorKind::ProtocolViolation => VarInt::from(0x0au8), ErrorKind::InvalidToken => VarInt::from(0x0bu8), ErrorKind::Application => VarInt::from(0x0cu8), ErrorKind::CryptoBufferExceeded => VarInt::from(0x0du8), ErrorKind::KeyUpdate => VarInt::from(0x0eu8), ErrorKind::AeadLimitReached => VarInt::from(0x0fu8), ErrorKind::NoViablePath => VarInt::from(0x10u8), ErrorKind::Crypto(x) => VarInt::from(0x0100u16 | x as u16), } } } #[derive(Debug, Clone, PartialEq, Eq, Copy)] pub enum ErrorFrameType { V1(FrameType), Ext(VarInt), } /// QUIC transport error. /// /// Its definition conforms to the usage of [`ConnectionCloseFrame`]. /// A value of 0 (equivalent to the mention of the PADDING frame) is used when the frame type is unknown. #[derive(Debug, Clone, PartialEq, Eq, Error)] #[error("{kind} in {frame_type:?}, reason: {reason}")] pub struct QuicError { kind: ErrorKind, frame_type: ErrorFrameType, reason: Cow<'static, str>, } impl QuicError { /// Create a new error with the given kind, frame type, and reason. /// The frame type is the one that triggered this error. pub fn new>>( kind: ErrorKind, frame_type: ErrorFrameType, reason: T, ) -> Self { Self { kind, frame_type, reason: reason.into(), } } /// Create a new error with unknown frame type, and /// the [`FrameType::Padding`] type will be used by default. pub fn with_default_fty>>(kind: ErrorKind, reason: T) -> Self { Self { kind, frame_type: FrameType::Padding.into(), reason: reason.into(), } } /// Return the error kind. pub fn kind(&self) -> ErrorKind { self.kind } /// Return the frame type that triggered this error. pub fn frame_type(&self) -> ErrorFrameType { self.frame_type } /// Return the reason of this error. pub fn reason(&self) -> &str { &self.reason } } impl From for ErrorFrameType { fn from(value: FrameType) -> Self { Self::V1(value) } } impl From for VarInt { fn from(val: ErrorFrameType) -> Self { match val { ErrorFrameType::V1(frame) => frame.into(), ErrorFrameType::Ext(value) => value, } } } /// App specific error. #[derive(Debug, Clone, PartialEq, Eq, Error)] #[error("App layer error occur with error code {error_code}, reason: {reason}")] pub struct AppError { error_code: VarInt, reason: Cow<'static, str>, } impl AppError { /// Create a new app error with the given app error code and reason. pub fn new(error_code: VarInt, reason: impl Into>) -> Self { Self { error_code, reason: reason.into(), } } /// Return the error code. /// /// The error code is an application error code. pub fn error_code(&self) -> u64 { self.error_code.into_u64() } /// Return the reason of this error. pub fn reason(&self) -> &str { &self.reason } } #[derive(Debug, Clone, PartialEq, Eq, Error, From)] pub enum Error { #[error(transparent)] Quic(QuicError), #[error(transparent)] App(AppError), } impl Error { pub fn kind(&self) -> ErrorKind { match self { Error::Quic(e) => e.kind(), Error::App(_) => ErrorKind::Application, } } pub fn frame_type(&self) -> ErrorFrameType { match self { Error::Quic(e) => e.frame_type(), Error::App(_) => FrameType::Padding.into(), } } } impl From for std::io::Error { fn from(e: Error) -> Self { Self::new(std::io::ErrorKind::BrokenPipe, e) } } impl From for ConnectionCloseFrame { fn from(e: Error) -> Self { match e { Error::Quic(e) => Self::new_quic(e.kind, e.frame_type, e.reason), Error::App(app_error) => Self::new_app(app_error.error_code, app_error.reason), } } } impl From for ConnectionCloseFrame { fn from(e: AppError) -> Self { Self::new_app(e.error_code, e.reason) } } impl From for Error { fn from(frame: ConnectionCloseFrame) -> Self { match frame { ConnectionCloseFrame::Quic(frame) => Self::Quic(QuicError { kind: frame.error_kind(), frame_type: frame.frame_type(), reason: frame.reason().to_owned().into(), }), ConnectionCloseFrame::App(frame) => Self::App(AppError { error_code: VarInt::from_u64(frame.error_code()) .expect("error code never overflow"), reason: frame.reason().to_owned().into(), }), } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_error_kind_display() { assert_eq!(ErrorKind::None.to_string(), "No error"); assert_eq!(ErrorKind::Internal.to_string(), "Implementation error"); assert_eq!(ErrorKind::Crypto(10).to_string(), "TLS alert code: 10"); } #[test] fn test_error_kind_conversion() { // Test VarInt to ErrorKind assert_eq!( ErrorKind::try_from(VarInt::from(0x00u8)).unwrap(), ErrorKind::None ); assert_eq!( ErrorKind::try_from(VarInt::from(0x10u8)).unwrap(), ErrorKind::NoViablePath ); assert_eq!( ErrorKind::try_from(VarInt::from(0x0100u16)).unwrap(), ErrorKind::Crypto(0) ); // Test invalid error code assert_eq!( ErrorKind::try_from(VarInt::from(0x1000u16)).unwrap_err(), InvalidErrorKind(0x1000) ); // Test ErrorKind to VarInt assert_eq!(VarInt::from(ErrorKind::None), VarInt::from(0x00u8)); assert_eq!(VarInt::from(ErrorKind::NoViablePath), VarInt::from(0x10u8)); assert_eq!(VarInt::from(ErrorKind::Crypto(5)), VarInt::from(0x0105u16)); } #[test] fn test_error_creation() { let err = QuicError::new(ErrorKind::Internal, FrameType::Ping.into(), "test error"); assert_eq!(err.kind(), ErrorKind::Internal); assert_eq!(err.frame_type(), FrameType::Ping.into()); let err = QuicError::with_default_fty(ErrorKind::Application, "default frame type"); assert_eq!(err.frame_type(), FrameType::Padding.into()); } #[test] fn test_error_conversion() { let err = Error::Quic(QuicError::new( ErrorKind::Internal, FrameType::Ping.into(), "test conversion", )); // Test Error to ConnectionCloseFrame let frame: ConnectionCloseFrame = err.clone().into(); match frame { ConnectionCloseFrame::Quic(frame) => { assert_eq!(frame.error_kind(), err.kind()); assert_eq!(frame.frame_type(), err.frame_type()); } _ => panic!("unexpected frame type"), } // Test Error to io::Error let io_err: std::io::Error = err.into(); assert_eq!(io_err.kind(), std::io::ErrorKind::BrokenPipe); } } ================================================ FILE: qbase/src/flow.rs ================================================ use std::{ ops::{Deref, DerefMut}, sync::{Arc, Mutex}, }; use crate::{ error::{Error, ErrorFrameType, ErrorKind, QuicError}, frame::{ DataBlockedFrame, FrameType, MaxDataFrame, io::{ReceiveFrame, SendFrame}, }, net::tx::{ArcSendWakers, Signals}, varint::VarInt, }; /// Connection-level global Stream Flow Control in the sending direction, /// regulated by the peer's `initial_max_data` transport parameter /// and updated by the [`MaxDataFrame`] sent by the peer. /// /// Private controler in [`ArcSendControler`]. #[derive(Debug)] struct SendControler { sent_data: u64, max_data: u64, flow_limited: bool, broker: TX, tx_wakers: ArcSendWakers, } impl SendControler { fn new(initial_max_data: u64, broker: TX, tx_wakers: ArcSendWakers) -> Self { Self { sent_data: 0, max_data: initial_max_data, flow_limited: false, broker, tx_wakers, } } fn increase_limit(&mut self, max_data: u64) { if max_data > self.max_data { self.max_data = max_data; self.flow_limited = false; self.tx_wakers.wake_all_by(Signals::FLOW_CONTROL); } } fn avaliable(&self) -> u64 { self.max_data - self.sent_data } fn commit(&mut self, flow: u64) where TX: SendFrame, { self.sent_data += flow; if self.avaliable() == 0 && !self.flow_limited { self.flow_limited = true; self.broker.send_frame([DataBlockedFrame::new( VarInt::from_u64(self.max_data) .expect("max_data of flow controller is very very hard to exceed 2^62 - 1"), )]); } } fn return_back(&mut self, flow: u64) { self.sent_data -= flow; if self.avaliable() > 0 { self.tx_wakers.wake_all_by(Signals::FLOW_CONTROL); } } fn revise_max_data(&mut self, zero_rtt_rejected: bool, max_data: u64) { if zero_rtt_rejected { self.max_data = 0; self.flow_limited = false; } self.increase_limit(max_data); } } /// Shared connection-level Stream Flow Control in the sending direction, /// regulated by the peer's `initial_max_data` transport parameter /// and updated by the [`MaxDataFrame`] received from the peer. /// /// Only the new data sent in [`StreamFrame`](`crate::frame::StreamFrame`) counts toward this limit. /// Retransmitted stream data does not count towards this limit. /// /// When flow control is 0, /// retransmitted stream data can still be sent, /// but new data cannot be sent. /// When the stream has no data to retransmit, /// meaning all old data has been successfully acknowledged, /// it is necessary to wait for the receiver's [`MaxDataFrame`]` /// to increase the connection-level flow control limit. /// /// To avoid having to pause sending tasks while waiting for the [`MaxDataFrame`], /// the receiver should promptly send the [`MaxDataFrame`] /// to increase the flow control limit, /// ensuring that the sender always has enough space to send smoothly. /// An extreme yet simple strategy is to set the flow control limit to infinity from the start, /// causing the connection-level flow control to never reach its limit, /// effectively rendering it useless. #[derive(Clone, Debug)] pub struct ArcSendControler(Arc, Error>>>); impl ArcSendControler { /// Creates a new [`ArcSendControler`] with `initial_max_data`. /// /// `initial_max_data` should be known to each other after the handshake is /// completed. If sending data in 0-RTT space, `initial_max_data` should be /// the value from the previous connection. /// /// `initial_max_data` is allowed to be 0, which is reasonable when creating a /// connection without knowing the peer's `iniitial_max_data` setting. pub fn new(initial_max_data: u64, broker: TX, tx_wakers: ArcSendWakers) -> Self { Self(Arc::new(Mutex::new(Ok(SendControler::new( initial_max_data, broker, tx_wakers, ))))) } fn increase_limit(&self, max_data: u64) { let mut guard = self.0.lock().unwrap(); if let Ok(inner) = guard.deref_mut() { inner.increase_limit(max_data); } } // Get some flow control credit to send fresh flow data. /// The returned value may be smaller than the parameter's intended value. /// If some QUIC error occured, it would return the error directly. /// /// # Note /// /// After obtaining the flow control, /// the traffic credit is considered to be consumed immediately. /// The unused flow control quota for this send will be returned to the sending controller. /// This design avoids the sending task’s exclusive access to the sending controller. pub fn credit(&self, quota: usize) -> Result, Error> where TX: SendFrame, { match self.0.lock().unwrap().as_mut() { Ok(inner) => { let avaliable = inner.avaliable().min(quota as u64); inner.commit(avaliable); Ok(Credit { available: avaliable as usize, controller: self, }) } Err(e) => Err(e.clone()), } } pub fn revise_max_data(&self, zero_rtt_rejected: bool, max_data: u64) { if let Ok(inner) = self.0.lock().unwrap().deref_mut() { inner.revise_max_data(zero_rtt_rejected, max_data); } } /// Connection-level Stream Flow Control can only be terminated /// if the connection encounters an error pub fn on_error(&self, error: &Error) { let mut guard = self.0.lock().unwrap(); if guard.deref().is_err() { return; } *guard = Err(error.clone()); } } /// [`ArcSendControler`] need to receive [`MaxDataFrame`] from peer /// to increase flow control limit continuely. impl ReceiveFrame for ArcSendControler { type Output = (); fn recv_frame(&self, frame: MaxDataFrame) -> Result { self.increase_limit(frame.max_data()); Ok(()) } } /// Exclusive access to the flow control limit. /// /// As mentioned in the [`ArcSendControler::credit`] method, /// the flow controller in the period between obtaining flow control /// and finally updating(or maybe not) the flow control should be exclusive. pub struct Credit<'a, TX> { available: usize, controller: &'a ArcSendControler, } impl Credit<'_, TX> { /// Return the available amount of new stream data that can be sent. pub fn available(&self) -> usize { self.available } } impl Credit<'_, TX> where TX: SendFrame, { /// Updates the amount of new data sent. pub fn post_sent(&mut self, amount: usize) { self.available -= amount; } } impl Drop for Credit<'_, TX> { fn drop(&mut self) { if let Ok(inner) = self.controller.0.lock().unwrap().as_mut() { inner.return_back(self.available as u64); } } } /// Receiver's flow controller for managing the flow limit of incoming stream data. #[derive(Debug, Default)] struct RecvController { rcvd_data: u64, max_data: u64, step: u64, broker: TX, } impl RecvController { /// Creates a new [`RecvController`] with the specified `initial_max_data`. fn new(initial_max_data: u64, broker: TX) -> Self { Self { rcvd_data: 0, max_data: initial_max_data, step: initial_max_data / 2, broker, } } } impl RecvController where TX: SendFrame, { /// Handles the event when new data is received. /// /// The data must be new, old retransmitted data does not count. Whether the data is /// new or not will be determined by each stream after delivering the data packet to them. /// The amount of new data will be passed as the `amount` parameter. fn on_new_rcvd(&mut self, frame_type: FrameType, amount: usize) -> Result { self.rcvd_data += amount as u64; if self.rcvd_data <= self.max_data { if self.rcvd_data + self.step >= self.max_data { self.max_data += self.step; self.broker .send_frame([MaxDataFrame::new(VarInt::from_u64(self.max_data).expect( "max_data of flow controller is very very hard to exceed 2^62 - 1", ))]) } Ok(amount) } else { // Err(Overflow((rcvd_data - max_data) as usize)) Err(QuicError::new( ErrorKind::FlowControl, ErrorFrameType::V1(frame_type), format!("flow control overflow: {}", self.rcvd_data - self.max_data), ) .into()) } } } /// Shared receiver's flow controller for managing the incoming stream data flow. /// /// Flow control on the receiving end, /// primarily used to regulate the data flow sent by the sender. /// Since the receive buffer is limited, /// if the application layer cannot read the data in time, /// the receive buffer will not expand, and the sender must be suspended. /// /// The sender must never send new stream data exceeding /// the flow control limit of the receiver advertised, /// otherwise it will be considered a [`FlowControl`](`crate::error::ErrorKind::FlowControl`) error. /// /// Additionally, the flow control on the receiving end also needs to /// promptly send a [`MaxDataFrame`] to the sender after the application layer reads the data, /// to expand the receive window since more receive buffer space is freed up, /// and to inform the sender that more data can be sent. #[derive(Debug, Default, Clone)] pub struct ArcRecvController(Arc>>); impl ArcRecvController { /// Creates a new [`ArcRecvController`] with local `initial_max_data` transport parameter. pub fn new(initial_max_data: u64, broker: TX) -> Self { Self(Arc::new(Mutex::new(RecvController::new( initial_max_data, broker, )))) } } impl ArcRecvController where TX: SendFrame, { /// Updates the total received data size and checks if the flow control limit is exceeded /// when new stream data is received. /// /// As mentioned in [`ArcSendControler`], if the flow control limit is exceeded, /// a [`Error`] error will be returned. pub fn on_new_rcvd(&self, frame_type: FrameType, amount: usize) -> Result { self.0.lock().unwrap().on_new_rcvd(frame_type, amount) } } /// [`ArcRecvController`] need to receive [`DataBlockedFrame`] from peer. /// /// However, the receiver may also not be able to immediately expand the receive window /// and must wait for the application layer to read the data to free up more space /// in the receive buffer. impl ReceiveFrame for ArcRecvController { type Output = (); fn recv_frame(&self, _frame: DataBlockedFrame) -> Result { // Do nothing Ok(()) } } /// Connection-level flow controller, including an [`ArcSendControler`] as the sending side /// and an [`ArcRecvController`] as the receiving side. #[derive(Debug, Clone)] pub struct FlowController { pub sender: ArcSendControler, pub recver: ArcRecvController, } impl FlowController { /// Creates a new `FlowController` with the specified initial send and receive window sizes. /// /// Unfortunately, at the beginning, the peer's `initial_max_data` is unknown. /// Therefore, peer's `initial_max_data` can be set to 0 initially, /// and then updated later after obtaining the peer's `initial_max_data` setting. pub fn new( peer_initial_max_data: u64, local_initial_max_data: u64, broker: TX, tx_wakers: ArcSendWakers, ) -> Self { Self { sender: ArcSendControler::new(peer_initial_max_data, broker.clone(), tx_wakers), recver: ArcRecvController::new(local_initial_max_data, broker), } } /// Updates the initial send window size, /// which should be the peer's `initial_max_data` transport parameter. /// So once the peer's [`Parameters`](`crate::param::Parameters`) are obtained, /// this method should be called immediately. pub fn reset_send_window(&self, snd_wnd: u64) { self.sender.increase_limit(snd_wnd); } /// Get some flow control credit to send fresh flow data. /// The returned value may be smaller than the parameter's intended value. /// If some QUIC error occured, it would return the error directly. pub fn send_limit(&self, quota: usize) -> Result, Error> where TX: SendFrame, { self.sender.credit(quota) } /// Handles the error event of the QUIC connection. /// /// It will makes /// the connection-level stream flow controller in the sending direction become unavailable, /// and the connection-level stream flow controller in the receiving direction terminate. pub fn on_conn_error(&self, error: &Error) { self.sender.on_error(error); } } impl FlowController where TX: SendFrame, { /// Updates the total received data size and checks if the flow control limit is exceeded. /// By the way, it will also send a [`MaxDataFrame`] to the sender /// to expand the receive window if necessary. pub fn on_new_rcvd(&self, frame_type: FrameType, amount: usize) -> Result { self.recver.on_new_rcvd(frame_type, amount) } } #[cfg(test)] mod tests { use derive_more::{Deref, DerefMut}; use super::*; #[derive(Clone, Debug, Default, Deref, DerefMut)] struct SendControllerBroker(Arc>>); impl SendFrame for SendControllerBroker { fn send_frame>(&self, iter: I) { self.0.lock().unwrap().extend(iter); } } #[test] fn test_send_controler() { let broker = SendControllerBroker::default(); let controler = ArcSendControler::new(0, broker.clone(), Default::default()); controler.increase_limit(100); let mut credit = controler.credit(200).unwrap(); assert_eq!(credit.available(), 100); credit.post_sent(50); assert_eq!(credit.available(), 50); credit.post_sent(50); assert_eq!(credit.available(), 0); drop(credit); // broker should have a DataBlockedFrame assert_eq!(broker.lock().unwrap().len(), 1); assert_eq!(broker.lock().unwrap()[0].limit(), 100); let credit = controler.credit(1).unwrap(); assert_eq!(credit.available(), 0); drop(credit); controler.increase_limit(200); let mut credit = controler.credit(200).unwrap(); assert_eq!(credit.available(), 100); credit.post_sent(50); assert_eq!(credit.available(), 50); credit.post_sent(50); assert_eq!(credit.available(), 0); drop(credit); // broker should have a DataBlockedFrame assert_eq!(broker.lock().unwrap().len(), 2); assert_eq!(broker.lock().unwrap()[1].limit(), 200); } #[derive(Clone, Debug, Default, Deref, DerefMut)] struct RecvControllerBroker(Arc>>); impl SendFrame for RecvControllerBroker { fn send_frame>(&self, iter: I) { self.0.lock().unwrap().extend(iter); } } #[test] fn test_recv_controller() { use crate::frame::{Fin, Len, Offset}; let broker = RecvControllerBroker::default(); let controler = ArcRecvController::new(100, broker.clone()); let amount = controler .on_new_rcvd(FrameType::Stream(Offset::Zero, Len::Omit, Fin::No), 20) .unwrap(); assert_eq!(amount, 20); assert_eq!(broker.lock().unwrap().len(), 0); let amount = controler .on_new_rcvd(FrameType::Stream(Offset::Zero, Len::Explicit, Fin::Yes), 30) .unwrap(); assert_eq!(amount, 30); // broker should have a MaxDataFrame assert_eq!(broker.lock().unwrap().len(), 1); assert_eq!(broker.lock().unwrap()[0].max_data(), 150); // test overflow let result = controler.on_new_rcvd(FrameType::ResetStream, 101); assert!(result.is_err()); assert_eq!(result.unwrap_err().kind(), ErrorKind::FlowControl); } } ================================================ FILE: qbase/src/frame/ack.rs ================================================ use std::ops::RangeInclusive; use nom::{Parser, combinator::map}; use crate::{ frame::{GetFrameType, io::WriteFrameType}, varint::{VarInt, WriteVarInt, be_varint}, }; /// ECN flag for ACK frames #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum Ecn { /// ECN counts are not present None, /// ECN counts are present Exist, } impl From for u8 { fn from(ecn: Ecn) -> u8 { match ecn { Ecn::None => 0, Ecn::Exist => 1, } } } impl From for Ecn { fn from(value: u8) -> Self { match value & 0x01 { 0 => Ecn::None, _ => Ecn::Exist, } } } /// ACK Frame /// /// ```text /// ACK Frame { /// Type (i) = 0x02..0x03, /// Largest Acknowledged (i), /// ACK Delay (i), /// ACK Range Count (i), /// First ACK Range (i), /// ACK Range (..) ..., /// [ECN Counts (..)], /// } /// ``` /// /// Receiver sends ACK frames (types 0x02 and 0x03) to inform the sender of packets they have /// received and processed. The ACK frame contains one or more ACK Ranges. /// /// See [ack frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-ack-frames) of QUIC RFC 9000. /// /// The ACK Range Count is not included in the struct because it is an intermediate variable. /// It can be obtained from the ranges when writing and is no longer needed after generating /// the ranges when parsing. #[derive(Debug, Clone, Eq, PartialEq)] pub struct AckFrame { largest: VarInt, delay: VarInt, first_range: VarInt, ranges: Vec<(VarInt, VarInt)>, ecn: Option, } impl super::GetFrameType for AckFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::Ack(if self.ecn.is_some() { Ecn::Exist } else { Ecn::None }) } } impl super::EncodeSize for AckFrame { fn max_encoding_size(&self) -> usize { 1 + 8 + 8 + 8 + 8 + self.ranges.len() * 16 + if self.ecn.is_some() { 24 } else { 0 } } fn encoding_size(&self) -> usize { let ack_range_count = VarInt::try_from(self.ranges.len()).unwrap(); 1 + self.largest.encoding_size() + self.delay.encoding_size() + ack_range_count.encoding_size() + self.first_range.encoding_size() + self .ranges .iter() .map(|(gap, ack)| gap.encoding_size() + ack.encoding_size()) .sum::() + if let Some(e) = self.ecn.as_ref() { e.encoding_size() } else { 0 } } } impl AckFrame { /// Create a new [`AckFrame`]. pub fn new( largest: VarInt, delay: VarInt, first_range: VarInt, ranges: Vec<(VarInt, VarInt)>, ecn: Option, ) -> Self { Self { largest, delay, first_range, ranges, ecn, } } /// Return the largest acknowledged packet number. pub fn largest(&self) -> u64 { self.largest.into_u64() } /// Return the delay in microseconds. pub fn delay(&self) -> u64 { self.delay.into_u64() } /// Return the first range. pub fn first_range(&self) -> u64 { self.first_range.into_u64() } /// Return the ranges. pub fn ranges(&self) -> &Vec<(VarInt, VarInt)> { &self.ranges } /// Return the ECN (Explicit Congestion Notification) counter. pub fn ecn(&self) -> Option { self.ecn } /// Set the value of the ECN (Explicit Congestion Notification) counter pub fn set_ecn(&mut self, ecn: EcnCounts) { self.ecn = Some(ecn); } /// Take the value of the ECN (Explicit Congestion Notification) counter pub fn take_ecn(&mut self) -> Option { self.ecn.take() } /// Iterate through the sequence numbers of the packets acknowledged by the iterative ACK frame, /// starting from the largest and going down. pub fn iter(&self) -> impl Iterator> + '_ { let right = self.largest.into_u64(); let left = right - self.first_range.into_u64(); Some(left..=right).into_iter().chain( self.ranges .iter() .map(|(gap, range)| (gap.into_u64(), range.into_u64())) .scan(left, |largest, (gap, range)| { let right = *largest - gap - 2; let left = right - range; *largest = left; Some(left..=right) }), ) } } /// The counts of Explicit Congestion Notification (ECN) types. /// /// See [ecn-counts](https://www.rfc-editor.org/rfc/rfc9000.html#name-ecn-counts) of QUIC RFC 9000. #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct EcnCounts { ect0: VarInt, ect1: VarInt, ce: VarInt, } impl EcnCounts { /// Create a new [`EcnCounts`]. pub fn new(ect0: VarInt, ect1: VarInt, ce: VarInt) -> Self { Self { ect0, ect1, ce } } /// Get the value of the ECT0 counter. pub fn ect0(&self) -> u64 { self.ect0.into_u64() } /// Get the value of the ECT1 counter. pub fn ect1(&self) -> u64 { self.ect1.into_u64() } /// Get the value of the CE counter. pub fn ce(&self) -> u64 { self.ce.into_u64() } /// Calculates the encoding size of the [`EcnCounts`] struct. fn encoding_size(&self) -> usize { self.ect0.encoding_size() + self.ect1.encoding_size() + self.ce.encoding_size() } } /// Parser for parsing an ACK frame with the given ECN flag, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn ack_frame_with_ecn(ecn: Ecn) -> impl Fn(&[u8]) -> nom::IResult<&[u8], AckFrame> { move |input: &[u8]| { let (mut remain, (largest, delay, count, first_range)) = (be_varint, be_varint, be_varint, be_varint).parse(input)?; let mut ranges = Vec::new(); let mut count = count.into_u64() as usize; while count > 0 { let (i, (gap, ack)) = (be_varint, be_varint).parse(remain)?; ranges.push((gap, ack)); count -= 1; remain = i; } let ecn = if ecn == Ecn::Exist { let (i, ecn) = be_ecn_counts(remain)?; remain = i; Some(ecn) } else { None }; Ok(( remain, AckFrame { largest, delay, first_range, ranges, ecn, }, )) } } /// Parse the ECN counts from the input bytes, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub(super) fn be_ecn_counts(input: &[u8]) -> nom::IResult<&[u8], EcnCounts> { map((be_varint, be_varint, be_varint), |(ect0, ect1, ce)| { EcnCounts { ect0, ect1, ce } }) .parse(input) } impl super::io::WriteFrame for T { fn put_frame(&mut self, frame: &AckFrame) { let frame_type = frame.frame_type(); self.put_frame_type(frame_type); self.put_varint(&frame.largest); self.put_varint(&frame.delay); let ack_range_count = VarInt::try_from(frame.ranges.len()).unwrap(); self.put_varint(&ack_range_count); self.put_varint(&frame.first_range); for (gap, ack) in &frame.ranges { self.put_varint(gap); self.put_varint(ack); } if let Some(ecn) = &frame.ecn { self.put_varint(&ecn.ect0); self.put_varint(&ecn.ect1); self.put_varint(&ecn.ce); } } } #[cfg(test)] mod tests { use nom::{Parser, combinator::flat_map}; use super::*; use crate::{ frame::{EncodeSize, FrameType, GetFrameType, io::WriteFrame}, varint::{VarInt, be_varint}, }; #[test] fn test_ack_frame() { // test frame type, encoding size, and max encoding size let mut frame = AckFrame { largest: VarInt::from_u32(0x1234), delay: VarInt::from_u32(0x1234), first_range: VarInt::from_u32(0x1234), ranges: vec![(VarInt::from_u32(3), VarInt::from_u32(20))], ecn: None, }; assert_eq!(frame.frame_type(), FrameType::Ack(Ecn::None)); assert_eq!(frame.encoding_size(), 1 + 2 * 3 + 1 + 2); assert_eq!(frame.max_encoding_size(), 1 + 4 * 8 + 2 * 8); // test set_ecn and take_ecn let ecn = EcnCounts { ect0: VarInt::from_u32(0x1234), ect1: VarInt::from_u32(0x1234), ce: VarInt::from_u32(0x1234), }; frame.set_ecn(ecn); assert!(frame.ecn.is_some()); assert_eq!(frame.take_ecn(), Some(ecn)); } #[test] fn test_read_ecn_count() { let input = vec![0x52, 0x34, 0x52, 0x34, 0x52, 0x34]; let (input, ecn) = be_ecn_counts(&input).unwrap(); assert!(input.is_empty()); assert_eq!( ecn, EcnCounts { ect0: VarInt::from_u32(0x1234), ect1: VarInt::from_u32(0x1234), ce: VarInt::from_u32(0x1234), } ); } #[test] fn test_read_ack_frame() { let input = vec![0x02, 0x52, 0x34, 0x52, 0x34, 0x01, 0x52, 0x34, 3, 20]; let (input, ack_frame) = flat_map(be_varint, |frame_type| { let ack_frame_type: VarInt = FrameType::Ack(Ecn::None).into(); assert_eq!(frame_type, ack_frame_type); ack_frame_with_ecn(Ecn::None) }) .parse(&input) .unwrap(); assert!(input.is_empty()); assert_eq!( ack_frame, AckFrame { largest: VarInt::from_u32(0x1234), delay: VarInt::from_u32(0x1234), first_range: VarInt::from_u32(0x1234), ranges: vec![(VarInt::from_u32(3), VarInt::from_u32(20))], ecn: None, } ); } #[test] fn test_write_ack_frame() { let mut buf = Vec::new(); let frame = AckFrame { largest: VarInt::from_u32(0x1234), delay: VarInt::from_u32(0x1234), first_range: VarInt::from_u32(0x1234), ranges: vec![(VarInt::from_u32(3), VarInt::from_u32(20))], ecn: Some(EcnCounts { ect0: VarInt::from_u32(0x1234), ect1: VarInt::from_u32(0x1234), ce: VarInt::from_u32(0x1234), }), }; buf.put_frame(&frame); assert_eq!( buf, vec![ 0x03, 0x52, 0x34, 0x52, 0x34, 0x01, 0x52, 0x34, 3, 20, // frame 0x52, 0x34, 0x52, 0x34, 0x52, 0x34 // ecn ] ); } #[test] fn test_ack_frame_into_iter() { // let mut frame = AckFrame::new(1000, 0, 0x1234, None).unwrap(); let frame = AckFrame { largest: VarInt::from_u32(1000), delay: VarInt::from_u32(0x1234), first_range: VarInt::from_u32(0), ranges: vec![ (VarInt::from_u32(0), VarInt::from_u32(2)), (VarInt::from_u32(4), VarInt::from_u32(30)), (VarInt::from_u32(7), VarInt::from_u32(40)), ], ecn: None, }; // frame.alternating_gap_and_range(0, 2); // frame.alternating_gap_and_range(4, 30); // frame.alternating_gap_and_range(7, 40); let mut iter = frame.iter(); assert_eq!(iter.next(), Some(1000..=1000)); assert_eq!(iter.next(), Some(996..=998)); assert_eq!(iter.next(), Some(960..=990)); assert_eq!(iter.next(), Some(911..=951)); assert_eq!(iter.next(), None); } } ================================================ FILE: qbase/src/frame/add_address.rs ================================================ use std::net::{IpAddr, SocketAddr}; use derive_more::Deref; use super::{ EncodeSize, GetFrameType, io::{WriteFrame, WriteFrameType}, }; use crate::{ net::{AddrFamily, Family, NatType, WriteSocketAddr, be_socket_addr}, varint::{VarInt, WriteVarInt, be_varint}, }; // ADD_ADDRESS Frame { // Type (i) = 0x3d7e90..0x3d7e91, // Sequence Number (i), // [ IPv4 (32) ], // [ IPv6 (128) ], // Port (16), // Tire (i), // NAT Type (i), // } #[derive(Debug, Clone, Copy, PartialEq, Eq, Deref)] pub struct AddAddressFrame { #[deref] address: SocketAddr, seq_num: VarInt, tire: VarInt, nat_type: NatType, } pub(crate) fn be_add_address_frame( family: Family, ) -> impl Fn(&[u8]) -> nom::IResult<&[u8], AddAddressFrame> { move |input| { let (remain, seq_num) = be_varint(input)?; let (remain, addr) = be_socket_addr(remain, family)?; let (remain, tire) = be_varint(remain)?; let (remain, nat_type) = be_varint(remain)?; let nat_type = NatType::try_from(nat_type).map_err(|_| { nom::Err::Error(nom::error::Error::new( remain, nom::error::ErrorKind::Verify, )) })?; Ok(( remain, AddAddressFrame { seq_num, address: addr, tire, nat_type, }, )) } } impl GetFrameType for AddAddressFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::AddAddress(self.address.family()) } } impl EncodeSize for AddAddressFrame { fn max_encoding_size(&self) -> usize { 4 // frame type + 8 // seq_num + 2 // port + 16 // ipv6 IP + 8 // tire + 8 // nat_type } fn encoding_size(&self) -> usize { let addr_size = match self.address.ip() { IpAddr::V4(_) => 2 + 4, IpAddr::V6(_) => 2 + 16, }; VarInt::from(self.frame_type()).encoding_size() + self.seq_num.encoding_size() + addr_size + self.tire.encoding_size() + VarInt::from(self.nat_type).encoding_size() } } impl AddAddressFrame { pub fn new(seq_num: u32, address: SocketAddr, tire: u32, nat_type: NatType) -> Self { Self { seq_num: VarInt::from_u32(seq_num), address, tire: VarInt::from_u32(tire), nat_type, } } pub fn seq_num(&self) -> u32 { self.seq_num.into_u64() as u32 } pub fn tire(&self) -> u32 { self.tire.into_u64() as u32 } pub fn nat_type(&self) -> NatType { self.nat_type } } impl WriteFrame for T { fn put_frame(&mut self, frame: &AddAddressFrame) { self.put_frame_type(frame.frame_type()); self.put_varint(&frame.seq_num); self.put_socket_addr(&frame.address); self.put_varint(&frame.tire); self.put_varint(&VarInt::from(frame.nat_type)); } } #[cfg(test)] mod tests { use std::net::Ipv4Addr; use bytes::BytesMut; use super::*; use crate::frame::{GetFrameType, be_frame_type, io::WriteFrame}; #[test] fn test_add_address_frame() { let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080); let frame = AddAddressFrame { seq_num: VarInt::from_u32(1u32), address: addr, tire: VarInt::from_u32(1u32), nat_type: NatType::FullCone, }; let mut buf = BytesMut::new(); buf.put_frame(&frame); let (remain, frame_type) = be_frame_type(&buf).unwrap(); assert_eq!(frame_type, frame.frame_type()); let frame2 = be_add_address_frame(Family::V4)(remain).unwrap().1; assert_eq!(frame, frame2); } } ================================================ FILE: qbase/src/frame/connection_close.rs ================================================ use std::borrow::Cow; use derive_more::From; use nom::bytes::complete::take; use super::FrameType; use crate::{ error::{ErrorFrameType, ErrorKind}, frame::{GetFrameType, be_frame_type, io::WriteFrameType}, varint::{VarInt, be_varint}, }; /// Layer flag for CONNECTION_CLOSE frames #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum Layer { /// QUIC transport layer (0x1c) Quic, /// Application layer (0x1d) App, } impl From for u8 { fn from(layer: Layer) -> u8 { match layer { Layer::Quic => 0, Layer::App => 1, } } } impl From for Layer { fn from(value: u8) -> Self { match value & 0x01 { 0 => Layer::Quic, _ => Layer::App, } } } #[derive(Debug, Clone, PartialEq, Eq)] pub struct AppCloseFrame { error_code: VarInt, reason: Cow<'static, str>, } impl AppCloseFrame { /// Return the error code of the frame. pub fn error_code(&self) -> u64 { self.error_code.into_u64() } /// Return the reason of the frame. pub fn reason(&self) -> &str { &self.reason } /// Otherwise, information about the application state might be revealed. /// /// Endpoints MUST clear the value of the Reason Phrase field and SHOULD use /// the APPLICATION_ERROR code when converting to a CONNECTION_CLOSE of type 0x1c. /// /// See [section-10.2.3-3](https://datatracker.ietf.org/doc/html/rfc9000#section-10.2.3-3) /// of [QUIC](https://datatracker.ietf.org/doc/html/rfc9000) for more details. pub fn conceal(&self) -> QuicCloseFrame { QuicCloseFrame { error_kind: ErrorKind::Application, frame_type: ErrorFrameType::V1(FrameType::Padding), reason: Cow::Borrowed(""), } } } impl From for QuicCloseFrame { fn from(_: AppCloseFrame) -> Self { QuicCloseFrame { error_kind: ErrorKind::Application, frame_type: ErrorFrameType::V1(FrameType::Padding), reason: Cow::Borrowed(""), } } } #[derive(Debug, Clone, PartialEq, Eq)] pub struct QuicCloseFrame { error_kind: ErrorKind, frame_type: ErrorFrameType, reason: Cow<'static, str>, } impl QuicCloseFrame { /// Return the error kind of the frame. pub fn error_kind(&self) -> ErrorKind { self.error_kind } /// Return the frame type of the frame. pub fn frame_type(&self) -> ErrorFrameType { self.frame_type } /// Return the reason of the frame. pub fn reason(&self) -> &str { &self.reason } } /// CONNECTION_CLOSE Frame. /// /// ```text /// CONNECTION_CLOSE Frame { /// Type (i) = 0x1c..0x1d, /// Error Code (i), /// [Frame Type (i)], /// Reason Phrase Length (i), /// Reason Phrase (..), /// } /// ``` /// /// See [connection close frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-connection-close-frames) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Clone, From, PartialEq, Eq)] pub enum ConnectionCloseFrame { App(AppCloseFrame), Quic(QuicCloseFrame), } impl super::GetFrameType for ConnectionCloseFrame { fn frame_type(&self) -> FrameType { match self { ConnectionCloseFrame::App(_) => FrameType::ConnectionClose(Layer::App), ConnectionCloseFrame::Quic(_) => FrameType::ConnectionClose(Layer::Quic), } } } impl super::EncodeSize for ConnectionCloseFrame { fn max_encoding_size(&self) -> usize { // reason's length could not exceed 16KB, so it can be encoded in 2 bytes. match self { ConnectionCloseFrame::App(frame) => 1 + 8 + 2 + frame.reason.len(), ConnectionCloseFrame::Quic(frame) => 1 + 8 + 8 + 2 + frame.reason.len(), } } fn encoding_size(&self) -> usize { match self { ConnectionCloseFrame::App(frame) => { 1 + frame.error_code.encoding_size() // reason's length could not exceed 16KB. + VarInt::try_from(frame.reason.len()).unwrap().encoding_size() + frame.reason.len() } ConnectionCloseFrame::Quic(frame) => { 1 + VarInt::from(frame.error_kind).encoding_size() + 1 // reason's length could not exceed 16KB. + VarInt::try_from(frame.reason.len()).unwrap().encoding_size() + frame.reason.len() } } } } impl ConnectionCloseFrame { /// Create a new `ConnectionCloseFrame` at QUIC layer. pub fn new_quic( error_kind: ErrorKind, frame_type: ErrorFrameType, reason: impl Into>, ) -> Self { Self::Quic(QuicCloseFrame { error_kind, frame_type, reason: reason.into(), }) } /// Create a new `ConnectionCloseFrame` at application layer. pub fn new_app(error_code: VarInt, reason: impl Into>) -> Self { Self::App(AppCloseFrame { error_code, reason: reason.into(), }) } } fn be_app_close_frame(input: &[u8]) -> nom::IResult<&[u8], AppCloseFrame> { let (remain, error_code) = be_varint(input)?; let (remain, reason_length) = be_varint(remain)?; let (remain, reason) = take(reason_length)(remain)?; let cow = String::from_utf8_lossy(reason).into_owned(); Ok(( remain, AppCloseFrame { error_code, reason: Cow::Owned(cow), }, )) } fn be_quic_close_frame(input: &[u8]) -> nom::IResult<&[u8], QuicCloseFrame> { let (remain, error_code) = be_varint(input)?; let error_kind = ErrorKind::try_from(error_code) .map_err(|_e| nom::Err::Error(nom::error::make_error(input, nom::error::ErrorKind::Alt)))?; let (remain, frame_type) = be_frame_type(remain) .map_err(|_e| nom::Err::Error(nom::error::make_error(input, nom::error::ErrorKind::Alt)))?; let (remain, reason_length) = be_varint(remain)?; let (remain, reason) = take(reason_length)(remain)?; let cow = String::from_utf8_lossy(reason).into_owned(); Ok(( remain, QuicCloseFrame { error_kind, frame_type: frame_type.into(), reason: Cow::Owned(cow), }, )) } /// Return a parser for a CONNECTION_CLOSE frame with the given layer. /// /// The `layer` parameter specifies which type of CONNECTION_CLOSE frame to parse: /// - `Layer::Conn`: Parse a QUIC transport layer CONNECTION_CLOSE frame (0x1c) /// - `Layer::App`: Parse an application layer CONNECTION_CLOSE frame (0x1d) /// /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn connection_close_frame_at_layer( layer: Layer, ) -> impl Fn(&[u8]) -> nom::IResult<&[u8], ConnectionCloseFrame> { move |input: &[u8]| match layer { Layer::App => { be_app_close_frame(input).map(|(remain, app)| (remain, ConnectionCloseFrame::App(app))) } Layer::Quic => be_quic_close_frame(input) .map(|(remain, quic)| (remain, ConnectionCloseFrame::Quic(quic))), } } impl super::io::WriteFrame for T { fn put_frame(&mut self, frame: &ConnectionCloseFrame) { use crate::varint::WriteVarInt; self.put_frame_type(frame.frame_type()); match frame { ConnectionCloseFrame::App(frame) => { self.put_varint(&frame.error_code); let len = frame.reason.len().min(self.remaining_mut()); self.put_varint(&VarInt::from_u32(len as u32)); self.put_slice(&frame.reason.as_bytes()[..len]); } ConnectionCloseFrame::Quic(frame) => { self.put_varint(&frame.error_kind.into()); self.put_varint(&frame.frame_type.into()); let len = frame.reason.len().min(self.remaining_mut()); self.put_varint(&VarInt::from_u32(len as u32)); self.put_slice(&frame.reason.as_bytes()[..len]); } } } } #[cfg(test)] mod tests { use super::*; use crate::{ error::ErrorKind, frame::{ EncodeSize, FrameType, GetFrameType, io::{WriteFrame, WriteFrameType}, stream::{Fin, Len, Offset}, }, varint::VarInt, }; #[test] fn test_connection_close_frame() { let frame = ConnectionCloseFrame::new_app(VarInt::from_u32(0x1234), "wrong"); assert_eq!(frame.frame_type(), FrameType::ConnectionClose(Layer::App)); assert_eq!(frame.max_encoding_size(), 1 + 8 + 2 + 5); assert_eq!(frame.encoding_size(), 1 + 2 + 1 + 5); } #[test] fn test_read_connection_close_frame() { use nom::{Parser, combinator::flat_map}; use crate::varint::be_varint; let mut buf = Vec::new(); buf.put_frame_type(FrameType::ConnectionClose(Layer::App)); buf.extend_from_slice(&[0x0c, 5, b'w', b'r', b'o', b'n', b'g']); let app_close_frame_type = VarInt::from(FrameType::ConnectionClose(Layer::App)); let (input, frame) = flat_map(be_varint, |frame_type| { if frame_type == app_close_frame_type { connection_close_frame_at_layer(Layer::App) } else { panic!("wrong frame type: {frame_type}") } }) .parse(buf.as_ref()) .unwrap(); assert!(input.is_empty()); assert_eq!( frame, super::ConnectionCloseFrame::new_app(VarInt::from_u32(0x0c), "wrong",) ); } #[test] fn test_write_connection_close_frame() { use super::FrameType; let mut buf = Vec::::new(); let frame = ConnectionCloseFrame::new_quic( ErrorKind::FlowControl, FrameType::Stream(Offset::NonZero, Len::Explicit, Fin::No).into(), "wrong", ); buf.put_frame(&frame); let mut expected = Vec::new(); expected.put_frame_type(FrameType::ConnectionClose(Layer::Quic)); expected.extend_from_slice(&[0x03, 0xe, 5, b'w', b'r', b'o', b'n', b'g']); assert_eq!(buf, expected); } } ================================================ FILE: qbase/src/frame/crypto.rs ================================================ use std::ops::Range; use nom::Parser; use crate::{ frame::{GetFrameType, io::WriteFrameType}, util::{ContinuousData, WriteData}, varint::{VARINT_MAX, VarInt, WriteVarInt, be_varint}, }; /// CRYPTO Frame /// /// ```text /// CRYPTO Frame { /// Type (i) = 0x06, /// Offset (i), /// Length (i), /// Crypto Data (..), /// } /// ``` /// /// See [crypto frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-crypto-frames) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct CryptoFrame { offset: VarInt, length: VarInt, } impl super::GetFrameType for CryptoFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::Crypto } } impl super::EncodeSize for CryptoFrame { fn max_encoding_size(&self) -> usize { 1 + 8 + 8 } fn encoding_size(&self) -> usize { 1 + self.offset.encoding_size() + self.length.encoding_size() } } impl CryptoFrame { /// Create a new [`CryptoFrame`] with the given offset and length. pub fn new(offset: VarInt, length: VarInt) -> Self { Self { offset, length } } /// Return the offset of the frame. pub fn offset(&self) -> u64 { self.offset.into_u64() } /// Return the length of the frame. pub fn len(&self) -> u64 { self.length.into_u64() } /// Evaluate the maximum number of bytes of data that can be accommodated, /// starting from a certain offset, within a given capacity. If it cannot /// accommodate a CryptoFrame header or can only accommodate 0 bytes, return None. /// /// Note: Panic if the offset exceeds 2^62-1, or the the capacity is too large /// (about 2^32. It is impossible to have so much crypto stream data) pub fn estimate_max_capacity(capacity: usize, offset: u64) -> Option { assert!(offset <= VARINT_MAX); capacity // Must accommodate at least one byte, 'len' takes up 1 byte, // content takes up 1 byte. If these are not satisfied, return None. .checked_sub(1 + VarInt::from_u64(offset).unwrap().encoding_size() + 2) .map(|remaining| match remaining { // Including the 1 byte already considered in check_sub, // 'length' still takes up 1 byte. value @ 0..=62 => value + 1, // The encoding of 'length' directly takes up 2 bytes, the final 2 bytes // subtracted in 'check_sub' are all occupied by the encoding of 'length'. // Interestingly, if only 65 bytes are left after removing the encoding of // Type and Offset, whether the encoding of 'length' takes up 1 byte or 2 // bytes, only 63 bytes of data can be carried. value @ 0x3F..=0x3F_FF => value, // For the following lengths, the encoding of 'length' needs to occupy 4 bytes. // When the buffer capacity is 0x4000 or 0x40001, the encoding of 'length' // changes to 4 bytes, but the capacity is not enough, so it needs to be rolled back. 0x40_00..=0x40_01 => 0x3FFF, value @ 0x40_02..=0x40_00_00_01 => value - 2, // Any longer, a packet exceeding 100 million bytes is already impossible. _ => unreachable!("crypto frame length could not be too large"), }) } /// Return the range of bytes that this frame covers. pub fn range(&self) -> Range { let start = self.offset.into_u64(); let end = start + self.length.into_u64(); start..end } } /// Parse a CRYPTO frame from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_crypto_frame(input: &[u8]) -> nom::IResult<&[u8], CryptoFrame> { let (remain, (offset, length)) = (be_varint, be_varint).parse(input)?; if offset.into_u64() + offset.into_u64() > VARINT_MAX { return Err(nom::Err::Error(nom::error::make_error( input, nom::error::ErrorKind::TooLarge, ))); } Ok((remain, CryptoFrame { offset, length })) } impl super::io::WriteDataFrame for T where T: bytes::BufMut + WriteData, D: ContinuousData, { fn put_data_frame(&mut self, frame: &CryptoFrame, data: &D) { assert_eq!(frame.length.into_u64(), data.len() as u64); self.put_frame_type(frame.frame_type()); self.put_varint(&frame.offset); self.put_varint(&frame.length); self.put_data(data); } } #[cfg(test)] mod tests { use super::CryptoFrame; use crate::{ frame::{ EncodeSize, FrameType, GetFrameType, io::{WriteDataFrame, WriteFrameType}, }, varint::VarInt, }; #[test] fn test_crypto_frame() { let frame = CryptoFrame::new(VarInt::from_u32(0), VarInt::from_u32(500)); assert_eq!(frame.frame_type(), super::super::FrameType::Crypto); assert_eq!(frame.max_encoding_size(), 1 + 8 + 8); assert_eq!(frame.encoding_size(), 1 + 1 + 2); assert_eq!(frame.offset(), 0); assert_eq!(frame.len(), 500); assert_eq!(frame.range(), 0..500); } #[test] fn test_read_crypto_frame() { use super::be_crypto_frame; let buf = vec![0x52, 0x34, 0x80, 0x00, 0x56, 0x78]; let (remain, frame) = be_crypto_frame(&buf).unwrap(); assert_eq!(remain, &[]); assert_eq!( frame, CryptoFrame::new(VarInt::from_u32(0x1234), VarInt::from_u32(0x5678)) ); } #[test] fn test_write_crypto_frame() { let mut buf = bytes::BytesMut::new(); let frame = CryptoFrame::new(VarInt::from_u32(0x1234), VarInt::from_u32(0x5)); buf.put_data_frame(&frame, b"hello"); let mut expected = Vec::new(); expected.put_frame_type(FrameType::Crypto); expected.extend_from_slice(&[0x52, 0x34, 0x05]); expected.extend_from_slice(b"hello"); assert_eq!(buf, bytes::Bytes::from(expected)); } #[test] fn test_encoding_capacity_estimate() { assert_eq!(CryptoFrame::estimate_max_capacity(1, 0), None); assert_eq!(CryptoFrame::estimate_max_capacity(4, 0), Some(1)); assert_eq!(CryptoFrame::estimate_max_capacity(4, 64), None); assert_eq!(CryptoFrame::estimate_max_capacity(5, 65), Some(1)); assert_eq!(CryptoFrame::estimate_max_capacity(67, 65), Some(63)); assert_eq!(CryptoFrame::estimate_max_capacity(68, 65), Some(63)); assert_eq!(CryptoFrame::estimate_max_capacity(69, 65), Some(64)); assert_eq!(CryptoFrame::estimate_max_capacity(16387, 65), Some(16382)); assert_eq!(CryptoFrame::estimate_max_capacity(16388, 65), Some(16383)); assert_eq!(CryptoFrame::estimate_max_capacity(16389, 65), Some(16383)); assert_eq!(CryptoFrame::estimate_max_capacity(16390, 65), Some(16383)); assert_eq!(CryptoFrame::estimate_max_capacity(16391, 65), Some(16384)); } #[test] #[should_panic] fn test_encoding_with_offset_exceeded() { CryptoFrame::estimate_max_capacity(60, 1 << 62); } #[test] #[should_panic] fn test_encoding_with_length_too_large() { CryptoFrame::estimate_max_capacity(1 << 31, 20); } } ================================================ FILE: qbase/src/frame/data_blocked.rs ================================================ use crate::{ frame::{GetFrameType, io::WriteFrameType}, varint::{VarInt, WriteVarInt, be_varint}, }; /// DATA_BLOCKED Frame /// /// ```text /// DATA_BLOCKED Frame { /// Type (i) = 0x14, /// Maximum Data (i), /// } /// ``` /// /// See [data-blocked frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-data_blocked-frames) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub struct DataBlockedFrame { limit: VarInt, } impl super::GetFrameType for DataBlockedFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::DataBlocked } } impl super::EncodeSize for DataBlockedFrame { fn max_encoding_size(&self) -> usize { 1 + 8 } fn encoding_size(&self) -> usize { 1 + self.limit.encoding_size() } } impl DataBlockedFrame { /// Create a new [`DataBlockedFrame`] with the given limit. pub fn new(limit: VarInt) -> Self { Self { limit } } /// Return the limit of the frame. pub fn limit(&self) -> u64 { self.limit.into_u64() } } /// Parse a DATA_BLOCKED frame from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_data_blocked_frame(input: &[u8]) -> nom::IResult<&[u8], DataBlockedFrame> { use nom::{Parser, combinator::map}; map(be_varint, DataBlockedFrame::new).parse(input) } impl super::io::WriteFrame for T { fn put_frame(&mut self, frame: &DataBlockedFrame) { self.put_frame_type(frame.frame_type()); self.put_varint(&frame.limit); } } #[cfg(test)] mod tests { use super::DataBlockedFrame; use crate::{ frame::{EncodeSize, FrameType, GetFrameType, io::WriteFrame}, varint::VarInt, }; #[test] fn test_data_blocked_frame() { let frame = DataBlockedFrame::new(VarInt::from_u32(0x1234)); assert_eq!(frame.frame_type(), FrameType::DataBlocked); assert_eq!(frame.max_encoding_size(), 1 + 8); assert_eq!(frame.encoding_size(), 1 + 2); } #[test] fn test_read_data_blocked_frame() { use super::be_data_blocked_frame; let buf = vec![0x52, 0x34]; let (_, frame) = be_data_blocked_frame(&buf).unwrap(); assert_eq!(frame, DataBlockedFrame::new(VarInt::from_u32(0x1234))); } #[test] fn test_write_data_blocked_frame() { let mut buf = Vec::new(); buf.put_frame(&DataBlockedFrame::new(VarInt::from_u32(0x1234))); let frame_type: VarInt = FrameType::DataBlocked.into(); assert_eq!(buf, vec![frame_type.into_u64() as u8, 0x52, 0x34]); } } ================================================ FILE: qbase/src/frame/datagram.rs ================================================ use bytes::Buf; use nom::IResult; use super::{FrameType, GetFrameType, io::WriteFrameType}; use crate::{ util::{ContinuousData, WriteData}, varint::{VarInt, WriteVarInt, be_varint}, }; /// DATAGRAM Frame /// /// ```text /// DATAGRAM Frame { /// Type (i) = 0x30..0x31, /// [Length (i)], /// Datagram Data (..), /// } /// ``` /// /// See [datagram frame types](https://www.rfc-editor.org/rfc/rfc9000.html#name-datagram-frame-types) /// of [An Unreliable Datagram Extension to QUIC](https://www.rfc-editor.org/rfc/rfc9221.html) /// for more details. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct DatagramFrame { encode_len: bool, len: VarInt, } impl DatagramFrame { /// Create a new `DatagramFrame` with the given length. pub fn new(encode_len: bool, len: VarInt) -> Self { Self { encode_len, len } } #[inline] pub fn encode_len(&self) -> bool { self.encode_len } #[inline] pub fn len(&self) -> VarInt { self.len } } impl GetFrameType for DatagramFrame { fn frame_type(&self) -> FrameType { FrameType::Datagram(self.encode_len as _) } } impl super::EncodeSize for DatagramFrame { fn max_encoding_size(&self) -> usize { 1 + 8 } fn encoding_size(&self) -> usize { 1 + self .encode_len .then_some(self.len) .map(VarInt::encoding_size) .unwrap_or_default() } } /// Return a parser for DATAGRAM frames with a flag, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn datagram_frame_with_flag(flag: u8) -> impl FnOnce(&[u8]) -> IResult<&[u8], DatagramFrame> { move |input| { let (remain, len) = if flag == 1 { be_varint(input)? } else { let len = VarInt::try_from(input.remaining()) .expect("size of datagram frame payload never exceeds limit"); (input, len) }; let with_len = flag == 1; Ok(( remain, DatagramFrame { encode_len: with_len, len, }, )) } } impl super::io::WriteDataFrame for T where T: bytes::BufMut + WriteData, D: ContinuousData, { fn put_data_frame(&mut self, frame: &DatagramFrame, data: &D) { self.put_frame_type(frame.frame_type()); if frame.encode_len { self.put_varint(&frame.len); } self.put_data(data); } } #[cfg(test)] mod tests { use super::*; use crate::frame::{EncodeSize, io::WriteDataFrame}; #[test] fn test_datagram_frame() { let frame = DatagramFrame { encode_len: true, len: VarInt::from_u32(3), }; assert_eq!(frame.frame_type(), FrameType::Datagram(1)); assert_eq!(frame.max_encoding_size(), 1 + 8); assert_eq!(frame.encoding_size(), 1 + 1); } #[test] fn test_datagram_frame_with_flag() { let input = [0x05, 0x00, 0x00, 0x00, 0x00, 0x00]; let expected_output = DatagramFrame { encode_len: true, len: VarInt::from_u32(5), }; let (remain, frame) = datagram_frame_with_flag(1)(&input).unwrap(); assert_eq!(remain, &[0x00, 0x00, 0x00, 0x00, 0x00]); assert_eq!(frame, expected_output); } #[test] fn test_datagram_frame_with_flag_no_length() { let input = b"114514"; let expected_output = DatagramFrame { encode_len: false, len: VarInt::from_u32(6), }; let (remain, frame) = datagram_frame_with_flag(0)(input).unwrap(); assert_eq!(remain, input); assert_eq!(frame, expected_output); } #[test] fn test_put_datagram_frame_with_length() { let frame = DatagramFrame { encode_len: true, len: VarInt::from_u32(3), }; let mut buf = Vec::new(); buf.put_data_frame(&frame, &[0x01, 0x02, 0x03]); assert_eq!(&buf, &[0x31, 0x03, 0x01, 0x02, 0x03]); } #[test] fn test_put_datagram_frame_no_length() { let frame = DatagramFrame { encode_len: false, len: VarInt::from_u32(3), }; let mut buf = Vec::new(); buf.put_data_frame(&frame, &[0x01, 0x02, 0x03]); assert_eq!(&buf, &[0x30, 0x01, 0x02, 0x03]); } } ================================================ FILE: qbase/src/frame/error.rs ================================================ use nom::error::ErrorKind as NomErrorKind; use thiserror::Error; use super::FrameType; use crate::{ error::{ErrorKind as QuicErrorKind, QuicError}, packet::r#type::Type, varint::VarInt, }; /// Parse errors when decoding QUIC frames. #[derive(Debug, Clone, Eq, PartialEq, Error)] pub enum Error { #[error("A packet containing no frames")] NoFrames, #[error("Incomplete frame type: {0}")] IncompleteType(String), #[error("Invalid frame type from {0}")] InvalidType(VarInt), #[error("Wrong frame type {0:?}")] WrongType(FrameType, Type), #[error("Incomplete frame {0:?}: {1}")] IncompleteFrame(FrameType, String), #[error("Error occurred when parsing frame {0:?}: {1}")] ParseError(FrameType, String), } impl From for QuicError { fn from(e: Error) -> Self { match e { // An endpoint MUST treat receipt of a packet containing no frames as a connection error of type PROTOCOL_VIOLATION. Error::NoFrames => { Self::with_default_fty(QuicErrorKind::ProtocolViolation, e.to_string()) } Error::IncompleteType(_) => { Self::with_default_fty(QuicErrorKind::FrameEncoding, e.to_string()) } Error::InvalidType(_) => { Self::with_default_fty(QuicErrorKind::FrameEncoding, e.to_string()) } Error::WrongType(fty, _) => { Self::new(QuicErrorKind::FrameEncoding, fty.into(), e.to_string()) } Error::IncompleteFrame(fty, _) => { Self::new(QuicErrorKind::FrameEncoding, fty.into(), e.to_string()) } Error::ParseError(fty, _) => { Self::new(QuicErrorKind::FrameEncoding, fty.into(), e.to_string()) } } } } impl From> for Error { fn from(error: nom::Err) -> Self { match error { nom::Err::Incomplete(_needed) => { unreachable!("Because the parsing of QUIC packets and frames is not stream-based.") } nom::Err::Error(err) | nom::Err::Failure(err) => err, } } } impl nom::error::ParseError<&[u8]> for Error { fn from_error_kind(_input: &[u8], _kind: NomErrorKind) -> Self { debug_assert_eq!(_kind, NomErrorKind::ManyTill); unreachable!("QUIC frame parser must always consume") } fn append(_input: &[u8], _kind: NomErrorKind, source: Self) -> Self { // 在解析帧时遇到了source错误,many_till期望通过ManyTill的错误类型告知 // 这里,源错误更有意义,所以直接返回源错误 debug_assert_eq!(_kind, NomErrorKind::ManyTill); source } } // TODO: conver DecodingError to quic error #[cfg(test)] mod tests { use nom::error::ParseError; use super::*; use crate::packet::r#type::{ Type, long::{Type::V1, Ver1}, }; #[test] fn test_error_conversion_to_transport_error() { let cases = vec![ (Error::NoFrames, QuicErrorKind::ProtocolViolation), ( Error::IncompleteType("test".to_string()), QuicErrorKind::FrameEncoding, ), ( Error::InvalidType(VarInt::from_u32(0x1f)), QuicErrorKind::FrameEncoding, ), ( Error::WrongType(FrameType::Ping, Type::Long(V1(Ver1::INITIAL))), QuicErrorKind::FrameEncoding, ), ( Error::IncompleteFrame(FrameType::Ping, "incomplete".to_string()), QuicErrorKind::FrameEncoding, ), ( Error::ParseError(FrameType::Ping, "parse error".to_string()), QuicErrorKind::FrameEncoding, ), ]; for (error, expected_kind) in cases { let quic_error: QuicError = error.into(); assert_eq!(quic_error.kind(), expected_kind); } } #[test] fn test_nom_error_conversion() { let error = Error::NoFrames; let nom_error = nom::Err::Error(error.clone()); let converted: Error = nom_error.into(); assert_eq!(converted, error); let nom_failure = nom::Err::Failure(error.clone()); let converted: Error = nom_failure.into(); assert_eq!(converted, error); } #[test] fn test_parse_error_impl() { let error = Error::ParseError(FrameType::Ping, "test error".to_string()); let appended = Error::append(&[], NomErrorKind::ManyTill, error.clone()); assert_eq!(appended, error); } #[test] #[should_panic(expected = "QUIC frame parser must always consume")] fn test_parse_error_unreachable() { Error::from_error_kind(&[], NomErrorKind::ManyTill); } #[test] fn test_error_display() { let error = Error::NoFrames; assert_eq!(error.to_string(), "A packet containing no frames"); let error = Error::IncompleteType("test".to_string()); assert_eq!(error.to_string(), "Incomplete frame type: test"); let error = Error::InvalidType(VarInt::from_u32(0x1f)); assert_eq!(error.to_string(), "Invalid frame type from 31"); } } ================================================ FILE: qbase/src/frame/handshake_done.rs ================================================ use super::EncodeSize; use crate::frame::{GetFrameType, io::WriteFrameType}; /// HandshakeDone frame /// /// ```text /// HANDSHAKE_DONE Frame { /// Type (i) = 0x1e, /// } /// ``` /// /// See [HANDSHAKE_DONE Frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-handshake_done-frames) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub struct HandshakeDoneFrame; impl super::GetFrameType for HandshakeDoneFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::HandshakeDone } } impl EncodeSize for HandshakeDoneFrame {} /// Parse a HANDSHAKE_DONE frame from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. #[allow(unused)] pub fn be_handshake_done_frame(input: &[u8]) -> nom::IResult<&[u8], HandshakeDoneFrame> { Ok((input, HandshakeDoneFrame)) } impl super::io::WriteFrame for T { fn put_frame(&mut self, frame: &HandshakeDoneFrame) { self.put_frame_type(frame.frame_type()); } } #[cfg(test)] mod tests { use crate::{ frame::{ EncodeSize, FrameType, GetFrameType, HandshakeDoneFrame, io::{WriteFrame, WriteFrameType}, }, varint::VarInt, }; #[test] fn test_handshake_done_frame() { assert_eq!(HandshakeDoneFrame.frame_type(), FrameType::HandshakeDone); assert_eq!(HandshakeDoneFrame.max_encoding_size(), 1); assert_eq!(HandshakeDoneFrame.encoding_size(), 1); } #[test] fn test_read_handshake_done_frame() { use nom::{Parser, combinator::flat_map}; use super::be_handshake_done_frame; use crate::varint::be_varint; let handshake_done_frame_type = VarInt::from(FrameType::HandshakeDone); let buf = vec![handshake_done_frame_type.into_u64() as u8]; let (input, frame) = flat_map(be_varint, |frame_type| { if frame_type == handshake_done_frame_type { be_handshake_done_frame } else { panic!("wrong frame type: {frame_type}") } }) .parse(buf.as_ref()) .unwrap(); assert!(input.is_empty()); assert_eq!(frame, super::HandshakeDoneFrame); } #[test] fn test_write_handshake_done_frame() { let mut buf = Vec::new(); buf.put_frame(&HandshakeDoneFrame); let mut expected = Vec::new(); expected.put_frame_type(FrameType::HandshakeDone); assert_eq!(buf, expected); } } ================================================ FILE: qbase/src/frame/io.rs ================================================ use std::{ pin::Pin, task::{Context, Poll}, }; use bytes::Bytes; use super::{ ack::ack_frame_with_ecn, add_address::be_add_address_frame, connection_close::connection_close_frame_at_layer, crypto::be_crypto_frame, data_blocked::be_data_blocked_frame, datagram::datagram_frame_with_flag, max_data::be_max_data_frame, max_stream_data::be_max_stream_data_frame, max_streams::max_streams_frame_with_dir, new_connection_id::be_new_connection_id_frame, new_token::be_new_token_frame, path_challenge::be_path_challenge_frame, path_response::be_path_response_frame, punch_done::be_punch_done_frame, punch_hello::be_punch_hello_frame, punch_me_now::be_punch_me_now_frame, remove_address::be_remove_address_frame, reset_stream::be_reset_stream_frame, retire_connection_id::be_retire_connection_id_frame, stop_sending::be_stop_sending_frame, stream::stream_frame_with_flag, stream_data_blocked::be_stream_data_blocked_frame, streams_blocked::streams_blocked_frame_with_dir, *, }; use crate::{ArcReceiving, Receiving, ResetError, util::ContinuousData}; /// Return a parser for a complete frame from the raw bytes with the given type, /// [nom](https://docs.rs/nom/latest/nom/) parser style. /// /// Some frames like [`StreamFrame`] and [`CryptoFrame`] have a data body, /// which use `bytes::Bytes` to store. fn complete_frame( frame_type: FrameType, raw: Bytes, ) -> impl Fn(&[u8]) -> nom::IResult<&[u8], Frame> { use nom::{Parser, combinator::map}; move |input: &[u8]| match frame_type { FrameType::Padding => Ok((input, Frame::Padding(PaddingFrame))), FrameType::Ping => Ok((input, Frame::Ping(PingFrame))), FrameType::ConnectionClose(layer) => { map(connection_close_frame_at_layer(layer), Frame::Close).parse(input) } FrameType::NewConnectionId => { map(be_new_connection_id_frame, Frame::NewConnectionId).parse(input) } FrameType::RetireConnectionId => { map(be_retire_connection_id_frame, Frame::RetireConnectionId).parse(input) } FrameType::DataBlocked => map(be_data_blocked_frame, Frame::DataBlocked).parse(input), FrameType::MaxData => map(be_max_data_frame, Frame::MaxData).parse(input), FrameType::PathChallenge => map(be_path_challenge_frame, Frame::PathChallenge).parse(input), FrameType::PathResponse => map(be_path_response_frame, Frame::PathResponse).parse(input), FrameType::HandshakeDone => Ok((input, Frame::HandshakeDone(HandshakeDoneFrame))), FrameType::NewToken => map(be_new_token_frame, Frame::NewToken).parse(input), FrameType::Ack(ecn) => map(ack_frame_with_ecn(ecn), Frame::Ack).parse(input), FrameType::ResetStream => { map(be_reset_stream_frame, |f| Frame::StreamCtl(f.into())).parse(input) } FrameType::StopSending => { map(be_stop_sending_frame, |f| Frame::StreamCtl(f.into())).parse(input) } FrameType::MaxStreamData => { map(be_max_stream_data_frame, |f| Frame::StreamCtl(f.into())).parse(input) } FrameType::MaxStreams(dir) => map(max_streams_frame_with_dir(dir), |f| { Frame::StreamCtl(f.into()) }) .parse(input), FrameType::StreamsBlocked(dir) => map(streams_blocked_frame_with_dir(dir), |f| { Frame::StreamCtl(f.into()) }) .parse(input), FrameType::StreamDataBlocked => { map(be_stream_data_blocked_frame, |f| Frame::StreamCtl(f.into())).parse(input) } FrameType::Crypto => { let (input, frame) = be_crypto_frame(input)?; let start = raw.len() - input.len(); let len = frame.len() as usize; if input.len() < len { Err(nom::Err::Incomplete(nom::Needed::new(len - input.len()))) } else { let data = raw.slice(start..start + len); Ok((&input[len..], Frame::Crypto(frame, data))) } } FrameType::Stream(offset, len, fin) => { let (input, frame) = stream_frame_with_flag(offset, len, fin)(input)?; let start = raw.len() - input.len(); let len = frame.len(); if input.len() < len { Err(nom::Err::Incomplete(nom::Needed::new(len - input.len()))) } else { let data = raw.slice(start..start + len); Ok((&input[len..], Frame::Stream(frame, data))) } } FrameType::Datagram(with_len) => { let (input, frame) = datagram_frame_with_flag(with_len)(input)?; let start = raw.len() - input.len(); match frame.encode_len() { true if frame.len().into_u64() > input.len() as u64 => Err(nom::Err::Incomplete( nom::Needed::new((frame.len().into_u64() - input.len() as u64) as usize), )), true => { let data = raw.slice(start..start + frame.len().into_u64() as usize); Ok(( &input[frame.len().into_u64() as usize..], Frame::Datagram(frame, data), )) } false => { let data = raw.slice(start..); Ok((&[], Frame::Datagram(frame, data))) } } } FrameType::AddAddress(family) => { map(be_add_address_frame(family), Frame::AddAddress).parse(input) } FrameType::RemoveAddress => map(be_remove_address_frame, Frame::RemoveAddress).parse(input), FrameType::PunchMeNow(family) => { map(be_punch_me_now_frame(family), Frame::PunchMeNow).parse(input) } FrameType::PunchHello => map(be_punch_hello_frame, Frame::PunchHello).parse(input), FrameType::PunchDone => map(be_punch_done_frame, Frame::PunchDone).parse(input), } } /// Parse a frame type from the raw bytes, [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_frame(raw: &Bytes, packet_type: Type) -> Result<(usize, Frame, FrameType), Error> { let input = raw.as_ref(); let (remain, frame_type) = be_frame_type(input)?; if !frame_type.belongs_to(packet_type) { return Err(Error::WrongType(frame_type, packet_type)); } let (remain, frame) = complete_frame(frame_type, raw.clone())(remain).map_err(|e| match e { ne @ nom::Err::Incomplete(_) => { nom::Err::Error(Error::IncompleteFrame(frame_type, ne.to_string())) } nom::Err::Error(ne) => { // may be TooLarge in MaxStreamsFrame/CryptoFrame/StreamFrame, // or may be Verify in NewConnectionIdFrame, // or may be Alt in ConnectionCloseFrame nom::Err::Error(Error::ParseError( frame_type, ne.code.description().to_owned(), )) } _ => unreachable!("parsing frame never fails"), })?; Ok((input.len() - remain.len(), frame, frame_type)) } /// A [`bytes::BufMut`] extension trait, makes buffer more friendly /// to write all kinds of frames. pub trait WriteFrame: bytes::BufMut { /// Write a frame to the buffer. fn put_frame(&mut self, frame: &F); } impl WriteFrame> for B where D: ContinuousData, B: BufMut + ?Sized, for<'b> &'b mut B: crate::util::WriteData, { fn put_frame(&mut self, frame: &Frame) { #[inline(always)] fn put + ?Sized>(buf: &mut B, frame: &F) { buf.put_frame(frame); } let mut buf = self; match frame { Frame::Padding(f) => put(&mut buf, f), Frame::Ping(f) => put(&mut buf, f), Frame::Ack(f) => put(&mut buf, f), Frame::Close(f) => put(&mut buf, f), Frame::NewToken(f) => put(&mut buf, f), Frame::MaxData(f) => put(&mut buf, f), Frame::DataBlocked(f) => put(&mut buf, f), Frame::AddAddress(f) => put(&mut buf, f), Frame::RemoveAddress(f) => put(&mut buf, f), Frame::PunchMeNow(f) => put(&mut buf, f), Frame::PunchHello(f) => put(&mut buf, f), Frame::PunchDone(f) => put(&mut buf, f), Frame::NewConnectionId(f) => put(&mut buf, f), Frame::RetireConnectionId(f) => put(&mut buf, f), Frame::HandshakeDone(f) => put(&mut buf, f), Frame::PathChallenge(f) => put(&mut buf, f), Frame::PathResponse(f) => put(&mut buf, f), Frame::StreamCtl(f) => put(&mut buf, f), Frame::Stream(f, d) => buf.put_data_frame(f, d), Frame::Crypto(f, d) => buf.put_data_frame(f, d), Frame::Datagram(f, d) => buf.put_data_frame(f, d), } } } /// A [`bytes::BufMut`] extension trait, makes buffer more friendly /// to write frame with data. pub trait WriteDataFrame: bytes::BufMut { /// Write a frame and its data to the buffer. fn put_data_frame(&mut self, frame: &F, data: &D); } /// A [`bytes::BufMut`] extension trait to write [`FrameType`]. pub trait WriteFrameType: bytes::BufMut { /// Write a frame type to the buffer. fn put_frame_type(&mut self, frame_type: FrameType); } impl WriteFrameType for T { fn put_frame_type(&mut self, frame_type: FrameType) { use crate::varint::WriteVarInt; let fty: VarInt = frame_type.into(); self.put_varint(&fty); } } /// Some modules that need send specific frames can implement `SendFrame` trait directly. /// /// Alternatively, a temporary buffer that stores certain frames can also implement this trait, /// But additional processing is required to ensure that the frames in the buffer are eventually /// sent to the peer. pub trait SendFrame { /// Need send the frames to the peer fn send_frame>(&self, iter: I); } /// Some modules that need receive specific frames can implement `ReceiveFrame` trait directly. /// /// Alternatively, a temporary buffer that stores certain frames can also implement this trait, /// But additional processing is required to ensure that the frames in the buffer are eventually /// delivered to the corresponding modules. pub trait ReceiveFrame { type Output; /// Receive the frames from the peer fn recv_frame(&self, frame: T) -> Result; } impl Future for Receiving { type Output = Result, ResetError>; fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { let state = self.get_mut(); match std::mem::take(state) { Self::Pending => Poll::Pending, Self::Waiting(waker) => { *state = Self::Waiting(waker); Poll::Pending } Self::Rcvd(frame) => { *state = Self::Read; Poll::Ready(Ok(Some(frame))) } Self::Read => { *state = Self::Read; Poll::Ready(Ok(None)) } Self::Reset => { *state = Self::Reset; Poll::Ready(Err(ResetError)) } } } } impl ReceiveFrame for ArcReceiving { type Output = (); fn recv_frame(&self, frame: F) -> Result { self.0.lock().unwrap().recv_frame(frame); Ok(()) } } ================================================ FILE: qbase/src/frame/max_data.rs ================================================ use crate::{ frame::{GetFrameType, io::WriteFrameType}, varint::{VarInt, WriteVarInt, be_varint}, }; /// MAX_DATA Frame /// /// ```text /// MAX_DATA Frame { /// Type (i) = 0x10, /// Maximum Data (i), /// } /// ``` /// /// See [MAX_DATA Frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-max_data-frames) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub struct MaxDataFrame { max_data: VarInt, } impl super::GetFrameType for MaxDataFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::MaxData } } impl super::EncodeSize for MaxDataFrame { fn max_encoding_size(&self) -> usize { 1 + 8 } fn encoding_size(&self) -> usize { 1 + self.max_data.encoding_size() } } impl MaxDataFrame { /// Create a new [`MaxDataFrame`] with the given maximum data. pub fn new(max_data: VarInt) -> Self { Self { max_data } } /// Return the maximum data of the frame. pub fn max_data(&self) -> u64 { self.max_data.into_u64() } } /// Parse a MAX_DATA frame from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_max_data_frame(input: &[u8]) -> nom::IResult<&[u8], MaxDataFrame> { use nom::{Parser, combinator::map}; map(be_varint, MaxDataFrame::new).parse(input) } impl super::io::WriteFrame for T { fn put_frame(&mut self, frame: &MaxDataFrame) { self.put_frame_type(frame.frame_type()); self.put_varint(&frame.max_data); } } #[cfg(test)] mod tests { use super::MaxDataFrame; use crate::{ frame::{ EncodeSize, FrameType, GetFrameType, io::{WriteFrame, WriteFrameType}, }, varint::VarInt, }; #[test] fn test_max_data_frame() { let frame = MaxDataFrame::new(VarInt::from_u32(0x1234)); assert_eq!(frame.frame_type(), FrameType::MaxData); assert_eq!(frame.max_encoding_size(), 1 + 8); assert_eq!(frame.encoding_size(), 1 + 2); } #[test] fn test_read_max_data_frame() { use nom::{Parser, combinator::flat_map}; use super::be_max_data_frame; use crate::varint::be_varint; let max_data_frame_type = VarInt::from(FrameType::MaxData); let buf = vec![max_data_frame_type.into_u64() as u8, 0x52, 0x34]; let (input, frame) = flat_map(be_varint, |frame_type| { if frame_type == max_data_frame_type { be_max_data_frame } else { panic!("wrong frame type: {frame_type}") } }) .parse(buf.as_ref()) .unwrap(); assert!(input.is_empty()); assert_eq!(frame, MaxDataFrame::new(VarInt::from_u32(0x1234),)); } #[test] fn test_write_max_data_frame() { let mut buf = Vec::new(); buf.put_frame(&MaxDataFrame::new(VarInt::from_u32(0x1234))); let mut expected = Vec::new(); expected.put_frame_type(FrameType::MaxData); expected.extend_from_slice(&[0x52, 0x34]); assert_eq!(buf, expected); } } ================================================ FILE: qbase/src/frame/max_stream_data.rs ================================================ use crate::{ frame::{GetFrameType, io::WriteFrameType}, sid::{StreamId, WriteStreamId, be_streamid}, varint::{VarInt, WriteVarInt, be_varint}, }; /// MAX_STREAM_DATA frame. /// /// ```text /// MAX_STREAM_DATA Frame { /// Type (i) = 0x11, /// Stream ID (i), /// Maximum Stream Data (i), /// } /// ``` /// /// See [MAX_STREAM_DATA Frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-max_stream_data-frames) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct MaxStreamDataFrame { stream_id: StreamId, max_stream_data: VarInt, } impl MaxStreamDataFrame { /// Create a new [`MaxStreamDataFrame`]. pub fn new(stream_id: StreamId, max_stream_data: VarInt) -> Self { Self { stream_id, max_stream_data, } } /// Return the stream ID of the frame. pub fn stream_id(&self) -> StreamId { self.stream_id } /// Return the maximum stream data of the frame. pub fn max_stream_data(&self) -> u64 { self.max_stream_data.into_u64() } } impl super::GetFrameType for MaxStreamDataFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::MaxStreamData } } impl super::EncodeSize for MaxStreamDataFrame { fn max_encoding_size(&self) -> usize { 1 + 8 + 8 } fn encoding_size(&self) -> usize { 1 + self.stream_id.encoding_size() + self.max_stream_data.encoding_size() } } /// Parse a MAX_STREAM_DATA frame from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_max_stream_data_frame(input: &[u8]) -> nom::IResult<&[u8], MaxStreamDataFrame> { use nom::{Parser, combinator::map, sequence::pair}; map( pair(be_streamid, be_varint), |(stream_id, max_stream_data)| MaxStreamDataFrame { stream_id, max_stream_data, }, ) .parse(input) } impl super::io::WriteFrame for T { fn put_frame(&mut self, frame: &MaxStreamDataFrame) { self.put_frame_type(frame.frame_type()); self.put_streamid(&frame.stream_id); self.put_varint(&frame.max_stream_data); } } #[cfg(test)] mod tests { use super::MaxStreamDataFrame; use crate::{ frame::{ EncodeSize, FrameType, GetFrameType, io::{WriteFrame, WriteFrameType}, }, varint::VarInt, }; #[test] fn test_max_stream_data_frame() { let frame = MaxStreamDataFrame::new(VarInt::from_u32(0x1234).into(), VarInt::from_u32(0x5678)); assert_eq!(frame.stream_id, VarInt::from_u32(0x1234).into()); assert_eq!(frame.max_stream_data, VarInt::from_u32(0x5678)); assert_eq!(frame.frame_type(), FrameType::MaxStreamData); assert_eq!(frame.max_encoding_size(), 1 + 8 + 8); assert_eq!(frame.encoding_size(), 1 + 2 + 4); } #[test] fn test_read_max_stream_data_frame() { use super::be_max_stream_data_frame; let buf = vec![0x52, 0x34, 0x80, 0, 0x56, 0x78]; let (_, frame) = be_max_stream_data_frame(&buf).unwrap(); assert_eq!(frame.stream_id(), VarInt::from_u32(0x1234).into()); assert_eq!(frame.max_stream_data(), 0x5678); } #[test] fn test_write_max_stream_data_frame() { let mut buf = Vec::new(); buf.put_frame(&MaxStreamDataFrame::new( VarInt::from_u32(0x1234).into(), VarInt::from_u32(0x5678), )); let mut expected = Vec::new(); expected.put_frame_type(FrameType::MaxStreamData); expected.extend_from_slice(&[0x52, 0x34, 0x80, 0, 0x56, 0x78]); assert_eq!(buf, expected); } } ================================================ FILE: qbase/src/frame/max_streams.rs ================================================ use crate::{ frame::{GetFrameType, io::WriteFrameType}, sid::{Dir, MAX_STREAMS_LIMIT}, varint::{VarInt, WriteVarInt, be_varint}, }; /// MAX_STREAMS frame. /// /// ```text /// MAX_STREAMS Frame { /// Type (i) = 0x12..0x13, /// Maximum Streams (i), /// } /// ``` /// /// See [MAX_STREAMS Frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-max_streams-frames) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum MaxStreamsFrame { Bi(VarInt), Uni(VarInt), } impl MaxStreamsFrame { pub fn with(dir: Dir, max_streams: VarInt) -> Self { match dir { Dir::Bi => MaxStreamsFrame::Bi(max_streams), Dir::Uni => MaxStreamsFrame::Uni(max_streams), } } } impl super::GetFrameType for MaxStreamsFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::MaxStreams(match self { MaxStreamsFrame::Bi(_) => Dir::Bi, MaxStreamsFrame::Uni(_) => Dir::Uni, }) } } impl super::EncodeSize for MaxStreamsFrame { fn max_encoding_size(&self) -> usize { 1 + 8 } fn encoding_size(&self) -> usize { 1 + match self { MaxStreamsFrame::Bi(max_streams) => max_streams.encoding_size(), MaxStreamsFrame::Uni(max_streams) => max_streams.encoding_size(), } } } /// Returns a parser for MAX_STREAMS frame with the given direction, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn max_streams_frame_with_dir( dir: Dir, ) -> impl Fn(&[u8]) -> nom::IResult<&[u8], MaxStreamsFrame> { move |input: &[u8]| { let (remain, max_streams) = be_varint(input)?; if max_streams > MAX_STREAMS_LIMIT { Err(nom::Err::Error(nom::error::Error::new( input, nom::error::ErrorKind::TooLarge, ))) } else { Ok(( remain, match dir { Dir::Bi => MaxStreamsFrame::Bi(max_streams), Dir::Uni => MaxStreamsFrame::Uni(max_streams), }, )) } } } impl super::io::WriteFrame for T { fn put_frame(&mut self, frame: &MaxStreamsFrame) { match frame { MaxStreamsFrame::Bi(max_streams) => { // self.put_frame_type(frame.frame_type()); self.put_frame_type(frame.frame_type()); self.put_varint(max_streams); } MaxStreamsFrame::Uni(max_streams) => { self.put_frame_type(frame.frame_type()); self.put_varint(max_streams); } } } } #[cfg(test)] mod tests { use nom::{Parser, combinator::flat_map}; use super::{MaxStreamsFrame, max_streams_frame_with_dir}; use crate::{ frame::{ EncodeSize, FrameType, GetFrameType, io::{WriteFrame, WriteFrameType}, }, sid::Dir, varint::{VarInt, be_varint}, }; #[test] fn test_max_streams_frame() { let frame = MaxStreamsFrame::Bi(VarInt::from_u32(0x1234)); assert_eq!(frame.frame_type(), FrameType::MaxStreams(Dir::Bi)); assert_eq!(frame.max_encoding_size(), 1 + 8); assert_eq!(frame.encoding_size(), 1 + 2); let frame = MaxStreamsFrame::Uni(VarInt::from_u32(0x1236)); assert_eq!(frame.frame_type(), FrameType::MaxStreams(Dir::Uni)); assert_eq!(frame.max_encoding_size(), 1 + 8); assert_eq!(frame.encoding_size(), 1 + 2); } #[test] fn test_read_max_streams_frame() { let max_streams_bi_type = VarInt::from(FrameType::MaxStreams(Dir::Bi)); let max_streams_uni_type = VarInt::from(FrameType::MaxStreams(Dir::Uni)); let buf = vec![max_streams_bi_type.into_u64() as u8, 0x52, 0x34]; let (input, frame) = flat_map(be_varint, |frame_type| { if frame_type == max_streams_bi_type { max_streams_frame_with_dir(Dir::Bi) } else { panic!("wrong frame type: {frame_type}") } }) .parse(buf.as_ref()) .unwrap(); assert!(input.is_empty()); assert_eq!(frame, MaxStreamsFrame::Bi(VarInt::from_u32(0x1234))); let buf = vec![max_streams_uni_type.into_u64() as u8, 0x52, 0x36]; let (input, frame) = flat_map(be_varint, |frame_type| { if frame_type == max_streams_uni_type { max_streams_frame_with_dir(Dir::Uni) } else { panic!("wrong frame type: {frame_type}") } }) .parse(buf.as_ref()) .unwrap(); assert!(input.is_empty()); assert_eq!(frame, MaxStreamsFrame::Uni(VarInt::from_u32(0x1236))); } #[test] fn test_read_too_large_max_streams_frame() { let mut buf = Vec::new(); buf.put_frame_type(FrameType::MaxStreams(Dir::Bi)); buf.extend_from_slice(&[0xd0, 0x34, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80]); let result = flat_map(be_varint, |frame_type| { if frame_type == VarInt::from(FrameType::MaxStreams(Dir::Bi)) { max_streams_frame_with_dir(Dir::Bi) } else { panic!("wrong frame type: {frame_type}") } }) .parse(buf.as_ref()); assert_eq!( result, Err(nom::Err::Error(nom::error::Error::new( &buf[1..], nom::error::ErrorKind::TooLarge, ))) ); } #[test] fn test_write_max_streams_frame() { let mut buf = Vec::new(); buf.put_frame(&MaxStreamsFrame::Bi(VarInt::from_u32(0x1234))); let mut expected = Vec::new(); expected.put_frame_type(FrameType::MaxStreams(Dir::Bi)); expected.extend_from_slice(&[0x52, 0x34]); assert_eq!(buf, expected); buf.clear(); buf.put_frame(&MaxStreamsFrame::Uni(VarInt::from_u32(0x1236))); expected.clear(); expected.put_frame_type(FrameType::MaxStreams(Dir::Uni)); expected.extend_from_slice(&[0x52, 0x36]); assert_eq!(buf, expected); } } ================================================ FILE: qbase/src/frame/new_connection_id.rs ================================================ use crate::{ cid::{ConnectionId, WriteConnectionId, be_connection_id}, frame::{GetFrameType, io::WriteFrameType}, token::{RESET_TOKEN_SIZE, ResetToken, be_reset_token}, varint::{VarInt, WriteVarInt, be_varint}, }; /// NEW_CONNECTION_ID frame. /// /// ```text /// NEW_CONNECTION_ID Frame { /// Type (i) = 0x18, /// Sequence Number (i), /// Retire Prior To (i), /// Length (8), /// Connection ID (8..160), /// Stateless Reset Token (128), /// } /// ``` /// /// See [NEW_CONNECTION_ID Frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-new_connection_id-frames) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct NewConnectionIdFrame { sequence: VarInt, retire_prior_to: VarInt, id: ConnectionId, reset_token: ResetToken, } impl NewConnectionIdFrame { /// Create a new [`NewConnectionIdFrame`]. pub fn new(cid: ConnectionId, sequence: VarInt, retire_prior_to: VarInt) -> Self { let reset_token = ResetToken::random_gen(); Self { sequence, retire_prior_to, id: cid, reset_token, } } /// Return the sequence number of the frame. pub fn sequence(&self) -> u64 { self.sequence.into_u64() } /// Return the retire prior to of the frame. pub fn retire_prior_to(&self) -> u64 { self.retire_prior_to.into_u64() } /// Return the connection ID of the frame. pub fn connection_id(&self) -> &ConnectionId { &self.id } /// Return the reset token of the frame. pub fn reset_token(&self) -> &ResetToken { &self.reset_token } } impl super::GetFrameType for NewConnectionIdFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::NewConnectionId } } impl super::EncodeSize for NewConnectionIdFrame { fn max_encoding_size(&self) -> usize { 1 + 8 + 8 + 21 + RESET_TOKEN_SIZE } fn encoding_size(&self) -> usize { 1 + self.sequence.encoding_size() + self.retire_prior_to.encoding_size() + 1 + self.id.len as usize + RESET_TOKEN_SIZE } } /// Parse a NEW_CONNECTION_ID frame from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_new_connection_id_frame(input: &[u8]) -> nom::IResult<&[u8], NewConnectionIdFrame> { let (remain, sequence) = be_varint(input)?; let (remain, retire_prior_to) = be_varint(remain)?; // The value in the Retire Prior To field MUST be less than or equal to the value in the // Sequence Number field. Receiving a value in the Retire Prior To field that is greater // than that in the Sequence Number field MUST be treated as a connection error of type // FRAME_ENCODING_ERROR. if retire_prior_to > sequence { // TODO: 这里有信息损失 return Err(nom::Err::Error(nom::error::make_error( input, nom::error::ErrorKind::Verify, ))); } let (remain, cid) = be_connection_id(remain)?; if cid.is_empty() { // TODO: 这里有信息损失 return Err(nom::Err::Error(nom::error::make_error( input, nom::error::ErrorKind::Verify, ))); } let (remain, reset_token) = be_reset_token(remain)?; Ok(( remain, NewConnectionIdFrame { sequence, retire_prior_to, id: cid, reset_token, }, )) } impl super::io::WriteFrame for T { fn put_frame(&mut self, frame: &NewConnectionIdFrame) { self.put_frame_type(frame.frame_type()); self.put_varint(&frame.sequence); self.put_varint(&frame.retire_prior_to); self.put_connection_id(&frame.id); self.put_slice(frame.reset_token.as_slice()); } } #[cfg(test)] mod tests { use bytes::{BufMut, BytesMut}; use super::*; use crate::frame::{ EncodeSize, FrameType, GetFrameType, io::{WriteFrame, WriteFrameType}, }; #[test] fn test_new_connection_id_frame() { let new_cid_frame = NewConnectionIdFrame::new( ConnectionId::from_slice(&[1, 2, 3, 4][..]), VarInt::from_u32(1), VarInt::from_u32(0), ); assert_eq!(new_cid_frame.sequence(), 1); assert_eq!(new_cid_frame.retire_prior_to(), 0); assert_eq!( new_cid_frame.id, ConnectionId::from_slice(&[1, 2, 3, 4][..]) ); assert_eq!(new_cid_frame.frame_type(), FrameType::NewConnectionId); assert_eq!( new_cid_frame.max_encoding_size(), 1 + 8 + 8 + 21 + RESET_TOKEN_SIZE ); assert_eq!(new_cid_frame.encoding_size(), 1 + 1 + 1 + 1 + 4 + 16); } #[test] fn test_frame_parsing() { let mut buf = BytesMut::new(); let original_cid = ConnectionId::from_slice(&[1, 2, 3, 4][..]); let original_frame = NewConnectionIdFrame::new(original_cid, VarInt::from_u32(1), VarInt::from_u32(0)); // Write frame to buffer buf.put_frame(&original_frame); // Skip frame type byte let (_, parsed_frame) = be_new_connection_id_frame(&buf[1..]).unwrap(); assert_eq!(parsed_frame.sequence(), original_frame.sequence()); assert_eq!( parsed_frame.retire_prior_to(), original_frame.retire_prior_to() ); assert_eq!(parsed_frame.connection_id(), original_frame.connection_id()); assert_eq!(parsed_frame.reset_token(), original_frame.reset_token()); } #[test] fn test_invalid_retire_prior_to() { let mut buf = BytesMut::new(); buf.put_frame_type(FrameType::NewConnectionId); buf.put_varint(&VarInt::from_u32(1)); // sequence buf.put_varint(&VarInt::from_u32(2)); // retire_prior_to > sequence assert!(be_new_connection_id_frame(&buf[1..]).is_err()); } #[test] fn test_zero_length_connection_id() { let mut buf = BytesMut::new(); buf.put_frame_type(FrameType::NewConnectionId); buf.put_varint(&VarInt::from_u32(1)); buf.put_varint(&VarInt::from_u32(0)); buf.put_u8(0); // zero length CID assert!(be_new_connection_id_frame(&buf[1..]).is_err()); } } ================================================ FILE: qbase/src/frame/new_token.rs ================================================ use derive_more::Deref; use crate::{ frame::{GetFrameType, io::WriteFrameType}, varint::{VarInt, WriteVarInt, be_varint}, }; /// NEW_TOKEN frame. /// /// ```text /// NEW_TOKEN Frame { /// Type (i) = 0x07, /// Token Length (i), /// Token (..), /// } /// ``` /// /// See [NEW_TOKEN Frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-new_token-frames) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Deref, Debug, Clone, PartialEq, Eq)] pub struct NewTokenFrame { #[deref] token: Vec, } impl super::GetFrameType for NewTokenFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::NewToken } } impl super::EncodeSize for NewTokenFrame { fn max_encoding_size(&self) -> usize { // token's length could not exceed 20 1 + 1 + self.token.len() } fn encoding_size(&self) -> usize { 1 + 1 + self.token.len() } } impl NewTokenFrame { /// Create a new [`NewTokenFrame`] with the given token. pub fn new(token: Vec) -> Self { Self { token } } /// Create a new [`NewTokenFrame`] from the given token slice. pub fn from_slice(token: &[u8]) -> Self { Self { token: token.to_vec(), } } /// Return the token of the frame. pub fn token(&self) -> &[u8] { &self.token } } /// Parse a NEW_TOKEN frame from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_new_token_frame(input: &[u8]) -> nom::IResult<&[u8], NewTokenFrame> { use nom::{ Parser, bytes::streaming::take, combinator::{flat_map, map}, }; flat_map(be_varint, |length| { map(take(length.into_u64() as usize), NewTokenFrame::from_slice) }) .parse(input) } impl super::io::WriteFrame for T { fn put_frame(&mut self, frame: &NewTokenFrame) { self.put_frame_type(frame.frame_type()); self.put_varint(&VarInt::from_u32(frame.token.len() as u32)); self.put_slice(&frame.token); } } #[cfg(test)] mod tests { use crate::frame::{ EncodeSize, FrameType, GetFrameType, io::{WriteFrame, WriteFrameType}, }; #[test] fn test_new_token_frame() { let frame = super::NewTokenFrame::new(vec![0x01, 0x02]); assert_eq!(frame.frame_type(), FrameType::NewToken); assert_eq!(frame.max_encoding_size(), 1 + 1 + 2); assert_eq!(frame.encoding_size(), 1 + 1 + 2); } #[test] fn test_read_new_token_frame() { use super::be_new_token_frame; let buf = vec![0x02, 0x01, 0x02]; let (input, frame) = be_new_token_frame(&buf).unwrap(); assert!(input.is_empty()); assert_eq!(frame.token, vec![0x01, 0x02]); } #[test] fn test_write_new_token_frame() { let mut buf = Vec::::new(); let frame = super::NewTokenFrame::from_slice(&[0x01, 0x02]); buf.put_frame(&frame); let mut expected = Vec::new(); expected.put_frame_type(FrameType::NewToken); expected.extend_from_slice(&[0x02, 0x01, 0x02]); assert_eq!(buf, expected); } } ================================================ FILE: qbase/src/frame/padding.rs ================================================ use crate::frame::{GetFrameType, io::WriteFrameType}; /// PADDING Frame. /// /// ```text /// PADDING Frame { /// Type (i) = 0x00, /// } /// ``` /// /// See [PADDING Frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-padding-frames) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub struct PaddingFrame; impl super::GetFrameType for PaddingFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::Padding } } impl super::EncodeSize for PaddingFrame {} /// Parse a PADDING frame from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. #[allow(dead_code)] pub fn be_padding_frame(input: &[u8]) -> nom::IResult<&[u8], PaddingFrame> { Ok((input, PaddingFrame)) } impl super::io::WriteFrame for T { fn put_frame(&mut self, frame: &PaddingFrame) { self.put_frame_type(frame.frame_type()); } } #[cfg(test)] mod tests { use super::{PaddingFrame, be_padding_frame}; use crate::{ frame::{ EncodeSize, FrameType, GetFrameType, io::{WriteFrame, WriteFrameType}, }, varint::VarInt, }; #[test] fn test_padding_frame() { assert_eq!(PaddingFrame.frame_type(), FrameType::Padding); assert_eq!(PaddingFrame.max_encoding_size(), 1); assert_eq!(PaddingFrame.encoding_size(), 1); } #[test] fn test_read_padding_frame() { use nom::{Parser, combinator::flat_map}; use crate::varint::be_varint; let padding_frame_type = VarInt::from(FrameType::Padding); let buf = vec![padding_frame_type.into_u64() as u8]; let (input, frame) = flat_map(be_varint, |frame_type| { if frame_type == padding_frame_type { be_padding_frame } else { unreachable!("wrong frame type: {}", frame_type) } }) .parse(buf.as_ref()) .unwrap(); assert!(input.is_empty()); assert_eq!(frame, PaddingFrame); } #[test] fn test_write_padding_frame() { let mut buf = Vec::new(); buf.put_frame(&PaddingFrame); let mut expected = Vec::new(); expected.put_frame_type(FrameType::Padding); assert_eq!(buf, expected); } } ================================================ FILE: qbase/src/frame/path_challenge.rs ================================================ use derive_more::Deref; use rand::RngExt; use crate::frame::{GetFrameType, io::WriteFrameType}; /// PATH_CHALLENGE frame. /// /// ```text /// PATH_CHALLENGE Frame { /// Type (i) = 0x1a, /// Data (64), /// } /// ``` /// /// See [PATH_CHALLENGE Frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-path_challenge-frames) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Deref)] pub struct PathChallengeFrame { #[deref] data: [u8; 8], } impl PathChallengeFrame { pub fn from_slice(data: &[u8]) -> Self { let mut frame = Self { data: [0; 8] }; frame.data.copy_from_slice(data); frame } pub fn random() -> Self { let mut rng = rand::rng(); let mut data = [0; 8]; rng.fill(&mut data); Self { data } } } impl super::GetFrameType for PathChallengeFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::PathChallenge } } impl super::EncodeSize for PathChallengeFrame { fn max_encoding_size(&self) -> usize { 1 + self.data.len() } fn encoding_size(&self) -> usize { 1 + self.data.len() } } /// Parse a PATH_CHALLENGE frame from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_path_challenge_frame(input: &[u8]) -> nom::IResult<&[u8], PathChallengeFrame> { use nom::{Parser, bytes::streaming::take, combinator::map}; map(take(8usize), PathChallengeFrame::from_slice).parse(input) } // BufMut write extension for PATH_CHALLENGE_FRAME impl super::io::WriteFrame for T { fn put_frame(&mut self, frame: &PathChallengeFrame) { self.put_frame_type(frame.frame_type()); self.put_slice(&frame.data); } } #[cfg(test)] mod tests { use nom::{Parser, combinator::flat_map}; use super::be_path_challenge_frame; use crate::{ frame::{ EncodeSize, FrameType, GetFrameType, io::{WriteFrame, WriteFrameType}, }, varint::{VarInt, be_varint}, }; #[test] fn test_path_challenge_frame() { let frame = super::PathChallengeFrame::from_slice(&[ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, ]); assert_eq!(frame.frame_type(), FrameType::PathChallenge); assert_eq!(frame.max_encoding_size(), 1 + 8); assert_eq!(frame.encoding_size(), 1 + 8); } #[test] fn test_read_path_challenge_frame() { let path_challenge_frame_type = VarInt::from(FrameType::PathChallenge); let mut buf = Vec::new(); buf.put_frame_type(FrameType::PathChallenge); buf.extend_from_slice(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); let (input, frame) = flat_map(be_varint, |frame_type| { if frame_type == path_challenge_frame_type { be_path_challenge_frame } else { panic!("wrong frame type: {frame_type}") } }) .parse(buf.as_ref()) .unwrap(); assert!(input.is_empty()); assert_eq!( frame, super::PathChallengeFrame { data: [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08] } ); } #[test] fn test_write_path_challenge_frame() { let mut buf = Vec::new(); let frame = super::PathChallengeFrame::from_slice(&[ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, ]); buf.put_frame(&frame); let mut expected = Vec::new(); expected.put_frame_type(FrameType::PathChallenge); expected.extend_from_slice(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); assert_eq!(buf, expected); } } ================================================ FILE: qbase/src/frame/path_response.rs ================================================ use std::ops::Deref; use derive_more::Deref; use crate::frame::{GetFrameType, io::WriteFrameType}; /// PATH_RESPONSE Frame. /// /// ```text /// PATH_RESPONSE Frame { /// Type (i) = 0x1b, /// Data (64), /// } /// ``` /// /// See [PATH_RESPONSE Frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-path_response-frames) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Clone, Copy, Default, Deref, PartialEq, Eq)] pub struct PathResponseFrame { #[deref] data: [u8; 8], } impl PathResponseFrame { fn from_slice(data: &[u8]) -> Self { let mut frame = Self { data: [0; 8] }; frame.data.copy_from_slice(data); frame } } /// The only public way to create a PathResponseFrame is from a PathChallengeFrame impl From for PathResponseFrame { fn from(challenge: super::PathChallengeFrame) -> Self { Self::from_slice(challenge.deref()) } } impl super::GetFrameType for PathResponseFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::PathResponse } } impl super::EncodeSize for PathResponseFrame { fn max_encoding_size(&self) -> usize { 1 + self.data.len() } fn encoding_size(&self) -> usize { 1 + self.data.len() } } /// Parse a PATH_RESPONSE frame from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_path_response_frame(input: &[u8]) -> nom::IResult<&[u8], PathResponseFrame> { use nom::{Parser, bytes::complete::take, combinator::map}; map(take(8usize), PathResponseFrame::from_slice).parse(input) } impl super::io::WriteFrame for T { fn put_frame(&mut self, frame: &PathResponseFrame) { self.put_frame_type(frame.frame_type()); self.put_slice(&frame.data); } } #[cfg(test)] mod tests { use super::*; use crate::frame::{ EncodeSize, FrameType, GetFrameType, io::{WriteFrame, WriteFrameType}, }; #[test] fn test_path_response_frame() { let frame = PathResponseFrame::from_slice(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); assert_eq!(frame.frame_type(), FrameType::PathResponse); assert_eq!(frame.max_encoding_size(), 1 + 8); assert_eq!(frame.encoding_size(), 1 + 8); } #[test] fn test_read_path_response_frame() { use nom::{Parser, combinator::flat_map}; use crate::{ frame::FrameType, varint::{VarInt, be_varint}, }; let path_response_frame_type = VarInt::from(FrameType::PathResponse); let mut buf = Vec::new(); buf.put_frame_type(FrameType::PathResponse); buf.extend_from_slice(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); let (input, frame) = flat_map(be_varint, |frame_type| { if frame_type == path_response_frame_type { be_path_response_frame } else { panic!("wrong frame type: {frame_type}") } }) .parse(buf.as_ref()) .unwrap(); assert!(input.is_empty()); assert_eq!( frame, PathResponseFrame::from_slice(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]) ); } #[test] fn test_write_path_response_frame() { let mut buf = Vec::::new(); let frame = PathResponseFrame::from_slice(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); buf.put_frame(&frame); let mut expected = Vec::new(); expected.put_frame_type(FrameType::PathResponse); expected.extend_from_slice(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); assert_eq!(buf, expected); } } ================================================ FILE: qbase/src/frame/ping.rs ================================================ use crate::frame::{GetFrameType, io::WriteFrameType}; /// PING Frame. /// /// ```text /// PING Frame { /// Type (i) = 0x01, /// } /// ``` /// /// See [PING Frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-ping-frames) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub struct PingFrame; impl super::GetFrameType for PingFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::Ping } } impl super::EncodeSize for PingFrame {} /// Parse a PING frame from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. #[allow(unused)] pub fn be_ping_frame(input: &[u8]) -> nom::IResult<&[u8], PingFrame> { Ok((input, PingFrame)) } impl super::io::WriteFrame for T { fn put_frame(&mut self, frame: &PingFrame) { self.put_frame_type(frame.frame_type()); } } #[cfg(test)] mod tests { use super::PingFrame; use crate::{ frame::{ EncodeSize, FrameType, GetFrameType, io::{WriteFrame, WriteFrameType}, }, varint::VarInt, }; #[test] fn test_ping_frame() { assert_eq!(PingFrame.frame_type(), FrameType::Ping); assert_eq!(PingFrame.max_encoding_size(), 1); assert_eq!(PingFrame.encoding_size(), 1); } #[test] fn test_read_ping_frame() { use nom::{Parser, combinator::flat_map}; use super::be_ping_frame; use crate::varint::be_varint; let ping_frame_type = VarInt::from(FrameType::Ping); let buf = vec![ping_frame_type.into_u64() as u8]; let (input, frame) = flat_map(be_varint, |frame_type| { if frame_type == ping_frame_type { be_ping_frame } else { panic!("wrong frame type: {frame_type}") } }) .parse(buf.as_ref()) .unwrap(); assert!(input.is_empty()); assert_eq!(frame, PingFrame); } #[test] fn test_write_ping_frame() { let mut buf = Vec::new(); buf.put_frame(&PingFrame); let mut expected = Vec::new(); expected.put_frame_type(FrameType::Ping); assert_eq!(buf, expected); } } ================================================ FILE: qbase/src/frame/punch_done.rs ================================================ use super::{ EncodeSize, GetFrameType, io::{WriteFrame, WriteFrameType}, }; use crate::{ frame::PunchHelloFrame, varint::{VarInt, WriteVarInt, be_varint}, }; /// PUNCH_Done Frame { /// Type (i) = 0x3d7e96, /// Local Sequence Number (i), /// Remote Sequence Number (i), /// Probe Identifier (i), /// } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct PunchDoneFrame { local_seq: VarInt, remote_seq: VarInt, probe_id: VarInt, } impl PunchDoneFrame { pub fn new(local_seq: u32, remote_seq: u32, probe_id: u32) -> Self { Self { local_seq: VarInt::from_u32(local_seq), remote_seq: VarInt::from_u32(remote_seq), probe_id: VarInt::from_u32(probe_id), } } pub fn local_seq(&self) -> u32 { self.local_seq.into_u64() as u32 } pub fn remote_seq(&self) -> u32 { self.remote_seq.into_u64() as u32 } pub fn probe_id(&self) -> u32 { self.probe_id.into_u64() as u32 } /// Construct a PunchDone responding to a received PunchHello, /// automatically swapping local/remote seq to reflect our perspective. pub fn respond_to(hello: &PunchHelloFrame) -> Self { Self::new(hello.remote_seq(), hello.local_seq(), hello.probe_id()) } } impl GetFrameType for PunchDoneFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::PunchDone } } impl EncodeSize for PunchDoneFrame { fn max_encoding_size(&self) -> usize { 4 + 8 + 8 + 8 } fn encoding_size(&self) -> usize { VarInt::from(self.frame_type()).encoding_size() + self.local_seq.encoding_size() + self.remote_seq.encoding_size() + self.probe_id.encoding_size() } } impl WriteFrame for T { fn put_frame(&mut self, frame: &PunchDoneFrame) { self.put_frame_type(frame.frame_type()); self.put_varint(&frame.local_seq); self.put_varint(&frame.remote_seq); self.put_varint(&frame.probe_id); } } pub(crate) fn be_punch_done_frame(input: &[u8]) -> nom::IResult<&[u8], PunchDoneFrame> { let (input, local_seq) = be_varint(input)?; let (input, remote_seq) = be_varint(input)?; let (input, probe_id) = be_varint(input)?; Ok(( input, PunchDoneFrame { local_seq, remote_seq, probe_id, }, )) } ================================================ FILE: qbase/src/frame/punch_hello.rs ================================================ use super::{ EncodeSize, GetFrameType, io::{WriteFrame, WriteFrameType}, }; use crate::varint::{VarInt, WriteVarInt, be_varint}; /// PUNCH_Hello Frame { /// Type (i) = 0x3d7e95, /// Local Sequence Number (i), /// Remote Sequence Number (i), /// Probe Identifier (i), /// } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct PunchHelloFrame { local_seq: VarInt, remote_seq: VarInt, probe_id: VarInt, } impl PunchHelloFrame { pub fn new(local_seq: u32, remote_seq: u32, probe_id: u32) -> Self { Self { local_seq: VarInt::from_u32(local_seq), remote_seq: VarInt::from_u32(remote_seq), probe_id: VarInt::from_u32(probe_id), } } pub fn local_seq(&self) -> u32 { self.local_seq.into_u64() as u32 } pub fn remote_seq(&self) -> u32 { self.remote_seq.into_u64() as u32 } pub fn probe_id(&self) -> u32 { self.probe_id.into_u64() as u32 } } impl GetFrameType for PunchHelloFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::PunchHello } } impl EncodeSize for PunchHelloFrame { fn max_encoding_size(&self) -> usize { 4 + 8 + 8 + 8 } fn encoding_size(&self) -> usize { VarInt::from(self.frame_type()).encoding_size() + self.local_seq.encoding_size() + self.remote_seq.encoding_size() + self.probe_id.encoding_size() } } impl WriteFrame for T { fn put_frame(&mut self, frame: &PunchHelloFrame) { self.put_frame_type(frame.frame_type()); self.put_varint(&frame.local_seq); self.put_varint(&frame.remote_seq); self.put_varint(&frame.probe_id); } } pub(crate) fn be_punch_hello_frame(input: &[u8]) -> nom::IResult<&[u8], PunchHelloFrame> { let (input, local_seq) = be_varint(input)?; let (input, remote_seq) = be_varint(input)?; let (input, probe_id) = be_varint(input)?; Ok(( input, PunchHelloFrame { local_seq, remote_seq, probe_id, }, )) } ================================================ FILE: qbase/src/frame/punch_me_now.rs ================================================ use std::net::SocketAddr; use derive_more::Deref; use super::{ EncodeSize, GetFrameType, io::{WriteFrame, WriteFrameType}, }; use crate::{ net::{AddrFamily, Family, NatType, be_socket_addr}, varint::{VarInt, WriteVarInt, be_varint}, }; /// PUNCH_ME_NOW Frame /// ///```text /// PUNCH_ME_NOW Frame { /// Type (i) = 0x3d7e92,0x3d7e93 /// Local Sequence Number (i), /// Remote Sequence Number (i), /// [ IPv4 (32) ], /// [ IPv6 (128) ], /// Port (16), /// Tire (i), /// Nat type (i), /// } /// ``` #[derive(Debug, Clone, Copy, PartialEq, Eq, Deref)] pub struct PunchMeNowFrame { local_seq: VarInt, remote_seq: VarInt, #[deref] address: SocketAddr, tire: VarInt, nat_type: NatType, } pub(crate) fn be_punch_me_now_frame( family: Family, ) -> impl Fn(&[u8]) -> nom::IResult<&[u8], PunchMeNowFrame> { move |input| { let (remain, local_seq) = be_varint(input)?; let (remain, remote_seq) = be_varint(remain)?; let (remain, address) = be_socket_addr(remain, family)?; let (remain, tire) = be_varint(remain)?; let (remain, nat_type) = be_varint(remain)?; let nat_type = NatType::try_from(nat_type).map_err(|_| { nom::Err::Error(nom::error::Error::new( remain, nom::error::ErrorKind::Verify, )) })?; Ok(( remain, PunchMeNowFrame { local_seq, remote_seq, address, tire, nat_type, }, )) } } impl GetFrameType for PunchMeNowFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::PunchMeNow(self.address.family()) } } impl EncodeSize for PunchMeNowFrame { fn max_encoding_size(&self) -> usize { 4 + 8 + 8 + self.address.max_encoding_size() + 8 + 8 } fn encoding_size(&self) -> usize { VarInt::from(self.frame_type()).encoding_size() + self.local_seq.encoding_size() + self.remote_seq.encoding_size() + self.address.encoding_size() + self.tire.encoding_size() + VarInt::from(self.nat_type).encoding_size() } } impl PunchMeNowFrame { pub fn new( local_seq: u32, remote_seq: u32, address: SocketAddr, tire: u32, nat_type: NatType, ) -> Self { Self { local_seq: VarInt::from_u32(local_seq), remote_seq: VarInt::from_u32(remote_seq), address, tire: VarInt::from_u32(tire), nat_type, } } pub fn local_seq(&self) -> u32 { self.local_seq.into_u64() as u32 } pub fn remote_seq(&self) -> u32 { self.remote_seq.into_u64() as u32 } pub fn nat_type(&self) -> NatType { self.nat_type } pub fn set_addr(&mut self, addr: SocketAddr) { self.address = addr; } pub fn address(&self) -> SocketAddr { self.address } pub fn tire(&self) -> u32 { self.tire.into_u64() as u32 } } impl WriteFrame for T { fn put_frame(&mut self, frame: &PunchMeNowFrame) { self.put_frame_type(frame.frame_type()); self.put_varint(&frame.local_seq); self.put_varint(&frame.remote_seq); self.put_u16(frame.address.port()); match frame.address.ip() { std::net::IpAddr::V4(ipv4) => self.put_slice(&ipv4.octets()), std::net::IpAddr::V6(ipv6) => self.put_slice(&ipv6.octets()), } self.put_varint(&frame.tire); self.put_varint(&VarInt::from(frame.nat_type)); } } #[cfg(test)] mod tests { use bytes::BytesMut; use super::*; use crate::frame::{GetFrameType, be_frame_type, io::WriteFrame}; #[test] fn test_punch_me_now_frame() { let frame = PunchMeNowFrame { local_seq: VarInt::from_u32(1), remote_seq: VarInt::from_u32(2), address: "127.0.0.1:12345".parse().unwrap(), tire: VarInt::from_u32(0x01u32), nat_type: NatType::FullCone, }; let mut buf = BytesMut::with_capacity(frame.max_encoding_size()); buf.put_frame(&frame); let (remain, frame_type) = be_frame_type(&buf).unwrap(); assert_eq!(frame_type, frame.frame_type()); let frame2 = be_punch_me_now_frame(Family::V4)(remain).unwrap().1; assert_eq!(frame, frame2); } } ================================================ FILE: qbase/src/frame/remove_address.rs ================================================ use derive_more::Deref; use super::{ EncodeSize, GetFrameType, io::{WriteFrame, WriteFrameType}, }; use crate::varint::{VarInt, WriteVarInt, be_varint}; /// REMOVE_ADDRESS Frame { /// Type (i) = 0x3d7e94, /// Sequence Number (i), /// } #[derive(Debug, Clone, Copy, PartialEq, Eq, Deref)] pub struct RemoveAddressFrame { #[deref] pub seq_num: VarInt, } pub(crate) fn be_remove_address_frame(input: &[u8]) -> nom::IResult<&[u8], RemoveAddressFrame> { let (input, sequence_number) = be_varint(input)?; Ok(( input, RemoveAddressFrame { seq_num: sequence_number, }, )) } impl GetFrameType for RemoveAddressFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::RemoveAddress } } impl EncodeSize for RemoveAddressFrame { fn max_encoding_size(&self) -> usize { 4 + 8 } fn encoding_size(&self) -> usize { VarInt::from(self.frame_type()).encoding_size() + self.seq_num.encoding_size() } } impl WriteFrame for T { fn put_frame(&mut self, frame: &RemoveAddressFrame) { self.put_frame_type(frame.frame_type()); self.put_varint(&frame.seq_num); } } #[cfg(test)] mod tests { use bytes::BytesMut; use super::*; use crate::frame::{GetFrameType, be_frame_type, io::WriteFrame}; #[test] fn test_remove_address_frame() { let frame = RemoveAddressFrame { seq_num: VarInt::from_u32(0x1234), }; assert_eq!(frame.max_encoding_size(), 12); assert_eq!(frame.encoding_size(), 6); let mut buf = BytesMut::new(); buf.put_frame(&frame); let (remain, frame_type) = be_frame_type(&buf).unwrap(); assert_eq!(frame_type, frame.frame_type()); let frame2 = be_remove_address_frame(remain).unwrap().1; assert_eq!(frame, frame2); } } ================================================ FILE: qbase/src/frame/reset_stream.rs ================================================ use thiserror::Error; use crate::{ frame::{GetFrameType, io::WriteFrameType}, sid::{StreamId, WriteStreamId, be_streamid}, varint::{VarInt, WriteVarInt, be_varint}, }; /// RESET_STREAM frame. /// /// ```text /// RESET_STREAM Frame { /// Type (i) = 0x04, /// Stream ID (i), /// Application Protocol Error Code (i), /// Final Size (i), /// } /// ``` /// /// See [RESET_STREAM Frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-reset_stream-frames) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct ResetStreamFrame { stream_id: StreamId, app_error_code: VarInt, final_size: VarInt, } impl super::GetFrameType for ResetStreamFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::ResetStream } } impl super::EncodeSize for ResetStreamFrame { fn max_encoding_size(&self) -> usize { 1 + 8 + 8 + 8 } fn encoding_size(&self) -> usize { 1 + self.stream_id.encoding_size() + self.app_error_code.encoding_size() + self.final_size.encoding_size() } } impl ResetStreamFrame { /// Create a new [`ResetStreamFrame`]. pub fn new(stream_id: StreamId, app_error_code: VarInt, final_size: VarInt) -> Self { Self { stream_id, app_error_code, final_size, } } /// Return the stream ID of the frame. pub fn stream_id(&self) -> StreamId { self.stream_id } /// Return the application protocol error code of the frame. pub fn app_error_code(&self) -> u64 { self.app_error_code.into_u64() } /// Return the final size of the frame. pub fn final_size(&self) -> u64 { self.final_size.into_u64() } } /// Parse a RESET_STREAM frame from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_reset_stream_frame(input: &[u8]) -> nom::IResult<&[u8], ResetStreamFrame> { use nom::{Parser, combinator::map}; map( (be_streamid, be_varint, be_varint), |(stream_id, app_error_code, final_size)| ResetStreamFrame { stream_id, app_error_code, final_size, }, ) .parse(input) } impl super::io::WriteFrame for T { fn put_frame(&mut self, frame: &ResetStreamFrame) { self.put_frame_type(frame.frame_type()); self.put_streamid(&frame.stream_id); self.put_varint(&frame.app_error_code); self.put_varint(&frame.final_size); } } #[derive(Clone, Copy, Debug, Error, PartialEq, Eq)] #[error("The stream was reset with app error code: {app_error_code}, final size: {final_size}")] pub struct ResetStreamError { app_error_code: VarInt, final_size: VarInt, } impl ResetStreamError { pub fn new(app_error_code: VarInt, final_size: VarInt) -> Self { Self { app_error_code, final_size, } } pub fn error_code(&self) -> u64 { self.app_error_code.into_u64() } pub fn combine(self, sid: StreamId) -> ResetStreamFrame { ResetStreamFrame { stream_id: sid, app_error_code: self.app_error_code, final_size: self.final_size, } } } impl From<&ResetStreamFrame> for ResetStreamError { fn from(frame: &ResetStreamFrame) -> Self { Self { app_error_code: frame.app_error_code, final_size: frame.final_size, } } } #[cfg(test)] mod tests { use nom::{Parser, combinator::flat_map}; use super::{ResetStreamError, ResetStreamFrame}; use crate::{ frame::{ EncodeSize, FrameType, GetFrameType, io::{WriteFrame, WriteFrameType}, }, varint::{VarInt, be_varint}, }; #[test] fn test_reset_stream_frame() { let frame = ResetStreamFrame::new( VarInt::from_u32(0x1234).into(), VarInt::from_u32(0x5678), VarInt::from_u32(0x9abc), ); assert_eq!(frame.frame_type(), FrameType::ResetStream); assert_eq!(frame.max_encoding_size(), 1 + 8 + 8 + 8); assert_eq!(frame.encoding_size(), 1 + 2 + 4 + 4); assert_eq!(frame.stream_id(), VarInt::from_u32(0x1234).into()); assert_eq!(frame.app_error_code(), 0x5678); assert_eq!(frame.final_size(), 0x9abc); let reset_stream_error: ResetStreamError = (&frame).into(); assert_eq!( reset_stream_error, ResetStreamError::new(VarInt::from_u32(0x5678), VarInt::from_u32(0x9abc)) ); } #[test] fn test_read_reset_stream_frame() { let mut buf = Vec::new(); buf.put_frame_type(FrameType::ResetStream); buf.extend_from_slice(&[0x52, 0x34, 0x80, 0, 0x56, 0x78, 0x80, 0, 0x9a, 0xbc]); let (input, frame) = flat_map(be_varint, |frame_type| { if frame_type == VarInt::from(FrameType::ResetStream) { super::be_reset_stream_frame } else { panic!("wrong frame type: {frame_type}") } }) .parse(buf.as_ref()) .unwrap(); assert!(input.is_empty()); assert_eq!( frame, ResetStreamFrame::new( VarInt::from_u32(0x1234).into(), VarInt::from_u32(0x5678), VarInt::from_u32(0x9abc), ) ); } #[test] fn test_write_reset_stream_frame() { let mut buf = Vec::new(); buf.put_frame(&ResetStreamFrame::new( VarInt::from_u32(0x1234).into(), // 0x5678 = 0b01010110 01111000 => 0b10000000 0x00 0x56 0x78 VarInt::from_u32(0x5678), // 0x9abc = 0b10011010 10111100 => 0b10000000 0x00 0x9a 0xbc VarInt::from_u32(0x9abc), )); let mut expected = Vec::new(); expected.put_frame_type(FrameType::ResetStream); expected.extend_from_slice(&[0x52, 0x34, 0x80, 0, 0x56, 0x78, 0x80, 0, 0x9a, 0xbc]); assert_eq!(buf, expected); } } ================================================ FILE: qbase/src/frame/retire_connection_id.rs ================================================ use crate::{ frame::{GetFrameType, io::WriteFrameType}, varint::{VarInt, WriteVarInt, be_varint}, }; /// RETIRE_CONNECTION_ID frame. /// /// ```text /// RETIRE_CONNECTION_ID Frame { /// Type (i) = 0x19, /// Sequence Number (i), /// } /// ``` /// /// See [RETIRE_CONNECTION_ID Frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-retire_connection_id-frames) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct RetireConnectionIdFrame { sequence: VarInt, } impl super::GetFrameType for RetireConnectionIdFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::RetireConnectionId } } impl super::EncodeSize for RetireConnectionIdFrame { fn max_encoding_size(&self) -> usize { 1 + 8 } fn encoding_size(&self) -> usize { 1 + self.sequence.encoding_size() } } impl RetireConnectionIdFrame { /// Create a new [`RetireConnectionIdFrame`]. pub fn new(sequence: VarInt) -> Self { Self { sequence } } /// Return the sequence number of the frame. pub fn sequence(&self) -> u64 { self.sequence.into_u64() } } /// Parse a RETIRE_CONNECTION_ID frame from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_retire_connection_id_frame(input: &[u8]) -> nom::IResult<&[u8], RetireConnectionIdFrame> { use nom::{Parser, combinator::map}; map(be_varint, RetireConnectionIdFrame::new).parse(input) } impl super::io::WriteFrame for T { fn put_frame(&mut self, frame: &RetireConnectionIdFrame) { self.put_frame_type(frame.frame_type()); self.put_varint(&frame.sequence); } } #[cfg(test)] mod tests { use super::{RetireConnectionIdFrame, be_retire_connection_id_frame}; use crate::{ frame::{ EncodeSize, FrameType, GetFrameType, io::{WriteFrame, WriteFrameType}, }, varint::VarInt, }; #[test] fn test_retire_connection_id_frame() { let frame = RetireConnectionIdFrame::new(VarInt::from_u32(0x1234)); assert_eq!(frame.frame_type(), FrameType::RetireConnectionId); assert_eq!(frame.max_encoding_size(), 1 + 8); assert_eq!(frame.encoding_size(), 1 + 2); assert_eq!(frame.sequence(), 0x1234); } #[test] fn test_read_retire_connection_id_frame() { let buf = vec![0x52, 0x34]; let (remain, frame) = be_retire_connection_id_frame(&buf).unwrap(); assert!(remain.is_empty()); assert_eq!( frame, RetireConnectionIdFrame::new(VarInt::from_u32(0x1234)) ); } #[test] fn test_write_retire_connection_id_frame() { let mut buf = Vec::new(); let frame = RetireConnectionIdFrame::new(VarInt::from_u32(0x1234)); buf.put_frame(&frame); let mut expected = Vec::new(); expected.put_frame_type(FrameType::RetireConnectionId); expected.extend_from_slice(&[0x52, 0x34]); assert_eq!(buf, expected); } } ================================================ FILE: qbase/src/frame/stop_sending.rs ================================================ use crate::{ frame::{GetFrameType, io::WriteFrameType}, sid::{StreamId, WriteStreamId, be_streamid}, varint::{VarInt, WriteVarInt, be_varint}, }; /// STOP_SENDING frame. /// /// ```text /// STOP_SENDING Frame { /// Type (i) = 0x05, /// Stream ID (i), /// Application Protocol Error Code (i), /// } /// ``` /// /// See [STOP_SENDING Frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-stop_sending-frames) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct StopSendingFrame { stream_id: StreamId, app_err_code: VarInt, } impl StopSendingFrame { /// Create a new [`StopSendingFrame`]. pub fn new(stream_id: StreamId, app_err_code: VarInt) -> Self { Self { stream_id, app_err_code, } } /// Return the stream ID of the frame. pub fn stream_id(&self) -> StreamId { self.stream_id } /// Return the application protocol error code of the frame. pub fn app_err_code(&self) -> u64 { self.app_err_code.into_u64() } /// Compose a RESET_STREAM frame from the STOP_SENDING frame with the given final size. pub fn reset_stream(&self, final_size: VarInt) -> super::ResetStreamFrame { super::ResetStreamFrame::new(self.stream_id, self.app_err_code, final_size) } } impl super::GetFrameType for StopSendingFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::StopSending } } impl super::EncodeSize for StopSendingFrame { fn max_encoding_size(&self) -> usize { 1 + 8 + 8 } fn encoding_size(&self) -> usize { 1 + self.stream_id.encoding_size() + self.app_err_code.encoding_size() } } /// Parse a STOP_SENDING frame from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_stop_sending_frame(input: &[u8]) -> nom::IResult<&[u8], StopSendingFrame> { use nom::{Parser, combinator::map}; map((be_streamid, be_varint), |(stream_id, app_err_code)| { StopSendingFrame { stream_id, app_err_code, } }) .parse(input) } impl super::io::WriteFrame for T { fn put_frame(&mut self, frame: &StopSendingFrame) { self.put_frame_type(frame.frame_type()); self.put_streamid(&frame.stream_id); self.put_varint(&frame.app_err_code); } } #[cfg(test)] mod tests { use super::{StopSendingFrame, be_stop_sending_frame}; use crate::{ frame::{ EncodeSize, FrameType, GetFrameType, io::{WriteFrame, WriteFrameType}, }, varint::{VarInt, be_varint}, }; #[test] fn test_stop_sending_frame() { let frame = StopSendingFrame::new(VarInt::from_u32(0x1234).into(), VarInt::from_u32(0x5678)); assert_eq!(frame.stream_id(), VarInt::from_u32(0x1234).into()); assert_eq!(frame.app_err_code(), 0x5678); assert_eq!(frame.frame_type(), FrameType::StopSending); assert_eq!(frame.max_encoding_size(), 1 + 8 + 8); assert_eq!(frame.encoding_size(), 1 + 2 + 4); } #[test] fn test_parse_stop_sending_frame() { use nom::{Parser, combinator::flat_map}; let frame = StopSendingFrame::new(VarInt::from_u32(0x1234).into(), VarInt::from_u32(0x5678)); let mut buf = Vec::new(); buf.put_frame(&frame); let stop_sending_frame_type = VarInt::from(FrameType::StopSending); let (input, parsed) = flat_map(be_varint, |frame_type| { if frame_type == stop_sending_frame_type { be_stop_sending_frame } else { panic!("wrong frame type: {frame_type}") } }) .parse(buf.as_ref()) .unwrap(); assert!(input.is_empty()); assert_eq!(parsed, frame); } #[test] fn test_write_stop_sending_frame() { let mut buf = Vec::new(); let frame = StopSendingFrame { stream_id: VarInt::from_u32(0x1234).into(), app_err_code: VarInt::from_u32(0x5678), }; buf.put_frame(&frame); let mut expected = Vec::new(); expected.put_frame_type(FrameType::StopSending); expected.extend_from_slice(&[0x52, 0x34, 0x80, 0, 0x56, 0x78]); assert_eq!(buf, expected); } } ================================================ FILE: qbase/src/frame/stream.rs ================================================ use std::ops::Range; use super::GetFrameType; use crate::{ frame::EncodeSize, sid::{StreamId, WriteStreamId, be_streamid}, util::{ContinuousData, WriteData}, varint::{VARINT_MAX, VarInt, WriteVarInt, be_varint}, }; /// Offset flag for STREAM frames #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum Offset { /// Offset field is zero (not present in frame) Zero, /// Offset field is non-zero (present in frame) NonZero, } impl From for u8 { fn from(offset: Offset) -> u8 { match offset { Offset::Zero => 0, Offset::NonZero => 0x04, } } } impl From for Offset { fn from(value: u64) -> Self { match value & 0x04 { 0 => Offset::Zero, _ => Offset::NonZero, } } } /// Length flag for STREAM frames #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum Len { /// Length field is present Explicit, /// Length field is omitted (extends to end of packet) Omit, } impl From for u8 { fn from(length: Len) -> u8 { match length { Len::Explicit => 0x02, Len::Omit => 0, } } } impl From for Len { fn from(value: u64) -> Self { match value & 0x02 { 0 => Len::Omit, _ => Len::Explicit, } } } /// Fin flag for STREAM frames #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum Fin { /// Stream is finished Yes, /// Stream is not finished No, } impl From for u8 { fn from(fin: Fin) -> u8 { match fin { Fin::Yes => 0x01, Fin::No => 0, } } } impl From for Fin { fn from(value: u64) -> Self { match value & 0x01 { 0 => Fin::No, _ => Fin::Yes, } } } /// STREAM frame. /// /// ```text /// STREAM Frame { /// Type (i) = 0x08..0x0f, /// Stream ID (i), /// [Offset (i)], /// [Length (i)], /// Stream Data (..), /// } /// ``` /// /// The lower 3 bits of the frame type are used to indicate the presence of the following fields: /// - OFF bit: 0x04 /// - LEN bit: 0x02 /// - FIN bit: 0x01 /// /// See [STREAM Frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-stream-frames) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct StreamFrame { id: StreamId, offset: VarInt, length: usize, len_bit: Len, fin_bit: Fin, } pub const STREAM_FRAME_MAX_ENCODING_SIZE: usize = 1 + 8 + 8 + 8; impl GetFrameType for StreamFrame { fn frame_type(&self) -> super::FrameType { let offset = if self.offset == 0 { Offset::Zero } else { Offset::NonZero }; super::FrameType::Stream(offset, self.len_bit, self.fin_bit) } } impl super::EncodeSize for StreamFrame { fn max_encoding_size(&self) -> usize { STREAM_FRAME_MAX_ENCODING_SIZE } fn encoding_size(&self) -> usize { 1 + self.id.encoding_size() + if self.offset.into_u64() != 0 { self.offset.encoding_size() } else { 0 } + if self.len_bit == Len::Explicit { VarInt::from_u64(self.length as u64) .expect("msg length must be less than 2^62") .encoding_size() } else { 0 } } } /// Efficient strategies for encoding stream frames #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct EncodingStrategy { len_bit: Len, pre_padding: usize, } impl EncodingStrategy { /// Cound the stream frame carry its data's length. pub fn len_bit(&self) -> Len { self.len_bit } /// How many padding frames should be put before the stream frame. pub fn pre_padding(&self) -> usize { self.pre_padding } } impl StreamFrame { /// Create a new stream frame with the given stream id, offset, and length. pub fn new(id: StreamId, offset: u64, length: usize) -> Self { assert!(offset <= VARINT_MAX); Self { id, offset: VarInt::from_u64(offset) .expect("offset of stream frame must be less than 2^62"), length, len_bit: Len::Omit, fin_bit: Fin::No, } } /// Return the stream id of this stream frame. pub fn stream_id(&self) -> StreamId { self.id } /// Return whether this stream frame is the end of the stream. pub fn is_fin(&self) -> bool { self.fin_bit == Fin::Yes } /// Return the offset of this stream frame. pub fn offset(&self) -> u64 { self.offset.into_u64() } /// Return the length of this stream frame. pub fn len(&self) -> usize { self.length } /// Return whether this stream frame is empty. pub fn is_empty(&self) -> bool { self.length == 0 } /// Return the range of this stream frame covered. pub fn range(&self) -> Range { self.offset.into_u64()..self.offset.into_u64() + self.length as u64 } /// Set the end of stream flag of this stream frame. pub fn set_eos_flag(&mut self, is_eos: bool) { if is_eos { self.fin_bit = Fin::Yes; } else { self.fin_bit = Fin::No; } } /// Set the length bit of this stream frame. pub fn set_len_bit(&mut self, len_bit: Len) { self.len_bit = len_bit; } /// Returns the most efficient stream frame encoding strategy. /// /// By default, a stream frame is considered the last frame within a data packet, /// allowing it to carry data up to the maximum payload capacity. However, if the /// data does not fill the entire frame and there is sufficient space remaining /// in the packet, other data frames can be carried after it. In this case, the /// frame is designated as carrying length. However, when a stream frame with a length /// is put into the data packet, the remaining space may be too small to put another /// stream frame. Filling the remaining space is sometimes more beneficial to taking /// advantage of GSO features. pub fn encoding_strategy(&self, capacity: usize) -> EncodingStrategy { // this method is used to determine the encoding strategy of the stream frame debug_assert_eq!(self.len_bit, Len::Omit); let encoding_size_without_length = self.encoding_size() + self.length; assert!(encoding_size_without_length <= capacity); let len_encoding_size = VarInt::try_from(self.length) .expect("length of stream frame must be less than 2^62") .encoding_size(); let remaining = capacity - encoding_size_without_length; if remaining >= len_encoding_size { let remaining = remaining - len_encoding_size; // TODO: It doesn't make sense, STREAM_FRAME_MAX_ENCODING_SIZE is 25 bytes // but the minium stream size can be as small as 3 bytes // with stream id less than 64 and offset 0 and without length if remaining < STREAM_FRAME_MAX_ENCODING_SIZE { EncodingStrategy { len_bit: Len::Explicit, pre_padding: remaining, } } else { EncodingStrategy { len_bit: Len::Explicit, pre_padding: 0, } } } else { EncodingStrategy { len_bit: Len::Omit, pre_padding: remaining, } } } /// Estimate the maximum capacity that one stream frame with the given capacity, /// stream id, and offset can carry. pub fn estimate_max_capacity(capacity: usize, sid: StreamId, offset: u64) -> Option { assert!(offset <= VARINT_MAX); let mut least = 1 + sid.encoding_size(); if offset != 0 { least += VarInt::from_u64(offset).unwrap().encoding_size(); } if capacity <= least { None } else { Some(capacity - least) } } } /// Return a parser for a stream frame with the given flag, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn stream_frame_with_flag( offset: Offset, len: Len, fin: Fin, ) -> impl Fn(&[u8]) -> nom::IResult<&[u8], StreamFrame> { move |input| { let (remain, id) = be_streamid(input)?; let (remain, offset) = if offset == Offset::NonZero { be_varint(remain)? } else { (remain, VarInt::default()) }; let (remain, length) = if len == Len::Explicit { let (remain, length) = be_varint(remain)?; (remain, length.into_u64() as usize) } else { (remain, remain.len()) }; if offset.into_u64() + length as u64 > VARINT_MAX { return Err(nom::Err::Error(nom::error::make_error( input, nom::error::ErrorKind::TooLarge, ))); } Ok(( remain, StreamFrame { id, offset, length, len_bit: len, fin_bit: fin, }, )) } } impl super::io::WriteDataFrame for T where T: bytes::BufMut + WriteData, D: ContinuousData, { fn put_data_frame(&mut self, frame: &StreamFrame, data: &D) { use crate::frame::io::WriteFrameType; self.put_frame_type(frame.frame_type()); self.put_streamid(&frame.id); if frame.offset.into_u64() != 0 { self.put_varint(&frame.offset); } if frame.len_bit == Len::Explicit { // Generally, a data frame will not exceed 4GB. self.put_varint(&VarInt::from_u32(frame.length as u32)); } self.put_data(data); } } #[cfg(test)] mod tests { use bytes::Bytes; use nom::{Parser, combinator::flat_map}; use super::*; use crate::{ frame::{EncodeSize, FrameType, GetFrameType, io::WriteDataFrame}, varint::{VarInt, be_varint}, }; #[test] fn test_stream_frame() { let stream_frame = StreamFrame { id: VarInt::from_u32(0x1234).into(), offset: VarInt::from_u32(0x1234), length: 11, len_bit: Len::Explicit, fin_bit: Fin::No, }; assert_eq!( stream_frame.frame_type(), FrameType::Stream(Offset::NonZero, Len::Explicit, Fin::No) ); assert_eq!(stream_frame.max_encoding_size(), 1 + 8 + 8 + 8); assert_eq!(stream_frame.encoding_size(), 1 + 2 + 2 + 1); } #[test] fn test_read_stream_frame() { let raw = Bytes::from_static(&[ 0x0e, 0x52, 0x34, 0x52, 0x34, 0x0b, b'h', b'e', b'l', b'l', b'o', b' ', b'w', b'o', b'r', b'l', b'd', 0, ]); let input = raw.as_ref(); let (input, frame) = flat_map(be_varint, |frame_type| { let stream_frame_type: VarInt = FrameType::Stream(Offset::NonZero, Len::Explicit, Fin::No).into(); assert_eq!(frame_type, stream_frame_type); stream_frame_with_flag(Offset::NonZero, Len::Explicit, Fin::No) }) .parse(input) .unwrap(); assert_eq!( input, &[ b'h', b'e', b'l', b'l', b'o', b' ', b'w', b'o', b'r', b'l', b'd', 0, ][..] ); assert_eq!( frame, StreamFrame { id: VarInt::from_u32(0x1234).into(), offset: VarInt::from_u32(0x1234), length: 11, len_bit: Len::Explicit, fin_bit: Fin::No, } ); } #[test] fn test_read_last_stream_frame() { let raw = Bytes::from_static(&[ 0x0c, 0x52, 0x34, 0x52, 0x34, b'h', b'e', b'l', b'l', b'o', b' ', b'w', b'o', b'r', b'l', b'd', ]); let input = raw.as_ref(); let (input, frame) = flat_map(be_varint, |frame_type| { let stream_frame_type: VarInt = FrameType::Stream(Offset::NonZero, Len::Omit, Fin::No).into(); assert_eq!(frame_type, stream_frame_type); stream_frame_with_flag(Offset::NonZero, Len::Omit, Fin::No) }) .parse(input) .unwrap(); assert_eq!( input, &[ b'h', b'e', b'l', b'l', b'o', b' ', b'w', b'o', b'r', b'l', b'd', ][..] ); assert_eq!( frame, StreamFrame { id: VarInt::from_u32(0x1234).into(), offset: VarInt::from_u32(0x1234), length: 11, len_bit: Len::Omit, fin_bit: Fin::No, } ); } #[test] fn test_write_initial_stream_frame() { let mut buf = Vec::new(); let frame = StreamFrame { id: VarInt::from_u32(0x1234).into(), offset: VarInt::from_u32(0), length: 11, len_bit: Len::Explicit, fin_bit: Fin::Yes, }; buf.put_data_frame(&frame, b"hello world"); assert_eq!( buf, vec![ 0xb, 0x52, 0x34, 0x0b, b'h', b'e', b'l', b'l', b'o', b' ', b'w', b'o', b'r', b'l', b'd' ] ); } #[test] fn test_write_last_stream_frame() { let mut buf = Vec::new(); let frame = StreamFrame { id: VarInt::from_u32(0x1234).into(), offset: VarInt::from_u32(0), length: 11, len_bit: Len::Omit, fin_bit: Fin::Yes, }; buf.put_data_frame(&frame, b"hello world"); assert_eq!( buf, vec![ 0x9, 0x52, 0x34, b'h', b'e', b'l', b'l', b'o', b' ', b'w', b'o', b'r', b'l', b'd' ] ); } #[test] fn test_write_eos_frame() { let mut buf = Vec::new(); let frame = StreamFrame { id: VarInt::from_u32(0x1234).into(), offset: VarInt::from_u32(0x1234), length: 11, len_bit: Len::Explicit, fin_bit: Fin::Yes, }; buf.put_data_frame(&frame, b"hello world"); assert_eq!( buf, vec![ 0x0f, 0x52, 0x34, 0x52, 0x34, 0x0b, b'h', b'e', b'l', b'l', b'o', b' ', b'w', b'o', b'r', b'l', b'd' ] ); } #[test] fn test_write_unfinished_stream_frame() { let mut buf = Vec::new(); let frame = StreamFrame { id: VarInt::from_u32(0x1234).into(), offset: VarInt::from_u32(0x1234), length: 11, len_bit: Len::Explicit, fin_bit: Fin::No, }; buf.put_data_frame(&frame, b"hello world"); assert_eq!( buf, vec![ 0x0e, 0x52, 0x34, 0x52, 0x34, 0x0b, b'h', b'e', b'l', b'l', b'o', b' ', b'w', b'o', b'r', b'l', b'd' ] ); } } ================================================ FILE: qbase/src/frame/stream_data_blocked.rs ================================================ use crate::{ frame::{GetFrameType, io::WriteFrameType}, sid::{StreamId, WriteStreamId, be_streamid}, varint::{VarInt, WriteVarInt, be_varint}, }; /// STREAM_DATA_BLOCKED frame. /// /// ```text /// STREAM_DATA_BLOCKED Frame { /// Type (i) = 0x15, /// Stream ID (i), /// Maximum Stream Data (i), /// } /// ``` /// /// See [STREAM_DATA_BLOCKED Frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-stream_data_blocked-frames) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct StreamDataBlockedFrame { stream_id: StreamId, maximum_stream_data: VarInt, } impl super::GetFrameType for StreamDataBlockedFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::StreamDataBlocked } } impl super::EncodeSize for StreamDataBlockedFrame { fn max_encoding_size(&self) -> usize { 1 + 8 + 8 } fn encoding_size(&self) -> usize { 1 + self.stream_id.encoding_size() + self.maximum_stream_data.encoding_size() } } impl StreamDataBlockedFrame { /// Create a new [`StreamDataBlockedFrame`]. pub fn new(stream_id: StreamId, maximum_stream_data: VarInt) -> Self { Self { stream_id, maximum_stream_data, } } /// Return the stream ID of the frame. pub fn stream_id(&self) -> StreamId { self.stream_id } /// Return the maximum stream data of the frame. pub fn maximum_stream_data(&self) -> u64 { self.maximum_stream_data.into_u64() } } /// Parse a STREAM_DATA_BLOCKED frame from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_stream_data_blocked_frame(input: &[u8]) -> nom::IResult<&[u8], StreamDataBlockedFrame> { let (input, stream_id) = be_streamid(input)?; let (input, maximum_stream_data) = be_varint(input)?; Ok(( input, StreamDataBlockedFrame { stream_id, maximum_stream_data, }, )) } impl super::io::WriteFrame for T { fn put_frame(&mut self, frame: &StreamDataBlockedFrame) { self.put_frame_type(frame.frame_type()); self.put_streamid(&frame.stream_id); self.put_varint(&frame.maximum_stream_data); } } #[cfg(test)] mod tests { use super::StreamDataBlockedFrame; use crate::{ frame::{ EncodeSize, FrameType, GetFrameType, io::{WriteFrame, WriteFrameType}, }, varint::VarInt, }; #[test] fn test_stream_data_blocked_frame() { let frame = StreamDataBlockedFrame::new(VarInt::from_u32(0x1234).into(), VarInt::from_u32(0x5678)); assert_eq!(frame.frame_type(), FrameType::StreamDataBlocked); assert_eq!(frame.max_encoding_size(), 1 + 8 + 8); assert_eq!(frame.encoding_size(), 1 + 2 + 4); assert_eq!(frame.stream_id(), VarInt::from_u32(0x1234).into()); assert_eq!(frame.maximum_stream_data(), 0x5678); } #[test] fn test_read_stream_data_blocked() { use super::be_stream_data_blocked_frame; let buf = [0x52, 0x34, 0x80, 0, 0x56, 0x78]; let (_, frame) = be_stream_data_blocked_frame(&buf).unwrap(); assert_eq!( frame, StreamDataBlockedFrame::new(VarInt::from_u32(0x1234).into(), VarInt::from_u32(0x5678)) ); } #[test] fn test_write_stream_data_blocked_frame() { let mut buf = Vec::new(); buf.put_frame(&StreamDataBlockedFrame::new( VarInt::from_u32(0x1234).into(), VarInt::from_u32(0x5678), )); let mut expected = Vec::new(); expected.put_frame_type(FrameType::StreamDataBlocked); expected.extend_from_slice(&[0x52, 0x34, 0x80, 0, 0x56, 0x78]); assert_eq!(buf, expected); } } ================================================ FILE: qbase/src/frame/streams_blocked.rs ================================================ use crate::{ frame::{GetFrameType, io::WriteFrameType}, sid::Dir, varint::{VarInt, WriteVarInt, be_varint}, }; /// STREAMS_BLOCKED frame. /// /// ```text /// STREAMS_BLOCKED Frame { /// Type (i) = 0x16..0x17, /// Maximum Streams (i), /// } /// ``` /// /// See [STREAMS_BLOCKED Frames](https://www.rfc-editor.org/rfc/rfc9000.html#name-streams_blocked-frames) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum StreamsBlockedFrame { Bi(VarInt), Uni(VarInt), } impl StreamsBlockedFrame { pub fn with(dir: Dir, max_streams: VarInt) -> Self { match dir { Dir::Bi => StreamsBlockedFrame::Bi(max_streams), Dir::Uni => StreamsBlockedFrame::Uni(max_streams), } } } impl super::GetFrameType for StreamsBlockedFrame { fn frame_type(&self) -> super::FrameType { super::FrameType::StreamsBlocked(match self { StreamsBlockedFrame::Bi(_) => Dir::Bi, StreamsBlockedFrame::Uni(_) => Dir::Uni, }) } } impl super::EncodeSize for StreamsBlockedFrame { fn max_encoding_size(&self) -> usize { 1 + 8 } fn encoding_size(&self) -> usize { 1 + match self { StreamsBlockedFrame::Bi(stream_id) => stream_id.encoding_size(), StreamsBlockedFrame::Uni(stream_id) => stream_id.encoding_size(), } } } /// Return a parser for STREAMS_BLOCKED frame with the given direction, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn streams_blocked_frame_with_dir( dir: Dir, ) -> impl Fn(&[u8]) -> nom::IResult<&[u8], StreamsBlockedFrame> { move |input: &[u8]| { let (input, max_streams) = be_varint(input)?; Ok(( input, match dir { Dir::Bi => StreamsBlockedFrame::Bi(max_streams), Dir::Uni => StreamsBlockedFrame::Uni(max_streams), }, )) } } impl super::io::WriteFrame for T { fn put_frame(&mut self, frame: &StreamsBlockedFrame) { match frame { StreamsBlockedFrame::Bi(max_streams) => { self.put_frame_type(frame.frame_type()); self.put_varint(max_streams); } StreamsBlockedFrame::Uni(max_streams) => { self.put_frame_type(frame.frame_type()); self.put_varint(max_streams); } } } } #[cfg(test)] mod tests { use super::StreamsBlockedFrame; use crate::{ frame::{ EncodeSize, FrameType, GetFrameType, io::{WriteFrame, WriteFrameType}, }, sid::Dir, varint::VarInt, }; #[test] fn test_stream_data_blocked_frame() { let frame = StreamsBlockedFrame::Bi(VarInt::from_u32(0x1234)); assert_eq!(frame.frame_type(), FrameType::StreamsBlocked(Dir::Bi)); assert_eq!(frame.max_encoding_size(), 1 + 8); assert_eq!(frame.encoding_size(), 1 + 2); let frame = StreamsBlockedFrame::Uni(VarInt::from_u32(0x1234)); assert_eq!(frame.frame_type(), FrameType::StreamsBlocked(Dir::Uni)); assert_eq!(frame.max_encoding_size(), 1 + 8); assert_eq!(frame.encoding_size(), 1 + 2); } #[test] fn test_read_streams_blocked_frame() { use nom::{Parser, combinator::flat_map}; use super::streams_blocked_frame_with_dir; use crate::varint::be_varint; let streams_blocked_bi_type = VarInt::from(FrameType::StreamsBlocked(Dir::Bi)); let streams_blocked_uni_type = VarInt::from(FrameType::StreamsBlocked(Dir::Uni)); let buf = vec![streams_blocked_bi_type.into_u64() as u8, 0x52, 0x34]; let (input, frame) = flat_map(be_varint, |frame_type| { if frame_type == streams_blocked_bi_type { streams_blocked_frame_with_dir(Dir::Bi) } else { panic!("wrong frame type: {frame_type}") } }) .parse(buf.as_ref()) .unwrap(); assert!(input.is_empty()); assert_eq!(frame, StreamsBlockedFrame::Bi(VarInt::from_u32(0x1234))); let buf = vec![streams_blocked_uni_type.into_u64() as u8, 0x52, 0x34]; let (input, frame) = flat_map(be_varint, |frame_type| { if frame_type == streams_blocked_uni_type { streams_blocked_frame_with_dir(Dir::Uni) } else { panic!("wrong frame type: {frame_type}") } }) .parse(buf.as_ref()) .unwrap(); assert!(input.is_empty()); assert_eq!(frame, StreamsBlockedFrame::Uni(VarInt::from_u32(0x1234))); } #[test] fn test_write_streams_blocked_frame() { let mut buf = Vec::new(); buf.put_frame(&StreamsBlockedFrame::Bi(VarInt::from_u32(0x1234))); let mut expected = Vec::new(); expected.put_frame_type(FrameType::StreamsBlocked(Dir::Bi)); expected.extend_from_slice(&[0x52, 0x34]); assert_eq!(buf, expected); let mut buf = Vec::new(); buf.put_frame(&StreamsBlockedFrame::Uni(VarInt::from_u32(0x1234))); let mut expected = Vec::new(); expected.put_frame_type(FrameType::StreamsBlocked(Dir::Uni)); expected.extend_from_slice(&[0x52, 0x34]); assert_eq!(buf, expected); } } ================================================ FILE: qbase/src/frame.rs ================================================ use std::fmt::Debug; use bytes::{Buf, BufMut, Bytes}; use derive_more::{Deref, DerefMut, From, TryInto}; use enum_dispatch::enum_dispatch; use io::WriteFrame; use super::varint::VarInt; use crate::{net::Family, packet::r#type::Type, sid::Dir}; mod ack; mod connection_close; mod crypto; mod data_blocked; mod datagram; mod handshake_done; mod max_data; mod max_stream_data; mod max_streams; mod new_connection_id; mod new_token; mod padding; mod path_challenge; mod path_response; mod ping; mod reset_stream; mod retire_connection_id; mod stop_sending; mod stream; mod stream_data_blocked; mod streams_blocked; mod add_address; mod punch_done; mod punch_hello; mod punch_me_now; mod remove_address; /// Error module for parsing frames pub mod error; /// IO module for frame encoding and decoding pub mod io; pub use ack::{AckFrame, Ecn, EcnCounts}; pub use add_address::AddAddressFrame; pub use connection_close::{AppCloseFrame, ConnectionCloseFrame, Layer, QuicCloseFrame}; pub use crypto::CryptoFrame; pub use data_blocked::DataBlockedFrame; pub use datagram::DatagramFrame; #[doc(hidden)] pub use error::Error; pub use handshake_done::HandshakeDoneFrame; pub use max_data::MaxDataFrame; pub use max_stream_data::MaxStreamDataFrame; pub use max_streams::MaxStreamsFrame; pub use new_connection_id::NewConnectionIdFrame; pub use new_token::NewTokenFrame; pub use padding::PaddingFrame; pub use path_challenge::PathChallengeFrame; pub use path_response::PathResponseFrame; pub use ping::PingFrame; pub use punch_done::PunchDoneFrame; pub use punch_hello::PunchHelloFrame; pub use punch_me_now::PunchMeNowFrame; pub use remove_address::RemoveAddressFrame; pub use reset_stream::{ResetStreamError, ResetStreamFrame}; pub use retire_connection_id::RetireConnectionIdFrame; pub use stop_sending::StopSendingFrame; pub use stream::{EncodingStrategy, Fin, Len, Offset, STREAM_FRAME_MAX_ENCODING_SIZE, StreamFrame}; pub use stream_data_blocked::StreamDataBlockedFrame; pub use streams_blocked::StreamsBlockedFrame; /// Define the basic behaviors for all kinds of frames #[enum_dispatch] pub trait GetFrameType { /// Return the type of frame fn frame_type(&self) -> FrameType; } #[enum_dispatch] pub trait EncodeSize { /// Return the max number of bytes needed to encode this value /// /// Calculate the maximum size by summing up the maximum length of each field. /// If a field type has a maximum length, use it, otherwise use the actual length /// of the data in that field. /// /// When packaging data, by pre-estimating this value to effectively avoid spending /// extra resources to calculate the actual encoded size. fn max_encoding_size(&self) -> usize { 1 } /// Return the exact number of bytes needed to encode this value fn encoding_size(&self) -> usize { 1 } } /// The `Spec` summarizes any special rules governing the processing /// or generation of the frame type, as indicated by the following characters. /// /// See [table-3](https://www.rfc-editor.org/rfc/rfc9000.html#table-3) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. pub enum Spec { /// Packets containing only frames with this marking are not ack-eliciting. /// /// See [Section 13.2](https://www.rfc-editor.org/rfc/rfc9000.html#generating-acks) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. NonAckEliciting = 1, /// Packets containing only frames with this marking do not count toward bytes /// in flight for congestion control purposes. /// See [section-12.4-14.4](https://www.rfc-editor.org/rfc/rfc9000.html#section-12.4-14.4) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html). /// /// Similar to TCP, packets containing only ACK frames do not count toward bytes /// in flight and are not congestion controlled. /// See [Section 7.4](https://www.rfc-editor.org/rfc/rfc9002#section-7-4) /// of [QUIC-RECOVERY](https://www.rfc-editor.org/rfc/rfc9002). CongestionControlFree = 2, /// Packets containing only frames with this marking can be used to probe /// new network paths during connection migration. /// /// See [Section 9.1](https://www.rfc-editor.org/rfc/rfc9000.html#probing) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html). ProbeNewPath = 4, /// The contents of frames with this marking are flow controlled. /// /// See [Section 4](https://www.rfc-editor.org/rfc/rfc9000.html#flow-control) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. FlowControlled = 8, } pub trait ContainSpec { fn contain(&self, spec: Spec) -> bool; } impl ContainSpec for u8 { #[inline] fn contain(&self, spec: Spec) -> bool { *self & spec as u8 != 0 } } /// The sum type of all the core QUIC frame types. /// /// See [table-3](https://www.rfc-editor.org/rfc/rfc9000.html#table-3) /// and [frame types and formats](https://www.rfc-editor.org/rfc/rfc9000.html#name-frame-types-and-formats) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum FrameType { /// PADDED frame, see [`PaddingFrame`]. Padding, /// PING frame, see [`PingFrame`]. Ping, /// ACK frame, see [`AckFrame`]. Ack(Ecn), /// RESET_STREAM frame, see [`ResetStreamFrame`]. ResetStream, /// STOP_SENDING frame, see [`StopSendingFrame`]. StopSending, /// CRYPTO frame, see [`CryptoFrame`]. Crypto, /// NEW_TOKEN frame, see [`NewTokenFrame`]. NewToken, /// STREAM frame, see [`StreamFrame`]. Stream(Offset, Len, Fin), /// MAX_DATA frame, see [`MaxDataFrame`]. MaxData, /// MAX_STREAM_DATA frame, see [`MaxStreamDataFrame`]. MaxStreamData, /// MAX_STREAMS frame, see [`MaxStreamsFrame`]. MaxStreams(Dir), /// DATA_BLOCKED frame, see [`DataBlockedFrame`]. DataBlocked, /// STREAM_DATA_BLOCKED frame, see [`StreamDataBlockedFrame`]. StreamDataBlocked, /// STREAMS_BLOCKED frame, see [`StreamsBlockedFrame`]. StreamsBlocked(Dir), /// NEW_CONNECTION_ID frame, see [`NewConnectionIdFrame`]. NewConnectionId, /// RETIRE_CONNECTION_ID frame, see [`RetireConnectionIdFrame`]. RetireConnectionId, /// PATH_CHALLENGE frame, see [`PathChallengeFrame`]. PathChallenge, /// PATH_RESPONSE frame, see [`PathResponseFrame`]. PathResponse, /// CONNECTION_CLOSE frame, see [`ConnectionCloseFrame`]. ConnectionClose(Layer), /// HANDSHAKE_DONE frame, see [`HandshakeDoneFrame`]. HandshakeDone, /// DATAGRAM frame, see [`DatagramFrame`]. Datagram(u8), /// ADD_ADDRESS frame, see [`AddAddressFrame`]. AddAddress(Family), /// REMOVE_ADDRESS frame, see [`RemoveAddressFrame`]. RemoveAddress, /// PUNCH_ME_NOW frame, see [`PunchMeNowFrame`]. PunchMeNow(Family), /// PUNCH_HELLO frame, see [`PunchHelloFrame`]. PunchHello, /// PUNCH_DONE frame, see [`PunchDoneFrame`]. PunchDone, } #[enum_dispatch] pub trait FrameFeature { /// Return whether a frame type belongs to the given packet_type fn belongs_to(&self, packet_type: Type) -> bool; /// Return the specs of the frame type fn specs(&self) -> u8; } impl FrameFeature for T { fn belongs_to(&self, packet_type: Type) -> bool { self.frame_type().belongs_to(packet_type) } fn specs(&self) -> u8 { self.frame_type().specs() } } impl FrameFeature for FrameType { fn belongs_to(&self, packet_type: Type) -> bool { use crate::packet::r#type::{ long::{Type::V1, Ver1}, short::OneRtt, }; // IH01 let i = matches!(packet_type, Type::Long(V1(Ver1::INITIAL))); let h = matches!(packet_type, Type::Long(V1(Ver1::HANDSHAKE))); let o = matches!(packet_type, Type::Long(V1(Ver1::ZERO_RTT))); let l = matches!(packet_type, Type::Short(OneRtt(_))); match self { FrameType::Padding => i | h | o | l, FrameType::Ping => i | h | o | l, FrameType::Ack(_) => i | h | l, FrameType::ResetStream => o | l, FrameType::StopSending => o | l, FrameType::Crypto => i | h | l, FrameType::NewToken => l, FrameType::Stream(..) => o | l, FrameType::MaxData => o | l, FrameType::MaxStreamData => o | l, FrameType::MaxStreams(_) => o | l, FrameType::DataBlocked => o | l, FrameType::StreamDataBlocked => o | l, FrameType::StreamsBlocked(_) => o | l, FrameType::NewConnectionId => o | l, FrameType::RetireConnectionId => o | l, FrameType::PathChallenge => o | l, FrameType::PathResponse => l, // The application-specific variant of CONNECTION_CLOSE (type 0x1d) can only be // sent using 0-RTT or 1-RTT packets; // See [Section 12.5](https://www.rfc-editor.org/rfc/rfc9000.html#section-12.5). // // When an application wishes to abandon a connection during the handshake, // an endpoint can send a CONNECTION_CLOSE frame (type 0x1c) with an error code // of APPLICATION_ERROR in an Initial or Handshake packet. FrameType::ConnectionClose(layer) => match layer { Layer::App => o | l, Layer::Quic => i | h | o | l, }, FrameType::HandshakeDone => l, FrameType::Datagram(_) => o | l, FrameType::AddAddress(_) => o | l, FrameType::RemoveAddress => o | l, FrameType::PunchMeNow(_) => o | l, FrameType::PunchHello => o | l, FrameType::PunchDone => o | l, } } fn specs(&self) -> u8 { let (n, c, p, f) = ( Spec::NonAckEliciting as u8, Spec::CongestionControlFree as u8, Spec::ProbeNewPath as u8, Spec::FlowControlled as u8, ); match self { FrameType::Padding => n | p, FrameType::Ack(_) => n | c, FrameType::Stream(..) => f, FrameType::NewConnectionId => p, FrameType::PathChallenge => p, FrameType::PathResponse => p, // different from [table 3](https://www.rfc-editor.org/rfc/rfc9000.html#table-3), // add the [`Spec::Con`] for the CONNECTION_CLOSE frame FrameType::ConnectionClose(_) => n | c, FrameType::PunchHello => n, FrameType::PunchDone => n, _ => 0, } } } impl TryFrom for FrameType { type Error = Error; fn try_from(frame_type: VarInt) -> Result { Ok(match frame_type.into_u64() { 0x00 => FrameType::Padding, 0x01 => FrameType::Ping, // The last bit is the ECN flag. 0x02 => FrameType::Ack(Ecn::None), 0x03 => FrameType::Ack(Ecn::Exist), 0x04 => FrameType::ResetStream, 0x05 => FrameType::StopSending, 0x06 => FrameType::Crypto, 0x07 => FrameType::NewToken, // The last three bits are the offset, length, and fin flag bits respectively. ty @ 0x08..=0x0f => FrameType::Stream(Offset::from(ty), Len::from(ty), Fin::from(ty)), 0x10 => FrameType::MaxData, 0x11 => FrameType::MaxStreamData, // The last bit is the direction flag bit, 0 indicates bidirectional, 1 indicates unidirectional. 0x12 => FrameType::MaxStreams(Dir::Bi), 0x13 => FrameType::MaxStreams(Dir::Uni), 0x14 => FrameType::DataBlocked, 0x15 => FrameType::StreamDataBlocked, // The last bit is the direction flag bit, 0 indicates bidirectional, 1 indicates unidirectional. 0x16 => FrameType::StreamsBlocked(Dir::Bi), 0x17 => FrameType::StreamsBlocked(Dir::Uni), 0x18 => FrameType::NewConnectionId, 0x19 => FrameType::RetireConnectionId, 0x1a => FrameType::PathChallenge, 0x1b => FrameType::PathResponse, 0x1c => FrameType::ConnectionClose(Layer::Quic), 0x1d => FrameType::ConnectionClose(Layer::App), 0x1e => FrameType::HandshakeDone, // The last bit is the length flag bit, 0 the length field is absent and the Datagram Data // field extends to the end of the packet, 1 the length field is present. ty @ (0x30 | 0x31) => FrameType::Datagram(ty as u8 & 1), 0x3d7e90 => FrameType::AddAddress(Family::V4), 0x3d7e91 => FrameType::AddAddress(Family::V6), 0x3d7e92 => FrameType::PunchMeNow(Family::V4), 0x3d7e93 => FrameType::PunchMeNow(Family::V6), 0x3d7e94 => FrameType::RemoveAddress, 0x3d7e95 => FrameType::PunchHello, 0x3d7e96 => FrameType::PunchDone, // May be extension frame _ => return Err(Self::Error::InvalidType(frame_type)), }) } } impl From for VarInt { fn from(frame_type: FrameType) -> Self { match frame_type { FrameType::Padding => VarInt::from_u32(0x00), FrameType::Ping => VarInt::from_u32(0x01), FrameType::Ack(Ecn::None) => VarInt::from_u32(0x02), FrameType::Ack(Ecn::Exist) => VarInt::from_u32(0x03), FrameType::ResetStream => VarInt::from_u32(0x04), FrameType::StopSending => VarInt::from_u32(0x05), FrameType::Crypto => VarInt::from_u32(0x06), FrameType::NewToken => VarInt::from_u32(0x07), FrameType::Stream(offset, len, fin) => { let offset: u8 = offset.into(); let len: u8 = len.into(); let fin: u8 = fin.into(); VarInt::from(0x08u8 | offset | len | fin) } FrameType::MaxData => VarInt::from_u32(0x10), FrameType::MaxStreamData => VarInt::from_u32(0x11), FrameType::MaxStreams(Dir::Bi) => VarInt::from_u32(0x12), FrameType::MaxStreams(Dir::Uni) => VarInt::from_u32(0x13), FrameType::DataBlocked => VarInt::from_u32(0x14), FrameType::StreamDataBlocked => VarInt::from_u32(0x15), FrameType::StreamsBlocked(Dir::Bi) => VarInt::from_u32(0x16), FrameType::StreamsBlocked(Dir::Uni) => VarInt::from_u32(0x17), FrameType::NewConnectionId => VarInt::from_u32(0x18), FrameType::RetireConnectionId => VarInt::from_u32(0x19), FrameType::PathChallenge => VarInt::from_u32(0x1a), FrameType::PathResponse => VarInt::from_u32(0x1b), FrameType::ConnectionClose(Layer::Quic) => VarInt::from_u32(0x1c), FrameType::ConnectionClose(Layer::App) => VarInt::from_u32(0x1d), FrameType::HandshakeDone => VarInt::from_u32(0x1e), FrameType::Datagram(with_len) => VarInt::from(0x30 | with_len), FrameType::AddAddress(family) => VarInt::from_u32(0x3d7e90 | family as u32), FrameType::PunchMeNow(family) => VarInt::from_u32(0x3d7e92 | family as u32), FrameType::RemoveAddress => VarInt::from_u32(0x3d7e94), FrameType::PunchHello => VarInt::from_u32(0x3d7e95), FrameType::PunchDone => VarInt::from_u32(0x3d7e96), } } } /// Parse the frame type from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_frame_type(input: &[u8]) -> nom::IResult<&[u8], FrameType, Error> { let (remain, frame_type) = crate::varint::be_varint(input).map_err(|_| { nom::Err::Error(Error::IncompleteType(format!( "Incomplete frame type from input: {input:?}" ))) })?; let frame_type = FrameType::try_from(frame_type).map_err(nom::Err::Error)?; Ok((remain, frame_type)) } /// Sum type of all the stream related frames except [`StreamFrame`]. #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[enum_dispatch(EncodeSize, GetFrameType)] pub enum StreamCtlFrame { /// RESET_STREAM frame, see [`ResetStreamFrame`]. ResetStream(ResetStreamFrame), /// STOP_SENDING frame, see [`StopSendingFrame`]. StopSending(StopSendingFrame), /// MAX_STREAM_DATA frame, see [`MaxStreamDataFrame`]. MaxStreamData(MaxStreamDataFrame), /// MAX_STREAMS frame, see [`MaxStreamsFrame`]. MaxStreams(MaxStreamsFrame), /// STREAM_DATA_BLOCKED frame, see [`StreamDataBlockedFrame`]. StreamDataBlocked(StreamDataBlockedFrame), /// STREAMS_BLOCKED frame, see [`StreamsBlockedFrame`]. StreamsBlocked(StreamsBlockedFrame), } /// Sum type of all the reliable frames. #[derive(Debug, Clone, Eq, PartialEq)] #[enum_dispatch(EncodeSize, GetFrameType)] pub enum ReliableFrame { /// NEW_TOKEN frame, see [`NewTokenFrame`]. NewToken(NewTokenFrame), /// MAX_DATA frame, see [`MaxDataFrame`]. MaxData(MaxDataFrame), /// DATA_BLOCKED frame, see [`DataBlockedFrame`]. DataBlocked(DataBlockedFrame), /// NEW_CONNECTION_ID frame, see [`NewConnectionIdFrame`]. NewConnectionId(NewConnectionIdFrame), /// RETIRE_CONNECTION_ID frame, see [`RetireConnectionIdFrame`]. RetireConnectionId(RetireConnectionIdFrame), /// HANDSHAKE_DONE frame, see [`HandshakeDoneFrame`]. HandshakeDone(HandshakeDoneFrame), /// ADD_ADDRESS frame, see [`AddAddressFrame`]. AddAddress(AddAddressFrame), /// REMOVE_ADDRESS frame, see [`RemoveAddressFrame`]. RemoveAddress(RemoveAddressFrame), /// PUNCH_ME_NOW frame, see [`PunchMeNowFrame`]. PunchMeNow(PunchMeNowFrame), /// PUNCH_DONE frame, see [`PunchDoneFrame`]. PunchDone(PunchDoneFrame), /// STREAM control frame, see [`StreamCtlFrame`]. StreamCtl(StreamCtlFrame), } /// Sum type of all the frames. /// /// The data frames' body are stored in the second field. #[derive(Debug, Clone, From, TryInto, Eq, PartialEq)] pub enum Frame { /// PADDING frame, see [`PaddingFrame`]. Padding(PaddingFrame), /// PING frame, see [`PingFrame`]. Ping(PingFrame), /// ACK frame, see [`AckFrame`]. Ack(AckFrame), /// CONNECTION_CLOSE frame, see [`ConnectionCloseFrame`]. Close(ConnectionCloseFrame), /// NEW_TOKEN frame, see [`NewTokenFrame`]. NewToken(NewTokenFrame), /// MAX_DATA frame, see [`MaxDataFrame`]. MaxData(MaxDataFrame), /// DATA_BLOCKED frame, see [`DataBlockedFrame`]. DataBlocked(DataBlockedFrame), /// NEW_CONNECTION_ID frame, see [`NewConnectionIdFrame`]. NewConnectionId(NewConnectionIdFrame), /// RETIRE_CONNECTION_ID frame, see [`RetireConnectionIdFrame`]. RetireConnectionId(RetireConnectionIdFrame), /// HANDSHAKE_DONE frame, see [`HandshakeDoneFrame`]. HandshakeDone(HandshakeDoneFrame), /// PATH_CHALLENGE frame, see [`PathChallengeFrame`]. PathChallenge(PathChallengeFrame), /// PATH_RESPONSE frame, see [`PathResponseFrame`]. PathResponse(PathResponseFrame), /// Stream control frame, see [`StreamCtlFrame`]. StreamCtl(StreamCtlFrame), /// STREAM frame and its data, see [`StreamFrame`]. Stream(StreamFrame, D), /// CRYPTO frame and its data, see [`CryptoFrame`]. Crypto(CryptoFrame, D), /// DATAGRAM frame and its data, see [`DatagramFrame`]. Datagram(DatagramFrame, D), /// ADD_ADDRESS frame, see [`AddAddressFrame`]. AddAddress(AddAddressFrame), /// REMOVE_ADDRESS frame, see [`RemoveAddressFrame`]. RemoveAddress(RemoveAddressFrame), /// PUNCH_ME_NOW frame, see [`PunchMeNowFrame`]. PunchMeNow(PunchMeNowFrame), /// PUNCH_HELLO frame, see [`PunchHelloFrame`]. PunchHello(PunchHelloFrame), /// PUNCH_DONE frame, see [`PunchDoneFrame`]. PunchDone(PunchDoneFrame), } impl From for Frame { #[inline] fn from(frame: ReliableFrame) -> Self { match frame { ReliableFrame::NewToken(new_token_frame) => Frame::NewToken(new_token_frame), ReliableFrame::MaxData(max_data_frame) => Frame::MaxData(max_data_frame), ReliableFrame::DataBlocked(data_blocked_frame) => { Frame::DataBlocked(data_blocked_frame) } ReliableFrame::NewConnectionId(new_connection_id_frame) => { Frame::NewConnectionId(new_connection_id_frame) } ReliableFrame::RetireConnectionId(retire_connection_id_frame) => { Frame::RetireConnectionId(retire_connection_id_frame) } ReliableFrame::HandshakeDone(handshake_done_frame) => { Frame::HandshakeDone(handshake_done_frame) } ReliableFrame::AddAddress(add_address_frame) => Frame::AddAddress(add_address_frame), ReliableFrame::RemoveAddress(remove_address_frame) => { Frame::RemoveAddress(remove_address_frame) } ReliableFrame::PunchMeNow(punch_me_now_frame) => Frame::PunchMeNow(punch_me_now_frame), ReliableFrame::PunchDone(punch_done_frame) => Frame::PunchDone(punch_done_frame), ReliableFrame::StreamCtl(stream_frame) => Frame::StreamCtl(stream_frame), } } } impl<'f, D> TryFrom<&'f Frame> for CryptoFrame { type Error = &'f Frame; #[inline] fn try_from(frame: &'f Frame) -> Result { match frame { Frame::Crypto(frame, _data) => Ok(*frame), frame => Err(frame), } } } impl<'f, D> TryFrom<&'f Frame> for ReliableFrame { type Error = &'f Frame; #[inline] fn try_from(frame: &'f Frame) -> Result { match frame { Frame::NewToken(new_token_frame) => { Ok(ReliableFrame::NewToken(new_token_frame.clone())) } Frame::MaxData(max_data_frame) => Ok(ReliableFrame::MaxData(*max_data_frame)), Frame::DataBlocked(data_blocked_frame) => { Ok(ReliableFrame::DataBlocked(*data_blocked_frame)) } Frame::NewConnectionId(new_connection_id_frame) => { Ok(ReliableFrame::NewConnectionId(*new_connection_id_frame)) } Frame::RetireConnectionId(retire_connection_id_frame) => Ok( ReliableFrame::RetireConnectionId(*retire_connection_id_frame), ), Frame::HandshakeDone(handshake_done_frame) => { Ok(ReliableFrame::HandshakeDone(*handshake_done_frame)) } Frame::AddAddress(add_address_frame) => { Ok(ReliableFrame::AddAddress(*add_address_frame)) } Frame::RemoveAddress(remove_address_frame) => { Ok(ReliableFrame::RemoveAddress(*remove_address_frame)) } Frame::PunchMeNow(punch_me_now_frame) => { Ok(ReliableFrame::PunchMeNow(*punch_me_now_frame)) } Frame::PunchDone(punch_done_frame) => Ok(ReliableFrame::PunchDone(*punch_done_frame)), Frame::StreamCtl(stream_frame) => Ok(ReliableFrame::StreamCtl(*stream_frame)), frame => Err(frame), } } } impl GetFrameType for Frame { #[doc = " Return the type of frame"] #[inline] fn frame_type(&self) -> FrameType { match self { Frame::Padding(f) => f.frame_type(), Frame::Ping(f) => f.frame_type(), Frame::Ack(f) => f.frame_type(), Frame::Close(f) => f.frame_type(), Frame::NewToken(f) => f.frame_type(), Frame::MaxData(f) => f.frame_type(), Frame::DataBlocked(f) => f.frame_type(), Frame::NewConnectionId(f) => f.frame_type(), Frame::RetireConnectionId(f) => f.frame_type(), Frame::HandshakeDone(f) => f.frame_type(), Frame::PathChallenge(f) => f.frame_type(), Frame::PathResponse(f) => f.frame_type(), Frame::StreamCtl(f) => f.frame_type(), Frame::Stream(f, _) => f.frame_type(), Frame::Crypto(f, _) => f.frame_type(), Frame::Datagram(f, _) => f.frame_type(), Frame::AddAddress(f) => f.frame_type(), Frame::RemoveAddress(f) => f.frame_type(), Frame::PunchMeNow(f) => f.frame_type(), Frame::PunchHello(f) => f.frame_type(), Frame::PunchDone(f) => f.frame_type(), } } } impl EncodeSize for Frame { #[doc = " Return the max number of bytes needed to encode this value"] #[doc = ""] #[doc = " Calculate the maximum size by summing up the maximum length of each field."] #[doc = " If a field type has a maximum length, use it, otherwise use the actual length"] #[doc = " of the data in that field."] #[doc = ""] #[doc = " When packaging data, by pre-estimating this value to effectively avoid spending"] #[doc = " extra resources to calculate the actual encoded size."] #[inline] fn max_encoding_size(&self) -> usize { match self { Frame::Padding(f) => f.max_encoding_size(), Frame::Ping(f) => f.max_encoding_size(), Frame::Ack(f) => f.max_encoding_size(), Frame::Close(f) => f.max_encoding_size(), Frame::NewToken(f) => f.max_encoding_size(), Frame::MaxData(f) => f.max_encoding_size(), Frame::DataBlocked(f) => f.max_encoding_size(), Frame::NewConnectionId(f) => f.max_encoding_size(), Frame::RetireConnectionId(f) => f.max_encoding_size(), Frame::HandshakeDone(f) => f.max_encoding_size(), Frame::PathChallenge(f) => f.max_encoding_size(), Frame::PathResponse(f) => f.max_encoding_size(), Frame::StreamCtl(f) => f.max_encoding_size(), Frame::Stream(f, _) => f.max_encoding_size(), Frame::Crypto(f, _) => f.max_encoding_size(), Frame::Datagram(f, _) => f.max_encoding_size(), Frame::AddAddress(f) => f.max_encoding_size(), Frame::RemoveAddress(f) => f.max_encoding_size(), Frame::PunchMeNow(f) => f.max_encoding_size(), Frame::PunchHello(f) => f.max_encoding_size(), Frame::PunchDone(f) => f.max_encoding_size(), } } #[doc = " Return the exact number of bytes needed to encode this value"] #[inline] fn encoding_size(&self) -> usize { match self { Frame::Padding(f) => f.encoding_size(), Frame::Ping(f) => f.encoding_size(), Frame::Ack(f) => f.encoding_size(), Frame::Close(f) => f.encoding_size(), Frame::NewToken(f) => f.encoding_size(), Frame::MaxData(f) => f.encoding_size(), Frame::DataBlocked(f) => f.encoding_size(), Frame::NewConnectionId(f) => f.encoding_size(), Frame::RetireConnectionId(f) => f.encoding_size(), Frame::HandshakeDone(f) => f.encoding_size(), Frame::PathChallenge(f) => f.encoding_size(), Frame::PathResponse(f) => f.encoding_size(), Frame::StreamCtl(f) => f.encoding_size(), Frame::Stream(f, _) => f.encoding_size(), Frame::Crypto(f, _) => f.encoding_size(), Frame::Datagram(f, _) => f.encoding_size(), Frame::AddAddress(f) => f.encoding_size(), Frame::RemoveAddress(f) => f.encoding_size(), Frame::PunchMeNow(f) => f.encoding_size(), Frame::PunchHello(f) => f.encoding_size(), Frame::PunchDone(f) => f.encoding_size(), } } } /// Reads frames from a buffer until the packet buffer is empty. #[derive(Deref, DerefMut)] pub struct FrameReader { #[deref] #[deref_mut] payload: Bytes, packet_type: Type, } impl FrameReader { /// Creates a [`FrameReader`] for a packet of type `packet_type` pub fn new(payload: Bytes, packet_type: Type) -> Self { Self { payload, packet_type, } } } impl Iterator for FrameReader { type Item = Result<(Frame, FrameType), Error>; fn next(&mut self) -> Option { if self.payload.is_empty() { return None; } match io::be_frame(&self.payload, self.packet_type) { Ok((consumed, frame, frame_type)) => { self.payload.advance(consumed); Some(Ok((frame, frame_type))) } Err(e) => Some(Err(e)), } } } impl WriteFrame for T { fn put_frame(&mut self, frame: &StreamCtlFrame) { match frame { StreamCtlFrame::ResetStream(frame) => self.put_frame(frame), StreamCtlFrame::StopSending(frame) => self.put_frame(frame), StreamCtlFrame::MaxStreamData(frame) => self.put_frame(frame), StreamCtlFrame::MaxStreams(frame) => self.put_frame(frame), StreamCtlFrame::StreamDataBlocked(frame) => self.put_frame(frame), StreamCtlFrame::StreamsBlocked(frame) => self.put_frame(frame), } } } impl WriteFrame for T { fn put_frame(&mut self, frame: &ReliableFrame) { match frame { ReliableFrame::NewToken(frame) => self.put_frame(frame), ReliableFrame::MaxData(frame) => self.put_frame(frame), ReliableFrame::DataBlocked(frame) => self.put_frame(frame), ReliableFrame::NewConnectionId(frame) => self.put_frame(frame), ReliableFrame::RetireConnectionId(frame) => self.put_frame(frame), ReliableFrame::HandshakeDone(frame) => self.put_frame(frame), ReliableFrame::AddAddress(frame) => self.put_frame(frame), ReliableFrame::RemoveAddress(frame) => self.put_frame(frame), ReliableFrame::PunchMeNow(frame) => self.put_frame(frame), ReliableFrame::PunchDone(frame) => self.put_frame(frame), ReliableFrame::StreamCtl(frame) => self.put_frame(frame), } } } #[cfg(test)] mod tests { use std::net::SocketAddr; use nom::Parser; use super::*; use crate::{ net::Family, packet::{ PacketContent, r#type::{ Type, long::{Type::V1, Ver1}, short::OneRtt, }, }, varint::{WriteVarInt, be_varint}, }; #[test] fn test_frame_type_conversion() { let frame_types = vec![ FrameType::Padding, FrameType::Ping, FrameType::Ack(Ecn::None), FrameType::Stream(Offset::Zero, Len::Omit, Fin::No), FrameType::MaxData, FrameType::ConnectionClose(Layer::Quic), FrameType::HandshakeDone, FrameType::Datagram(0), ]; for frame_type in frame_types { let byte: VarInt = frame_type.into(); assert_eq!(FrameType::try_from(byte).unwrap(), frame_type); } } #[test] fn test_frame_type_specs() { assert!(FrameType::Padding.specs().contain(Spec::NonAckEliciting)); assert!( FrameType::Ack(Ecn::None) .specs() .contain(Spec::CongestionControlFree) ); assert!( FrameType::Stream(Offset::Zero, Len::Omit, Fin::No) .specs() .contain(Spec::FlowControlled) ); assert!(FrameType::PathChallenge.specs().contain(Spec::ProbeNewPath)); } #[test] fn test_frame_type_belongs_to() { let initial = Type::Long(V1(Ver1::INITIAL)); assert!(FrameType::Padding.belongs_to(initial)); assert!(FrameType::Ping.belongs_to(initial)); assert!(FrameType::Ack(Ecn::None).belongs_to(initial)); assert!(!FrameType::Stream(Offset::Zero, Len::Omit, Fin::No).belongs_to(initial)); } #[test] fn test_frame_reader() { let mut buf = bytes::BytesMut::new(); buf.put_u8(0x00); // PADDING buf.put_u8(0x01); // PING let packet_type = Type::Long(V1(Ver1::INITIAL)); let mut reader = FrameReader::new(buf.freeze(), packet_type); // Read PADDING frame let (frame, frame_type) = reader.next().unwrap().unwrap(); assert!(matches!(frame, Frame::Padding(_))); assert!(frame_type.specs().contain(Spec::NonAckEliciting)); // Read PING frame let (frame, frame_type) = reader.next().unwrap().unwrap(); assert!(matches!(frame, Frame::Ping(_))); assert!(!frame_type.specs().contain(Spec::NonAckEliciting)); // No more frames assert!(reader.next().is_none()); } #[test] fn test_invalid_frame_type() { assert!(FrameType::try_from(VarInt::from_u32(0xFF)).is_err()); } #[test] fn test_frame_reader_parses_add_address_frame() { use super::io::WriteFrame; let add_address = AddAddressFrame::new( 1, "127.0.0.1:4433".parse::().unwrap(), 2, crate::net::NatType::RestrictedPort, ); let expected = add_address; let mut buf = bytes::BytesMut::new(); buf.put_frame(&ReliableFrame::AddAddress(add_address)); let mut reader = FrameReader::new(buf.freeze(), Type::Short(OneRtt(0.into()))); let (frame, frame_type) = reader.next().unwrap().unwrap(); assert_eq!(frame_type, FrameType::AddAddress(Family::V4)); assert_eq!(frame, Frame::AddAddress(expected)); assert!(reader.next().is_none()); } #[test] fn test_frame_reader_rejects_add_address_frame_in_non_data_packets() { use super::io::WriteFrame; let mut buf = bytes::BytesMut::new(); buf.put_frame(&ReliableFrame::AddAddress(AddAddressFrame::new( 7, "127.0.0.1:8443".parse::().unwrap(), 4, crate::net::NatType::Dynamic, ))); for packet_type in [ Type::Long(V1(Ver1::INITIAL)), Type::Long(V1(Ver1::HANDSHAKE)), ] { let mut reader = FrameReader::new(buf.clone().freeze(), packet_type); assert_eq!( reader.next().unwrap().unwrap_err(), Error::WrongType(FrameType::AddAddress(Family::V4), packet_type) ); } } #[test] fn test_manual_unknown_custom_frame_fallback() { use crate::varint::WriteVarInt; #[derive(Debug, Clone, Eq, PartialEq)] struct UnknownCustomFrame { pub seq_num: VarInt, pub tire: VarInt, pub nat_type: VarInt, } fn be_unknown_custom_frame(input: &[u8]) -> nom::IResult<&[u8], UnknownCustomFrame> { use nom::{combinator::verify, sequence::preceded}; preceded( verify(be_varint, |typ| typ == &VarInt::from_u32(0xff)), (be_varint, be_varint, be_varint), ) .map(|(seq_num, tire, nat_type)| UnknownCustomFrame { seq_num, tire, nat_type, }) .parse(input) } fn parse_unknown_custom_frame(input: &[u8]) -> Result<(usize, UnknownCustomFrame), Error> { let origin = input.len(); let (remain, frame) = be_unknown_custom_frame(input).map_err(|_| { Error::IncompleteType(format!("Incomplete frame type from input: {input:?}")) })?; let consumed = origin - remain.len(); Ok((consumed, frame)) } impl super::io::WriteFrame for T { fn put_frame(&mut self, frame: &UnknownCustomFrame) { self.put_varint(&0xff_u32.into()); self.put_varint(&frame.seq_num); self.put_varint(&frame.tire); self.put_varint(&frame.nat_type); } } let mut buf = bytes::BytesMut::new(); let unknown_custom_frame = UnknownCustomFrame { seq_num: VarInt::from_u32(0x01), tire: VarInt::from_u32(0x02), nat_type: VarInt::from_u32(0x03), }; buf.put_frame(&unknown_custom_frame); buf.put_frame(&PaddingFrame); buf.put_frame(&PaddingFrame); buf.put_frame(&unknown_custom_frame); buf.put_varint(&0xfe_u32.into()); let mut padding_count = 0; let mut unknown_custom_count = 0; let mut reader = FrameReader::new(buf.freeze(), Type::Short(OneRtt(0.into()))); loop { match reader.next() { Some(Ok((frame, typ))) => { assert!(matches!(frame, Frame::Padding(_))); assert_eq!(typ, FrameType::Padding); padding_count += 1; } Some(Err(_e)) => { // Parse the unknown custom frame manually. if let Ok((consum, frame)) = parse_unknown_custom_frame(&reader) { reader.advance(consum); assert_eq!(frame, unknown_custom_frame); unknown_custom_count += 1; } else { reader.clear(); } } None => break, }; } assert_eq!(padding_count, 2); assert_eq!(unknown_custom_count, 2); } #[test] fn test_frame_reader_stops_at_unknown_custom_frame() { let mut buf = bytes::BytesMut::new(); buf.put_frame(&PaddingFrame); buf.put_frame(&PaddingFrame); // error frame type buf.put_varint(&0xfe_u32.into()); buf.put_frame(&PaddingFrame); let mut padding_count = 0; let _ = FrameReader::new(buf.freeze(), Type::Short(OneRtt(0.into()))).try_fold( PacketContent::default(), |packet_contains, frame| { let (frame, frame_type) = frame?; assert!(matches!(frame, Frame::Padding(_))); assert_eq!(frame_type, FrameType::Padding); padding_count += 1; Result::<_, Error>::Ok(packet_contains) }, ); assert_eq!(padding_count, 2); } } ================================================ FILE: qbase/src/handshake.rs ================================================ use std::sync::{ Arc, atomic::{AtomicBool, Ordering}, }; use crate::{ error::{Error, ErrorKind, QuicError}, frame::{ HandshakeDoneFrame, io::{ReceiveFrame, SendFrame}, }, role::Role, }; /// The completion flag for the client handshake. /// /// The client considers the handshake complete only after /// receiving the [`HandshakeDoneFrame`] from the server. /// In the QUIC protocol, there are no tasks that specifically /// require waiting for the client handshake to complete. /// Instead, it simply queries the handshake status. #[derive(Debug, Default, Clone)] pub struct ClientHandshake { done: Arc, } impl ClientHandshake { /// Check if the client handshake is complete. pub fn is_handshake_done(&self) -> bool { self.done.load(Ordering::Acquire) } /// Receive the HANDSHAKE_DONE frame. /// /// Once the client receives the HANDSHAKE_DONE frame, /// it marks the completion of the client handshake. /// /// Return whether it is the first time to receive the HANDSHAKE_DONE frame. pub fn recv_handshake_done_frame(&self, _frame: HandshakeDoneFrame) -> bool { !self.done.swap(true, Ordering::AcqRel) } } /// Server's handshake status. /// /// - `T` is responsible for reliably sending [`HandshakeDoneFrame`] to the client. /// It can be a channel, a queue, or a buffer. Whatever, it must be able to send the /// [`HandshakeDoneFrame`] to the client. /// /// The server considers the handshake complete only after receiving /// the [finished message](https://www.rfc-editor.org/rfc/rfc8446.html#section-4.4.4) /// from the client during the TLS handshake process. /// If the [finished message](https://www.rfc-editor.org/rfc/rfc8446.html#section-4.4.4) /// from the TLS handshake is not received, /// the server can also consider the handshake complete upon receiving and /// successfully decrypting the client's 1-RTT packet. /// Once the server's handshake is complete, the server will send a [`HandshakeDoneFrame`] immediately. #[derive(Debug, Clone)] pub struct ServerHandshake where T: SendFrame + Clone, { is_done: Arc, output: T, } impl ServerHandshake where T: SendFrame + Clone, { /// Create a new server handshake signal. /// /// The `output` is responsible for sending the [`HandshakeDoneFrame`] to the client, /// see [`ServerHandshake`]. pub fn new(output: T) -> Self { ServerHandshake { is_done: Arc::new(AtomicBool::new(false)), output, } } /// Check if the server handshake is complete. pub fn is_handshake_done(&self) -> bool { self.is_done.load(Ordering::Acquire) } /// Actively set the server's handshake status to complete. /// /// Call this method when the TLS handshake /// [finished message](https://www.rfc-editor.org/rfc/rfc8446.html#section-4.4.4) is received. /// If the TLS handshake completion message is not received, /// receiving and successfully decrypting the client's 1-RTT packet /// is also considered handshake completion. /// Servers MUST NOT send a [`HandshakeDoneFrame`] before completing the handshake. /// and once the server handshake is complete, /// servers should send the [`HandshakeDoneFrame`] immediately. /// See [`ServerHandshake`]. /// /// This method return [`true`] when it first time set the handshake status to complete. pub fn done(&self) -> bool { if self .is_done .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) .is_ok() { self.output.send_frame([HandshakeDoneFrame]); true } else { false } } } /// A merged handshake state that can be used by both the client and the server. /// /// For convenience, a unified [`Handshake`]` should be used, /// which will internally choose the corresponding behavior based on the role. #[derive(Debug, Clone)] pub enum Handshake where T: SendFrame + Clone, { /// The client's handshake state if the endpoint is a client. Client(ClientHandshake), /// The server's handshake state if the endpoint is a server. Server(ServerHandshake), } impl Handshake where T: SendFrame + Clone, { /// Create a new handshake state, based on the role. pub fn new(role: Role, output: T) -> Self { match role { Role::Client => Handshake::Client(ClientHandshake::default()), Role::Server => Handshake::Server(ServerHandshake::new(output)), } } /// Create a new client handshake state. pub fn new_client() -> Self { Handshake::Client(ClientHandshake::default()) } /// Create a new server handshake state. /// The `output` is responsible for sending the [`HandshakeDoneFrame`] to the client, /// see [`ServerHandshake::new`]. pub fn new_server(output: T) -> Self { Handshake::Server(ServerHandshake::new(output)) } /// Check if the handshake is complete. pub fn is_handshake_done(&self) -> bool { match self { Handshake::Client(h) => h.is_handshake_done(), Handshake::Server(h) => h.is_handshake_done(), } } /// Set the handshake status to complete(for server) /// /// For client, this method does nothing and always returns [`false`]. /// /// This method return [`true`] when it first time set the handshake status to complete. pub fn done(&self) -> bool { match self { Handshake::Client(..) => false, /* for client, do nothing */ Handshake::Server(h) => h.done(), } } /// Return the role of this handshake signal. pub fn role(&self) -> Role { match self { Handshake::Client(_) => Role::Client, Handshake::Server(_) => Role::Server, } } } impl ReceiveFrame for Handshake where T: SendFrame + Clone, { type Output = bool; /// Receive the [`HandshakeDoneFrame`]. /// /// A [`HandshakeDoneFrame`] can only be received by the client. /// A server MUST treat receipt of a [`HandshakeDoneFrame`] /// as a connection error of type PROTOCOL_VIOLATION. /// See [section 19.20](https://www.rfc-editor.org/rfc/rfc9000.html#section-19.20) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html). /// /// Return whether it is the first time to receive the HANDSHAKE_DONE frame(for client). fn recv_frame(&self, frame: HandshakeDoneFrame) -> Result { match self { Handshake::Client(h) => Ok(h.recv_handshake_done_frame(frame)), _ => Err(QuicError::with_default_fty( ErrorKind::ProtocolViolation, "Server received a HANDSHAKE_DONE frame", ) .into()), } } } #[cfg(test)] mod tests { use derive_more::Deref; use super::*; use crate::{ error::ErrorKind, frame::io::{ReceiveFrame, SendFrame}, util::ArcAsyncDeque, }; #[derive(Debug, Default, Clone, Deref)] struct HandshakeDoneFrameTx(ArcAsyncDeque); impl SendFrame for HandshakeDoneFrameTx { fn send_frame>(&self, iter: I) { (&self.0).extend(iter); } } #[test] fn test_client_handshake() { let handshake = Handshake::::new_client(); assert!(!handshake.is_handshake_done()); let ret = handshake.recv_frame(HandshakeDoneFrame); assert!(ret.is_ok()); assert!(handshake.is_handshake_done()); } #[test] fn test_client_handshake_done() { let handshake = Handshake::::new_client(); assert!(!handshake.is_handshake_done()); assert!(handshake.recv_frame(HandshakeDoneFrame).unwrap()); assert!(handshake.is_handshake_done()); // recv_frame will only return `true` once when handshake first done assert!(!handshake.recv_frame(HandshakeDoneFrame).unwrap()); assert!(handshake.is_handshake_done()); } #[test] fn test_server_handshake() { let handshake = Handshake::new_server(HandshakeDoneFrameTx::default()); assert!(!handshake.is_handshake_done()); assert!(handshake.done()); assert!(handshake.is_handshake_done()); // same as last test assert!(!handshake.done()); assert!(handshake.is_handshake_done()); } #[test] fn test_server_recv_handshake_done_frame() { let handshake = Handshake::new_server(HandshakeDoneFrameTx::default()); assert!(!handshake.is_handshake_done()); let ret = handshake.recv_frame(HandshakeDoneFrame); assert_eq!( ret, Err(QuicError::with_default_fty( ErrorKind::ProtocolViolation, "Server received a HANDSHAKE_DONE frame", ) .into()) ); } #[test] fn test_server_send_handshake_done_frame() { let handshake = ServerHandshake::new(HandshakeDoneFrameTx::default()); handshake.done(); assert!(handshake.is_handshake_done()); assert_eq!(handshake.output.len(), 1); } } ================================================ FILE: qbase/src/lib.rs ================================================ #![allow(clippy::all)] //! # The QUIC base library //! //! The `qbase` library defines the necessary basic structures in the QUIC protocol, //! including connection IDs, stream IDs, frames, packets, keys, parameters, error codes, etc. //! //! Additionally, based on these basic structures, //! it defines components for various mechanisms in QUIC, //! including flow control, handshake, tokens, stream ID management, connection ID management, etc. //! //! Finally, the `qbase` module also defines some utility functions //! for handling common data structures in the QUIC protocol. //! #![allow(clippy::all)] use std::{ ops::{Index, IndexMut}, pin::Pin, sync::{Arc, Mutex}, task::{Context, Poll, Waker}, }; use futures::FutureExt; use thiserror::Error; /// Operations about QUIC connection IDs. pub mod cid; /// [QUIC errors](https://www.rfc-editor.org/rfc/rfc9000.html#name-error-codes). pub mod error; /// QUIC connection-level flow control. pub mod flow; /// QUIC frames and their codec. pub mod frame; /// Handshake signal for QUIC connections. pub mod handshake; /// QUIC connection metrics for tracking data volumes. pub mod metric; /// Endpoint address and Pathway. pub mod net; /// QUIC packets and their codec. pub mod packet; /// [QUIC transport parameters and their codec](https://www.rfc-editor.org/rfc/rfc9000.html#name-transport-parameter-encodin). pub mod param; /// QUIC client and server roles. pub mod role; /// Stream id types and controllers for different roles and different directions. pub mod sid; /// Max idle timer and defer idle timer. pub mod time; /// Issuing, storing and verifing tokens operations. pub mod token; /// Utilities for common data structures. pub mod util; /// [Variable-length integers](https://www.rfc-editor.org/rfc/rfc9000.html#name-variable-length-integer-enc). pub mod varint; /// The epoch of sending, usually been seen as the index of spaces. #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)] pub enum Epoch { Initial = 0, Handshake = 1, Data = 2, } pub trait GetEpoch { fn epoch(&self) -> Epoch; } impl Epoch { pub const EPOCHS: [Epoch; 3] = [Epoch::Initial, Epoch::Handshake, Epoch::Data]; /// An iterator for the epoch of each spaces. /// /// Equals to `Epoch::EPOCHES.iter()` pub fn iter() -> std::slice::Iter<'static, Epoch> { Self::EPOCHS.iter() } /// The number of epoches. pub const fn count() -> usize { Self::EPOCHS.len() } } impl Index for [T] where T: Sized, { type Output = T; fn index(&self, index: Epoch) -> &Self::Output { self.index(index as usize) } } impl IndexMut for [T] where T: Sized, { fn index_mut(&mut self, index: Epoch) -> &mut Self::Output { self.index_mut(index as usize) } } #[derive(Debug, Default)] pub enum Receiving { #[default] Pending, Waiting(Waker), Rcvd(F), Read, Reset, } impl Receiving { fn recv_frame(&mut self, frame: F) { match std::mem::take(self) { Self::Pending => { *self = Self::Rcvd(frame); } Self::Waiting(waker) => { waker.wake(); *self = Self::Rcvd(frame); } _ => (), } } fn reset(&mut self) { if let Self::Waiting(waker) = std::mem::replace(self, Self::Reset) { waker.wake(); } } } #[derive(Debug, Error)] #[error("Reset")] pub struct ResetError; #[derive(Debug, Default, Clone)] pub struct ArcReceiving(Arc>>); impl ArcReceiving { pub fn reset(&self) { self.0.lock().unwrap().reset(); } } impl Future for ArcReceiving { type Output = Result, ResetError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.0.lock().unwrap().poll_unpin(cx) } } #[cfg(test)] mod tests {} ================================================ FILE: qbase/src/metric.rs ================================================ use std::sync::{ Arc, atomic::{AtomicU64, Ordering}, }; /// Metrics for tracking data volumes in a QUIC connection. /// /// This struct provides atomic counters to track: /// - Data written by application but not yet sent /// - Data sent but not yet acknowledged /// - Data sent and acknowledged #[derive(Debug, Default)] pub struct ConnectionMetrics { /// Data written by application layer but not yet sent by transport layer pending_bytes: AtomicU64, /// Data sent by transport layer but not yet acknowledged by peer inflight_bytes: AtomicU64, /// Data sent and acknowledged by peer acked_bytes: AtomicU64, } impl ConnectionMetrics { /// Increments the pending send bytes counter when application writes data. /// /// Called when application layer writes data to a stream. pub fn new_pending(&self, bytes: u64) { self.pending_bytes.fetch_add(bytes, Ordering::Relaxed); } /// Updates counters when transport layer sends new data. /// /// Increments sent_unacked_bytes and decrements pending_send_bytes. /// Called when transport layer sends new stream data. pub fn on_data_sent(&self, bytes: u64) { self.inflight_bytes.fetch_add(bytes, Ordering::Relaxed); self.pending_bytes.fetch_sub(bytes, Ordering::Relaxed); } /// Updates counters when data is acknowledged by peer. /// /// Increments sent_acked_bytes and decrements sent_unacked_bytes. /// Called when receiving acknowledgment for stream data. pub fn on_data_acked(&self, bytes: u64) { self.acked_bytes.fetch_add(bytes, Ordering::Relaxed); self.inflight_bytes.fetch_sub(bytes, Ordering::Relaxed); } /// Gets the current amount of data pending to be sent. pub fn pending_bytes(&self) -> u64 { self.pending_bytes.load(Ordering::Relaxed) } /// Gets the current amount of data sent but not acknowledged. pub fn inflight_bytes(&self) -> u64 { self.inflight_bytes.load(Ordering::Relaxed) } /// Gets the total amount of data sent and acknowledged. pub fn acked_bytes(&self) -> u64 { self.acked_bytes.load(Ordering::Relaxed) } } /// Arc-wrapped ConnectionMetrics for shared ownership across the connection. pub type ArcConnectionMetrics = Arc; #[cfg(test)] mod tests { use super::*; #[test] fn test_connection_metrics_new() { let metrics = ConnectionMetrics::default(); assert_eq!(metrics.pending_bytes(), 0); assert_eq!(metrics.inflight_bytes(), 0); assert_eq!(metrics.acked_bytes(), 0); } #[test] fn test_add_pending_send() { let metrics = ConnectionMetrics::default(); metrics.new_pending(100); assert_eq!(metrics.pending_bytes(), 100); metrics.new_pending(50); assert_eq!(metrics.pending_bytes(), 150); } #[test] fn test_on_data_sent() { let metrics = ConnectionMetrics::default(); metrics.new_pending(200); metrics.on_data_sent(150); assert_eq!(metrics.pending_bytes(), 50); assert_eq!(metrics.inflight_bytes(), 150); } #[test] fn test_on_data_acked() { let metrics = ConnectionMetrics::default(); metrics.new_pending(200); metrics.on_data_sent(150); metrics.on_data_acked(100); assert_eq!(metrics.pending_bytes(), 50); assert_eq!(metrics.inflight_bytes(), 50); assert_eq!(metrics.acked_bytes(), 100); } #[test] fn test_full_data_flow() { let metrics = ConnectionMetrics::default(); // Application writes 1000 bytes metrics.new_pending(1000); assert_eq!(metrics.pending_bytes(), 1000); assert_eq!(metrics.inflight_bytes(), 0); assert_eq!(metrics.acked_bytes(), 0); // Transport layer sends 600 bytes metrics.on_data_sent(600); assert_eq!(metrics.pending_bytes(), 400); assert_eq!(metrics.inflight_bytes(), 600); assert_eq!(metrics.acked_bytes(), 0); // Peer acknowledges 300 bytes metrics.on_data_acked(300); assert_eq!(metrics.pending_bytes(), 400); assert_eq!(metrics.inflight_bytes(), 300); assert_eq!(metrics.acked_bytes(), 300); // Transport layer sends remaining 400 bytes metrics.on_data_sent(400); assert_eq!(metrics.pending_bytes(), 0); assert_eq!(metrics.inflight_bytes(), 700); assert_eq!(metrics.acked_bytes(), 300); // Peer acknowledges all remaining data metrics.on_data_acked(700); assert_eq!(metrics.pending_bytes(), 0); assert_eq!(metrics.inflight_bytes(), 0); assert_eq!(metrics.acked_bytes(), 1000); } #[test] fn test_arc_connection_metrics() { let metrics = Arc::new(ConnectionMetrics::default()); let metrics_clone = Arc::clone(&metrics); metrics.new_pending(100); assert_eq!(metrics_clone.pending_bytes(), 100); metrics_clone.on_data_sent(100); assert_eq!(metrics.inflight_bytes(), 100); assert_eq!(metrics.pending_bytes(), 0); } } ================================================ FILE: qbase/src/net/addr.rs ================================================ use std::{ fmt::Display, net::{AddrParseError, SocketAddr}, ops::Deref, str::FromStr, }; use bytes::BufMut; use serde::{Deserialize, Serialize}; use crate::net::{Family, be_socket_addr}; #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum EndpointAddr { Direct { addr: SocketAddr, }, Agent { agent: SocketAddr, outer: SocketAddr, }, } impl EndpointAddr { pub fn direct(addr: SocketAddr) -> Self { EndpointAddr::Direct { addr } } pub fn with_agent(agent: SocketAddr, outer: SocketAddr) -> Self { EndpointAddr::Agent { agent, outer } } /// Returns the outer addr of this EndpointAddr /// /// Note: Before successful hole punching with this Endpoint, packets should be sent to the addr /// returned by deref() to establish communication. Once hole punching is successful or about to /// begin, use the addr returned by this function. pub fn addr(&self) -> SocketAddr { match self { EndpointAddr::Direct { addr } => *addr, EndpointAddr::Agent { outer, .. } => *outer, } } pub fn encoding_size(&self) -> usize { match self { EndpointAddr::Direct { addr: SocketAddr::V4(_), } => 2 + 4, EndpointAddr::Direct { addr: SocketAddr::V6(_), } => 2 + 16, EndpointAddr::Agent { agent: SocketAddr::V4(_), outer: SocketAddr::V4(_), } => 2 + 4 + 2 + 4, EndpointAddr::Agent { agent: SocketAddr::V6(_), outer: SocketAddr::V6(_), } => 2 + 16 + 2 + 16, _ => unimplemented!("Unix socket addresses are not supported"), } } } pub trait WriteEndpointAddr { fn put_endpoint_addr(&mut self, endpoint: EndpointAddr); } impl WriteEndpointAddr for T { fn put_endpoint_addr(&mut self, endpoint: EndpointAddr) { use crate::net::WriteSocketAddr; match endpoint { EndpointAddr::Direct { addr } => self.put_socket_addr(&addr), EndpointAddr::Agent { agent, outer: inner, } => { self.put_socket_addr(&agent); self.put_socket_addr(&inner); } } } } pub fn be_endpoint_addr( input: &[u8], relay: u8, family: Family, ) -> nom::IResult<&[u8], EndpointAddr> { if relay != 0 { let (remain, agent) = be_socket_addr(input, family)?; let (remain, outer) = be_socket_addr(remain, family)?; Ok((remain, EndpointAddr::with_agent(agent, outer))) } else { let (remain, addr) = be_socket_addr(input, family)?; Ok((remain, EndpointAddr::direct(addr))) } } impl Display for EndpointAddr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { EndpointAddr::Direct { addr } => write!(f, "{addr}"), EndpointAddr::Agent { agent, outer } => write!(f, "{agent}-{outer}"), } } } impl Deref for EndpointAddr { type Target = SocketAddr; fn deref(&self) -> &Self::Target { match self { EndpointAddr::Direct { addr } => addr, EndpointAddr::Agent { agent, .. } => agent, } } } impl FromStr for EndpointAddr { type Err = AddrParseError; fn from_str(s: &str) -> Result { if let Some((first, second)) = s.split_once("-") { // Agent format: "inet:1.12.124.56:1234-inet:202.106.68.43:6080" let agent = first.trim().parse()?; let outer = second.trim().parse()?; Ok(EndpointAddr::with_agent(agent, outer)) } else { // Direct format: "1.12.124.56:1234" let addr = s.trim().parse()?; Ok(EndpointAddr::direct(addr)) } } } impl From for EndpointAddr { fn from(addr: SocketAddr) -> Self { EndpointAddr::direct(addr) } } impl From<(SocketAddr, SocketAddr)> for EndpointAddr { fn from((agent, outer): (SocketAddr, SocketAddr)) -> Self { EndpointAddr::with_agent(agent, outer) } } ================================================ FILE: qbase/src/net/nat.rs ================================================ use std::io; use crate::varint::VarInt; bitflags::bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct NetFeature: u8 { const Blocked = 0x01; const Public = 0x02; const Restricted = 0x04; const PortRestricted = 0x08; const Symmetric = 0x10; const Dynamic = 0x20; } } #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] pub enum NatType { Blocked = 0x00, FullCone = 0x01, RestrictedCone = 0x02, RestrictedPort = 0x03, Symmetric = 0x04, Dynamic = 0x05, } impl From for VarInt { fn from(nat_type: NatType) -> Self { VarInt::from(nat_type as u8) } } impl TryFrom for NatType { type Error = io::Error; fn try_from(value: u8) -> Result { match value { 0x00 => Ok(NatType::Blocked), 0x01 => Ok(NatType::FullCone), 0x02 => Ok(NatType::RestrictedCone), 0x03 => Ok(NatType::RestrictedPort), 0x04 => Ok(NatType::Symmetric), 0x05 => Ok(NatType::Dynamic), _ => Err(io::Error::new( io::ErrorKind::InvalidInput, "Invalid value for NatType", )), } } } impl TryFrom for NatType { type Error = io::Error; fn try_from(value: VarInt) -> Result { Self::try_from(value.into_u64() as u8) } } impl From for NatType { fn from(value: NetFeature) -> Self { if value.contains(NetFeature::Blocked) { NatType::Blocked } else if value.contains(NetFeature::Symmetric) { NatType::Symmetric } else if value.contains(NetFeature::Dynamic) { NatType::Dynamic } else if value.contains(NetFeature::PortRestricted) { NatType::RestrictedPort } else if value.contains(NetFeature::Restricted) { NatType::RestrictedCone } else { NatType::FullCone } } } ================================================ FILE: qbase/src/net/route.rs ================================================ use std::{fmt::Display, net::SocketAddr}; use bytes::BufMut; use derive_more::{Deref, DerefMut}; use nom::number::streaming::be_u8; use serde::{Deserialize, Serialize}; use crate::{ frame::EncodeSize, net::{Family, addr::EndpointAddr, be_socket_addr}, }; #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct Pathway { local: E, remote: E, } impl Pathway { #[inline] pub fn new(local: E, remote: E) -> Self { Self { local, remote } } #[inline] pub fn local(&self) -> E where E: Clone, { self.local.clone() } #[inline] pub fn remote(&self) -> E where E: Clone, { self.remote.clone() } #[inline] pub fn map(self, mut f: impl FnMut(E) -> E1) -> Pathway { Pathway { local: f(self.local), remote: f(self.remote), } } #[inline] pub fn flip(self) -> Self { Self { local: self.remote, remote: self.local, } } } impl Display for Pathway { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}---{}", self.local, self.remote) } } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] pub struct Link { pub src: SocketAddr, pub dst: SocketAddr, } impl Display for Link { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}<->{}", self.src, self.dst) } } pub fn be_link(input: &[u8]) -> nom::IResult<&[u8], Link> { let (remain, family) = be_u8(input)?; let family = match family { 0 => Family::V4, 1 => Family::V6, _ => { return Err(nom::Err::Error(nom::error::Error::new( input, nom::error::ErrorKind::Alt, ))); } }; let (remain, src) = be_socket_addr(remain, family)?; let (remain, dst) = be_socket_addr(remain, family)?; Ok((remain, Link { src, dst })) } pub trait WriteLink { fn put_link(&mut self, link: &Link); } impl WriteLink for T { fn put_link(&mut self, link: &Link) { use crate::net::WriteSocketAddr; self.put_u8(link.src.is_ipv6() as u8); self.put_socket_addr(&link.src); self.put_socket_addr(&link.dst); } } impl EncodeSize for Link { fn max_encoding_size(&self) -> usize { 1 + self.src.max_encoding_size() + self.dst.max_encoding_size() } fn encoding_size(&self) -> usize { 1 + self.src.encoding_size() + self.dst.encoding_size() } } impl Link { #[inline] pub fn new(src: SocketAddr, dst: SocketAddr) -> Self { Self { src, dst } } #[inline] pub fn flip(self) -> Self { Self { src: self.dst, dst: self.src, } } } impl> From for Pathway { fn from(link: Link) -> Self { Pathway::new(E::from(link.src), E::from(link.dst)) } } #[derive(Clone, Copy, Debug, Deref, DerefMut)] pub struct Line { #[deref] #[deref_mut] pub link: Link, pub ttl: u8, // Explicit congestion notification (ECN) pub ecn: Option, // packet segment size pub seg_size: u16, } impl Line { pub const DEFAULT_TTL: u8 = 64; pub fn new(link: Link, ttl: u8, ecn: Option, seg_size: u16) -> Self { Self { link, ttl, ecn, seg_size, } } } impl Default for Line { fn default() -> Self { Self { link: Link::new( SocketAddr::from(([0, 0, 0, 0], 0)), SocketAddr::from(([0, 0, 0, 0], 0)), ), ttl: Self::DEFAULT_TTL, ecn: None, seg_size: 0, } } } #[derive(Debug, Clone, Copy, Deref, DerefMut)] pub struct Route { pub pathway: Pathway, #[deref] #[deref_mut] pub line: Line, } impl Route { pub fn new(pathway: Pathway, line: Line) -> Self { Self { pathway, line } } /// Create a new empty packet header for receive packets. pub fn empty() -> Self { let src = SocketAddr::from(([0, 0, 0, 0], 0)); let dst = SocketAddr::from(([0, 0, 0, 0], 0)); let link = Link::new(SocketAddr::from(src), SocketAddr::from(dst)); Self::new(link.into(), Line::default()) } pub fn pathway(&self) -> Pathway { self.pathway } pub fn link(&self) -> Link { self.line.link } pub fn ttl(&self) -> u8 { self.line.ttl } pub fn ecn(&self) -> Option { self.line.ecn } pub fn seg_size(&self) -> u16 { self.line.seg_size } } #[cfg(test)] mod tests { use super::*; #[test] fn test_endpoint_addr_from_str() { // Test direct format let addr = "127.0.0.1:8080".parse::().unwrap(); assert!(matches!(addr, EndpointAddr::Direct { .. })); // Test agent format let addr = "127.0.0.1:8080-192.168.1.1:9000" .parse::() .unwrap(); assert!(matches!(addr, EndpointAddr::Agent { .. })); // Test with whitespace let addr = " 127.0.0.1:8080 - 192.168.1.1:9000 " .parse::() .unwrap(); assert!(matches!(addr, EndpointAddr::Agent { .. })); // Test invalid format assert!("invalid".parse::().is_err()); } } ================================================ FILE: qbase/src/net/tx.rs ================================================ use std::{ collections::BTreeMap, future::poll_fn, sync::{Arc, Mutex, MutexGuard}, task::{Context, Poll, Waker}, }; use super::route::Pathway; type SignalsBits = u16; bitflags::bitflags! { #[derive(Debug, Clone, Copy,PartialEq, Eq)] pub struct Signals: SignalsBits { const CONGESTION = 1 << 0; // cc const FLOW_CONTROL = 1 << 1; // flow const TRANSPORT = 1 << 2; // ack/retran/reliable.... const WRITTEN = 1 << 3; // fresh stream const CONNECTION_ID = 1 << 4; // cid const CREDIT = 1 << 5; // aa const KEYS = 1 << 6; // key(no waker in SendWaker) const PING = 1 << 7; // packet which contains ping frames only const TLS_FIN = 1 << 8; // TLS handshake is required to send and receive 1rtt data const PATH_VALIDATE = 1 << 9; // path validated } } #[derive(Default, Debug)] pub struct SendWaker { waker: Option, // Signals 对应的bit设置为1意为该位的条件已经满足,为0表示需要该条件满足 state: SignalsBits, } impl SendWaker { pub fn new() -> Self { Self::default() } const WAITING: SignalsBits = 0; #[inline] pub fn poll_wait_for(&mut self, cx: &mut Context, signals: Signals) -> Poll<()> { if self.state & signals.bits() == 0 { self.state = !signals.bits(); match self.waker.as_ref() { Some(old_waker) if old_waker.will_wake(cx.waker()) => {} _ => self.waker = Some(cx.waker().clone()), } Poll::Pending } else { self.state = Self::WAITING; Poll::Ready(()) } } #[inline] fn wake_by(&mut self, signals: Signals) { if self.state | signals.bits() != self.state { if let Some(waker) = self.waker.as_ref() { waker.wake_by_ref(); } } self.state |= signals.bits(); } } unsafe impl Send for SendWaker {} unsafe impl Sync for SendWaker {} #[derive(Debug, Default, Clone)] pub struct ArcSendWaker(Arc>); impl ArcSendWaker { #[inline] pub fn new() -> Self { Self(Arc::new(Mutex::new(SendWaker::new()))) } #[inline] pub async fn wait_for(&self, signals: Signals) { poll_fn(|cx| self.0.lock().unwrap().poll_wait_for(cx, signals)).await } #[inline] pub fn wake_by(&self, signals: Signals) { self.0.lock().unwrap().wake_by(signals); } } /// connection level send wakers #[derive(Debug, Default)] pub struct SendWakers { last_woken: Option, paths: BTreeMap, } impl SendWakers { #[inline] pub fn new() -> Self { Default::default() } #[inline] pub fn insert(&mut self, pathway: Pathway, waker: &ArcSendWaker) { self.paths.entry(pathway).or_insert_with(|| waker.clone()); } #[inline] pub fn remove(&mut self, pathway: &Pathway) { self.paths.remove(pathway); } #[inline] pub fn wake_all_by(&mut self, signals: Signals) { fn wake_all_by<'a>( paths: impl IntoIterator, signals: Signals, ) -> Option { let mut paths = paths.into_iter().peekable(); let first_path = paths.peek().map(|(pathway, _)| pathway).copied().copied(); paths.for_each(|(_, waker)| { waker.wake_by(signals); }); first_path } use std::ops::Bound::*; self.last_woken = match self.last_woken { Some(last_woken) => wake_all_by( self.paths .range((Excluded(last_woken), Unbounded)) .chain(self.paths.range((Unbounded, Included(last_woken)))), signals, ), None => wake_all_by(self.paths.range(..), signals), } } } #[derive(Default, Debug, Clone)] pub struct ArcSendWakers(Arc>); impl ArcSendWakers { #[inline] pub fn new() -> Self { Self::default() } fn lock_guard(&self) -> MutexGuard<'_, SendWakers> { self.0.lock().unwrap() } #[inline] pub fn insert(&self, pathway: Pathway, waker: &ArcSendWaker) { self.lock_guard().insert(pathway, waker); } #[inline] pub fn remove(&self, pathway: &Pathway) { self.lock_guard().remove(pathway); } #[inline] pub fn wake_all_by(&self, signals: Signals) { self.lock_guard().wake_all_by(signals); } } #[cfg(test)] mod tests { use std::sync::atomic::{AtomicUsize, Ordering::*}; impl ArcSendWaker { fn state(&self) -> SignalsBits { self.0.lock().unwrap().state } } use super::*; #[tokio::test] async fn single_condition() { let waker = ArcSendWaker::new(); let woken_times = Arc::new(AtomicUsize::new(0)); tokio::spawn({ let waker = waker.clone(); let wake_times = woken_times.clone(); async move { loop { waker.wait_for(Signals::CONGESTION).await; wake_times.fetch_add(1, Release); } } }); waker.wake_by(Signals::FLOW_CONTROL); tokio::task::yield_now().await; assert_eq!(woken_times.load(Acquire), 0); // not woken waker.wake_by(Signals::TRANSPORT); tokio::task::yield_now().await; assert_eq!(woken_times.load(Acquire), 0); // not woken waker.wake_by(Signals::CONGESTION); tokio::task::yield_now().await; assert_eq!(woken_times.load(Acquire), 1); // woken } #[tokio::test] async fn all_condition() { let waker = ArcSendWaker::new(); let woken_times = Arc::new(AtomicUsize::new(0)); tokio::spawn({ let waker = waker.clone(); let wake_times = woken_times.clone(); async move { loop { waker.wait_for(Signals::all()).await; wake_times.fetch_add(1, Release); } } }); let wait_for_all_cond_state = !Signals::all().bits(); waker.wake_by(Signals::FLOW_CONTROL); tokio::task::yield_now().await; assert_eq!(woken_times.load(Acquire), 1); // woken assert_eq!(waker.state(), wait_for_all_cond_state); waker.wake_by(Signals::TRANSPORT); tokio::task::yield_now().await; assert_eq!(woken_times.load(Acquire), 2); // woken assert_eq!(waker.state(), wait_for_all_cond_state); waker.wake_by(Signals::CONGESTION); tokio::task::yield_now().await; assert_eq!(woken_times.load(Acquire), 3); // woken assert_eq!(waker.state(), wait_for_all_cond_state); } #[tokio::test] async fn wake_before_register() { let waker = ArcSendWaker::new(); let woken_times = Arc::new(AtomicUsize::new(0)); waker.wake_by(Signals::CONGESTION); // pre set woken state tokio::spawn({ let waker = waker.clone(); let wake_times = woken_times.clone(); async move { loop { waker.wait_for(Signals::CONGESTION).await; wake_times.fetch_add(1, Release); } } }); let wait_for_quota_state = !Signals::CONGESTION.bits(); tokio::task::yield_now().await; assert_eq!(woken_times.load(Acquire), 1); // woken assert_eq!(waker.state(), wait_for_quota_state); } #[tokio::test] async fn state_change() { let waker = ArcSendWaker::new(); let woken_times = Arc::new(AtomicUsize::new(0)); tokio::spawn({ let waker = waker.clone(); let wake_times = woken_times.clone(); let wait_for = move |r#for| { let wake_times = wake_times.clone(); let waker = waker.clone(); async move { waker.wait_for(r#for).await; wake_times.fetch_add(1, Release); } }; async move { wait_for(Signals::all()).await; wait_for(Signals::CONGESTION | Signals::TRANSPORT).await; wait_for(Signals::TRANSPORT).await; } }); let wait_for_all_cond_state = !Signals::all().bits(); let wait_for_quota_state = !(Signals::CONGESTION | Signals::TRANSPORT).bits(); let wait_for_data_state = !Signals::TRANSPORT.bits(); tokio::task::yield_now().await; assert_eq!(woken_times.load(Acquire), 0); // not woken assert_eq!(waker.state(), wait_for_all_cond_state); waker.wake_by(Signals::TRANSPORT); // all condition will be met tokio::task::yield_now().await; assert_eq!(woken_times.load(Acquire), 1); // woken assert_eq!(waker.state(), wait_for_quota_state); waker.wake_by(Signals::CONGESTION); // quota\data will be met tokio::task::yield_now().await; assert_eq!(woken_times.load(Acquire), 2); // woken assert_eq!(waker.state(), wait_for_data_state); waker.wake_by(Signals::CONGESTION); // only data will be met tokio::task::yield_now().await; assert_eq!(woken_times.load(Acquire), 2); // not woken waker.wake_by(Signals::FLOW_CONTROL); // only data will be met tokio::task::yield_now().await; assert_eq!(woken_times.load(Acquire), 2); // not woken waker.wake_by(Signals::TRANSPORT); // only data will be met tokio::task::yield_now().await; assert_eq!(woken_times.load(Acquire), 3); // woken assert_eq!(waker.state(), SendWaker::WAITING); // state reset } #[tokio::test] async fn mult_wake_signals() { let waker = ArcSendWaker::new(); let woken_times = Arc::new(AtomicUsize::new(0)); tokio::spawn({ let waker = waker.clone(); let wake_times = woken_times.clone(); async move { loop { wake_times.fetch_add(1, Release); waker.wait_for(Signals::TRANSPORT).await; } } }); tokio::task::yield_now().await; assert_eq!(woken_times.load(Acquire), 1); // wake assert_eq!(waker.state(), !Signals::TRANSPORT.bits()); waker.wake_by(Signals::TRANSPORT); tokio::task::yield_now().await; assert_eq!(woken_times.load(Acquire), 2); // enter + wake assert_eq!(waker.state(), !Signals::TRANSPORT.bits()); waker.wake_by(Signals::CONGESTION | Signals::TRANSPORT); tokio::task::yield_now().await; assert_eq!(woken_times.load(Acquire), 3); // enter + wake * 2 assert_eq!(waker.state(), !Signals::TRANSPORT.bits()); } #[tokio::test] async fn not_wake() { let waker = ArcSendWaker::new(); let woken_times = Arc::new(AtomicUsize::new(0)); tokio::spawn({ let waker = waker.clone(); let wake_times = woken_times.clone(); async move { loop { wake_times.fetch_add(1, Release); waker.wait_for(Signals::CONGESTION).await; } } }); tokio::task::yield_now().await; assert_eq!(woken_times.load(Acquire), 1); // not woken waker.wake_by(Signals::FLOW_CONTROL); tokio::task::yield_now().await; assert_eq!(woken_times.load(Acquire), 1); // not woken } } ================================================ FILE: qbase/src/net.rs ================================================ use std::{ fmt::Display, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, str::FromStr, }; use bytes::BufMut; use nom::{ IResult, Parser, combinator::{flat_map, map}, number::complete::{be_u16, be_u32, be_u128}, }; use serde::{Deserialize, Serialize}; use thiserror::Error; use crate::frame::EncodeSize; pub mod addr; pub mod nat; pub mod route; pub mod tx; pub use nat::{NatType, NetFeature}; /// IP protocol family /// /// Represents IPv4 or IPv6 protocol family. #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] pub enum Family { /// IPv4 protocol family V4 = 0, /// IPv6 protocol family V6 = 1, } impl Display for Family { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Family::V4 => write!(f, "v4"), Family::V6 => write!(f, "v6"), } } } /// Invalid IP protocol family error /// /// Returned when attempting to parse an unsupported IP protocol family string. /// /// Supported values: `v4`, `V4`, `v6`, `V6` #[derive(Debug, Clone, Error, PartialEq, Eq)] #[error("Invalid ip family")] pub struct UnknownFamily; impl FromStr for Family { type Err = UnknownFamily; fn from_str(s: &str) -> Result { match s.to_lowercase().as_str() { "v4" => Ok(Family::V4), "v6" => Ok(Family::V6), _ => Err(UnknownFamily), } } } pub trait AddrFamily { /// Get the IP protocol family /// /// Returns `IpFamily::V4` for IPv4 addresses and `IpFamily::V6` for IPv6 addresses. fn family(&self) -> Family; } impl AddrFamily for std::net::Ipv4Addr { fn family(&self) -> Family { Family::V4 } } impl AddrFamily for std::net::Ipv6Addr { fn family(&self) -> Family { Family::V6 } } impl AddrFamily for std::net::IpAddr { fn family(&self) -> Family { match self { std::net::IpAddr::V4(_) => Family::V4, std::net::IpAddr::V6(_) => Family::V6, } } } impl AddrFamily for std::net::SocketAddr { fn family(&self) -> Family { self.ip().family() } } pub trait WriteSocketAddr { fn put_socket_addr(&mut self, addr: &SocketAddr); } impl WriteSocketAddr for T { fn put_socket_addr(&mut self, addr: &SocketAddr) { self.put_u16(addr.port()); match addr.ip() { IpAddr::V4(ipv4) => self.put_u32(ipv4.into()), IpAddr::V6(ipv6) => self.put_u128(ipv6.into()), } } } pub fn be_socket_addr(input: &[u8], family: Family) -> IResult<&[u8], SocketAddr> { flat_map(be_u16, |port| { map(be_ip_addr(family), move |ip| SocketAddr::new(ip, port)) }) .parse(input) } pub fn be_ip_addr(family: Family) -> impl Fn(&[u8]) -> IResult<&[u8], IpAddr> { move |input| match family { Family::V6 => map(be_u128, |ip| IpAddr::V6(Ipv6Addr::from(ip))).parse(input), Family::V4 => map(be_u32, |ip| IpAddr::V4(Ipv4Addr::from(ip))).parse(input), } } impl EncodeSize for SocketAddr { fn max_encoding_size(&self) -> usize { 2 + 16 // IPv6 address } fn encoding_size(&self) -> usize { match self.ip() { IpAddr::V4(_) => 2 + 4, IpAddr::V6(_) => 2 + 16, } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_ip_family_display_and_parse() { assert_eq!(Family::V4.to_string(), "v4"); assert_eq!(Family::V6.to_string(), "v6"); assert_eq!("v4".parse::().unwrap(), Family::V4); assert_eq!("V4".parse::().unwrap(), Family::V4); assert_eq!("v6".parse::().unwrap(), Family::V6); assert_eq!("V6".parse::().unwrap(), Family::V6); assert!(matches!("v7".parse::(), Err(UnknownFamily))); } } ================================================ FILE: qbase/src/packet/decrypt.rs ================================================ use rustls::quic::{HeaderProtectionKey, PacketKey}; use super::{ GetPacketNumberLength, KeyPhaseBit, LongSpecificBits, PacketNumber, ShortSpecificBits, error::Error, take_pn_len, }; /// Removes the header protection of the long packet. /// Returns the undecoded packet number in the header finally. /// /// When receiving a long packet, the header protection must be removed before /// the packet number can be decoded. If removing header protection is failed, it /// indicates that the packet is problematic and can be ignored. /// In this case, no error but None will be returned. /// If not so, it will put the QUIC connection in a situation that is highly susceptible /// to denial-of-service attacks. /// /// Note that after removing the long header protection, the 2-bit reserved bits of the /// long header, i.e., the 5th and 6th bits of the first byte of the first byte, must /// be 0, otherwise it will return a connection error of type PROTOCOL_VIOLATION. /// /// See [Section 17.2](https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-8.2) of /// QUIC RFC 9000. /// /// After obtaining the undecoded packet number, it is necessary to rely on the largest /// received packet number to further decode the actual packet number. pub fn remove_protection_of_long_packet( key: &dyn HeaderProtectionKey, pkt_buf: &mut [u8], payload_offset: usize, ) -> Result, Error> { let (pre_data, payload) = pkt_buf.split_at_mut(payload_offset); let first_byte = &mut pre_data[0]; let (max_pn_buf, sample) = payload.split_at_mut(4); // 去除包头保护失败,忽略即可 if key .decrypt_in_place(&sample[..key.sample_len()], first_byte, max_pn_buf) .is_err() { return Ok(None); } let specific_bits = LongSpecificBits::from(*first_byte); let pn_len = specific_bits.pn_len()?; let (_, undecoded_pn) = take_pn_len(pn_len)(max_pn_buf).unwrap(); Ok(Some(undecoded_pn)) } /// Removes the header protection of the short packet. /// Returns the undecoded packet number and the key phase bit in the header. /// /// When receiving a short packet, the header protection must be removed first before /// the packet number can be decoded. If removing header protection is failed, it /// indicates that the packet is problematic and can be ignored. /// In this case, no error but None will be returned instead. /// If not so, it will put the QUIC connection in a situation that is highly susceptible /// to denial-of-service attacks. /// /// Note that after removing the long header protection, the 2-bit reserved bits of the /// long header, i.e., the 4th and 5th bits of the first byte of the first byte, must /// be 0, otherwise it will return a connection error of type PROTOCOL_VIOLATION. /// /// See [Section 17.3.1](https://www.rfc-editor.org/rfc/rfc9000.html#section-17.3.1-4.8) of /// QUIC RFC 9000. /// /// After obtaining the undecoded packet number, it is necessary to rely on the maximum /// receiving packet number to further decode the actual packet number. pub fn remove_protection_of_short_packet( key: &dyn HeaderProtectionKey, pkt_buf: &mut [u8], payload_offset: usize, ) -> Result, Error> { let (pre_data, payload) = pkt_buf.split_at_mut(payload_offset); let first_byte = &mut pre_data[0]; let (max_pn_buf, sample) = payload.split_at_mut(4); // 去除包头保护失败,忽略即可 if key .decrypt_in_place(&sample[..key.sample_len()], first_byte, max_pn_buf) .is_err() { return Ok(None); } let clear_bits = ShortSpecificBits::from(*first_byte); let pn_len = clear_bits.pn_len()?; let (_, undecoded_pn) = take_pn_len(pn_len)(max_pn_buf).unwrap(); Ok(Some((undecoded_pn, clear_bits.key_phase()))) } /// Decrypt the body of a packet, applicable to both long and short packets. /// /// It will decrypt the body data of the packet in place and return the length of the valid /// plaintext body data in the packet. /// The final valid plaintext body length is not equal to the raw ciphered body length of the packet. /// This is because the ciphertext body length usually contains checksum codes at the end, /// which is not part of the plaintext body. /// /// Decrypting a packet relies on the packet number decoded from the packet header, and then /// uses the corresponding level of packet decryption key to decrypt the packet body. /// The packet body refers to the content located after the packet number. /// Decrypting a packet will verify the integrity of the packet. /// If decryption fails, it indicates that the packet is incorrect (strangely, removing the /// header protection succeeded, right?), indicating an error in the peer's packaging /// and encrypting logic, and then the QUIC connection should be terminated. pub fn decrypt_packet( key: &dyn PacketKey, pn: u64, pkt_buf: &mut [u8], body_offset: usize, ) -> Result { let (aad, body) = pkt_buf.split_at_mut(body_offset); let plain = key .decrypt_in_place(pn, aad, body) .map_err(|_| Error::DecryptPacketFailure)?; // should return plain.len() Ok(plain.len()) } ================================================ FILE: qbase/src/packet/encrypt.rs ================================================ use std::ops::Deref; use rustls::quic::{HeaderProtectionKey, PacketKey}; use super::{KeyPhaseBit, LongSpecificBits, ShortSpecificBits}; /// Encrypt the packet body, applicable to both long and short packets. /// /// It relies on the packet encryption key of the corresponding level and the packet /// number to encrypt the packet body. /// The packet body refers to the packet data located after the packet number, /// specifically including the intergrity checksum codes at the end, which usually consist of /// 16 bytes depending on the encryption algorithm. /// /// # Note /// /// Before encrypting the packet body, the entire packet content must be fully and /// correctly populated, including the packet header and body, especially the last /// few bits of the first byte. pub fn encrypt_packet(key: &dyn PacketKey, pn: u64, pkt_buf: &mut [u8], body_offset: usize) { let (aad, body_tag) = pkt_buf.split_at_mut(body_offset); let (body, tag_buf) = body_tag.split_at_mut(body_tag.len() - key.tag_len()); let tag = key.encrypt_in_place(pn, aad, body).unwrap(); tag_buf.copy_from_slice(tag.as_ref()); } /// Add header protection, applicable to both long and short packets. /// Mainly protects the Reserved Bits and Packet Number Length in the packet header, /// as well as the Packet Number. /// /// Use the header protection key of the corresponding level to protect the header. /// For long headers, the last 4 bits of the first byte are protected; /// and for short headers, the last 5 bits of the first byte are protected. /// /// This function uses the first bit of the first byte of the packet to determine /// whether it is a long packet or a short packet, and then performs the corresponding /// header protection. /// /// ## Note /// /// Before encrypting the packet body, the entire packet content must be fully and /// correctly filled, including the packet header and body, especially the last /// few bits of the first byte, and the packet body encryption must be completed. pub fn protect_header( key: &dyn HeaderProtectionKey, pkt_buf: &mut [u8], payload_offset: usize, pn_len: usize, ) { let (predata, payload) = pkt_buf.split_at_mut(payload_offset); let first_byte = &mut predata[0]; let (max_pn_buf, sample) = payload.split_at_mut(4); let sample_len = key.sample_len(); key.encrypt_in_place(&sample[..sample_len], first_byte, &mut max_pn_buf[..pn_len]) .unwrap(); } /// Encode the last 4 specific bits of the first byte of the long packet, i.e., /// two reserved bits of 0 and two bits of packet number encoding length. pub fn encode_long_first_byte(first_byte: &mut u8, pn_len: usize) { let specific_bits = LongSpecificBits::with_pn_len(pn_len); *first_byte |= specific_bits.deref(); } /// Encode the last 5 specific bits of the first byte of the short packet, i.e., /// two reserverd bits of 0, one bit of key phase, and two bits of packet number encoding length. pub fn encode_short_first_byte(first_byte: &mut u8, pn_len: usize, key_phase: KeyPhaseBit) { let mut specific_bits = ShortSpecificBits::with_pn_len(pn_len); specific_bits.set_key_phase(key_phase); *first_byte |= specific_bits.deref(); } #[cfg(test)] mod tests {} ================================================ FILE: qbase/src/packet/error.rs ================================================ use nom::error::ErrorKind as NomErrorKind; use thiserror::Error; use super::r#type::Type; /// Parse error of QUIC packet. #[derive(Debug, PartialEq, Eq, Error)] pub enum Error { #[error("Unsupport version {0}")] UnsupportedVersion(u32), #[error("Invalid fixed bit in long header")] InvalidFixedBit, #[error("Incomplete packet type: {0}")] IncompleteType(String), #[error("Incomplete packet header {0:?}: {1}")] IncompleteHeader(Type, String), #[error("Incomplete packet body {0:?}: {1}")] IncompletePacket(Type, String), #[error("Sampling of {0:?} packet content less than 20 bytes, only {1} bytes available")] UnderSampling(Type, usize), #[error("Fail to remove protection")] RemoveProtectionFailure, #[error("Invalid reserved bits: {0:05b} & {1:05b} must be 0")] InvalidReservedBits(u8, u8), #[error("Fail to decrypt packet")] DecryptPacketFailure, } impl nom::error::ParseError<&[u8]> for Error { fn from_error_kind(_input: &[u8], _kind: NomErrorKind) -> Self { debug_assert_eq!(_kind, NomErrorKind::ManyTill); unreachable!("QUIC frame parser must always consume") } fn append(_input: &[u8], _kind: NomErrorKind, source: Self) -> Self { // 在解析帧时遇到了source错误,many_till期望通过ManyTill的错误类型告知 // 这里,源错误更有意义,所以直接返回源错误 debug_assert_eq!(_kind, NomErrorKind::ManyTill); source } } impl From for crate::error::QuicError { fn from(e: Error) -> Self { match e { Error::InvalidReservedBits(_, _) => crate::error::QuicError::with_default_fty( crate::error::ErrorKind::ProtocolViolation, e.to_string(), ), _ => unreachable!(), } } } ================================================ FILE: qbase/src/packet/header/long.rs ================================================ use derive_more::{Deref, DerefMut}; use super::*; use crate::{cid::ConnectionId, varint::VarInt}; /// The long header structure, whose specific contents are determined by the /// concrete packet type, including VN/Retry/Initial/0Rtt/Handshake packet. /// /// Long headers are used for packets that are sent prior to the establishment /// of 1-RTT keys. Once 1-RTT keys are available, a sender switches to sending /// packets using the short header. /// /// ```text /// +---------------+-------------+------+--------------+------+--------------+----------+ /// |1|1|X X 0 0 0 0| Version(32) | DCIL | DCID(0..160) | SCIL | SCID(0..160) | Specific | /// +---+---+---+---+-------------+------+--------------+------+--------------+----------+ /// |<->|<->|<->| /// | | | /// | | +---> packet number length /// | +---> reserved bits, must be zero /// +---> represent specific long packet type /// ``` /// /// See [Long Header Packet Format](https://www.rfc-editor.org/rfc/rfc9000.html#name-long-header-packets) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Default, Clone, Deref, DerefMut)] pub struct LongHeader { dcid: ConnectionId, scid: ConnectionId, #[deref] #[deref_mut] specific: T, } impl super::GetDcid for LongHeader { fn dcid(&self) -> &ConnectionId { &self.dcid } } impl super::GetScid for LongHeader { fn scid(&self) -> &ConnectionId { &self.scid } } // The following is the header definition, which may exist in all future versions // of QUIC, so it is placed in this file without distinguishing versions. /// The specific contents of the version negotiation packet, which includes all the /// version numbers supported by the server. /// /// When the server receives an initial packet or 0-RTT packet with an unsupported /// version number, it will respond with a version negotiation packet that contains /// all the version numbers supported by the server, each version being 32 bits. #[derive(Debug, Default, Clone)] pub struct VersionNegotiation { versions: Vec, } impl VersionNegotiation { /// Create a new VersionNegotiation packet from the version numbers. pub fn new(versions: Vec) -> Self { VersionNegotiation { versions } } /// Get the version numbers supported by the server. pub fn versions(&self) -> &Vec { &self.versions } } /// The specific contents of the retry packet, which includes a retry token and a /// 16-byte integrity checksum codes. /// /// After accepting the client's new connection, the server may return a retry packet /// due to load balancing strategies or simply for address verification, /// requiring the client to reconnect to the new address with the token. #[derive(Debug, Default, Clone)] pub struct Retry { token: Vec, integrity: [u8; 16], } impl Retry { /// Create a new Retry packet from the token and integrity value. /// /// The token is required to be carried by the Initial packet when the client /// reconnects in the future and will be used by the server for address verification. pub fn new(token: &[u8], integrity: &[u8]) -> Self { let mut retry = Retry { token: Vec::from(token), integrity: [0; 16], }; retry.integrity.copy_from_slice(integrity); retry } /// Get the retry token. pub fn token(&self) -> &Vec { &self.token } /// Get the integrity value. pub fn integrity(&self) -> &[u8; 16] { &self.integrity } } /// The specific contents of the initial packet, which just includes a token. /// /// The token comes from the Retry packet responded by the server, or it is issued to /// the client by the server through the NewToken frame in past QUIC connections. /// After the server receives this token, it will be used for address verification. /// If the client connects to the server for the first time, the token is empty. #[derive(Debug, Default, Clone)] pub struct Initial { token: Vec, } impl Initial { /// Create a new Initial packet from the token. pub fn with_token(token: Vec) -> Self { Initial { token } } /// Create a new Initial packet from the token slice. pub fn from_slice(token: &[u8]) -> Self { Initial { token: Vec::from(token), } } /// Get the token. pub fn token(&self) -> &Vec { &self.token } } /// The specific contents of the 0-RTT packet, which is empty. #[derive(Debug, Default, Clone)] pub struct ZeroRtt; /// The specific contents of the handshake packet, which is empty. #[derive(Debug, Default, Clone)] pub struct Handshake; impl EncodeHeader for Initial { fn size(&self) -> usize { VarInt::try_from(self.token.len()) .expect("token length can not be more than 2^62") .encoding_size() + self.token.len() } } impl EncodeHeader for ZeroRtt {} impl EncodeHeader for Handshake {} /// Version negotiation packet, which is a long header packet. /// /// See [version negotiation packet](https://www.rfc-editor.org/rfc/rfc9000.html#name-version-negotiation-packet) /// in [RFC9000](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. pub type VersionNegotiationHeader = LongHeader; /// Retry packet, which is a long header packet. /// /// See [retry packet](https://www.rfc-editor.org/rfc/rfc9000.html#name-retry-packet) /// in [RFC9000](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. pub type RetryHeader = LongHeader; /// Initial packet header, which is a long header packet. /// /// See [initial packet](https://www.rfc-editor.org/rfc/rfc9000.html#name-initial-packet) /// in [RFC9000](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. pub type InitialHeader = LongHeader; /// Handshake packet header, which is a long header packet. /// /// See [handshake packet](https://www.rfc-editor.org/rfc/rfc9000.html#name-handshake-packet) /// in [RFC9000](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. pub type HandshakeHeader = LongHeader; /// 0-RTT packet header, which is a long header packet. /// /// See [0-RTT packet](https://www.rfc-editor.org/rfc/rfc9000.html#name-0-rtt-packet) /// in [RFC9000](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. pub type ZeroRttHeader = LongHeader; impl EncodeHeader for LongHeader { fn size(&self) -> usize { 1 + 4 + 1 + self.dcid.len() // dcid长度最多20字节,长度编码只占1字节,加上cid本身的长度 + 1 + self.scid.len() // scid一样 + self.specific.size() } fn length_encoding(&self) -> usize { 2 // 长包头都带有length字段,统一2字节,能表达1~16KB的长度的包 } } macro_rules! bind_type { ($($type:ty => $value:expr),*) => { $( impl GetType for $type { fn get_type(&self) -> Type { $value } } )* }; } bind_type!( VersionNegotiationHeader => Type::Long(LongType::VersionNegotiation), RetryHeader => Type::Long(LongType::V1(Version::<1, _>(v1::Type::Retry))), InitialHeader => Type::Long(LongType::V1(Version::<1, _>(v1::Type::Initial))), ZeroRttHeader => Type::Long(LongType::V1(Version::<1, _>(v1::Type::ZeroRtt))), HandshakeHeader => Type::Long(LongType::V1(Version::<1, _>(v1::Type::Handshake))) ); /// The sum type of long packets that carry data, /// including Initial, ZeroRtt, and Handshake packets. #[derive(Debug, Clone)] #[enum_dispatch(Encode, GetType, GetDcid, GetScid)] pub enum DataHeader { Initial(InitialHeader), ZeroRtt(ZeroRttHeader), Handshake(HandshakeHeader), } /// The io module provides functions for parsing and writing long headers. pub mod io { use std::ops::Deref; use bytes::BufMut; use nom::{ Err, Parser, bytes::streaming::take, combinator::{eof, map}, multi::{length_data, many_till}, number::streaming::be_u32, }; use super::*; use crate::{ cid::WriteConnectionId, packet::{ header::io::WriteHeader, r#type::{ io::WritePacketType, long::{Type as LongType, v1::Type as LongV1Type}, }, }, varint::{WriteVarInt, be_varint}, }; /// Parse the version negotiation packet, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_version_negotiation(input: &[u8]) -> nom::IResult<&[u8], VersionNegotiation> { let (remain, (versions, _)) = many_till(be_u32, eof).parse(input)?; Ok((remain, VersionNegotiation::new(versions))) } /// Parse the retry packet, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_retry(input: &[u8]) -> nom::IResult<&[u8], Retry> { if input.len() < 16 { return Err(Err::Incomplete(nom::Needed::new(16))); } let token_length = input.len() - 16; let (integrity, token) = take(token_length)(input)?; Ok((&[][..], Retry::new(token, integrity))) } /// Parse the initial packet, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_initial(input: &[u8]) -> nom::IResult<&[u8], Initial> { map(length_data(be_varint), Initial::from_slice).parse(input) } /// Parse the 0-RTT packet, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_zero_rtt(input: &[u8]) -> nom::IResult<&[u8], ZeroRtt> { Ok((input, ZeroRtt)) } /// Parse the handshake packet, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_handshake(input: &[u8]) -> nom::IResult<&[u8], Handshake> { Ok((input, Handshake)) } /// The builder for the long header, which is used to create a long header. /// /// ## Example /// ``` /// use qbase::{cid::ConnectionId, packet::header::long::io::LongHeaderBuilder}; /// /// let scid = ConnectionId::from_slice(b"scid"); /// let dcid = ConnectionId::from_slice(b"dcid"); /// /// let handshake_header = LongHeaderBuilder::with_cid(dcid, scid).handshake(); /// ``` pub struct LongHeaderBuilder { pub(crate) dcid: ConnectionId, pub(crate) scid: ConnectionId, } impl LongHeaderBuilder { /// Create a new long header builder with the given destination /// and source connection IDs. pub fn with_cid(dcid: ConnectionId, scid: ConnectionId) -> Self { Self { dcid, scid } } /// Build into a version negotiation header. pub fn vn(self, versions: Vec) -> LongHeader { self.wrap(VersionNegotiation::new(versions)) } /// Build into a retry header. pub fn retry(self, token: Vec, integrity: [u8; 16]) -> LongHeader { self.wrap(Retry { token, integrity }) } /// Build into an initial header. pub fn initial(self, token: Vec) -> LongHeader { self.wrap(Initial::with_token(token)) } /// Build into a 0-RTT header. pub fn zero_rtt(self) -> LongHeader { self.wrap(ZeroRtt) } /// Build into a handshake header. pub fn handshake(self) -> LongHeader { self.wrap(Handshake) } /// Wrap the specific header into the long generic header. /// Return the specific long header. pub fn wrap(self, specific: T) -> LongHeader { LongHeader { dcid: self.dcid, scid: self.scid, specific, } } /// Parse a long header from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. /// /// The input buffer would be the remaining data of the buffer. pub fn parse(self, ty: LongType, input: &[u8]) -> nom::IResult<&[u8], Header> { match ty { LongType::VersionNegotiation => { let (remain, versions) = be_version_negotiation(input)?; Ok((remain, Header::VN(self.wrap(versions)))) } LongType::V1(ty) => match ty.deref() { LongV1Type::Retry => { let (remain, retry) = be_retry(input)?; Ok((remain, Header::Retry(self.wrap(retry)))) } LongV1Type::Initial => { let (remain, initial) = be_initial(input)?; Ok((remain, Header::Initial(self.wrap(initial)))) } LongV1Type::ZeroRtt => { let (remain, zero_rtt) = be_zero_rtt(input)?; Ok((remain, Header::ZeroRtt(self.wrap(zero_rtt)))) } LongV1Type::Handshake => { let (remain, handshake) = be_handshake(input)?; Ok((remain, Header::Handshake(self.wrap(handshake)))) } }, } } } /// A [`bytes::BufMut`] extension trait, makes buffer more friendly to write long headers. pub trait WriteSpecific: BufMut { /// Write the specific header content. fn put_specific(&mut self, _specific: &S) {} } impl WriteSpecific for T { fn put_specific(&mut self, specific: &VersionNegotiation) { for version in &specific.versions { self.put_u32(*version); } } } impl WriteSpecific for T { fn put_specific(&mut self, specific: &Retry) { self.put_slice(&specific.token); self.put_slice(&specific.integrity); } } impl WriteSpecific for T { fn put_specific(&mut self, specific: &Initial) { self.put_varint( &VarInt::try_from(specific.token.len()) .expect("token length can not be more than 2^62"), ); self.put_slice(&specific.token); } } /// 0-Rtt headers are empty, so there is nothing to write. impl WriteSpecific for T {} /// Handshake headers are empty, so there is nothing to write. impl WriteSpecific for T {} impl WriteHeader> for T where T: BufMut + WriteSpecific, LongHeader: GetType, { fn put_header(&mut self, header: &LongHeader) { let ty = header.get_type(); self.put_packet_type(&ty); self.put_connection_id(&header.dcid); self.put_connection_id(&header.scid); self.put_specific(&header.specific); } } } #[cfg(test)] mod tests { use crate::packet::header::WriteSpecific; #[test] fn test_be_version_negotiation() { use super::io::be_version_negotiation; let buf = vec![0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02]; let (remain, versions) = be_version_negotiation(buf.as_ref()).unwrap(); assert_eq!(versions.versions, vec![0x01, 0x02]); assert_eq!(remain.len(), 0); } #[test] fn test_be_retry() { use super::io::be_retry; let buf = vec![ 0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, ]; let (remain, retry) = be_retry(buf.as_ref()).unwrap(); assert_eq!( retry.integrity, [ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f ] ); assert_eq!(retry.token, vec![0x00, 0x00, 0x00]); assert_eq!(remain.len(), 0); let buf = vec![ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, ]; match be_retry(&buf) { Err(e) => assert_eq!(e, nom::Err::Incomplete(nom::Needed::new(16))), _ => panic!("unexpected result"), } } #[test] fn test_be_initial() { use crate::packet::header::long::io::be_initial; // Note: The length of the last bit is filled in when sending, here set as 0x01 // Consistent behavior with zero_rtt and handshake let buf = vec![0x03, 0x00, 0x00, 0x00]; let (remain, initial) = be_initial(buf.as_ref()).unwrap(); assert_eq!(initial.token, vec![0x00, 0x00, 0x00]); assert_eq!(remain.len(), 0); } #[test] fn test_write_version_negotiation_long_header() { use super::{LongHeaderBuilder, VersionNegotiation}; use crate::cid::ConnectionId; let mut buf = Vec::::new(); let vn_long_header = LongHeaderBuilder::with_cid(ConnectionId::default(), ConnectionId::default()) .wrap(VersionNegotiation::new(vec![0x01, 0x02])); buf.put_specific(&vn_long_header.specific); assert_eq!(buf, vec![0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02]); } #[test] fn test_write_retry_long_header() { use super::{LongHeaderBuilder, Retry}; use crate::cid::ConnectionId; let mut buf = Vec::::new(); let retry_long_header = LongHeaderBuilder::with_cid(ConnectionId::default(), ConnectionId::default()).wrap( Retry::new( &[0x00, 0x00, 0x00], &[ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, ], ), ); buf.put_specific(&retry_long_header.specific); assert_eq!( buf, vec![ 0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, ] ); } #[test] fn test_write_initial_long_header() { use super::LongHeaderBuilder; use crate::cid::ConnectionId; let mut buf = Vec::::new(); let initial_long_header = LongHeaderBuilder::with_cid(ConnectionId::default(), ConnectionId::default()) .initial(vec![0x00, 0x00, 0x00]); buf.put_specific(&initial_long_header.specific); assert_eq!(buf, vec![0x03, 0x00, 0x00, 0x00,]); } } ================================================ FILE: qbase/src/packet/header/short.rs ================================================ use super::*; use crate::{cid::ConnectionId, packet::SpinBit}; /// A packet with a short header does not include a length, /// so it can only be the last packet in a UDP datagram. /// /// ```text /// +---spin bit /// | +---key phase bits /// | | /// +----+-----+----+------+--------------+----......---+ /// |1|1|S 0 0 K 0 0| DCIL | DCID(0..160) | Payload ... | /// +-----+---+-+---+------+--------------+----......---+ /// |<->| |<->| /// | | /// | +---> packet number length /// +---> reserved bits, must be 0 /// ``` /// /// See [1-RTT Packet](https://www.rfc-editor.org/rfc/rfc9000.html#name-1-rtt-packet) /// in [RFC9000](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] pub struct OneRttHeader { // For simplicity, the spin bit is also part of the 1RTT header. spin: SpinBit, dcid: ConnectionId, } impl OneRttHeader { /// Create a new 1RTT header. pub fn new(spin: SpinBit, dcid: ConnectionId) -> Self { Self { spin, dcid } } /// Get the spin bit. pub fn spin(&self) -> SpinBit { self.spin } } impl EncodeHeader for OneRttHeader { fn size(&self) -> usize { 1 + self.dcid.len() } } impl GetType for OneRttHeader { fn get_type(&self) -> Type { Type::Short(OneRtt(self.spin)) } } impl super::GetDcid for OneRttHeader { fn dcid(&self) -> &ConnectionId { &self.dcid } } /// The io module provides functions for parsing and writing 1RTT headers. pub mod io { use bytes::BufMut; use super::{GetType, OneRttHeader}; use crate::packet::{header::io::WriteHeader, signal::SpinBit, r#type::io::WritePacketType}; /// Parse a 1RTT header from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_one_rtt_header( spin: SpinBit, dcid_len: usize, input: &[u8], ) -> nom::IResult<&[u8], OneRttHeader> { use nom::bytes::streaming::take; let (remain, dcid) = take(dcid_len)(input)?; let dcid = crate::cid::ConnectionId::from_slice(dcid); Ok((remain, OneRttHeader { spin, dcid })) } impl WriteHeader for T { fn put_header(&mut self, header: &OneRttHeader) { let ty = header.get_type(); self.put_packet_type(&ty); self.put_slice(&header.dcid); } } } #[cfg(test)] mod tests { use crate::packet::header::io::WriteHeader; #[test] fn test_read_one_rtt_header() { use super::io::be_one_rtt_header; use crate::packet::{SpinBit, header::ConnectionId}; let (remain, header) = be_one_rtt_header(SpinBit::One, 0, &[][..]).unwrap(); assert_eq!(remain.len(), 0); assert_eq!(header.spin, SpinBit::One); assert_eq!(header.dcid, ConnectionId::default()); } #[test] fn test_write_one_rtt_header() { use super::OneRttHeader; use crate::{cid::ConnectionId, packet::SpinBit}; let mut buf = vec![]; let header = OneRttHeader { spin: SpinBit::One, dcid: ConnectionId::default(), }; buf.put_header(&header); // Note: 0x60 == SHORT_HEADER_BIT | FIXED_BIT | Toggle.value() assert_eq!(buf, [0x60]); } } ================================================ FILE: qbase/src/packet/header.rs ================================================ use enum_dispatch::enum_dispatch; use crate::cid::ConnectionId; /// All structure definitions related to long headers. pub mod long; /// All structure definitions related to short headers. pub mod short; #[doc(hidden)] pub use long::{ DataHeader, HandshakeHeader, InitialHeader, LongHeader, RetryHeader, VersionNegotiationHeader, ZeroRttHeader, io::{LongHeaderBuilder, WriteSpecific}, }; #[doc(hidden)] pub use short::OneRttHeader; use super::r#type::{ Type, long::{Type as LongType, Version, v1}, short::OneRtt, }; /// Each packet has its type. For more detailed definition on packet types, see [`Type`]. #[enum_dispatch] pub trait GetType { /// Get the packet type. fn get_type(&self) -> Type; } /// When encoding a packet for sending, we need to know the size of the packet encoding, /// so this trait needs to be implemented. /// /// However, the length field of the packet payload is variable-length encoded and /// requires special encoding, which is not considered here. #[enum_dispatch] pub trait EncodeHeader { /// Returns the length of the encoded packet header. fn size(&self) -> usize { 0 } fn length_encoding(&self) -> usize { 0 } } /// Get the Destination Connection ID (DCID) of the packet, each packet has a DCID. #[enum_dispatch] pub trait GetDcid { /// Get the Destination Connection ID (DCID) of the packet. fn dcid(&self) -> &ConnectionId; } /// Get the Source Connection ID (SCID) of the packet, only long packets have SCID. #[enum_dispatch] pub trait GetScid { /// Get the Source Connection ID (SCID) of the packet. fn scid(&self) -> &ConnectionId; } /// The sum type of all packet headers. #[derive(Debug, Clone)] #[enum_dispatch(GetDcid)] pub enum Header { VN(long::VersionNegotiationHeader), Retry(long::RetryHeader), Initial(long::InitialHeader), ZeroRtt(long::ZeroRttHeader), Handshake(long::HandshakeHeader), OneRtt(short::OneRttHeader), } /// The io module for packet headers, including /// how to parse the header from a UDP packet and /// how to write the header into a UDP packet. pub mod io { use super::{ Header, LongHeader, OneRttHeader, long::{Handshake, Initial, Retry, VersionNegotiation, ZeroRtt, io::LongHeaderBuilder}, }; use crate::{ cid::be_connection_id, packet::{ header::short::io::be_one_rtt_header, r#type::{Type, short::OneRtt}, }, }; /// Parse a packet header from the input buffer, /// returns [`Header`] if succeed, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_header( packet_type: Type, dcid_len: usize, input: &[u8], ) -> nom::IResult<&[u8], Header> { match packet_type { Type::Long(long_ty) => { let (remain, dcid) = be_connection_id(input)?; let (remain, scid) = be_connection_id(remain)?; let builder = LongHeaderBuilder { dcid, scid }; builder.parse(long_ty, remain) } Type::Short(OneRtt(spin)) => { let (remain, one_rtt) = be_one_rtt_header(spin, dcid_len, input)?; Ok((remain, Header::OneRtt(one_rtt))) } } } /// A [`bytes::BufMut`] extension trait for writing packet headers. /// /// When sending packets, it is necessary to organize the data and write /// various types of QUIC packets into an UDP datagram. This trait will /// be used to write the packet header. pub trait WriteHeader: bytes::BufMut { /// Write a packet header to the buffer. fn put_header(&mut self, header: &H); } impl WriteHeader
for T where T: bytes::BufMut + WriteHeader> + WriteHeader> + WriteHeader> + WriteHeader> + WriteHeader> + WriteHeader, { fn put_header(&mut self, header: &Header) { match header { Header::VN(vn) => self.put_header(vn), Header::Retry(retry) => self.put_header(retry), Header::Initial(initial) => self.put_header(initial), Header::ZeroRtt(zero_rtt) => self.put_header(zero_rtt), Header::Handshake(handshake) => self.put_header(handshake), Header::OneRtt(one_rtt) => self.put_header(one_rtt), } } } } #[cfg(test)] mod tests { use std::ops::Deref; use super::{ Header, LongHeaderBuilder, io::be_header, long::{Handshake, Initial, Retry, VersionNegotiation, ZeroRtt}, }; use crate::{ cid::ConnectionId, packet::{ GetDcid, OneRttHeader, SpinBit, header::{GetScid, io::WriteHeader}, r#type::{ Type, long::{self, Ver1}, short::OneRtt, }, }, }; #[test] fn test_read_header() { // VersionNegotiation Header let buf = vec![0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02]; let (remain, vn_long_header) = be_header(Type::Long(long::Type::VersionNegotiation), 0, &buf).unwrap(); assert_eq!(remain.len(), 0); match vn_long_header { Header::VN(vn) => { assert_eq!(vn.dcid(), &ConnectionId::default()); assert_eq!(vn.scid(), &ConnectionId::default()); assert_eq!(vn.versions(), &vec![0x01, 0x02]); } _ => panic!("unexpected header type"), } // Retry Header let buf = vec![ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, ]; let (remain, retry_long_header) = be_header(Type::Long(long::Type::V1(Ver1::RETRY)), 0, &buf).unwrap(); assert_eq!(remain.len(), 0); match retry_long_header { Header::Retry(retry) => { assert_eq!(retry.dcid(), &ConnectionId::default()); assert_eq!(retry.scid(), &ConnectionId::default()); assert_eq!(retry.token().deref(), &[0x00, 0x00, 0x00]); assert_eq!( retry.integrity(), &[ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f ] ); } _ => panic!("unexpected header type"), } // Retry Header with invalid length let buf = vec![ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, ]; match be_header(Type::Long(long::Type::V1(Ver1::RETRY)), 0, &buf) { Err(e) => assert_eq!(e, nom::Err::Incomplete(nom::Needed::new(16))), _ => panic!("unexpected result"), } // Initial Header let buf = vec![0x00, 0x00, 0x03, 0x01, 0x02, 0x03]; let (remain, initial_long_header) = be_header(Type::Long(long::Type::V1(Ver1::INITIAL)), 0, &buf).unwrap(); assert_eq!(remain.len(), 0); match initial_long_header { Header::Initial(initial) => { assert_eq!(initial.dcid(), &ConnectionId::default()); assert_eq!(initial.scid(), &ConnectionId::default()); assert_eq!(initial.token().deref(), [0x01, 0x02, 0x03,]); } _ => panic!("unexpected header type"), } // ZeroRTT Header let buf = vec![0x00, 0x00]; let (remain, zero_rtt_long_header) = be_header(Type::Long(long::Type::V1(Ver1::ZERO_RTT)), 0, &buf).unwrap(); assert_eq!(remain.len(), 0); match zero_rtt_long_header { Header::ZeroRtt(zero_rtt) => { assert_eq!(zero_rtt.dcid(), &ConnectionId::default()); assert_eq!(zero_rtt.scid(), &ConnectionId::default()); } _ => panic!("unexpected header type"), } // Handshake Header let buf = vec![0x00, 0x00]; let (remain, handshake_long_header) = be_header(Type::Long(long::Type::V1(Ver1::HANDSHAKE)), 0, &buf).unwrap(); assert_eq!(remain.len(), 0); match handshake_long_header { Header::Handshake(handshake) => { assert_eq!(handshake.dcid(), &ConnectionId::default()); assert_eq!(handshake.scid(), &ConnectionId::default()); } _ => panic!("unexpected header type"), } // OneRtt Header let buf = vec![]; let (remain, one_rtt_header) = be_header(Type::Short(OneRtt(SpinBit::One)), 0, &buf).unwrap(); assert_eq!(remain.len(), 0); match one_rtt_header { Header::OneRtt(one_rtt) => { assert_eq!( one_rtt, OneRttHeader::new(SpinBit::One, ConnectionId::default()) ); } _ => panic!("unexpected header type"), } } #[test] fn test_write_header() { // VersionNegotiation Header let mut buf = vec![]; let vn_long_header = Header::VN( LongHeaderBuilder::with_cid(ConnectionId::default(), ConnectionId::default()) .wrap(VersionNegotiation::new(vec![0x01, 0x02])), ); buf.put_header(&vn_long_header); assert_eq!( buf, [ 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02 ] ); // Retry Header let mut buf = vec![]; let retry_long_header = Header::Retry( LongHeaderBuilder::with_cid(ConnectionId::default(), ConnectionId::default()).wrap( Retry::new( &[0x00, 0x00, 0x00], &[ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, ], ), ), ); buf.put_header(&retry_long_header); assert_eq!( buf, [ 0xf0, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f ] ); // Initial Header let mut buf = vec![]; let initial_header = Header::Initial( LongHeaderBuilder::with_cid(ConnectionId::default(), ConnectionId::default()) .wrap(Initial::with_token(vec![0x01, 0x02, 0x03])), ); buf.put_header(&initial_header); assert_eq!( buf, [ 0xc0, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x03, 0x01, 0x02, 0x03 ] ); // ZeroRtt Header let mut buf = vec![]; let zero_rtt_header = Header::ZeroRtt( LongHeaderBuilder::with_cid(ConnectionId::default(), ConnectionId::default()) .wrap(ZeroRtt), ); buf.put_header(&zero_rtt_header); assert_eq!(buf, [0xd0, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00]); // Handshake Header let mut buf = vec![]; let handshake_header = Header::Handshake( LongHeaderBuilder::with_cid(ConnectionId::default(), ConnectionId::default()) .wrap(Handshake), ); buf.put_header(&handshake_header); assert_eq!(buf, [0xe0, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00]); // OneRtt Header with SpinBit::On let mut buf = vec![]; let one_rtt_header = Header::OneRtt(OneRttHeader::new(SpinBit::One, ConnectionId::default())); buf.put_header(&one_rtt_header); assert_eq!(buf, [0x60]); // OneRtt Header with SpinBit::Off let mut buf = vec![]; let one_rtt_header = Header::OneRtt(OneRttHeader::new(SpinBit::Zero, ConnectionId::default())); buf.put_header(&one_rtt_header); assert_eq!(buf, [0x40]); } } ================================================ FILE: qbase/src/packet/io.rs ================================================ use std::{any::Any, mem}; use bytes::BytesMut; use nom::{Parser, multi::length_data}; use super::{ error::Error, header::io::be_header, r#type::{Type, io::be_packet_type}, *, }; use crate::{ Epoch, frame::{io::WriteFrame, *}, net::tx::Signals, util::{ContinuousData, NonData, WriteData}, varint::be_varint, }; /// Parse the payload of a packet. /// /// - For long packets, the payload is a [`nom::multi::length_data`]. /// - For 1-RTT packet, the payload is the remaining content of the datagram. fn be_payload( pkty: Type, datagram: &mut BytesMut, remain_len: usize, ) -> Result<(BytesMut, usize), Error> { let offset = datagram.len() - remain_len; let input = &datagram[offset..]; let (remain, payload) = length_data(be_varint).parse(input).map_err(|e| match e { ne @ nom::Err::Incomplete(_) => Error::IncompleteHeader(pkty, ne.to_string()), _ => unreachable!("parsing packet header never generates error or failure"), })?; let payload_len = payload.len(); if payload_len < 20 { // The payload needs at least 20 bytes to have enough samples to remove the packet header protection. return Err(Error::UnderSampling(pkty, payload.len())); } let packet_length = datagram.len() - remain.len(); let bytes = datagram.split_to(packet_length); Ok((bytes, packet_length - payload_len)) } /// Parse the QUIC packet from the datagram, given the length of the DCID. /// Returns the parsed packet or an error, and the datagram removed the packet's content. pub fn be_packet(datagram: &mut BytesMut, dcid_len: usize) -> Result { let input = datagram.as_ref(); let (remain, pkty) = be_packet_type(input).map_err(|e| match e { ne @ nom::Err::Incomplete(_) => Error::IncompleteType(ne.to_string()), nom::Err::Error(e) => e, _ => unreachable!("parsing packet type never generates failure"), })?; let (remain, header) = be_header(pkty, dcid_len, remain).map_err(|e| match e { ne @ nom::Err::Incomplete(_) => Error::IncompleteHeader(pkty, ne.to_string()), _ => unreachable!("parsing packet header never generates error or failure"), })?; match header { Header::VN(header) => { datagram.clear(); Ok(Packet::VN(header)) } Header::Retry(header) => { datagram.clear(); Ok(Packet::Retry(header)) } Header::Initial(header) => { let (bytes, offset) = be_payload(pkty, datagram, remain.len())?; Ok(Packet::Data(DataPacket { header: DataHeader::Long(long::DataHeader::Initial(header)), bytes, offset, })) } Header::ZeroRtt(header) => { let (bytes, offset) = be_payload(pkty, datagram, remain.len())?; Ok(Packet::Data(DataPacket { header: DataHeader::Long(long::DataHeader::ZeroRtt(header)), bytes, offset, })) } Header::Handshake(header) => { let (bytes, offset) = be_payload(pkty, datagram, remain.len())?; Ok(Packet::Data(DataPacket { header: DataHeader::Long(long::DataHeader::Handshake(header)), bytes, offset, })) } Header::OneRtt(header) => { if remain.len() < 20 { // The payload needs at least 20 bytes to have enough samples to remove the packet header protection. return Err(Error::UnderSampling(pkty, remain.len())); } let remain_len = remain.len(); let bytes = mem::replace(datagram, BytesMut::new()); let offset = bytes.len() - remain_len; datagram.clear(); Ok(Packet::Data(DataPacket { header: DataHeader::Short(header), bytes, offset, })) } } } pub trait ProductHeader { fn new_header(&self) -> Result; } pub trait PacketSpace { type PacketAssembler<'b>: AssemblePacket where Self: 'b; fn new_packet<'b>( &'b self, header: H, buffer: &'b mut [u8], ) -> Result, Signals>; } // Target -> Target pub trait Package { fn dump(&mut self, target: &mut Target) -> Result; } impl + ?Sized> Package for &mut P { #[inline] fn dump(&mut self, target: &mut Target) -> Result { P::dump(self, target) } } impl + ?Sized> Package for Box

{ #[inline] fn dump(&mut self, target: &mut Target) -> Result { P::dump(self, target) } } impl> Package for Option

{ #[inline] fn dump(&mut self, target: &mut Target) -> Result { self.take() .map_or_else(|| Err(Signals::empty()), |mut package| package.dump(target)) } } impl> Package for [P] { #[inline] fn dump(&mut self, target: &mut Target) -> Result { let origin = target.remaining_mut(); let mut signals = Signals::empty(); let mut packet_content = PacketContent::default(); for package in self { match package.dump(target) { Ok(content) => packet_content += content, Err(s) => signals |= s, } } (origin != target.remaining_mut()) .then_some(packet_content) .ok_or(signals) } } impl, const N: usize> Package for [P; N] { #[inline] fn dump(&mut self, target: &mut Target) -> Result { let origin = target.remaining_mut(); let mut signals = Signals::empty(); let mut packet_content = PacketContent::default(); for package in self { match package.dump(target) { Ok(content) => packet_content += content, Err(s) => signals |= s, } } (origin != target.remaining_mut()) .then_some(packet_content) .ok_or(signals) } } pub struct PadTo20; impl<'b, P> Package

for PadTo20 where P: AsRef> + BufMut + ?Sized, { #[inline] fn dump(&mut self, target: &mut P) -> Result { let packet = target.as_ref(); match packet.payload_len() + packet.tag_len() { _ if packet.is_empty() => Err(Signals::empty()), len if len < 20 => { target.put_bytes(0, 20 - len); Ok(PacketContent::NonAckEliciting) } _ => Ok(PacketContent::NonAckEliciting), } } } pub struct PadToFull; impl<'b, P> Package

for PadToFull where P: AsRef> + BufMut + ?Sized, { #[inline] fn dump(&mut self, target: &mut P) -> Result { let packet = target.as_ref(); match packet.payload_len() + packet.tag_len() { _ if packet.is_empty() => Err(Signals::empty()), len if len < packet.buffer().len() => { target.put_bytes(0, packet.remaining_mut()); Ok(PacketContent::NonAckEliciting) } _ => Ok(PacketContent::NonAckEliciting), } } } pub struct PadProbe; impl<'b, P> Package

for PadProbe where P: AsRef> + BufMut + ?Sized, { #[inline] fn dump(&mut self, target: &mut P) -> Result { if target.as_ref().is_probe_new_path() { return PadToFull.dump(target); } Err(Signals::empty()) } } #[derive(Debug, Clone, Copy)] pub struct Repeat

(pub P); impl> Package for Repeat

{ #[inline] fn dump(&mut self, target: &mut Target) -> Result { let origin = target.remaining_mut(); let mut packet_content = PacketContent::default(); let signals = loop { match self.0.dump(target) { Ok(content) => packet_content += content, Err(signals) => break signals, } }; (origin != target.remaining_mut()) .then_some(packet_content) .ok_or(signals) } } pub struct Packages(pub T); macro_rules! impl_package_for_tuple { () => {}; ($head:ident $($tail:ident)*) => { impl_package_for_tuple!(@imp $head $($tail)*); impl_package_for_tuple!( $($tail)*); }; (@imp $($t:ident)*) => { impl),*> Package for Packages<($($t,)*)> { #[inline] fn dump(&mut self, target: &mut Target) -> Result { let origin = target.remaining_mut(); let mut signals = Signals::empty(); let mut packet_content = PacketContent::default(); #[allow(non_snake_case)] let ($($t,)*) = &mut self.0; $( #[allow(non_snake_case)] match $t.dump(target) { Ok(content) => packet_content += content, Err(s) => signals |= s, } )* (origin != target.remaining_mut()) .then_some(packet_content) .ok_or(signals) } } } } impl_package_for_tuple! { Z Y X W V U T S R Q P O N M L K J I H G F E D C B A } macro_rules! frame_packages { () => {}; (@imp_frame $($frame:tt)*) => { impl Package for $($frame)* where Target: BufMut + RecordFrame, NonData> + ?Sized, { #[inline] fn dump(&mut self, target: &mut Target) -> Result { if !(target.remaining_mut() >= self.max_encoding_size() || target.remaining_mut() >= self.encoding_size()) { return Err(Signals::CONGESTION); } let frame = self.clone().into(); target.record_frame(&frame); target.put_frame(&frame); Ok(PacketContent::from(self.frame_type())) } } }; (impl> Package for $frame:ident {} $($tail:tt)*) => { frame_packages!{ @imp_frame $frame } frame_packages!{ @imp_frame &$frame } frame_packages!{ $($tail)* } }; (@imp_data_frame $($frame_with_data:tt)*) => { impl Package for $($frame_with_data)* where Target: BufMut + RecordFrame, D> + ?Sized, D: ContinuousData + Clone, for<'b> &'b mut Target: WriteData, { #[inline] fn dump(&mut self, target: &mut Target) -> Result { let (frame, data) = self; if !(target.remaining_mut() >= frame.max_encoding_size() || target.remaining_mut() >= frame.encoding_size()) { return Err(Signals::CONGESTION); } let frame = (frame.clone(), data.clone()).into(); target.record_frame(&frame); target.put_frame(&frame); Ok(PacketContent::from(frame.frame_type())) } } }; (impl, D: ContinuousData> Package for ($frame:ident, D) {} $($tail:tt)*) => { frame_packages!{ @imp_data_frame ($frame, D) } frame_packages!{ @imp_data_frame &($frame, D) } frame_packages!{ $($tail)* } }; } frame_packages! { impl> Package for PaddingFrame {} impl> Package for PingFrame {} impl> Package for AckFrame {} impl> Package for ConnectionCloseFrame {} impl> Package for NewTokenFrame {} impl> Package for MaxDataFrame {} impl> Package for DataBlockedFrame {} impl> Package for HandshakeDoneFrame {} impl> Package for PathChallengeFrame {} impl> Package for PathResponseFrame {} impl> Package for StreamCtlFrame {} impl> Package for ReliableFrame {} impl> Package for PunchHelloFrame {} impl> Package for PunchDoneFrame {} impl, D: ContinuousData> Package for (StreamFrame, D) {} impl, D: ContinuousData> Package for (CryptoFrame, D) {} impl, D: ContinuousData> Package for (DatagramFrame, D) {} } pub enum Keys { LongHeaderPacket { keys: DirectionalKeys, }, ShortHeaderPacket { keys: DirectionalKeys, key_phase: KeyPhaseBit, }, } impl Debug for Keys { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::LongHeaderPacket { .. } => f .debug_struct("LongHeaderPacket") .field("keys", &"...") .finish(), Self::ShortHeaderPacket { key_phase, .. } => f .debug_struct("ShortHeaderPacket") .field("keys", &"...") .field("key_phase", key_phase) .finish(), } } } impl Keys { fn hpk(&self) -> &dyn rustls::quic::HeaderProtectionKey { match self { Self::LongHeaderPacket { keys } | Self::ShortHeaderPacket { keys, .. } => { keys.header.as_ref() } } } fn pk(&self) -> &dyn rustls::quic::PacketKey { match self { Self::LongHeaderPacket { keys } | Self::ShortHeaderPacket { keys, .. } => { keys.packet.as_ref() } } } fn key_phase(&self) -> Option { match self { Self::LongHeaderPacket { .. } => None, Self::ShortHeaderPacket { key_phase, .. } => Some(*key_phase), } } } #[derive(Debug)] struct PacketLayout { hdr_len: usize, len_encoding: usize, pn_len: usize, cursor: usize, end: usize, } impl PacketLayout { pub fn payload_len(&self) -> usize { self.cursor - self.hdr_len - self.len_encoding } pub fn is_empty(&self) -> bool { self.payload_len() == self.pn_len } } #[derive(Debug, CopyGetters)] pub struct PacketInfo { #[getset(get_copy = "pub")] packet_type: Type, #[getset(get_copy = "pub")] packet_number: u64, // Packets containing only frames with [`Spec::N`] are not ack-eliciting; // otherwise, they are ack-eliciting. #[getset(get_copy = "pub")] ack_eliciting: bool, // A Boolean that indicates whether the packet counts toward bytes in flight. // See [Section 2](https://www.rfc-editor.org/rfc/rfc9002#section-2) // and [Appendix A.1](https://www.rfc-editor.org/rfc/rfc9002#section-a.1) // of [QUIC Recovery](https://www.rfc-editor.org/rfc/rfc9002). // // Packets containing only frames with [`Spec::C`] do not // count toward bytes in flight for congestion control purposes. #[getset(get_copy = "pub")] in_flight: bool, // Packets containing only frames with [`Spec::P`] can be used to // probe new network paths during connection migration. #[getset(get_copy = "pub")] probe_new_path: bool, #[getset(get_copy = "pub")] largest_ack: Option, } impl PacketInfo { pub fn new(ty: Type, pn: u64) -> Self { Self { packet_type: ty, packet_number: pn, ack_eliciting: false, in_flight: false, probe_new_path: false, largest_ack: None, } } pub fn epoch(&self) -> Option { match self.packet_type() { Type::Long(long) => match long { r#type::long::Type::VersionNegotiation => None, r#type::long::Type::V1(version) => match version.0 { r#type::long::v1::Type::Initial => Some(Epoch::Initial), r#type::long::v1::Type::ZeroRtt => Some(Epoch::Data), r#type::long::v1::Type::Handshake => Some(Epoch::Handshake), r#type::long::v1::Type::Retry => None, }, }, Type::Short(..) => Some(Epoch::Data), } } pub fn add_frame(&mut self, frame: &F) { debug_assert!( frame.belongs_to(self.packet_type()), "Frame {:?} does not belong to packet type {:?}", std::any::type_name_of_val(frame), self.packet_type() ); self.ack_eliciting |= !frame.specs().contain(Spec::NonAckEliciting); self.in_flight |= !frame.specs().contain(Spec::CongestionControlFree); self.probe_new_path |= frame.specs().contain(Spec::ProbeNewPath); if let Some(ack_frame) = (frame as &dyn Any).downcast_ref::() { self.largest_ack = Some(match self.largest_ack { Some(largest_ack) => largest_ack.max(ack_frame.largest()), None => ack_frame.largest(), }); } } } pub trait RecordFrame { fn record_frame(&mut self, frame: &F); } impl RecordFrame, D> for PacketInfo { fn record_frame(&mut self, frame: &Frame) { debug_assert!( frame.belongs_to(self.packet_type(),), "Frame {:?} does not belong to packet type {:?}", frame.frame_type(), self.packet_type() ); self.ack_eliciting |= !frame.specs().contain(Spec::NonAckEliciting); self.in_flight |= !frame.specs().contain(Spec::CongestionControlFree); self.probe_new_path |= frame.specs().contain(Spec::ProbeNewPath); if let Frame::Ack(ack_frame) = frame { self.largest_ack = Some(match self.largest_ack { Some(largest_ack) => largest_ack.max(ack_frame.largest()), None => ack_frame.largest(), }); } } } impl RecordFrame for PacketWriter<'_> where PacketInfo: RecordFrame, { #[inline] fn record_frame(&mut self, frame: &F) { self.pkt_info.record_frame(frame); } } pub struct PacketWriter<'b> { keys: Keys, layout: PacketLayout, pkt_info: PacketInfo, buffer: &'b mut [u8], } impl<'b> PacketWriter<'b> { pub fn new_long( header: &LongHeader, buffer: &'b mut [u8], (actual_pn, encoded_pn): (u64, PacketNumber), keys: DirectionalKeys, ) -> Result where S: EncodeHeader, LongHeader: GetType, for<'a> &'a mut [u8]: WriteHeader>, { let hdr_len = header.size(); let len_encoding = header.length_encoding(); if buffer.len() < hdr_len + len_encoding + 20 { return Err(Signals::CONGESTION); } let (mut hdr_buf, mut payload_buf) = buffer.split_at_mut(hdr_len + len_encoding); hdr_buf.put_header(header); payload_buf.put_packet_number(encoded_pn); let cursor = hdr_len + len_encoding + encoded_pn.size(); Ok(Self { layout: PacketLayout { hdr_len, len_encoding, pn_len: encoded_pn.size(), cursor, end: buffer.len() - keys.packet.tag_len(), }, keys: Keys::LongHeaderPacket { keys }, pkt_info: PacketInfo::new(header.get_type(), actual_pn), buffer, }) } pub fn new_short( header: &OneRttHeader, buffer: &'b mut [u8], (actual_pn, encoded_pn): (u64, PacketNumber), keys: DirectionalKeys, key_phase: KeyPhaseBit, ) -> Result { let hdr_len = header.size(); if buffer.len() < hdr_len + 20 { return Err(Signals::CONGESTION); } let (mut hdr_buf, mut payload_buf) = buffer.split_at_mut(hdr_len); hdr_buf.put_header(header); payload_buf.put_packet_number(encoded_pn); Ok(Self { layout: PacketLayout { hdr_len, len_encoding: 0, pn_len: encoded_pn.size(), cursor: hdr_len + encoded_pn.size(), end: buffer.len() - keys.packet.tag_len(), }, keys: Keys::ShortHeaderPacket { keys, key_phase }, pkt_info: PacketInfo::new(header.get_type(), actual_pn), buffer, }) } #[inline] pub fn buffer(&self) -> &[u8] { self.buffer } #[inline] pub fn is_short_header(&self) -> bool { self.keys.key_phase().is_some() } #[inline] pub fn packet_type(&self) -> Type { self.pkt_info.packet_type() } #[inline] pub fn packet_number(&self) -> u64 { self.pkt_info.packet_number } #[inline] pub fn is_ack_eliciting(&self) -> bool { self.pkt_info.ack_eliciting } #[inline] pub fn in_flight(&self) -> bool { self.pkt_info.in_flight } #[inline] pub fn is_probe_new_path(&self) -> bool { self.pkt_info.probe_new_path } #[inline] pub fn payload_len(&self) -> usize { self.layout.payload_len() } #[inline] pub fn tag_len(&self) -> usize { self.keys.pk().tag_len() } #[inline] pub fn is_empty(&self) -> bool { self.layout.is_empty() } #[inline] pub fn packet_len(&self) -> usize { self.layout.cursor + self.keys.pk().tag_len() } } unsafe impl BufMut for PacketWriter<'_> { #[inline] fn remaining_mut(&self) -> usize { self.layout.end - self.layout.cursor } #[inline] unsafe fn advance_mut(&mut self, cnt: usize) { if self.remaining_mut() < cnt { panic!( "advance out of bounds: the len is {} but advancing by {}", cnt, self.remaining_mut() ); } self.layout.cursor += cnt; } #[inline] fn chunk_mut(&mut self) -> &mut UninitSlice { let range = self.layout.cursor..self.layout.end; UninitSlice::new(&mut self.buffer[range]) } } pub trait AssemblePacket: BufMut { #[inline] fn assemble_packet( &mut self, package: &mut dyn Package, ) -> Result { package.dump(self) } fn encrypt_and_protect_packet(self) -> (usize, PacketInfo); } impl AssemblePacket for PacketWriter<'_> { fn encrypt_and_protect_packet(self) -> (usize, PacketInfo) { use crate::{ packet::encrypt::*, varint::{EncodeBytes, VarInt, WriteVarInt}, }; let Self { keys, layout, pkt_info, buffer, } = self; let payload_len = layout.payload_len(); let tag_len = keys.pk().tag_len(); let actual_pn = pkt_info.packet_number; let pn_len = layout.pn_len; let pkt_size = layout.cursor + tag_len; assert!( payload_len + tag_len >= 20, "The payload and tag needs at least 20 bytes to have enough samples for the packet header protection." ); if let Some(key_phase) = keys.key_phase() { encode_short_first_byte(&mut buffer[0], pn_len, key_phase); let pk = keys.pk(); let payload_offset = layout.hdr_len; let body_offset = payload_offset + pn_len; encrypt_packet(pk, actual_pn, &mut buffer[..pkt_size], body_offset); let hpk = keys.hpk(); protect_header(hpk, &mut buffer[..pkt_size], payload_offset, pn_len); } else { let packet_len = payload_len + tag_len; let len_buffer_range = layout.hdr_len..layout.hdr_len + layout.len_encoding; let mut len_buf = &mut buffer[len_buffer_range]; len_buf.encode_varint(&VarInt::try_from(packet_len).unwrap(), EncodeBytes::Two); encode_long_first_byte(&mut buffer[0], pn_len); let pk = keys.pk(); let payload_offset = layout.hdr_len + layout.len_encoding; let body_offset = payload_offset + pn_len; encrypt_packet(pk, actual_pn, &mut buffer[..pkt_size], body_offset); let hpk = keys.hpk(); protect_header(hpk, &mut buffer[..pkt_size], payload_offset, pn_len); } (pkt_size, pkt_info) } } #[cfg(test)] mod tests { use std::sync::Arc; use super::*; use crate::{frame::CryptoFrame, varint::VarInt}; struct TransparentKeys; impl rustls::quic::PacketKey for TransparentKeys { fn decrypt_in_place<'a>( &self, _packet_number: u64, _header: &[u8], payload: &'a mut [u8], ) -> Result<&'a [u8], rustls::Error> { Ok(&payload[..payload.len() - self.tag_len()]) } fn encrypt_in_place( &self, _packet_number: u64, _header: &[u8], _payload: &mut [u8], ) -> Result { Ok(rustls::quic::Tag::from("transparent_keys".as_bytes())) } fn confidentiality_limit(&self) -> u64 { 0 } fn integrity_limit(&self) -> u64 { 0 } fn tag_len(&self) -> usize { 16 } } impl rustls::quic::HeaderProtectionKey for TransparentKeys { fn decrypt_in_place( &self, _sample: &[u8], _first_byte: &mut u8, _payload: &mut [u8], ) -> Result<(), rustls::Error> { Ok(()) } fn encrypt_in_place( &self, _sample: &[u8], _first_byte: &mut u8, _payload: &mut [u8], ) -> Result<(), rustls::Error> { Ok(()) } fn sample_len(&self) -> usize { 20 } } #[test] fn test_initial_packet_writer() { let mut buffer = vec![0u8; 128]; let header = LongHeaderBuilder::with_cid( ConnectionId::from_slice("testdcid".as_bytes()), ConnectionId::from_slice("testscid".as_bytes()), ) .initial(b"test_token".to_vec()); let pn = (0, PacketNumber::encode(0, 0)); let keys = DirectionalKeys { packet: Arc::new(TransparentKeys), header: Arc::new(TransparentKeys), }; let mut writer = PacketWriter::new_long(&header, &mut buffer, pn, keys).unwrap(); let frame = CryptoFrame::new(VarInt::from_u32(0), VarInt::from_u32(12)); writer .assemble_packet(&mut (frame, "client_hello".as_bytes())) .unwrap(); assert!(writer.is_ack_eliciting()); assert!(writer.in_flight()); let (sent_bytes, final_packet_layout) = writer.encrypt_and_protect_packet(); assert!(final_packet_layout.ack_eliciting()); assert!(final_packet_layout.in_flight()); assert_eq!(sent_bytes, 69); assert_eq!( &buffer[..sent_bytes], [ // initial packet: // header form (1) = 1,, long header // fixed bit (1) = 1, // long packet type (2) = 0, initial packet // reserved bits (2) = 0, // packet number length (2) = 0, 1 byte 193, // first byte 0, 0, 0, 1, // quic version // destination connection id, "testdcid" 8, // dcid length b't', b'e', b's', b't', b'd', b'c', b'i', b'd', // dcid bytes // source connection id, "testscid" 8, // scid length b't', b'e', b's', b't', b's', b'c', b'i', b'd', // scid bytes 10, // token length, no token b't', b'e', b's', b't', b'_', b't', b'o', b'k', b'e', b'n', // token bytes 64, 33, // payload length, 2 bytes encoded varint 0, 0, // encoded packet number // crypto frame header 6, // crypto frame type 0, // crypto frame offset 12, // crypto frame length // crypto frame data, "client hello" b'c', b'l', b'i', b'e', b'n', b't', b'_', b'h', b'e', b'l', b'l', b'o', // tag, "transparent_keys" b't', b'r', b'a', b'n', b's', b'p', b'a', b'r', b'e', b'n', b't', b'_', b'k', b'e', b'y', b's', ] .as_slice() ); } } ================================================ FILE: qbase/src/packet/keys.rs ================================================ use std::{ future::Future, ops::DerefMut, pin::Pin, sync::{Arc, Mutex, MutexGuard}, task::{Context, Poll, Waker}, }; use futures::FutureExt; use rustls::quic::{ DirectionalKeys as RustlsDirectionalKeys, HeaderProtectionKey, Keys as RustlsKeys, PacketKey, Secrets, }; /// Keys used to communicate in a single direction #[derive(Clone)] pub struct DirectionalKeys { /// Encrypts or decrypts a packet's headers pub header: Arc, /// Encrypts or decrypts the payload of a packet pub packet: Arc, } impl From for DirectionalKeys { fn from(keys: RustlsDirectionalKeys) -> Self { Self { header: keys.header.into(), packet: keys.packet.into(), } } } /// Complete set of keys used to communicate with the peer #[derive(Clone)] pub struct Keys { /// Encrypts outgoing packets pub local: DirectionalKeys, /// Decrypts incoming packets pub remote: DirectionalKeys, } impl From for Keys { fn from(keys: RustlsKeys) -> Self { Self { local: keys.local.into(), remote: keys.remote.into(), } } } use super::KeyPhaseBit; use crate::role::Role; #[derive(Clone)] enum KeysState { Pending(Option), Ready(K), Invalid, } impl KeysState { fn set(&mut self, keys: K) { match self { KeysState::Pending(waker) => { if let Some(waker) = waker.take() { waker.wake(); } *self = KeysState::Ready(keys); } KeysState::Ready(_) => unreachable!("KeysState::set called twice"), KeysState::Invalid => unreachable!("KeysState::set called after invalidation"), } } fn get(&mut self) -> Option<&K> { match self { KeysState::Ready(keys) => Some(keys), KeysState::Pending(..) | KeysState::Invalid => None, } } fn invalid(&mut self) -> Option { match std::mem::replace(self, KeysState::Invalid) { KeysState::Pending(waker) => { if let Some(waker) = waker { waker.wake(); } None } KeysState::Ready(keys) => Some(keys), KeysState::Invalid => None, } } } impl Future for KeysState { type Output = Option; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.get_mut() { KeysState::Pending(waker) => { if waker .as_ref() .is_some_and(|waker| !waker.will_wake(cx.waker())) { unreachable!( "Try to get remote keys from multiple tasks! This is a bug, please report it." ) } *waker = Some(cx.waker().clone()); Poll::Pending } KeysState::Ready(keys) => Poll::Ready(Some(keys.clone())), KeysState::Invalid => Poll::Ready(None), } } } /// Long packet keys, for encryption and decryption keys for those long packets, /// as well as keys for adding and removing long packet header protection. /// /// - When sending, obtain the local keys for packet encryption and adding header protection. /// If the keys are not ready, skip sending the packet of this level immidiately. /// - When receiving a packet and decrypting it, obtain the remote keys for removing header /// protection and packet decryption. /// If the keys are not ready, wait asynchronously until the keys to be ready to continue. /// /// ## Note /// /// The keys for 1-RTT packets are a separate structure, see [`ArcOneRttKeys`]. #[derive(Clone)] pub struct ArcKeys(Arc>>); impl ArcKeys { fn lock_guard(&self) -> MutexGuard<'_, KeysState> { self.0.lock().unwrap() } /// Create a Pending state [`ArcKeys`]. /// /// For a new Quic connection, initially only the Initial key is known, and the 0-RTT /// and Handshake keys are unknown. /// Therefore, the 0-RTT and Handshake keys can be created in a Pending state, waiting /// for updates during the TLS handshake process. pub fn new_pending() -> Self { Self(Arc::new(KeysState::Pending(None).into())) } /// Create an [`ArcKeys`] with a specified [`rustls::quic::Keys`]. /// /// The initial keys are known at first, can use this method to create the [`ArcKeys`]. pub fn with_keys(keys: Keys) -> Self { Self(Arc::new(KeysState::Ready(keys).into())) } /// Asynchronously obtain the remote keys for removing header protection and packet decryption. /// /// Rreturn [`GetRemoteKeys`], which implemented Future trait. /// /// ## Example /// /// The following is only a demonstration. /// In fact, removing header protection and decrypting packets are far more complex! /// /// ``` /// use qbase::packet::keys::ArcKeys; /// /// async fn decrypt_demo(keys: ArcKeys, cipher_text: &mut [u8]) { /// let Some(keys) = keys.get_remote_keys().await else { /// return; /// }; /// /// let hpk = keys.remote.header.as_ref(); /// let pk = keys.remote.packet.as_ref(); /// /// // use hpk to remove header protection... /// // use pk to decrypt packet body... /// } /// ``` pub fn get_remote_keys(&self) -> GetRemoteKeys<'_, Keys> { GetRemoteKeys(&self.0) } /// Get the local keys for packet encryption and adding header protection. /// If the keys is not ready, just return None immediately. /// /// ## Example /// /// The following is only a demonstration. /// In fact, encrypting packets and adding header protection are far more complex! /// /// ``` /// use qbase::packet::keys::ArcKeys; /// /// fn encrypt_demo(keys: ArcKeys, plain_text: &mut [u8]) { /// let Some(keys) = keys.get_local_keys() else { /// return; /// }; /// /// let hpk = keys.local.header.as_ref(); /// let pk = keys.local.packet.as_ref(); /// /// // use pk to encrypt packet body... /// // use hpk to add header protection... /// } /// ``` pub fn get_local_keys(&self) -> Option { self.lock_guard().get().cloned() } /// Set the keys to the [`ArcKeys`]. /// /// As the TLS handshake progresses, higher-level keys will be obtained. /// These keys are set to the related [`ArcKeys`] through this method, and /// its internal waker will be awakened to notify the packet decryption task /// to continue, if the internal waker was registered. pub fn set_keys(&self, keys: Keys) { self.lock_guard().set(keys); } /// Retire the keys, which means that the keys are no longer available. /// /// This is used when the connection enters the closing state or draining state. /// Especially in the closing state, the return keys are used to generate the final packet /// containing the ConnectionClose frame, and decrypt the data packets received from the /// peer for a while. pub fn invalid(&self) -> Option { self.lock_guard().invalid() } } /// To obtain the remote keys from [`ArcKeys`] or [`ArcOneRttKeys`] for removing long header protection /// and packet decryption. pub struct GetRemoteKeys<'k, K>(&'k Mutex>); impl Future for GetRemoteKeys<'_, K> { type Output = Option; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { Pin::new(self.0.lock().unwrap()).poll_unpin(cx) } } #[derive(Clone)] pub struct ArcZeroRttKeys { role: Role, keys: Arc>>, } impl ArcZeroRttKeys { pub fn new_pending(role: Role) -> Self { Self { role, keys: Arc::new(Mutex::new(KeysState::Pending(None))), } } fn lock_guard(&self) -> MutexGuard<'_, KeysState> { self.keys.lock().unwrap() } pub fn set_keys(&self, keys: DirectionalKeys) { self.lock_guard().set(keys); } pub fn get_encrypt_keys(&self) -> Option { match self.role { Role::Client => self.lock_guard().get().cloned(), Role::Server => None, } } pub fn get_decrypt_keys(&self) -> Option> { match self.role { Role::Client => None, Role::Server => Some(GetRemoteKeys(&self.keys)), } } pub fn invalid(&self) -> Option { self.lock_guard().invalid() } } /// The packet encryption and decryption keys for 1-RTT packets, /// which will still change after negotiation between the two endpoints. /// /// See [key update](https://www.rfc-editor.org/rfc/rfc9001#name-key-update) /// of [RFC 9001](https://www.rfc-editor.org/rfc/rfc9001) for more details. pub struct OneRttPacketKeys { cur_phase: KeyPhaseBit, secrets: Secrets, // TODO: 保存三个 remote: [Option>; 2], local: Arc, } impl OneRttPacketKeys { /// Create new [`OneRttPacketKeys`]. /// /// The TLS handshake session must exchange enough information to generate the 1-RTT keys. fn new(remote: Box, local: Box, secrets: Secrets) -> Self { Self { cur_phase: KeyPhaseBit::default(), secrets, remote: [Some(Arc::from(remote)), None], local: Arc::from(local), } } /// Proactively update the 1-RTT packet key locally. /// Or be informed by the peer to update the key. /// /// The key phase bit will be toggled and sent to the peer, /// informing the peer to update the key to next 1-RTT packet key too. pub fn update(&mut self) { self.cur_phase.toggle(); let key_set = self.secrets.next_packet_keys(); self.remote[self.cur_phase.as_index()] = Some(Arc::from(key_set.remote)); self.local = Arc::from(key_set.local); } /// Old key must be phased out within a certain period of time. /// /// If the old one don't go, the new ones won't come. /// If it is not phased out, it will be considered as new keys and /// fail to decrypt the packet in future. pub fn phase_out(&mut self) { self.remote[(!self.cur_phase).as_index()].take(); } /// Get the remote key to decrypt the incoming 1-RTT packet. /// If the key phase is not the current key phase, update the key, see [`Self::update`]. /// /// Return `Arc` to decrypt the incoming 1-RTT packet. pub fn get_remote(&mut self, key_phase: KeyPhaseBit, _pn: u64) -> Arc { if key_phase != self.cur_phase && self.remote[key_phase.as_index()].is_none() { self.update(); } self.remote[key_phase.as_index()].clone().unwrap() } /// Get the local current key to encrypt the outgoing packet. /// /// Return `Arc` to encrypt the outgoing 1-RTT packet. pub fn get_local(&self) -> (KeyPhaseBit, Arc) { (self.cur_phase, self.local.clone()) } } /// The packet encryption and decryption keys for 1-RTT packets, which will still /// change based on the KeyPhase bit in the receiving packet, or they can be update /// it proactively locally. /// /// For performance reasons, the second element of the tuple is the length of the /// tag of the local packet key's underlying AEAD algorithm redundantly. #[derive(Clone)] pub struct ArcOneRttPacketKeys(Arc<(Mutex, usize)>); impl ArcOneRttPacketKeys { /// Obtain exclusive access to the 1-RTT packet keys. /// During the exclusive period of encrypting or decrypting packets, /// the keys must not be updated elsewhere. pub fn lock_guard(&self) -> MutexGuard<'_, OneRttPacketKeys> { self.0.0.lock().unwrap() } /// Get the length of the tag of the packet key's underlying AEAD algorithm. /// /// For example, when collecting data to send, buffer needs to reserve /// the tag length space to fill in the integrity checksum codes. /// After collecting the data, encryption will be performed, and exclusive /// access will be obtained during encryption. /// There is no need to acquire the lock at the beginning to get the tag /// length, because nothing might be sent later, and the task might be canceled. /// This would save the initial locking overhead. /// Keeping a redundant tag length that can be obtained without locking /// will improve lock performance. pub fn tag_len(&self) -> usize { self.0.1 } } /// The header protection keys for 1-RTT packets. #[derive(Clone)] pub struct HeaderProtectionKeys { pub local: Arc, pub remote: Arc, } enum OneRttKeysState { Pending(Option), Ready { hpk: HeaderProtectionKeys, pk: ArcOneRttPacketKeys, }, Invalid, } /// 1-RTT packet keys, for packet encryption and decryption for 1-RTT packets, /// as well as keys for adding and removing 1-RTT packet header protection. /// /// and its packet key will be updated. /// /// Unlike [`ArcKeys`], the HeaderProtectionKey for 1-RTT keys does not change, /// but the PacketKey may still be updated with changes in the KeyPhase bit. /// Therefore, the HeaderProtectionKey and PacketKey need to be managed separately. #[derive(Clone)] pub struct ArcOneRttKeys(Arc>); impl ArcOneRttKeys { fn lock_guard(&self) -> MutexGuard<'_, OneRttKeysState> { self.0.lock().unwrap() } /// Create a Pending state [`ArcOneRttKeys`], waiting for the keys being ready /// from TLS handshaking. pub fn new_pending() -> Self { Self(Arc::new(OneRttKeysState::Pending(None).into())) } /// Set the keys to the [`ArcOneRttKeys`]. /// /// As the TLS handshake progresses, 1-RTT keys will finally be obtained. /// And then its internal waker will be awakened to notify the packet /// decryption task to continue, if the internal waker was registered. pub fn set_keys(&self, keys: RustlsKeys, secrets: Secrets) { let mut state = self.lock_guard(); match &mut *state { OneRttKeysState::Pending(waker) => { let hpk = HeaderProtectionKeys { local: Arc::from(keys.local.header), remote: Arc::from(keys.remote.header), }; let tag_len = keys.local.packet.tag_len(); let pk = ArcOneRttPacketKeys(Arc::new(( Mutex::new(OneRttPacketKeys::new( keys.remote.packet, keys.local.packet, secrets, )), tag_len, ))); if let Some(w) = waker.take() { w.wake(); } *state = OneRttKeysState::Ready { hpk, pk }; } OneRttKeysState::Ready { .. } => panic!("set_keys called twice"), OneRttKeysState::Invalid => panic!("set_keys called after invalidation"), } } pub fn invalid(&self) -> Option<(HeaderProtectionKeys, ArcOneRttPacketKeys)> { let mut state = self.lock_guard(); match std::mem::replace(state.deref_mut(), OneRttKeysState::Invalid) { OneRttKeysState::Pending(rx_waker) => { if let Some(waker) = rx_waker { waker.wake(); } None } OneRttKeysState::Ready { hpk, pk } => Some((hpk, pk)), OneRttKeysState::Invalid => unreachable!(), } } /// Get the local keys for packet encryption and adding header protection. /// If the keys are not ready, just return None immediately. /// /// Return a tuple of HeaderProtectionKey and OneRttPacketKeys. /// The OneRttPacketKeys need to be locked during the entire packet encryption process. pub fn get_local_keys(&self) -> Option<(Arc, ArcOneRttPacketKeys)> { let mut keys = self.lock_guard(); match &mut *keys { OneRttKeysState::Ready { hpk, pk, .. } => Some((hpk.local.clone(), pk.clone())), _ => None, } } pub fn remote_keys(&self) -> Option<(Arc, ArcOneRttPacketKeys)> { match &mut *self.lock_guard() { OneRttKeysState::Ready { hpk, pk, .. } => Some((hpk.remote.clone(), pk.clone())), _ => None, } } /// Asynchronously obtain the remote keys for removing header protection and packet decryption. /// /// Rreturn [`GetRemoteKeys`], which implemented the Future trait. pub fn get_remote_keys(&self) -> GetRemoteOneRttKeys<'_> { GetRemoteOneRttKeys(self) } } /// To obtain the remote key from [`ArcOneRttKeys`]` for removing 1-RTT header /// protection and packet decryption. pub struct GetRemoteOneRttKeys<'k>(&'k ArcOneRttKeys); impl Future for GetRemoteOneRttKeys<'_> { type Output = Option<(Arc, ArcOneRttPacketKeys)>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut keys = self.0.lock_guard(); match &mut *keys { OneRttKeysState::Pending(waker) => { if waker .as_ref() .is_some_and(|waker| !waker.will_wake(cx.waker())) { unreachable!( "Try to get remote keys from multiple tasks! This is a bug, please report it." ) } *waker = Some(cx.waker().clone()); Poll::Pending } OneRttKeysState::Ready { hpk, pk, .. } => { Poll::Ready(Some((hpk.remote.clone(), pk.clone()))) } OneRttKeysState::Invalid => Poll::Ready(None), } } } ================================================ FILE: qbase/src/packet/number.rs ================================================ use std::cmp::max; use bytes::BufMut; use thiserror::Error; /// An encoded or undecoded packet number /// /// The actual packet number is an integer in the range 0 to 2^62 - 1 and encoded in 1 to 4 bytes. /// /// See [packet numbers](https://www.rfc-editor.org/rfc/rfc9000.html#name-packet-numbers) and /// [packet number encoding and decoding](https://www.rfc-editor.org/rfc/rfc9000.html#section-17.1) /// of [RFC 9000](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum PacketNumber { U8(u8), U16(u16), U24(u32), U32(u32), } #[derive(Debug, Error, PartialEq, Eq)] pub enum InvalidPacketNumber { #[error("Packet number too old")] TooOld, #[error("Packet number too large")] TooLarge, #[error("Packet with this number has been received")] Duplicate, } /// Implement this trait for buffer, which can be used to write the packet number into the buffer. pub trait WritePacketNumber { /// Write the encoded packet number to the buffer. fn put_packet_number(&mut self, pn: PacketNumber); } impl WritePacketNumber for T { fn put_packet_number(&mut self, pn: PacketNumber) { use self::PacketNumber::*; match pn { U8(x) => self.put_u8(x), U16(x) => self.put_u16(x), U24(x) => { self.put_u8((x >> 16) as u8); self.put_u16(x as u16); } U32(x) => self.put_u32(x), } } } /// Parse the packet number from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. /// /// ## Example /// /// ``` /// use qbase::packet::number::{PacketNumber, take_pn_len}; /// /// let buf = [0x01, 0x00]; /// assert_eq!( /// (&[][..], PacketNumber::U16(1 << 8)), /// take_pn_len(2)(&buf).unwrap() /// ); /// ``` pub fn take_pn_len(pn_len: u8) -> impl FnMut(&[u8]) -> nom::IResult<&[u8], PacketNumber> { use nom::{ Parser, combinator::map, number::complete::{be_u8, be_u16, be_u24, be_u32}, }; move |input: &[u8]| match pn_len { 1 => map(be_u8, PacketNumber::U8).parse(input), 2 => map(be_u16, PacketNumber::U16).parse(input), 3 => map(be_u24, PacketNumber::U24).parse(input), 4 => map(be_u32, PacketNumber::U32).parse(input), _ => unreachable!(), } } impl PacketNumber { /// Encode the packet number, based on the maximum confirmed packet number. /// /// The size of the packet number encoding is at least one bit more than the /// base-2 logarithm of the number of contiguous unacknowledged packet numbers /// /// See [Section 17.1-5](https://www.rfc-editor.org/rfc/rfc9000.html#section-17.1-5) and /// [Appendix A.2](https://www.rfc-editor.org/rfc/rfc9000.html#section-a.2) /// for more details. pub fn encode(pn: u64, largest_acked: u64) -> Self { // Minimum 16-bit PN encoding ensures delayed packets on slower paths remain decodable let range = max((pn - largest_acked) * 2, (1 << 16) - 1); if range < 1 << 8 { Self::U8(pn as u8) } else if range < 1 << 16 { Self::U16(pn as u16) } else if range < 1 << 24 { Self::U24(pn as u32) } else if range < 1 << 32 { Self::U32(pn as u32) } else { panic!("packet number too large to encode") } } /// Return the size of the packet number encoding. pub fn size(self) -> usize { use self::PacketNumber::*; match self { U8(_) => 1, U16(_) => 2, U24(_) => 3, U32(_) => 4, } } /// Decode the packet number after header protection has been removed. /// /// The packet number is decoded based on the largest received packet number. /// The next expected packet is the largest received packet number plus one. /// /// See [Section 17.1-7](https://www.rfc-editor.org/rfc/rfc9000.html#section-17.1-7) and /// [Section A.3](https://www.rfc-editor.org/rfc/rfc9000.html#section-a.3) /// for more details. pub fn decode(self, expected: u64) -> u64 { use self::PacketNumber::*; let (truncated, nbits) = match self { U8(x) => (u64::from(x), 8), U16(x) => (u64::from(x), 16), U24(x) => (u64::from(x), 24), U32(x) => (u64::from(x), 32), }; let win = 1 << nbits; let hwin = win / 2; let mask = win - 1; // The incoming packet number should be greater than expected - hwin and less than or equal // to expected + hwin // // This means we can't just strip the trailing bits from expected and add the truncated // because that might yield a value outside the window. // // The following code calculates a candidate value and makes sure it's within the packet // number window. let candidate = (expected & !mask) | truncated; if expected.checked_sub(hwin).is_some_and(|x| candidate <= x) { candidate + win } else if candidate > expected + hwin && candidate > win { candidate - win } else { candidate } } } #[cfg(test)] mod tests { use super::{PacketNumber, WritePacketNumber}; #[test] fn test_read_packet_number() { let buf = [0x00]; assert_eq!( (&[][..], super::PacketNumber::U8(0)), super::take_pn_len(1)(&buf).unwrap() ); let buf = [0x01, 0x00]; assert_eq!( (&[][..], super::PacketNumber::U16(1 << 8)), super::take_pn_len(2)(&buf).unwrap() ); let buf = [0x01, 0x00, 0x00]; assert_eq!( (&[][..], super::PacketNumber::U24(1 << 16)), super::take_pn_len(3)(&buf).unwrap() ); let buf = [0x01, 0x00, 0x00, 0x00]; assert_eq!( (&[][..], super::PacketNumber::U32(1 << 24)), super::take_pn_len(4)(&buf).unwrap() ); } #[test] #[should_panic] fn test_read_packet_number_too_large() { let buf = [0x01, 0x00, 0x00, 0x00, 0x00]; super::take_pn_len(5)(&buf).unwrap(); } #[test] fn test_write_packet_number() { let mut buf = vec![]; buf.put_packet_number(PacketNumber::encode(0, 0)); // Minimum 16-bit PN encoding ensures delayed packets on slower paths remain decodable assert_eq!(buf, [0x00, 0x00]); buf.clear(); buf.put_packet_number(PacketNumber::encode(1 << 8, 0)); assert_eq!(buf, [0x01, 0x00]); buf.clear(); buf.put_packet_number(PacketNumber::encode(1 << 16, 0)); assert_eq!(buf, [0x01, 0x00, 0x00]); buf.clear(); buf.put_packet_number(PacketNumber::encode(1 << 24, 0)); assert_eq!(buf, [0x01, 0x00, 0x00, 0x00]); } #[test] fn test_encode_packet_number() { let pn = super::PacketNumber::encode((1 << 31) - 1, 0); assert_eq!(pn.decode(0), (1 << 31) - 1); let pn = super::PacketNumber::encode(0, 0); assert_eq!(pn.decode(0), 0); } #[test] #[should_panic] fn test_encode_packet_number_overflow() { PacketNumber::encode(1 << 31, 0); } } ================================================ FILE: qbase/src/packet/signal.rs ================================================ /// The spin bit in 1-RTT packets const SPIN_BIT: u8 = 0x20; /// The key phase bit in 1-RTT packets const KEY_PHASE_BIT: u8 = 0x04; /// The toggle type, which can be used to represent the spin bit and key phase bit. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub enum Toggle { /// Represents the bit is 0 #[default] Zero, /// Represents the bit is 1 One, } /// The spin bit in the 1-RTT packet. pub type SpinBit = Toggle; /// The key phase bit in the 1-RTT packet. pub type KeyPhaseBit = Toggle; impl Toggle { /// Toggle the bit, from 0 to 1, or from 1 to 0. pub fn toggle(&mut self) { *self = match self { Toggle::Zero => Toggle::One, Toggle::One => Toggle::Zero, } } /// Get the value of the bit. pub fn value(&self) -> u8 { match self { Toggle::Zero => 0, Toggle::One => B, } } /// Imply the bit to the byte. pub fn imply(&self, byte: &mut u8) { match self { Toggle::Zero => *byte &= !B, Toggle::One => *byte |= B, } } /// Treat Toggle as an index and get the index value it represents, i.e., 0 or 1 pub(crate) fn as_index(&self) -> usize { match self { Toggle::Zero => 0, Toggle::One => 1, } } } impl std::ops::Not for Toggle { type Output = Self; fn not(self) -> Self::Output { match self { Toggle::Zero => Toggle::One, Toggle::One => Toggle::Zero, } } } impl From for Toggle { fn from(value: u8) -> Self { if value & B == 0 { Toggle::Zero } else { Toggle::One } } } impl From> for u8 { fn from(value: Toggle) -> Self { value.value() } } impl From for Toggle { fn from(value: bool) -> Self { if value { Toggle::One } else { Toggle::Zero } } } impl From> for bool { fn from(value: Toggle) -> Self { match value { Toggle::Zero => false, Toggle::One => true, } } } ================================================ FILE: qbase/src/packet/type/long/v1.rs ================================================ use crate::packet::{error::Error, r#type::FIXED_BIT}; /// Long packet types. The 3th and 4th bits of the first byte of the long header /// represent the specific packet type. /// /// See [long header packet types](https://www.rfc-editor.org/rfc/rfc9000.html#name-long-header-packet-types) /// of [RFC 9000](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Type { /// Initial packet type, represented by 0b00 Initial, /// 0-RTT packet type, represented by 0b01 ZeroRtt, /// Handshake packet type, represented by 0b10 Handshake, /// Retry packet type, represented by 0b11 Retry, } const LONG_PACKET_TYPE_MASK: u8 = 0x30; const INITIAL_PACKET_TYPE: u8 = 0x00; const ZERO_RTT_PACKET_TYPE: u8 = 0x10; const HANDSHAKE_PACKET_TYPE: u8 = 0x20; const RETRY_PACKET_TYPE: u8 = 0x30; impl From for u8 { fn from(value: Type) -> u8 { match value { Type::Retry => RETRY_PACKET_TYPE, Type::Initial => INITIAL_PACKET_TYPE, Type::ZeroRtt => ZERO_RTT_PACKET_TYPE, Type::Handshake => HANDSHAKE_PACKET_TYPE, } } } impl TryFrom for Type { type Error = Error; fn try_from(value: u8) -> Result { if value & FIXED_BIT == 0 { return Err(Error::InvalidFixedBit); } match value & LONG_PACKET_TYPE_MASK { INITIAL_PACKET_TYPE => Ok(Type::Initial), ZERO_RTT_PACKET_TYPE => Ok(Type::ZeroRtt), HANDSHAKE_PACKET_TYPE => Ok(Type::Handshake), RETRY_PACKET_TYPE => Ok(Type::Retry), _ => unreachable!(), } } } #[cfg(test)] mod tests { #[test] fn test_try_from() { use super::Type; use crate::packet::error::Error; assert_eq!(Type::try_from(0xc0), Ok(Type::Initial)); assert_eq!(Type::try_from(0xd0), Ok(Type::ZeroRtt)); assert_eq!(Type::try_from(0xe0), Ok(Type::Handshake)); assert_eq!(Type::try_from(0xf0), Ok(Type::Retry)); assert_eq!(Type::try_from(0x00), Err(Error::InvalidFixedBit)); } } ================================================ FILE: qbase/src/packet/type/long.rs ================================================ use derive_more::Deref; /// Supports IQuic version 1, if other versions are supported in the future, add them here. pub mod v1; /// The long packet header contains version information, so the 32-bit /// version number info is also one part of the versioned packet type. /// /// `N`` represents an 32-bit version number, and /// `Ty`` represents the specific type of the version. #[derive(Debug, Clone, Copy, Deref, PartialEq, Eq)] pub struct Version(#[deref] pub(crate) Ty); /// Long packet types all have a Version, so the version number can be obtained /// from the long packet type. pub trait GetVersion { /// Get the version number from long packet type. fn get_version(&self) -> u32; } impl GetVersion for Version { fn get_version(&self) -> u32 { N } } /// Mainly define the long packet types of the IQuic version 1. impl Version<1, v1::Type> { /// Retry packet type of the IQuic version 1. pub const RETRY: Self = Self(v1::Type::Retry); /// Initial packet type of the IQuic version 1. pub const INITIAL: Self = Self(v1::Type::Initial); /// 0-RTT packet type of the IQuic version 1. pub const ZERO_RTT: Self = Self(v1::Type::ZeroRtt); /// Handshake packet type of the IQuic version 1. pub const HANDSHAKE: Self = Self(v1::Type::Handshake); } /// Represent the packet types in the IQuic version 1, including Retry/Initial/0-RTT/Handshake. pub type Ver1 = Version<1, v1::Type>; /// The sum types of the long packets. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Type { VersionNegotiation, V1(Version<1, v1::Type>), // in the future, add other versions here // V2(v2::HeaderType), } /// The io module provides the functions to parse and write the long packet type. pub mod io { use bytes::BufMut; use nom::number::streaming::be_u32; use super::{super::FIXED_BIT, *}; use crate::packet::error::Error; const LONG_HEADER_BIT: u8 = 0x80; /// Parse the long packet type from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn parse_long_type(ty: u8) -> impl FnMut(&[u8]) -> nom::IResult<&[u8], Type, Error> { move |input| { let (remain, version) = be_u32(input)?; match version { 0 => Ok((remain, Type::VersionNegotiation)), 1 => Ok(( remain, Type::V1(Version::<1, v1::Type>( ty.try_into().map_err(nom::Err::Error)?, )), )), v => Err(nom::Err::Error(Error::UnsupportedVersion(v))), } } } /// A [`bytes::BufMut`] extension trait, makes buffer more friendly to write long packet type. pub trait WriteLongType: BufMut { /// Write the long packet type to the buffer. fn put_long_type(&mut self, value: &Type); } impl WriteLongType for B { fn put_long_type(&mut self, value: &Type) { match value { Type::VersionNegotiation => { self.put_u8(LONG_HEADER_BIT); self.put_u32(0); } Type::V1(Version::<1, _>(ty)) => { let ty: u8 = (*ty).into(); self.put_u8(LONG_HEADER_BIT | FIXED_BIT | ty); self.put_u32(1); } } } } } #[cfg(test)] mod tests { use crate::packet::r#type::long::Ver1; #[test] fn test_read_long_type() { use super::{Type, io::parse_long_type}; let buf = vec![0x00, 0x00, 0x00, 0x01]; let (remain, ty) = parse_long_type(0xc0)(&buf).unwrap(); assert_eq!(remain.len(), 0); assert_eq!(ty, Type::V1(Ver1::INITIAL)); let buf = vec![0x00, 0x00, 0x00, 0x00]; let (remain, ty) = parse_long_type(0x80)(&buf).unwrap(); assert_eq!(remain.len(), 0); assert_eq!(ty, Type::VersionNegotiation); } #[test] #[should_panic] fn test_read_long_type_with_wrong_version() { use super::{Type, io::parse_long_type}; let buf = vec![0x00, 0x00, 0x00, 0x03]; let (remain, ty) = parse_long_type(0xc0)(&buf).unwrap(); assert_eq!(remain.len(), 0); assert_eq!(ty, Type::V1(Ver1::INITIAL)); } #[test] fn test_write_long_type() { use super::Type; use crate::packet::r#type::long::io::WriteLongType; let mut buf = vec![]; let ty = Type::V1(Ver1::INITIAL); buf.put_long_type(&ty); assert_eq!(buf, vec![0xc0, 0x00, 0x00, 0x00, 0x01]); } #[test] fn test_write_version_negotiation_long_type() { use super::Type; use crate::packet::r#type::long::io::WriteLongType; let mut buf = vec![]; let ty = Type::VersionNegotiation; buf.put_long_type(&ty); assert_eq!(buf, vec![0x80, 0x00, 0x00, 0x00, 0x00]); } } ================================================ FILE: qbase/src/packet/type/short.rs ================================================ use bytes::BufMut; use derive_more::Deref; use crate::packet::SpinBit; const SHORT_HEADER_BIT: u8 = 0x00; /// The type of the 1-Rtt packet. /// For simplicity, the spin bit is also one part of the 1-Rtt packet type. #[derive(Debug, Clone, Copy, Deref, PartialEq, Eq)] pub struct OneRtt(#[deref] pub SpinBit); impl From for OneRtt { fn from(value: u8) -> Self { OneRtt(SpinBit::from(value)) } } impl From for u8 { fn from(one_rtt: OneRtt) -> Self { SHORT_HEADER_BIT | super::FIXED_BIT | one_rtt.0.value() } } /// A [`bytes::BufMut`] extension trait, makes buffer more friendly to write the short packet type. pub trait WriteShortType: BufMut { /// Write the short packet type to the buffer. fn put_short_type(&mut self, ty: &OneRtt); } impl WriteShortType for B { fn put_short_type(&mut self, ty: &OneRtt) { self.put_u8((*ty).into()); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_write_short_type() { use super::OneRtt; let mut buf = vec![]; let ty = OneRtt::from(0x00); buf.put_short_type(&ty); // Note: 0x40 == SHORT_HEADER_BIT | super::FIXED_BIT | 0x00 assert_eq!(buf, vec![0x40]); let mut buf = vec![]; let ty = OneRtt::from(0x20); buf.put_short_type(&ty); // Note: 0x60 == SHORT_HEADER_BIT | super::FIXED_BIT | 0x20 assert_eq!(buf, vec![0x60]); } } ================================================ FILE: qbase/src/packet/type.rs ================================================ use derive_more::Deref; use super::{KeyPhaseBit, PacketNumber, error::Error}; /// Definitions of packet types related to long headers pub mod long; /// Definitions of packet types related to short headers pub mod short; /// Header form bit const HEADER_FORM_MASK: u8 = 0x80; /// The next bit (0x40) of byte 0 is set to 1, unless the packet is a Version Negotiation packet. const FIXED_BIT: u8 = 0x40; /// Reserved bits mask for long headers, for the 5th and 6th bits of the first byte of the long header pub const LONG_RESERVED_MASK: u8 = 0x0C; /// Reserved bits mask for short headers, for the 4th and 5th bits of the first byte of the short header pub const SHORT_RESERVED_MASK: u8 = 0x18; /// The lower specific bits of the first byte of the long or short header. /// 'R' represents the reserved bits. /// /// - For long packet headers, it is the lower 4 bits of the first byte, and R is 0x0C. /// - For the short packet header, it is the lower 5 bits of the first byte, and R is 0x18. #[derive(Debug, Clone, Copy, Deref)] pub struct SpecificBits(pub(super) u8); /// The lower 4 bits of the first byte of the long header. /// /// Include 2 reserved bits that must be 0, and 2 bits for the packet number length. /// All of them are protected. pub type LongSpecificBits = SpecificBits; /// The lower 5 bits of the first byte of the short header, i.e., the last 5 bits. /// /// Include 2 reserved bits that must be 0, 1 bit for the key phase, /// and 2 bits for the packet number length. /// All of them are protected. pub type ShortSpecificBits = SpecificBits; impl SpecificBits { /// Create a [`SpecificBits`] with the [`PacketNumber`]. pub fn from_pn(pn: &PacketNumber) -> Self { Self(pn.size() as u8 - 1) } /// Create a [`SpecificBits`] with the packet number length. pub fn with_pn_len(pn_size: usize) -> Self { debug_assert!(pn_size <= 4 && pn_size > 0); Self(pn_size as u8 - 1) } } impl ShortSpecificBits { /// Set the Key Phase bit to the specific bits for 1rtt header. pub fn set_key_phase(&mut self, key_phase_bit: KeyPhaseBit) { key_phase_bit.imply(&mut self.0); } /// Get the Key Phase bit from the specific bits of 1rtt header. pub fn key_phase(&self) -> KeyPhaseBit { KeyPhaseBit::from(self.0) } } impl From for SpecificBits { fn from(byte: u8) -> Self { Self(byte) } } /// Get the packet number length from the protected first byte of the long or short header. /// The reserved bits must be 0; otherwise, a connection error of type PROTOCOL_VIOLATION /// is returned. /// /// See [Section 17.2](https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-8.2) and /// [Section 17.3.1](https://www.rfc-editor.org/rfc/rfc9000.html#section-17.3.1-4.8) of QUIC. pub trait GetPacketNumberLength { /// The last two bits of first byte contain the length of the Packet Number const PN_LEN_MASK: u8 = 0x03; /// Get the encoding length of the Packet Number fn pn_len(&self) -> Result; } impl GetPacketNumberLength for SpecificBits { fn pn_len(&self) -> Result { let reserved_bit = self.0 & R; if reserved_bit == 0 { Ok((self.0 & Self::PN_LEN_MASK) + 1) } else { Err(Error::InvalidReservedBits(reserved_bit, R)) } } } /// The Type of the packet /// /// The Type is only extracted from the first 3 or 4 bits of the first byte, these contents /// are not protected. /// For simplicity and future-oriented considerations, the Version of the long packet header /// is also considered part of the Type, such as the Initial packet of V1 version, /// That is, the Initial packet only makes sense under the V1 version, and it is uncertain /// whether future versions of QUIC will still have Initial packets. /// The SpinBit of the short packet header should be part of the short packet header, but for /// simplicity, the SpinBit is also part of the 1RTT header type. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Type { Long(long::Type), Short(short::OneRtt), } impl Type { #[inline] pub fn encoding_size(&self) -> usize { match self { Type::Short(_) => 1, Type::Long(_) => 5, } } } /// The io module provides the functions to parse and write the packet type. pub mod io { use bytes::BufMut; use super::{long::io::WriteLongType, short::WriteShortType, *}; /// Parse the packet type from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_packet_type(input: &[u8]) -> nom::IResult<&[u8], Type, Error> { let (remain, ty) = nom::number::streaming::be_u8(input)?; if ty & HEADER_FORM_MASK == 0 { Ok((remain, Type::Short(short::OneRtt::from(ty)))) } else { let (remain, ty) = long::io::parse_long_type(ty)(remain)?; Ok((remain, Type::Long(ty))) } } /// A [`bytes::BufMut`] extension trait, makes buffer more friendly to write packet type. pub trait WritePacketType: BufMut { /// Write the packet type to the buffer. fn put_packet_type(&mut self, ty: &Type); } impl WritePacketType for B { fn put_packet_type(&mut self, ty: &Type) { match ty { Type::Short(one_rtt) => self.put_short_type(one_rtt), Type::Long(long_type) => self.put_long_type(long_type), } } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_long_clear_bits() { let specific_bits = SpecificBits::<0x0C>(0x0C); assert_eq!( specific_bits.pn_len(), Err(Error::InvalidReservedBits(0x0C, 0x0C)) ); let specific_bits = SpecificBits::<0x0C>(0x04); assert_eq!( specific_bits.pn_len(), Err(Error::InvalidReservedBits(0x04, 0x0C)) ); let specific_bits = SpecificBits::<0x0C>(0x08); assert_eq!( specific_bits.pn_len(), Err(Error::InvalidReservedBits(0x08, 0x0C)) ); let specific_bits = LongSpecificBits::with_pn_len(4); assert_eq!(specific_bits.pn_len().unwrap(), 4); let specific_bits = LongSpecificBits::with_pn_len(3); assert_eq!(specific_bits.pn_len().unwrap(), 3); let specific_bits = LongSpecificBits::with_pn_len(2); assert_eq!(specific_bits.pn_len().unwrap(), 2); let specific_bits = LongSpecificBits::with_pn_len(1); assert_eq!(specific_bits.pn_len().unwrap(), 1); } #[test] fn test_short_specific_bits() { let specific_bits = SpecificBits::<0x18>(0x18); assert_eq!( specific_bits.pn_len(), Err(Error::InvalidReservedBits(0x18, 0x18)) ); let specific_bits = SpecificBits::<0x18>(0x11); assert_eq!( specific_bits.pn_len(), Err(Error::InvalidReservedBits(0x10, 0x18)) ); let specific_bits = SpecificBits::<0x18>(0x0A); assert_eq!( specific_bits.pn_len(), Err(Error::InvalidReservedBits(0x08, 0x18)) ); let specific_bits = ShortSpecificBits::with_pn_len(4); assert_eq!(specific_bits.pn_len().unwrap(), 4); let specific_bits = ShortSpecificBits::with_pn_len(3); assert_eq!(specific_bits.pn_len().unwrap(), 3); let specific_bits = ShortSpecificBits::with_pn_len(2); assert_eq!(specific_bits.pn_len().unwrap(), 2); let specific_bits = ShortSpecificBits::with_pn_len(1); assert_eq!(specific_bits.pn_len().unwrap(), 1); } #[test] fn test_set_key_phase_bit() { let mut specific_bits = ShortSpecificBits::with_pn_len(4); assert_eq!(specific_bits.0, 0x03); specific_bits.set_key_phase(KeyPhaseBit::One); assert_eq!(specific_bits.0, 0x07); assert_eq!(specific_bits.key_phase(), KeyPhaseBit::One); specific_bits.set_key_phase(KeyPhaseBit::Zero); assert_eq!(specific_bits.0, 0x03); assert_eq!(specific_bits.key_phase(), KeyPhaseBit::Zero); } } ================================================ FILE: qbase/src/packet.rs ================================================ use std::{fmt::Debug, ops}; use bytes::{BufMut, BytesMut, buf::UninitSlice}; use derive_more::{Deref, DerefMut}; use enum_dispatch::enum_dispatch; use getset::CopyGetters; use header::{LongHeader, io::WriteHeader}; use crate::{ cid::ConnectionId, frame::{ContainSpec, FrameFeature, FrameType, Spec}, packet::keys::DirectionalKeys, }; /// QUIC packet parse error definitions. pub mod error; /// Define signal util, such as key phase bit and spin bit. pub mod signal; #[doc(hidden)] pub use signal::{KeyPhaseBit, SpinBit}; /// Definitions of QUIC packet types. pub mod r#type; #[doc(hidden)] pub use r#type::{ GetPacketNumberLength, LONG_RESERVED_MASK, LongSpecificBits, SHORT_RESERVED_MASK, ShortSpecificBits, Type, }; /// Definitions of QUIC packet headers. pub mod header; #[doc(hidden)] pub use header::{ EncodeHeader, GetDcid, GetScid, GetType, HandshakeHeader, Header, InitialHeader, LongHeaderBuilder, OneRttHeader, RetryHeader, VersionNegotiationHeader, ZeroRttHeader, long, }; /// The io module provides the functions to parse the QUIC packet. /// /// The writing of the QUIC packet is not provided here, they are written in place. pub mod io; pub use io::{ AssemblePacket, Package, PacketInfo, PacketSpace, PacketWriter, ProductHeader, RecordFrame, }; /// Encoding and decoding of packet number pub mod number; #[doc(hidden)] pub use number::{InvalidPacketNumber, PacketNumber, WritePacketNumber, take_pn_len}; /// Include operations such as decrypting QUIC packets, removing header protection, /// and parsing the first byte of the packet to get the right packet numbers pub mod decrypt; /// Include operations such as encrypting QUIC packets, adding header protection, /// and encoding the first byte of the packet with pn_len and key_phase optionally. pub mod encrypt; /// Encapsulate the crypto keys's logic for long headers and 1-RTT headers. pub mod keys; /// The sum type of all QUIC packet headers. #[derive(Debug, Clone)] #[enum_dispatch(GetDcid, GetType)] pub enum DataHeader { Long(long::DataHeader), Short(OneRttHeader), } /// The sum type of all QUIC data packets. /// /// The long header has the len field, the short header does not have the len field. /// Remember, the len field is not an attribute of the header, but a attribute of the packet. /// /// ```text /// +---> payload length in long packet /// | |<----------- payload --------->| /// +-----------+---+--------+------+-----+-----------+---......--+-------+ /// |X|1|X X 0 0|0 0| ...hdr | len(0..16) | pn(8..32) | body... | tag | /// +---+-------+-+-+--------+------------+-----+-----+---......--+-------+ /// | | /// +---> encoded pn length +---> encoded packet number /// ``` #[derive(Debug, Clone, Deref, DerefMut)] pub struct DataPacket { #[deref] #[deref_mut] pub header: DataHeader, pub bytes: BytesMut, // payload_offset pub offset: usize, } impl GetType for DataPacket { fn get_type(&self) -> Type { self.header.get_type() } } #[derive(Default, Debug, Clone, Copy, PartialEq)] pub enum PacketContent { #[default] NonAckEliciting, JustPing, EffectivePayload, } impl PacketContent { pub fn is_ack_eliciting(self) -> bool { self != Self::NonAckEliciting } } impl From for PacketContent { fn from(frame_type: FrameType) -> Self { match frame_type { FrameType::Ping => Self::JustPing, fty if !fty.specs().contain(Spec::NonAckEliciting) => Self::EffectivePayload, _ => Self::NonAckEliciting, } } } impl ops::AddAssign for PacketContent { fn add_assign(&mut self, rhs: FrameType) { match rhs { FrameType::Ping if *self == PacketContent::NonAckEliciting => *self = Self::JustPing, fty if !fty.specs().contain(Spec::NonAckEliciting) => *self = Self::EffectivePayload, _ => (), } } } impl ops::AddAssign for PacketContent { fn add_assign(&mut self, rhs: Self) { match rhs { PacketContent::EffectivePayload => *self = PacketContent::EffectivePayload, PacketContent::JustPing if *self == PacketContent::NonAckEliciting => { *self = PacketContent::JustPing } _ => {} } } } /// The sum type of all QUIC packets. #[derive(Debug, Clone)] pub enum Packet { VN(VersionNegotiationHeader), Retry(RetryHeader), // Data(header, bytes, payload_offset) Data(DataPacket), } /// QUIC packet reader, reading packets from the incoming datagrams. /// /// The parsing here does not involve removing header protection or decrypting the packet. /// It only parses information such as packet type and connection ID, /// and prepares for further delivery to the connection by finding the connection ID. /// /// The received packet is a BytesMut, in order to be decrypted in future, and make as few /// copies cheaply until it is read by the application layer. #[derive(Debug)] pub struct PacketReader { raw_bytes: BytesMut, dcid_len: usize, // TODO: 添加level,各种包类型顺序不能错乱,否则失败 } impl PacketReader { pub fn new(raw_bytes: BytesMut, dcid_len: usize) -> Self { Self { raw_bytes, dcid_len, } } } impl Iterator for PacketReader { type Item = Result; fn next(&mut self) -> Option { if self.raw_bytes.is_empty() { return None; } match io::be_packet(&mut self.raw_bytes, self.dcid_len) { Ok(packet) => Some(Ok(packet)), Err(error) => { tracing::debug!(target: "quic", ?error, "dropped unparsed packet"); self.raw_bytes.clear(); // no longer parsing Some(Err(error)) } } } } ================================================ FILE: qbase/src/param/core.rs ================================================ use std::{collections::HashMap, marker::PhantomData, time::Duration}; use bytes::Bytes; use derive_more::{From, TryInto, TryIntoError}; use super::{error::Error, preferred_address::PreferredAddress}; use crate::{ cid::ConnectionId, role::*, token::ResetToken, varint::{VARINT_MAX, VarInt}, }; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ParameterValueType { VarInt, Boolean, Bytes, Duration, ResetToken, ConnectionId, PreferredAddress, } #[derive(Debug, Clone, PartialEq, From)] pub enum ParameterValue { Bytes(Bytes), True, VarInt(VarInt), Duration(Duration), ConnectionId(ConnectionId), ResetToken(ResetToken), PreferredAddress(PreferredAddress), } impl ParameterValue { pub fn value_type(&self) -> ParameterValueType { match self { ParameterValue::VarInt(_) => ParameterValueType::VarInt, ParameterValue::True => ParameterValueType::Boolean, ParameterValue::Bytes(_) => ParameterValueType::Bytes, ParameterValue::Duration(_) => ParameterValueType::Duration, ParameterValue::ConnectionId(_) => ParameterValueType::ConnectionId, ParameterValue::ResetToken(_) => ParameterValueType::ResetToken, ParameterValue::PreferredAddress(_) => ParameterValueType::PreferredAddress, } } } impl From for ParameterValue { fn from(value: u32) -> Self { ParameterValue::VarInt(VarInt::from_u32(value)) } } impl From for ParameterValue { fn from(value: String) -> Self { ParameterValue::Bytes(Bytes::from(Vec::from(value))) } } impl TryFrom for Duration { type Error = TryIntoError; #[inline] fn try_from(value: ParameterValue) -> Result> { match value { ParameterValue::Duration(v) => Ok(v), _ => Err(TryIntoError::new(value, "Duration", "Duration")), } } } impl TryFrom for ConnectionId { type Error = TryIntoError; #[inline] fn try_from(value: ParameterValue) -> Result> { match value { ParameterValue::ConnectionId(v) => Ok(v), _ => Err(TryIntoError::new(value, "ConnectionId", "ConnectionId")), } } } impl TryFrom for VarInt { type Error = TryIntoError; #[inline] fn try_from(value: ParameterValue) -> Result> { match value { ParameterValue::VarInt(v) => Ok(v), _ => Err(TryIntoError::new(value, "VarInt", "VarInt")), } } } impl TryFrom for u64 { type Error = >::Error; #[inline] fn try_from(value: ParameterValue) -> Result { VarInt::try_from(value).map(|value| value.into_u64()) } } impl TryFrom for PreferredAddress { type Error = TryIntoError; #[inline] fn try_from(value: ParameterValue) -> Result> { match value { ParameterValue::PreferredAddress(v) => Ok(v), _ => Err(TryIntoError::new( value, "PreferredAddress", "PreferredAddress", )), } } } impl TryFrom for Bytes { type Error = TryIntoError; #[inline] fn try_from(value: ParameterValue) -> Result> { match value { ParameterValue::Bytes(v) => Ok(v), _ => Err(TryIntoError::new(value, "Bytes", "Bytes")), } } } impl TryFrom for bool { type Error = TryIntoError; #[inline] fn try_from(value: ParameterValue) -> Result { match value { ParameterValue::True => Ok(true), _ => Err(TryIntoError::new(value, "Enabled", "bool")), } } } impl TryFrom for ResetToken { type Error = TryIntoError; #[inline] fn try_from(value: ParameterValue) -> Result> { match value { ParameterValue::ResetToken(v) => Ok(v), _ => Err(TryIntoError::new(value, "ResetToken", "ResetToken")), } } } impl TryFrom for String { type Error = >::Error; #[inline] fn try_from(value: ParameterValue) -> Result { Bytes::try_from(value).map(|bytes| String::from_utf8_lossy(&bytes).into_owned()) } } #[repr(u64)] // qmacro::TransportParameter #[derive(qmacro::ParameterId, Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum ParameterId { #[param(value_type = ConnectionId)] OriginalDestinationConnectionId = 0x0000, #[param(value_type = Duration, default = Duration::ZERO)] MaxIdleTimeout = 0x0001, #[param(value_type = ResetToken)] StatelessResetToken = 0x0002, #[param(value_type = VarInt, default = 65527u32, bound = 1200..=65527)] MaxUdpPayloadSize = 0x0003, #[param(value_type = VarInt, default = 0u32)] InitialMaxData = 0x0004, #[param(value_type = VarInt, default = 0u32)] InitialMaxStreamDataBidiLocal = 0x0005, #[param(value_type = VarInt, default = 0u32)] InitialMaxStreamDataBidiRemote = 0x0006, #[param(value_type = VarInt, default = 0u32)] InitialMaxStreamDataUni = 0x0007, #[param(value_type = VarInt, default = 0u32)] InitialMaxStreamsBidi = 0x0008, #[param(value_type = VarInt, default = 0u32)] InitialMaxStreamsUni = 0x0009, #[param(value_type = VarInt, default = 3u32, bound = 0..=20)] AckDelayExponent = 0x000a, #[param(value_type = Duration, default = Duration::from_millis(25))] MaxAckDelay = 0x000b, #[param(value_type = Boolean)] DisableActiveMigration = 0x000c, #[param(value_type = PreferredAddress)] PreferredAddress = 0x000d, #[param(value_type = VarInt, default = 2u32, bound = 2..=VARINT_MAX)] ActiveConnectionIdLimit = 0x000e, #[param(value_type = ConnectionId)] InitialSourceConnectionId = 0x000f, #[param(value_type = ConnectionId)] RetrySourceConnectionId = 0x0010, #[param(value_type = VarInt, default = 0u32)] MaxDatagramFrameSize = 0x0020, #[param(value_type = Boolean)] GreaseQuicBit = 0x2ab2, /// Genemta extension parameter. #[param(value_type = Bytes, default = 0u32)] ClientName = 0xffee, } impl ParameterId { pub fn belong_to(self, role: Role) -> Result<(), Error> { match self { ParameterId::OriginalDestinationConnectionId | ParameterId::StatelessResetToken | ParameterId::PreferredAddress | ParameterId::RetrySourceConnectionId if role != Role::Server => { Err(Error::InvalidParameterId(self, role)) } ParameterId::ClientName if role != Role::Client => { Err(Error::InvalidParameterId(self, role)) } _ => Ok(()), } } } impl std::fmt::LowerHex for ParameterId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:x}", VarInt::from(*self).into_u64()) } } #[derive(Default, Debug, Clone, PartialEq)] pub struct Parameters { pub(super) map: HashMap, _role: PhantomData, } impl Parameters { pub fn get(&self, id: ParameterId) -> Option where V: TryFrom, { (self.map.get(&id).cloned().or_else(|| id.default_value())) .and_then(|value| value.try_into().ok()) } pub fn contains(&self, id: ParameterId) -> bool { self.map.contains_key(&id) } pub fn is_empty(&self) -> bool { self.map.is_empty() } } impl Parameters { pub fn new() -> Self { Self::default() } pub fn set(&mut self, id: ParameterId, value: impl Into) -> Result<(), Error> { let role: Role = R::into_role(); id.belong_to(role)?; let value = value.into(); id.validate(&value)?; self.map.insert(id, value); Ok(()) } } pub type ClientParameters = Parameters; pub type ServerParameters = Parameters; impl ServerParameters { #[inline] pub fn is_0rtt_accepted(&self, server_params: &ServerParameters) -> bool { [ ParameterId::InitialMaxData, ParameterId::InitialMaxStreamDataBidiLocal, ParameterId::InitialMaxStreamDataBidiRemote, ParameterId::InitialMaxStreamDataUni, ParameterId::InitialMaxStreamsBidi, ParameterId::InitialMaxStreamsUni, ParameterId::ActiveConnectionIdLimit, ParameterId::MaxDatagramFrameSize, ] .into_iter() .all( |id| match (self.get::(id), server_params.get::(id)) { (Some(old_value), Some(new_value)) => old_value <= new_value, _ => unreachable!("Expected VarInt values for 0-RTT acceptance check"), }, ) } } #[derive(Debug, Clone, PartialEq, From, TryInto)] pub enum PeerParameters { Client(ClientParameters), Server(ServerParameters), } ================================================ FILE: qbase/src/param/error.rs ================================================ use std::ops::RangeInclusive; use nom::error::ErrorKind as NomErrorKind; use thiserror::Error; use crate::{ error::{ErrorKind as QuicErrorKind, QuicError}, frame::FrameType, param::{ParameterId, ParameterValueType}, role::Role, varint::VarInt, }; /// Error for QUIC parameters. #[derive(Debug, PartialEq, Eq, Error)] pub enum Error { #[error("Incomplete parameter id: {0}")] IncompleteParameterId(String), #[error("Parameter {0} is not defined")] UnknownParameterId(VarInt), #[error("Lack {1:?} for {0}")] LackParameterId(Role, ParameterId), #[error("{0:?} is not belong to {1}")] InvalidParameterId(ParameterId, Role), #[error("Incomplete value for {0:?}: {1}")] IncompleteValue(ParameterId, String), #[error("{0:?} is not supported for {1:?}")] InvalidValueType(ParameterId, ParameterValueType), #[error("{0:?}'s value {1} is out of bounds {2:?}")] OutOfBounds(ParameterId, u64, RangeInclusive), } impl From for QuicError { fn from(e: Error) -> Self { Self::new( QuicErrorKind::TransportParameter, FrameType::Crypto.into(), e.to_string(), ) } } impl nom::error::ParseError<&[u8]> for Error { fn from_error_kind(_input: &[u8], _kind: NomErrorKind) -> Self { unreachable!("QUIC parameter parser must always consume") } fn append(_input: &[u8], _kind: NomErrorKind, source: Self) -> Self { source } } ================================================ FILE: qbase/src/param/handy.rs ================================================ use std::time::Duration; use crate::param::ParameterId; pub fn client_parameters() -> super::ClientParameters { let mut params = super::ClientParameters::default(); for (id, value) in [ (ParameterId::InitialMaxStreamsBidi, 100u32), (ParameterId::InitialMaxStreamsUni, 100u32), (ParameterId::InitialMaxData, 1u32 << 20), (ParameterId::InitialMaxStreamDataBidiLocal, 1u32 << 20), (ParameterId::InitialMaxStreamDataBidiRemote, 1u32 << 20), (ParameterId::InitialMaxStreamDataUni, 1u32 << 20), (ParameterId::ActiveConnectionIdLimit, 10u32), ] { params.set(id, value).expect("unreachable"); } params .set(ParameterId::MaxIdleTimeout, Duration::from_secs(20)) .expect("unreachable"); params } pub fn server_parameters() -> super::ServerParameters { let mut params = super::ServerParameters::default(); for (id, value) in [ (ParameterId::InitialMaxStreamsBidi, 100u32), (ParameterId::InitialMaxStreamsUni, 100u32), (ParameterId::InitialMaxData, 1u32 << 20), (ParameterId::InitialMaxStreamDataBidiLocal, 1u32 << 20), (ParameterId::InitialMaxStreamDataBidiRemote, 1u32 << 20), (ParameterId::InitialMaxStreamDataUni, 1u32 << 20), (ParameterId::ActiveConnectionIdLimit, 10u32), ] { params.set(id, value).expect("unreachable"); } params .set(ParameterId::MaxIdleTimeout, Duration::from_secs(30)) .expect("unreachable"); params } ================================================ FILE: qbase/src/param/io.rs ================================================ use std::{fmt::Debug, time::Duration}; use bytes::Bytes; use nom::{Parser, multi::length_data}; use crate::{ cid::{ConnectionId, WriteConnectionId}, error::QuicError, param::{ core::{ParameterId, ParameterValue, ParameterValueType, Parameters, ServerParameters}, error::Error, preferred_address::{PreferredAddress, WirtePreferredAddress, be_preferred_address}, }, role::{IntoRole, RequiredParameters, Role}, token::{ResetToken, WriteResetToken, be_reset_token}, varint::{VarInt, WriteVarInt, be_varint}, }; /// A [`bytes::BufMut`] extension trait, makes buffer more friendly /// to write the parameter id. pub trait WriteParameterId: bytes::BufMut { /// Write the parameter id to the buffer. fn put_parameter_id(&mut self, param_id: ParameterId); } impl WriteParameterId for T { fn put_parameter_id(&mut self, param_id: ParameterId) { self.put_varint(&VarInt::from(param_id)); } } pub fn be_raw_parameter(input: &[u8]) -> nom::IResult<&[u8], (VarInt, &[u8])> { let (remain, param_id) = crate::varint::be_varint(input)?; let (remain, data) = length_data(be_varint).parse(remain)?; Ok((remain, (param_id, data))) } pub fn be_parameter_value(input: &[u8], id: ParameterId) -> nom::IResult<&[u8], ParameterValue> { use nom::combinator::map; match id.value_type() { ParameterValueType::VarInt => map(be_varint, ParameterValue::VarInt).parse(input), ParameterValueType::Boolean => Ok((input, ParameterValue::True)), ParameterValueType::Bytes => { Ok((&[], ParameterValue::Bytes(Bytes::copy_from_slice(input)))) } ParameterValueType::Duration => { map(be_varint, |v| Duration::from_millis(v.into_u64()).into()).parse(input) } ParameterValueType::ResetToken => { map(be_reset_token, ParameterValue::ResetToken).parse(input) } ParameterValueType::ConnectionId => Ok(( &[], ParameterValue::ConnectionId(ConnectionId::from_slice(input)), )), ParameterValueType::PreferredAddress => { map(be_preferred_address, ParameterValue::PreferredAddress).parse(input) } } } // A trait for writing parameters to the buffer. pub trait WriteParameter { fn put_bytes_parameter(&mut self, id: ParameterId, bytes: &Bytes); fn put_cid_parameter(&mut self, id: ParameterId, cid: &ConnectionId); fn put_duration_parameter(&mut self, id: ParameterId, dur: &Duration) { let value = VarInt::from_u128(dur.as_millis()).expect("Duration too large"); self.put_varint_parameter(id, &value); } fn put_bool_parameter(&mut self, id: ParameterId); fn put_preferred_address_parameter(&mut self, id: ParameterId, addr: &PreferredAddress); fn put_reset_token_parameter(&mut self, id: ParameterId, token: &ResetToken); fn put_varint_parameter(&mut self, id: ParameterId, value: &VarInt); fn put_parameter(&mut self, id: ParameterId, value: &ParameterValue) { match value { ParameterValue::Bytes(bytes) => self.put_bytes_parameter(id, bytes), ParameterValue::ConnectionId(cid) => self.put_cid_parameter(id, cid), ParameterValue::Duration(dur) => self.put_duration_parameter(id, dur), ParameterValue::True => self.put_bool_parameter(id), ParameterValue::PreferredAddress(addr) => { self.put_preferred_address_parameter(id, addr) } ParameterValue::ResetToken(token) => self.put_reset_token_parameter(id, token), ParameterValue::VarInt(varint) => self.put_varint_parameter(id, varint), } } } /// A [`bytes::BufMut`] extension trait, makes buffer more friendly /// to write parameters. impl WriteParameter for T { fn put_bytes_parameter(&mut self, id: ParameterId, bytes: &Bytes) { self.put_parameter_id(id); self.put_varint(&VarInt::try_from(bytes.len()).expect("param too large")); self.put_slice(bytes); } fn put_cid_parameter(&mut self, id: ParameterId, cid: &ConnectionId) { self.put_parameter_id(id); self.put_connection_id(cid); } fn put_bool_parameter(&mut self, id: ParameterId) { self.put_parameter_id(id); self.put_varint(&VarInt::from_u32(0)); } fn put_preferred_address_parameter(&mut self, id: ParameterId, addr: &PreferredAddress) { self.put_parameter_id(id); self.put_varint(&VarInt::try_from(addr.encoding_size()).expect("param too large")); self.put_preferred_address(addr); } fn put_reset_token_parameter(&mut self, id: ParameterId, token: &ResetToken) { self.put_parameter_id(id); self.put_varint(&VarInt::try_from(token.encoding_size()).expect("param too large")); self.put_reset_token(token); } fn put_varint_parameter(&mut self, id: ParameterId, value: &VarInt) { self.put_parameter_id(id); self.put_varint(&VarInt::try_from(value.encoding_size()).expect("param too large")); self.put_varint(value); } } pub trait WriteParameters { fn put_parameters(&mut self, params: &Parameters); } impl WriteParameters for T { fn put_parameters(&mut self, params: &Parameters) { for (id, value) in ¶ms.map { self.put_parameter(*id, value); } } } fn handle_nom_error(input: &[u8], nom_error: nom::Err) -> Error { assert!( matches!(nom_error, nom::Err::Incomplete(..)), "Only incomplete errors should occur, but {nom_error:?} happened for input: {input:?}" ); Error::IncompleteParameterId(format!("incomplete parameter data for input: {input:?}")) } impl Parameters { pub fn parse_from_bytes(mut buf: &[u8]) -> Result { let mut parameters = Self::default(); while !buf.is_empty() { let (param_id, param_value); (buf, (param_id, param_value)) = be_raw_parameter(buf).map_err(|nom_error| handle_nom_error(buf, nom_error))?; let param_id = match ParameterId::try_from(param_id) { Ok(param_id) => param_id, Err(unknown @ Error::UnknownParameterId(..)) => { tracing::warn!(target: "quic", "{unknown}, ignore"); continue; // Ignore unknown parameters } Err(e) => return Err(e.into()), }; ParameterId::belong_to(param_id, R::into_role())?; let (remain, param_value) = be_parameter_value(param_value, param_id) .map_err(|nom_error| handle_nom_error(param_value, nom_error))?; assert!(remain.is_empty(), "Parameter value should consume all data"); parameters.set(param_id, param_value)?; } for id in R::required_parameters() { if !parameters.contains(id) { return Err(Error::LackParameterId(R::into_role(), id).into()); } } Ok(parameters) } } impl ServerParameters { pub fn try_from_remembered_bytes(mut buf: &[u8]) -> Result { let mut parameters = Self::new(); while !buf.is_empty() { let (param_id, param_value); (buf, (param_id, param_value)) = be_raw_parameter(buf).map_err(|nom_error| handle_nom_error(buf, nom_error))?; let param_id = match ParameterId::try_from(param_id) { Ok(param_id) => param_id, Err(unknown @ Error::UnknownParameterId(..)) => { tracing::warn!(target: "quic", "{unknown}, ignore"); continue; // Ignore unknown parameters } Err(e) => return Err(e.into()), }; ParameterId::belong_to(param_id, Role::Server)?; let (remain, param_value) = be_parameter_value(param_value, param_id) .map_err(|nom_error| handle_nom_error(param_value, nom_error))?; assert!(remain.is_empty(), "Parameter value should consume all data"); parameters.set(param_id, param_value)?; } Ok(parameters) } } ================================================ FILE: qbase/src/param/preferred_address.rs ================================================ use std::net::{SocketAddrV4, SocketAddrV6}; use getset::{CopyGetters, MutGetters, Setters}; use nom::Parser; use crate::{ cid::{ConnectionId, WriteConnectionId, be_connection_id}, token::{ResetToken, WriteResetToken, be_reset_token}, }; /// The server's preferred address, which is used to effect /// a change in server address at the end of the handshake. /// /// See [section-18.2-4.31](https://datatracker.ietf.org/doc/html/rfc9000#section-18.2-4.32) /// and [figure-22](https://datatracker.ietf.org/doc/html/rfc9000#figure-22) /// for more details. #[derive(CopyGetters, Setters, MutGetters, Debug, PartialEq, Clone, Copy)] pub struct PreferredAddress { #[getset(get_copy = "pub", set = "pub")] address_v4: SocketAddrV4, #[getset(get_copy = "pub", set = "pub")] address_v6: SocketAddrV6, #[getset(get_copy = "pub", set = "pub")] connection_id: ConnectionId, #[getset(get_copy = "pub", set = "pub")] stateless_reset_token: ResetToken, } impl PreferredAddress { /// Create a new preferred address. pub fn new( address_v4: SocketAddrV4, address_v6: SocketAddrV6, connection_id: ConnectionId, stateless_reset_token: ResetToken, ) -> Self { Self { address_v4, address_v6, connection_id, stateless_reset_token, } } /// Returns the encoding size of the preferred address. pub fn encoding_size(&self) -> usize { 6 + 18 + self.connection_id.encoding_size() + self.stateless_reset_token.encoding_size() } } /// Parse the preferred address from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. pub fn be_preferred_address(input: &[u8]) -> nom::IResult<&[u8], PreferredAddress> { use nom::{bytes::streaming::take, combinator::map}; let (input, address_v4) = map(take(6usize), |buf: &[u8]| { let mut addr = [0u8; 4]; addr.copy_from_slice(&buf[..4]); let port = u16::from_be_bytes([buf[4], buf[5]]); SocketAddrV4::new(addr.into(), port) }) .parse(input)?; let (input, address_v6) = map(take(18usize), |buf: &[u8]| { let mut addr = [0u8; 16]; addr.copy_from_slice(&buf[..16]); let port = u16::from_be_bytes([buf[16], buf[17]]); SocketAddrV6::new(addr.into(), port, 0, 0) }) .parse(input)?; let (input, connection_id) = be_connection_id(input)?; let (input, stateless_reset_token) = be_reset_token(input)?; Ok(( input, PreferredAddress { address_v4, address_v6, connection_id, stateless_reset_token, }, )) } /// A [`bytes::BufMut`] extension trait, makes buffer more friendly /// to write the preferred address. pub trait WirtePreferredAddress: bytes::BufMut { /// Write the preferred address to the buffer. fn put_preferred_address(&mut self, addr: &PreferredAddress); } impl WirtePreferredAddress for T { fn put_preferred_address(&mut self, addr: &PreferredAddress) { self.put_slice(&addr.address_v4.ip().octets()); self.put_u16(addr.address_v4.port()); self.put_slice(&addr.address_v6.ip().octets()); self.put_u16(addr.address_v6.port()); self.put_connection_id(&addr.connection_id); self.put_reset_token(&addr.stateless_reset_token); } } ================================================ FILE: qbase/src/param.rs ================================================ use std::{ fmt::Debug, ops::{Deref, DerefMut}, sync::{Arc, Mutex, MutexGuard}, task::{Context, Poll, Waker}, time::Duration, }; use crate::{ cid::ConnectionId, error::{Error, ErrorKind, QuicError}, frame::FrameType, role::Role, }; pub mod core; pub mod error; pub mod handy; pub mod io; pub mod preferred_address; pub use self::{ core::{ ClientParameters, ParameterId, ParameterValue, ParameterValueType, PeerParameters, ServerParameters, }, io::*, }; /// Requires that the connection IDs in the transport parameters of /// the received Initial packet must match those used during the /// connection establishment process. /// /// For the Initial packet received by the server from the client, /// the initial_source_connection_id in the client's Transport /// parameters must match the source connection id in that Initial packet. /// For the Initial packet received by the client from the server, /// not only must the server's Transport parameter /// initial_source_connection_id match the source connection id /// in that Initial packet, /// but also requires that the original_destination_connection_id matches the /// destination connection id in the first packet sent by the client. /// Specifically, if the server has responded with a Retry packet, /// then the server's Transport parameter retry_source_connection_id /// must match the source connection id in that Retry packet. /// /// See [Authenticating Connection IDs](https://datatracker.ietf.org/doc/html/rfc9000#name-authenticating-connection-i) /// of [RFC9000](https://datatracker.ietf.org/doc/html/rfc9000) /// for more details. /// /// Whether client or server, after receiving the Initial packet from /// the peer, these requirements must be set; /// then after parsing the peer's Transport parameters, verify that /// all these requirements are met. /// If not met, it is considered a TransportParameters error. #[derive(Debug, Clone, Copy)] enum Requirements { Client { initial_scid: Option, retry_scid: Option, origin_dcid: ConnectionId, }, Server { initial_scid: Option, }, } /// Transport parameters for QUIC. /// The transport parameters are used to negotiate the initial /// settings of a QUIC connection. /// /// They are exchanged in the Initial packets of the handshake, /// including client and server transport parameters. /// Client transport parameters and server transport parameters /// exist independently and are not merged. /// They each constrain the behavior of the remote peer. /// /// For different roles, local transport parameters and remote /// transport parameters differ. /// For example, as a client, the local transport parameters /// are client parameters, while remote transport parameters /// are server parameters. The same applies to the server. /// /// Note that client transport parameters and server transport /// parameters are different, as some transport parameters can /// only appear in server transport parameters. /// Therefore, for a QUIC connection, the transport parameter /// sets for both ends are defined as follows. #[derive(Debug)] pub struct Parameters { state: u8, client: Arc, server: Arc, remembered: Option>, requirements: Requirements, wakers: Vec, } impl Drop for Parameters { fn drop(&mut self) { self.wake_all(); } } impl Parameters { const CLIENT_READY: u8 = 1; const SERVER_READY: u8 = 2; /// Creates a new client transport parameters, with the client /// parameters and remembered server parameters if exist. /// /// It will wait for the server transport parameters to be /// received and parsed. pub fn new_client( client: ClientParameters, remembered: Option, origin_dcid: ConnectionId, ) -> Self { Self { state: Self::CLIENT_READY, client: Arc::new(client), server: Arc::default(), remembered: remembered.map(Arc::new), requirements: Requirements::Client { origin_dcid, initial_scid: None, retry_scid: None, }, wakers: Vec::with_capacity(2), } } /// Creates a new server transport parameters, with the server /// parameters. /// /// It will wait for the client transport parameters to be /// received and parsed. pub fn new_server(server: ServerParameters) -> Self { Self { state: Self::SERVER_READY, client: Arc::default(), server: Arc::new(server), remembered: None, requirements: Requirements::Server { initial_scid: None }, wakers: Vec::with_capacity(2), } } pub fn role(&self) -> Role { match self.requirements { Requirements::Client { .. } => Role::Client, Requirements::Server { .. } => Role::Server, } } pub fn client(&self) -> Option<&Arc> { if self.state & Self::CLIENT_READY != 0 { Some(&self.client) } else { None } } pub fn server(&self) -> Option<&Arc> { if self.state & Self::SERVER_READY != 0 { Some(&self.server) } else { None } } /// Returns the remembered server transport parameters if exist, /// which means the client connected the server, and stored the /// server transport parameters. /// /// It is meaningful only for the client, to send early data /// with 0Rtt packets before receving the server transport params. pub fn remembered(&self) -> Option<&Arc> { self.remembered.as_ref() } pub fn get_local>(&self, id: ParameterId) -> Option { match self.role() { Role::Client => self.client()?.get(id), Role::Server => self.server()?.get(id), } } pub fn get_remote>(&self, id: ParameterId) -> Option { match self.role() { Role::Client => self.server()?.get(id), Role::Server => self.client()?.get(id), } } // fn set_retry_scid(&mut self, cid: ConnectionId) { // assert_eq!(self.role(), Role::Server); // self.server.set_retry_source_connection_id(cid); // } pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> { if self.state == Self::CLIENT_READY | Self::SERVER_READY { Poll::Ready(()) } else { self.wakers.push(cx.waker().clone()); Poll::Pending } } pub fn is_remote_params_received(&self) -> bool { match self.role() { Role::Client => !self.server.is_empty(), Role::Server => !self.client.is_empty(), } } /// Returns true if the remote transport parameters have been received and authed. /// /// It is usually used to avoid processing remote transport parameters /// more than once. pub fn is_remote_params_ready(&self) -> bool { self.state == Self::CLIENT_READY | Self::SERVER_READY } /// Being called when the remote transport parameters are received. /// It will parse and check the remote transport parameters, /// and wake all the wakers waiting for the remote transport parameters /// if the remote transport parameters are valid. pub fn recv_remote_params( &mut self, params: impl Into, ) -> Result<(), QuicError> { match params.into() { PeerParameters::Client(p) => { assert_eq!(self.role(), Role::Server); assert!(self.client.is_empty()); self.client = Arc::new(p); } PeerParameters::Server(p) => { assert_eq!(self.role(), Role::Client); assert!(self.server.is_empty()); self.server = Arc::new(p); } } // Because TLS and packet parsing are in parallel, // the scid of the peer end may not be set when the transmission parameters of the peer are obtained. // Therefore, if the scid of the other end is not set, authentication will not be performed first, // and authentication will be performed when it is set. if self.authenticate_cids()? { self.state = Self::CLIENT_READY | Self::SERVER_READY; self.remembered.take(); self.wake_all(); return Ok(()); } Ok(()) } fn wake_all(&mut self) { for waker in self.wakers.drain(..) { waker.wake(); } } /// No matter the client or server, after receiving the Initial /// packet from the peer, the initial_source_connection_id in /// the remote transport parameters must equal the source connection /// id in the received Initial packet. /// /// If the peer's transmission parameters have not been verified, /// it will be verified here. If verification fails, this method will /// return Err. pub fn initial_scid_from_peer_need_equal( &mut self, cid: ConnectionId, ) -> Result<(), QuicError> { let initial_scid = match &mut self.requirements { Requirements::Client { initial_scid, .. } => initial_scid, Requirements::Server { initial_scid } => initial_scid, }; assert!(initial_scid.replace(cid).is_none()); // Because the TLS handshak and packet parsing are in parallel, // the scid of the peer end may not be set when the transmission parameters of the peer are obtained. // Therefore, if the scid of the other end is not set, authentication will not be performed first, // and authentication will be performed when it is set. if self.is_remote_params_received() && self.authenticate_cids()? { self.state = Self::CLIENT_READY | Self::SERVER_READY; self.remembered.take(); self.wake_all(); return Ok(()); } Ok(()) } /// After receiving the Retry packet from the server, the /// retry_source_connection_id in the server transport parameters /// must equal the source connection id in the Retry packet. pub fn retry_scid_from_server_need_equal(&mut self, cid: ConnectionId) { match &mut self.requirements { Requirements::Client { retry_scid, .. } => *retry_scid = Some(cid), Requirements::Server { .. } => panic!("server shuold never call this"), } } pub fn initial_scid_from_peer(&self) -> Option { match self.requirements { Requirements::Client { initial_scid, .. } => initial_scid, Requirements::Server { initial_scid, .. } => initial_scid, } } fn authenticate_cids(&self) -> Result { fn param_error(reason: &'static str) -> QuicError { QuicError::new( ErrorKind::TransportParameter, FrameType::Crypto.into(), reason, ) } // Because TLS and packet parsing are in parallel, // the scid of the peer end may not be set when the transmission parameters of the peer are obtained. // Therefore, if the scid of the other end is not set, authentication will not be performed first, // and authentication will be performed when it is set. match self.requirements { Requirements::Client { initial_scid, retry_scid: _, origin_dcid, } => { let Some(initial_scid) = initial_scid else { return Ok(false); }; if self .server .get::(ParameterId::InitialSourceConnectionId) .expect("this value must be set") != initial_scid { return Err(param_error( "Initial Source Connection ID from server mismatch", )); } // 并不正确,要和intiial_scid一样地去验证 // if self.server.retry_source_connection_id() != retry_scid { // return Err(param_error("Retry Source Connection ID mismatch")); // } if self .server .get::(ParameterId::OriginalDestinationConnectionId) .expect("this value must be set") != origin_dcid { return Err(param_error("Original Destination Connection ID mismatch")); } Ok(true) } Requirements::Server { initial_scid } => { let Some(initial_scid) = initial_scid else { return Ok(false); }; if self .client .get::(ParameterId::InitialSourceConnectionId) .expect("this value must be set") != initial_scid { return Err(param_error( "Initial Source Connection ID from client mismatch", )); } Ok(true) } } } /// Returns None if the remote parameters are not ready. pub fn negotiated_max_idle_timeout(&self) -> Option { let local_max_idle_timeout = self.get_local(ParameterId::MaxIdleTimeout)?; let remote_max_idle_timeout = self.get_remote(ParameterId::MaxIdleTimeout)?; Some(match (local_max_idle_timeout, remote_max_idle_timeout) { // rfc: https://datatracker.ietf.org/doc/html/rfc9000#name-idle-timeout // Each endpoint advertises a max_idle_timeout, but the effective value // at an endpoint is computed as the minimum of the two advertised // values (or the sole advertised value, if only one endpoint advertises // a non-zero value). By announcing a max_idle_timeout, an endpoint // commits to initiating an immediate close (Section 10.2) if // it abandons the connection prior to the effective value. (Duration::ZERO, Duration::ZERO) => Duration::MAX, (Duration::ZERO, d) | (d, Duration::ZERO) => d, // rfc: https://datatracker.ietf.org/doc/html/rfc9000#name-idle-timeout // If a max_idle_timeout is specified by either endpoint in its // transport parameters (Section 18.2), the connection is silently // closed and its state is discarded when it remains idle for longer // than the minimum of the max_idle_timeout value advertised by both // endpoints. (d1, d2) => d1.min(d2), }) } } /// Shared transport parameter sets for both endpoints. /// /// The local transport parameters are set initially, while /// the remote transport parameters must wait until they are /// received through network transmission and can be parsed. /// After parsing, the peer parameters must be immediately /// verified to ensure they meet the requirements and validity /// checks. /// /// Note that a connection error may occur before receiving /// the remote transport parameters, such as network unreachable. /// In such cases, the entire connection parameters will be /// converted into an error state. #[derive(Debug, Clone)] pub struct ArcParameters(Arc>>); // ArcParameters::lock_guard(&self) -> Result; // pub struct ArcParametersGuard: impl Deref pub struct ParametersGuard<'a>(MutexGuard<'a, Result>); impl Deref for ParametersGuard<'_> { type Target = Parameters; fn deref(&self) -> &Self::Target { self.0.as_ref().expect("parameters must be valid") } } impl DerefMut for ParametersGuard<'_> { fn deref_mut(&mut self) -> &mut Self::Target { self.0.as_mut().expect("parameters must be valid") } } impl From for ArcParameters { fn from(params: Parameters) -> Self { Self(Arc::new(Mutex::new(Ok(params)))) } } impl ArcParameters { #[inline] pub fn lock_guard(&self) -> Result, Error> { let guard = self.0.lock().unwrap(); match guard.as_ref() { Ok(_) => Ok(ParametersGuard(guard)), Err(e) => Err(e.clone()), } } #[inline] pub async fn remote_ready(&self) -> Result, Error> { std::future::poll_fn(|cx| { let mut parameters = self.lock_guard()?; parameters.poll_ready(cx).map(|()| Ok(parameters)) }) .await } // /// Sets the retry source connection ID in the server // /// transport parameters. // /// // /// It is meaningful only for the client, because only // /// server can send the Retry packet. // pub fn set_retry_scid(&self, cid: ConnectionId) { // let mut guard = self.0.lock().unwrap(); // if let Ok(params) = guard.deref_mut() { // params.set_retry_scid(cid); // } // } /// When some connection error occurred, convert this parameters /// into error state. pub fn on_conn_error(&self, error: &Error) { let mut guard = self.0.lock().unwrap(); if guard.deref_mut().is_ok() { *guard = Err(error.clone()); } } } #[cfg(test)] mod tests { use std::sync::Arc; use super::*; use crate::varint::VarInt; fn create_test_client_params() -> ClientParameters { let mut params = ClientParameters::default(); params .set( ParameterId::InitialSourceConnectionId, ConnectionId::from_slice(b"client_test"), ) .unwrap(); params } fn create_test_server_params() -> ServerParameters { let mut params = ServerParameters::default(); params .set( ParameterId::InitialSourceConnectionId, ConnectionId::from_slice(b"server_test"), ) .unwrap(); params .set( ParameterId::OriginalDestinationConnectionId, ConnectionId::from_slice(b"original"), ) .unwrap(); params } #[test] fn test_parameters_new() { let client_params = create_test_client_params(); let params = Parameters::new_client(client_params, None, ConnectionId::from_slice(b"odcid")); assert_eq!(params.role(), Role::Client); assert_eq!(params.state, Parameters::CLIENT_READY); let server_params = create_test_server_params(); let params = Parameters::new_server(server_params); assert_eq!(params.role(), Role::Server); assert_eq!(params.state, Parameters::SERVER_READY); } #[test] fn test_authenticate_cids() { let client_params = create_test_client_params(); let odcid = ConnectionId::from_slice(b"odcid"); let mut params = Parameters::new_client(client_params, None, odcid); let server_cid = ConnectionId::from_slice(b"server_test"); params .initial_scid_from_peer_need_equal(server_cid) .unwrap(); params.server = Arc::new({ let mut server_params = ServerParameters::default(); server_params .set(ParameterId::InitialSourceConnectionId, server_cid) .unwrap(); server_params .set(ParameterId::OriginalDestinationConnectionId, odcid) .unwrap(); server_params }); assert!(params.authenticate_cids().is_ok()); } #[test] fn test_parameters_as_client() { let client_params = create_test_client_params(); let arc_params = ArcParameters::from(Parameters::new_client( client_params, None, ConnectionId::from_slice(b"odcid"), )); // Test accessing parameters through lock_guard let guard = arc_params.lock_guard().unwrap(); // Test local params assert!(matches!( guard.get_local::(ParameterId::MaxUdpPayloadSize), Some(value) if value.into_u64() >= 1200 )); // Test remembered params assert!(guard.remembered().is_none()); } #[test] fn test_validate_remote_params() { // Test invalid max_udp_payload_size assert_eq!( ClientParameters::parse_from_bytes(&[ 1, 1, 0, // max_idle_timeout 3, 2, 0x43, 0xE8, // max_udp_payload_size: 1000 4, 1, 0, // initial_max_data 5, 1, 0, // initial_max_stream_data_bidi_local 6, 1, 0, // initial_max_stream_data_bidi_remote 7, 1, 0, // initial_max_stream_data_uni 8, 1, 0, // initial_max_streams_bidi 9, 1, 0, // initial_max_streams_uni 10, 1, 3, // ack_delay_exponent 11, 1, 25, // max_ack_delay 14, 1, 2, // active_connection_id_limit 15, 0, // initial_source_connection_id 32, 4, 128, 0, 255, 255, // max_datagram_frame_size ]), Err(QuicError::new( ErrorKind::TransportParameter, FrameType::Crypto.into(), "MaxUdpPayloadSize's value 1000 is out of bounds 1200..=65527", )) ); } #[test] fn test_write_parameters() { let client_params = create_test_client_params(); let params = ArcParameters::from(Parameters::new_client( client_params, None, ConnectionId::from_slice(b"odcid"), )); // Test that we can access the parameters let guard = params.lock_guard().unwrap(); assert_eq!(guard.role(), Role::Client); } #[tokio::test] async fn test_arc_parameters_error_handling() { let arc_params = ArcParameters::from(Parameters::new_client( create_test_client_params(), None, ConnectionId::from_slice(b"odcid"), )); // Simulate connection error let error = QuicError::new( ErrorKind::TransportParameter, FrameType::Crypto.into(), "test error", ) .into(); arc_params.on_conn_error(&error); assert!(arc_params.lock_guard().is_err()); } } ================================================ FILE: qbase/src/role.rs ================================================ use std::{fmt, ops}; use crate::param::ParameterId; /// Roles in the QUIC protocol, including client and server. /// /// The least significant bit (0x01) of the [`StreamId`](crate::sid) identifies the initiator role of the stream. /// Client-initiated streams have even-numbered stream IDs (with the bit set to 0), /// and server-initiated streams have odd-numbered stream IDs (with the bit set to 1). /// See [section-2.1-3](https://www.rfc-editor.org/rfc/rfc9000.html#section-2.1-3) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html). /// /// # Note /// /// As a protocol capable of multiplexing streams, QUIC is different from traditional /// HTTP protocols for clients and servers. /// In the QUIC protocol, it is not only the client that can actively open a new stream; /// the server can also actively open a new stream to push some data to the client. /// In fact, in a new stream, the server can initiate an HTTP3 request to the client, /// and the client, upon receiving the request, responds back to the server. /// In this case, the client surprisingly plays the role of the traditional "server", /// which is quite fascinating. /// /// # Example /// /// ``` /// use qbase::role::Role; /// /// let local = Role::Client; /// let peer = !local; /// let is_client = matches!(local, Role::Client); // true /// let is_server = matches!(peer, Role::Server); // true /// ``` #[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum Role { /// The initiator of a connection Client = 0, /// The acceptor of a connection Server = 1, } impl fmt::Display for Role { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.pad(match *self { Self::Client => "client", Self::Server => "server", }) } } impl ops::Not for Role { type Output = Self; fn not(self) -> Self { match self { Self::Client => Self::Server, Self::Server => Self::Client, } } } pub trait IntoRole { /// Convert the type into a [`Role`]. fn into_role() -> Role; } pub trait RequiredParameters { fn required_parameters() -> impl IntoIterator; } #[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] pub struct Client; impl From for Role { fn from(_: Client) -> Self { Role::Client } } impl IntoRole for Client { fn into_role() -> Role { Role::Client } } impl RequiredParameters for Client { fn required_parameters() -> impl IntoIterator { [ParameterId::InitialSourceConnectionId].into_iter() } } #[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] pub struct Server; impl From for Role { fn from(_: Server) -> Self { Role::Server } } impl IntoRole for Server { fn into_role() -> Role { Role::Server } } impl RequiredParameters for Server { fn required_parameters() -> impl IntoIterator { [ ParameterId::InitialSourceConnectionId, ParameterId::OriginalDestinationConnectionId, ] .into_iter() } } ================================================ FILE: qbase/src/sid/handy.rs ================================================ use super::{ControlStreamsConcurrency, Dir}; /// Consistent concurrency strategy increase limits as streams are closed, /// to keep the number of streams available to peers roughly consistent. #[derive(Debug)] pub struct ConsistentConcurrency { max_streams: [u64; 2], } impl ConsistentConcurrency { pub fn new(initial_max_bi: u64, initial_max_uni: u64) -> Self { Self { max_streams: [initial_max_bi, initial_max_uni], } } } impl ControlStreamsConcurrency for ConsistentConcurrency { fn on_accept_streams(&mut self, _dir: Dir, _sid: u64) -> Option { None } fn on_end_of_stream(&mut self, dir: Dir, _sid: u64) -> Option { let idx = dir as usize; let new_limit = self.max_streams[idx] + 1; self.max_streams[idx] = new_limit; Some(new_limit) } fn on_streams_blocked(&mut self, _dir: Dir, _max_streams: u64) -> Option { None } } /// Demand concurrency strategy increase limits as long as receiving a /// [`StreamsBlockedFrame`](crate::frame::StreamsBlockedFrame). #[derive(Debug)] pub struct DemandConcurrency; impl ControlStreamsConcurrency for DemandConcurrency { fn on_accept_streams(&mut self, _dir: Dir, _sid: u64) -> Option { None } fn on_end_of_stream(&mut self, _dir: Dir, _sid: u64) -> Option { None } fn on_streams_blocked(&mut self, _dir: Dir, max_streams: u64) -> Option { Some(max_streams + 1) } } ================================================ FILE: qbase/src/sid/local_sid.rs ================================================ use std::{ collections::VecDeque, sync::{Arc, Mutex}, task::{Context, Poll, Waker}, }; use super::{Dir, Role, StreamId}; use crate::{ frame::{ MaxStreamsFrame, StreamsBlockedFrame, io::{ReceiveFrame, SendFrame}, }, net::tx::{ArcSendWakers, Signals}, sid::MAX_STREAMS_LIMIT, varint::VarInt, }; /// Local stream IDs management. #[derive(Debug)] struct LocalStreamIds { /// Our role role: Role, max: [u64; 2], unallocated: [u64; 2], /// Used for waiting for the MaxStream frame notification from peer when we have exhausted the creation of stream IDs wakers: [VecDeque; 2], /// The StreamsBlocked frames that will be sent to peer blocked: BLOCKED, tx_wakers: ArcSendWakers, } impl LocalStreamIds where BLOCKED: SendFrame + Clone + Send + 'static, { /// Create a new [`LocalStreamIds`] with the given role, /// and maximum number of streams that can be created in each [`Dir`]. fn new( role: Role, init_max_bi_streams: u64, init_max_uni_streams: u64, blocked: BLOCKED, tx_wakers: ArcSendWakers, ) -> Self { debug_assert!( role == Role::Client || (init_max_bi_streams == 0 && init_max_uni_streams == 0), "Server cannot remember the parameters" ); Self { role, max: [init_max_bi_streams, init_max_uni_streams], unallocated: [0, 0], wakers: [VecDeque::with_capacity(2), VecDeque::with_capacity(2)], blocked, tx_wakers, } } /// Returns local role. fn role(&self) -> Role { self.role } /// Returns the number of opened streams in the `dir` direction. fn opened_streams(&self, dir: Dir) -> u64 { self.unallocated[dir as usize] } /// Receive the [`MaxStreamsFrame`](`crate::frame::MaxStreamsFrame`) from peer, /// update the maximum stream ID that can be opened locally in the given direction. fn recv_max_streams_frame(&mut self, frame: MaxStreamsFrame) { let (dir, val) = match frame { MaxStreamsFrame::Bi(max) => (Dir::Bi, max.into_u64()), MaxStreamsFrame::Uni(max) => (Dir::Uni, max.into_u64()), }; self.increase_limit(dir, val); } fn increase_limit(&mut self, dir: Dir, val: u64) { assert!(val <= MAX_STREAMS_LIMIT); let max_streams = &mut self.max[dir as usize]; // RFC9000: MAX_STREAMS frames that do not increase the stream limit MUST be ignored. if *max_streams < val { // The rejected 0rtt stream can be sent again, as if new data was written. if *max_streams < self.unallocated[dir as usize] { self.tx_wakers.wake_all_by(Signals::WRITTEN); } for waker in self.wakers[dir as usize].drain(..) { waker.wake(); } *max_streams = val; } } fn poll_alloc_sid(&mut self, cx: &mut Context<'_>, dir: Dir) -> Poll> { let idx = dir as usize; let max = self.max[idx]; let unallocated = self.unallocated[idx]; if unallocated > MAX_STREAMS_LIMIT { Poll::Ready(None) } else if unallocated < max { self.unallocated[idx] += 1; Poll::Ready(Some(StreamId::new(self.role, dir, unallocated))) } else { // waiting for MAX_STREAMS frame from peer self.wakers[idx].push_back(cx.waker().clone()); // if Poll::Pending is returned, connection can send a STREAMS_BLOCKED frame to peer self.blocked.send_frame([StreamsBlockedFrame::with( dir, VarInt::from_u64(max).expect("max_streams limit must be less than VARINT_MAX"), )]); Poll::Pending } } pub fn revise_max_streams( &mut self, zero_rtt_rejected: bool, max_stream_bidi: u64, max_stream_uni: u64, ) { if zero_rtt_rejected { self.max = [0, 0]; } self.increase_limit(Dir::Bi, max_stream_bidi); self.increase_limit(Dir::Uni, max_stream_uni); } } /// Management of stream IDs that can ben allowed to use locally. /// /// The maximum stream ID that can be created is limited by the /// [`MaxStreamsFrame`](`crate::frame::MaxStreamsFrame`) from the peer. /// /// When the stream IDs in the `dir` direction are exhausted, /// a [`StreamsBlockedFrame`](`crate::frame::StreamsBlockedFrame`) will be sent to the peer. /// The generic parameter `BLOCKED` is the container of the [`StreamsBlockedFrame`] /// that will be sent to peer, it can be a channel, a queue, or a buffer, /// as long as it can send the [`StreamsBlockedFrame`] to peer. #[derive(Debug, Clone)] pub struct ArcLocalStreamIds(Arc>>); impl ArcLocalStreamIds where BLOCKED: SendFrame + Clone + Send + 'static, { /// Create a new [`ArcLocalStreamIds`] with the given role, /// and maximum number of streams that can be created in each direction, /// the `blocked` contains the [`StreamsBlockedFrame`] that will be sent to peer. pub fn new( role: Role, max_bidi: u64, max_uni: u64, blocked: BLOCKED, tx_wakers: ArcSendWakers, ) -> Self { Self(Arc::new(Mutex::new(LocalStreamIds::new( role, max_bidi, max_uni, blocked, tx_wakers, )))) } /// Returns local role pub fn role(&self) -> Role { self.0.lock().unwrap().role() } /// Returns the number of opened streams in the `dir` direction. /// /// If `is_0rtt` is true, this will return the stream open in 0rtt phase. /// /// If `is_0rtt` is false, the return value will not be greater than max_streams, /// that is, if 0rtt is rejected, the return value may be less than the number of open streams. /// This is the number of streams that can actually be sent in the 1rtt space. pub fn opened_streams(&self, dir: Dir) -> u64 { self.0.lock().unwrap().opened_streams(dir) } /// Receive the [`MaxStreamsFrame`](`crate::frame::MaxStreamsFrame`) from peer, /// and then update the maximum stream ID that can be allowed to use locally. /// /// The maximum stream ID that can be allowed to use is limited by peer. /// Therefore, it mainly depends on the peer's attitude /// and is subject to the [`MaxStreamsFrame`](`crate::frame::MaxStreamsFrame`) /// received from peer. pub fn recv_max_streams_frame(&self, frame: MaxStreamsFrame) { self.0.lock().unwrap().recv_max_streams_frame(frame); } /// Asynchronously allocate the next new [`StreamId`] in the `dir` direction. /// /// Return a bool indicating whether the stream is opened in 0rtt phase. /// /// When the application layer wants to proactively open a new stream, /// it needs to first apply to allocate the next unused [`StreamId`]. /// Note that streams on a QUIC connection usually have a maximum concurrency limit, /// so when requesting a [`StreamId`], it may not be possible to obtain one due to /// reaching the maximum concurrency limit. /// However, this is temporary. When the active current streams end, /// the peer will expand the maximum stream ID limit through a /// [`MaxStreamsFrame`](`crate::frame::MaxStreamsFrame`), /// allowing the allocation of the [`StreamId`] meanwhile. /// /// Return Pending when the stream IDs in the `dir` direction are exhausted, /// until receiving the [`MaxStreamsFrame`](`crate::frame::MaxStreamsFrame`) from peer. /// /// Return None if the stream IDs in the `dir` direction finally exceed 2^60, /// but it is very very hard to happen. pub fn poll_alloc_sid(&self, cx: &mut Context<'_>, dir: Dir) -> Poll> { self.0.lock().unwrap().poll_alloc_sid(cx, dir) } pub fn revise_max_streams( &self, zero_rtt_rejected: bool, max_stream_bidi: u64, max_stream_uni: u64, ) { self.0.lock().unwrap().revise_max_streams( zero_rtt_rejected, max_stream_bidi, max_stream_uni, ); } } impl ReceiveFrame for ArcLocalStreamIds where BLOCKED: SendFrame + Clone + Send + 'static, { type Output = (); fn recv_frame(&self, frame: MaxStreamsFrame) -> Result { self.recv_max_streams_frame(frame); Ok(()) } } #[cfg(test)] mod tests { use derive_more::Deref; use super::*; use crate::util::ArcAsyncDeque; #[derive(Clone, Deref, Default)] struct StreamsBlockedFrameTx(ArcAsyncDeque); impl SendFrame for StreamsBlockedFrameTx { fn send_frame>(&self, iter: I) { (&self.0).extend(iter); } } #[test] fn test_stream_id_new() { let sid = StreamId::new(Role::Client, Dir::Bi, 0); assert_eq!(sid, StreamId(0)); assert_eq!(sid.role(), Role::Client); assert_eq!(sid.dir(), Dir::Bi); } #[test] fn test_recv_max_stream_frames() { let local = ArcLocalStreamIds::new( Role::Client, 0, 0, StreamsBlockedFrameTx::default(), ArcSendWakers::default(), ); local.recv_max_streams_frame(MaxStreamsFrame::Bi(VarInt::from_u32(0))); let waker = futures::task::noop_waker(); let mut cx = Context::from_waker(&waker); assert_eq!(local.poll_alloc_sid(&mut cx, Dir::Bi), Poll::Pending,); assert!(!local.0.lock().unwrap().wakers[0].is_empty()); local.recv_max_streams_frame(MaxStreamsFrame::Bi(VarInt::from_u32(1))); let _ = local.0.lock().unwrap().wakers[0].pop_front(); assert_eq!( local.poll_alloc_sid(&mut cx, Dir::Bi), Poll::Ready(Some(StreamId(0))) ); assert_eq!(local.poll_alloc_sid(&mut cx, Dir::Bi), Poll::Pending); assert!(!local.0.lock().unwrap().wakers[0].is_empty()); local.recv_max_streams_frame(MaxStreamsFrame::Uni(VarInt::from_u32(2))); assert_eq!( local.poll_alloc_sid(&mut cx, Dir::Uni), Poll::Ready(Some(StreamId(2))) ); assert_eq!( local.poll_alloc_sid(&mut cx, Dir::Uni), Poll::Ready(Some(StreamId(6))) ); assert_eq!(local.poll_alloc_sid(&mut cx, Dir::Uni), Poll::Pending); assert!(!local.0.lock().unwrap().wakers[1].is_empty()); } } ================================================ FILE: qbase/src/sid/remote_sid.rs ================================================ use std::sync::{Arc, Mutex}; use thiserror::Error; use super::{ControlStreamsConcurrency, Dir, Role, StreamId}; use crate::{ frame::{ MaxStreamsFrame, StreamsBlockedFrame, io::{ReceiveFrame, SendFrame}, }, varint::VarInt, }; /// Exceed the maximum stream ID limit error, /// similar with [`ErrorKind::StreamLimit`](`crate::error::ErrorKind::StreamLimit`). /// /// This error occurs when the stream ID in the received stream-related frames /// exceeds the maximum stream ID limit. #[derive(Debug, PartialEq, Error)] #[error("{0} exceed limit: {1}")] pub struct ExceedLimitError(StreamId, u64); /// Accept the stream ID received from peer, /// returned by [`ArcRemoteStreamIds::try_accept_sid`]. #[derive(Debug, PartialEq)] pub enum AcceptSid { /// Indicates that the stream ID is already exist. Old, /// Indicates that the stream ID is new and need to create. /// The `NeedCreate` inside indicates the range of stream IDs that need to be created together. New(NeedCreate), } /// The range of stream IDs that need to be created, /// see [`ArcRemoteStreamIds::try_accept_sid`] and [`AcceptSid::New`]. #[derive(Debug, PartialEq)] pub struct NeedCreate { start: StreamId, end: StreamId, } impl Iterator for NeedCreate { type Item = StreamId; fn next(&mut self) -> Option { if self.start > self.end { None } else { // Safety: Since being generated from "StreamIds", they could not overflow. let id = self.start; self.start = unsafe { self.start.next_unchecked() }; Some(id) } } } /// Remote stream IDs management. #[derive(Debug)] struct RemoteStreamIds { role: Role, // The role of the peer max: [u64; 2], // The maximum stream ID that limit peer to create unallocated: [StreamId; 2], // The stream ID that peer has not used ctrl: Box, // The strategy to control the concurrency of streams max_tx: MAX, // The channel to send the MAX_STREAMS frame to peer } impl RemoteStreamIds where MAX: SendFrame + Clone + Send + 'static, { /// Create a new [`RemoteStreamIds`] with the given role, /// and maximum number of streams that can be created by peer in each [`Dir`]. fn new( role: Role, max_bi: u64, max_uni: u64, max_tx: MAX, ctrl: Box, ) -> Self { Self { role, max: [max_bi, max_uni], unallocated: [ StreamId::new(role, Dir::Bi, 0), StreamId::new(role, Dir::Uni, 0), ], ctrl, max_tx, } } /// Returns the role of the peer. fn role(&self) -> Role { self.role } fn try_accept_sid(&mut self, sid: StreamId) -> Result { debug_assert_eq!(sid.role(), self.role); let idx = sid.dir() as usize; if sid.id() > self.max[idx] { return Err(ExceedLimitError(sid, self.max[idx])); } let cur = &mut self.unallocated[idx]; if sid < *cur { Ok(AcceptSid::Old) } else { let start = *cur; *cur = unsafe { sid.next_unchecked() }; if let Some(max_streams) = self.ctrl.on_accept_streams(sid.dir(), sid.id()) { self.max[idx] = max_streams; self.max_tx.send_frame([MaxStreamsFrame::with( sid.dir(), VarInt::from_u64(max_streams) .expect("max_streams must be less than VARINT_MAX"), )]); } Ok(AcceptSid::New(NeedCreate { start, end: sid })) } } fn on_end_of_stream(&mut self, sid: StreamId) { if sid.role() != self.role { return; } if let Some(max_streams) = self.ctrl.on_end_of_stream(sid.dir(), sid.id()) { self.max[sid.dir() as usize] = max_streams; self.max_tx.send_frame([MaxStreamsFrame::with( sid.dir(), VarInt::from_u64(max_streams).expect("max_streams must be less than VARINT_MAX"), )]); } } fn recv_streams_blocked_frame(&mut self, frame: StreamsBlockedFrame) { let (dir, max_streams) = match frame { StreamsBlockedFrame::Bi(max) => (Dir::Bi, max.into_u64()), StreamsBlockedFrame::Uni(max) => (Dir::Uni, max.into_u64()), }; if let Some(max_streams) = self.ctrl.on_streams_blocked(dir, max_streams) { self.max[dir as usize] = max_streams; self.max_tx.send_frame([MaxStreamsFrame::with( dir, VarInt::from_u64(max_streams).expect("max_streams must be less than VARINT_MAX"), )]); } } } /// Shared remote stream IDs, mainly controls and monitors the stream IDs /// in the received stream-related frames from peer. /// /// Checks whether the stream IDs exceed the limit ,and creates them if necessary. /// And sends a [`MaxStreamsFrame`](`crate::frame::MaxStreamsFrame`) /// to the peer to update the maximum stream ID limit in time. /// /// # Note /// /// After receiving the peer's stream-related frames, /// due to possible out-of-order reception issues, /// the stream IDs in these frames may have gaps, /// i.e., they may not be continuous with the previous stream ID of the same type. /// So before a stream is created, /// all streams of the same type with lower-numbered stream IDs MUST be created. /// This ensures that the creation order for streams is consistent on both endpoints #[derive(Debug, Clone)] pub struct ArcRemoteStreamIds(Arc>>); impl ArcRemoteStreamIds where MAX: SendFrame + Clone + Send + 'static, { /// Create a new [`ArcRemoteStreamIds`] with the given role, /// and maximum number of streams that can be created by peer in each direction. /// /// The maximum number of streams that can be created by peer in each direction /// are `initial_max_streams_bidi` and `initial_max_sterams_uni` /// in local [`Parameters`](`crate::param::Parameters`). /// See [section-18.2-4.21](https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-4.21) /// and [section-18.2-4.23](https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-4.23) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. pub fn new( role: Role, max_bi: u64, max_uni: u64, max_tx: MAX, ctrl: Box, ) -> Self { Self(Arc::new(Mutex::new(RemoteStreamIds::new( role, max_bi, max_uni, max_tx, ctrl, )))) } /// Returns the role of the peer. pub fn role(&self) -> Role { self.0.lock().unwrap().role() } /// Try to accept the stream ID received from peer. /// /// Only if this stream ID must be created by peer, this function needs to be called. /// /// This stream ID may belong to an already existing stream or a new stream that does not yet exist. /// If it is the latter, a new stream needs to be created. /// Before a stream is created, all streams of the same type /// with lower-numbered stream IDs MUST be created. /// See [section-3.2-6](https://www.rfc-editor.org/rfc/rfc9000.html#section-3.2-6) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. /// /// # Return /// /// - Return [`ExceedLimitError`] if the stream ID exceeds the maximum stream ID limit. /// - Return [`AcceptSid::Old`] if the stream ID is already exist. /// - Return [`AcceptSid::New`] if the stream ID is new and need to create. /// The `NeedCreate` inside indicates the range of stream IDs that need to be created. pub fn try_accept_sid(&self, sid: StreamId) -> Result { self.0.lock().unwrap().try_accept_sid(sid) } #[inline] pub fn on_end_of_stream(&self, sid: StreamId) { self.0.lock().unwrap().on_end_of_stream(sid); } #[inline] pub fn recv_streams_blocked_frame(&self, frame: StreamsBlockedFrame) { self.0.lock().unwrap().recv_streams_blocked_frame(frame); } } impl ReceiveFrame for ArcRemoteStreamIds where MAX: SendFrame + Clone + Send + 'static, { type Output = (); fn recv_frame(&self, frame: StreamsBlockedFrame) -> Result { self.recv_streams_blocked_frame(frame); Ok(()) } } #[cfg(test)] mod tests { use derive_more::Deref; use super::*; use crate::{sid::handy::ConsistentConcurrency, util::ArcAsyncDeque}; #[derive(Clone, Deref, Default)] struct MaxStreamsFrameTx(ArcAsyncDeque); impl SendFrame for MaxStreamsFrameTx { fn send_frame>(&self, iter: I) { (&self.0).extend(iter); } } #[test] fn test_try_accept_sid() { let remote = ArcRemoteStreamIds::new( Role::Server, 10, 5, MaxStreamsFrameTx::default(), Box::new(ConsistentConcurrency::new(10, 5)), ); let result = remote.try_accept_sid(StreamId(21)); assert_eq!( result, Ok(AcceptSid::New(NeedCreate { start: StreamId(1), end: StreamId(21) })) ); assert_eq!(remote.0.lock().unwrap().unallocated[0], StreamId(25)); let result = remote.try_accept_sid(StreamId(25)); assert_eq!( result, Ok(AcceptSid::New(NeedCreate { start: StreamId(25), end: StreamId(25) })) ); assert_eq!(remote.0.lock().unwrap().unallocated[0], StreamId(29)); let result = remote.try_accept_sid(StreamId(41)); assert_eq!( result, Ok(AcceptSid::New(NeedCreate { start: StreamId(29), end: StreamId(41) })) ); assert_eq!(remote.0.lock().unwrap().unallocated[0], StreamId(45)); if let Ok(AcceptSid::New(mut range)) = result { assert_eq!(range.next(), Some(StreamId(29))); assert_eq!(range.next(), Some(StreamId(33))); assert_eq!(range.next(), Some(StreamId(37))); assert_eq!(range.next(), Some(StreamId(41))); assert_eq!(range.next(), None); } let result = remote.try_accept_sid(StreamId(65)); assert_eq!(result, Err(ExceedLimitError(StreamId(65), 10))); } } ================================================ FILE: qbase/src/sid.rs ================================================ use std::fmt; use super::{ frame::MaxStreamsFrame, varint::{VarInt, WriteVarInt, be_varint}, }; use crate::{ frame::{StreamsBlockedFrame, io::SendFrame}, net::tx::ArcSendWakers, role::Role, }; /// Sum type for stream directions. /// /// Streams can be unidirectional or bidirectional. /// Unidirectional streams carry data in one direction: from the initiator of the stream to its peer. /// Bidirectional streams allow for data to be sent in both directions. /// See [section-2.1-1](https://www.rfc-editor.org/rfc/rfc9000.html#section-2.1-1) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html). /// /// The second least significant bit (0x02) of the [`StreamId`] distinguishes between /// bidirectional streams (with the bit set to 0) and unidirectional streams (with the bit set to 1). /// See [section-2.1-4](https://www.rfc-editor.org/rfc/rfc9000.html#section-2.1-4) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html). #[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum Dir { /// Data flows in both directions Bi = 0, /// Data flows only from the stream's initiator Uni = 1, } impl fmt::Display for Dir { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.pad(match *self { Self::Bi => "bidirectional", Self::Uni => "unidirectional", }) } } /// Streams are identified within a connection by a numeric value, /// referred to as the stream ID. /// /// A stream ID is a 62-bit integer (0 to 262-1) that is unique for all streams on a connection. /// Stream IDs are encoded as [`VarInt`]. /// A QUIC endpoint MUST NOT reuse a stream ID within a connection. /// /// There are four types of streams in QUIC, divided according to the role and direction of the stream. /// See [Stream ID Types](https://www.rfc-editor.org/rfc/rfc9000.html#name-stream-id-types) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct StreamId(u64); /// Maximum ID for each type of stream. /// /// [`StreamId`] is encoded with [`VarInt`]. /// After removing the lowest 2 bits for direction and role, /// the remaining 60 bits are used to represent the actual ID for each type of stream, /// so its maximum range cannot exceed 2^60. pub const MAX_STREAMS_LIMIT: u64 = (1 << 60) - 1; impl StreamId { /// Create a new stream ID with the given role, direction, and ID. /// /// It is prohibited to directly create a StreamId from external sources. /// StreamId can only be allocated incrementally by proactively creating new streams locally. /// or accepting new streams opened by peer. pub fn new(role: Role, dir: Dir, id: u64) -> Self { assert!(id <= MAX_STREAMS_LIMIT); Self((((id << 1) | (dir as u64)) << 1) | (role as u64)) } /// Returns the role of this stream ID. pub fn role(&self) -> Role { if self.0 & 0x1 == 0 { Role::Client } else { Role::Server } } /// Returns the direction of this stream ID. pub fn dir(&self) -> Dir { if self.0 & 2 == 0 { Dir::Bi } else { Dir::Uni } } /// Get the actual ID of this stream, removing the lowest 2 bits for direction and role. pub fn id(&self) -> u64 { self.0 >> 2 } unsafe fn next_unchecked(&self) -> Self { Self(self.0 + 4) } /// Return the encoding size of this stream ID. pub fn encoding_size(&self) -> usize { VarInt::from(*self).encoding_size() } } impl fmt::Display for StreamId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, "{} side {} stream {}", self.role(), self.dir(), self.id() ) } } impl From for StreamId { fn from(v: VarInt) -> Self { Self(v.into_u64()) } } impl From for VarInt { fn from(s: StreamId) -> Self { VarInt::from_u64(s.0).expect("stream id must be less than VARINT_MAX") } } impl From for u64 { fn from(s: StreamId) -> Self { s.0 } } /// Parse a stream ID from the input bytes, /// [nom](https://docs.rs/nom/6.2.1/nom/) parser style. pub fn be_streamid(input: &[u8]) -> nom::IResult<&[u8], StreamId> { use nom::{Parser, combinator::map}; map(be_varint, StreamId::from).parse(input) } /// A BufMut extension trait for writing a stream ID. pub trait WriteStreamId: bytes::BufMut { /// Write a stream ID to the buffer. fn put_streamid(&mut self, stream_id: &StreamId); } impl WriteStreamId for T { fn put_streamid(&mut self, stream_id: &StreamId) { self.put_varint(&(*stream_id).into()); } } /// Controls the concurrency of unidirectional and bidirectional streams created by the peer, /// primarily through [`StreamsBlockedFrame`] and [`MaxStreamsFrame`]. /// /// [RFC 9000](https://www.rfc-editor.org/rfc/rfc9000.html) /// leaves implementations to decide when and how many streams should be /// advertised to a peer via MAX_STREAMS. Implementations might choose to /// increase limits as streams are closed, to keep the number of streams /// available to peers roughly consistent. /// /// Implementations might also choose to increase limits as long as the /// peer needs to create new streams. /// /// See [controlling concurrency](https://www.rfc-editor.org/rfc/rfc9000.html#name-controlling-concurrency). /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. pub trait ControlStreamsConcurrency: fmt::Debug + Send + Sync { /// Called back upon accepting a new `dir` direction streams with stream id `sid` from peer, /// all previous inexistent `dir` direction streams should be opened by peer will also be created. /// /// Returns whether to increase the maximum stream ID limit, /// which will be communicated to the peer via a MAX_STREAMS frame in the future. /// If None is returned, it means there is no need to /// increase the MAX_STREAMS for the time being. #[must_use] fn on_accept_streams(&mut self, dir: Dir, sid: u64) -> Option; /// Called back upon a `dir` directional stream is ended, /// whether it is closed normally or reset abnormally. /// /// The `sid` is the stream ID of the ended `dir` direction stream. /// /// Returns whether to increase the maximum stream ID limit, /// which will be communicated to the peer via a MAX_STREAMS frame in the future. /// If None is returned, it means there is no need to /// increase the MAX_STREAMS for the time being. fn on_end_of_stream(&mut self, dir: Dir, sid: u64) -> Option; /// Called back upon receiving the StreamsBlocked frame, /// which indicates that the peer is limited to create more `dir` direction streams. /// /// It may optionally return an increased value for the `max_streams` /// for the `dir` directional streams. /// If None is returned, it means there is no need to increase /// the MAX_STREAMS for the time being. fn on_streams_blocked(&mut self, dir: Dir, max_streams: u64) -> Option; } impl ControlStreamsConcurrency for Box { fn on_accept_streams(&mut self, dir: Dir, sid: u64) -> Option { self.as_mut().on_accept_streams(dir, sid) } fn on_end_of_stream(&mut self, dir: Dir, sid: u64) -> Option { self.as_mut().on_end_of_stream(dir, sid) } fn on_streams_blocked(&mut self, dir: Dir, max_streams: u64) -> Option { self.as_mut().on_streams_blocked(dir, max_streams) } } pub trait ProductStreamsConcurrencyController: Send + Sync { fn init( &self, init_max_bidi_streams: u64, init_max_uni_streams: u64, ) -> Box; } impl ProductStreamsConcurrencyController for F where F: Fn(u64, u64) -> C + Send + Sync, C: ControlStreamsConcurrency + 'static, { #[inline] fn init( &self, init_max_bidi_streams: u64, init_max_uni_streams: u64, ) -> Box { Box::new((self)(init_max_bidi_streams, init_max_uni_streams)) } } pub mod handy; pub mod local_sid; pub use local_sid::ArcLocalStreamIds; pub mod remote_sid; pub use remote_sid::ArcRemoteStreamIds; /// Stream IDs management, including an [`ArcLocalStreamIds`] as local, /// and an [`ArcRemoteStreamIds`] as remote. #[derive(Debug, Clone)] pub struct StreamIds { pub local: ArcLocalStreamIds, pub remote: ArcRemoteStreamIds, } impl StreamIds where T: SendFrame + SendFrame + Clone + Send + 'static, { /// Create a new [`StreamIds`] with the given role, and maximum number of streams of each direction. /// /// The troublesome part is that the maximum number of streams that can be created locally /// is restricted by the peer's `initial_max_streams_uni` and `initial_max_streams_bidi` transport /// parameters, which are unknown at the beginning. /// Therefore, peer's `initial_max_streams_xx` can be set to 0 initially, /// and then updated later after obtaining the peer's `initial_max_streams_xx` setting. #[allow(clippy::too_many_arguments)] pub fn new( role: Role, local_max_bi: u64, local_max_uni: u64, remote_max_bi: u64, remote_max_uni: u64, sid_frames_tx: T, ctrl: Box, tx_wakers: ArcSendWakers, ) -> Self { // 缺省为0 let local = ArcLocalStreamIds::new( role, remote_max_bi, remote_max_uni, sid_frames_tx.clone(), tx_wakers, ); let remote = ArcRemoteStreamIds::new(!role, local_max_bi, local_max_uni, sid_frames_tx, ctrl); Self { local, remote } } } ================================================ FILE: qbase/src/time.rs ================================================ use std::sync::{Arc, Mutex, RwLock}; use thiserror::Error; use tokio::time::{Duration, Instant}; use crate::{frame::PingFrame, packet::PacketContent}; #[derive(Debug, Error)] #[error("Path has been idle for too long")] pub struct TimeOut; #[derive(Debug)] pub struct IdleConfig { max_idle_timeout: Duration, defer_idle_timeout: Duration, heartbeat_interval: Duration, } impl IdleConfig { fn suitable_heartbeat_interval(max_idle_timeout: Duration) -> Duration { if max_idle_timeout == Duration::ZERO { Duration::from_secs(30) } else { (max_idle_timeout / 2) .max(Duration::from_secs(1)) .min(Duration::from_secs(30)) } } // Creates a new `IdleTimer` with the specified maximum idle timeout and defer idle timeout. pub fn new(max_idle_timeout: Duration, defer_idle_timeout: Duration) -> Self { let heartbeat_interval = Self::suitable_heartbeat_interval(max_idle_timeout); Self { max_idle_timeout, defer_idle_timeout, heartbeat_interval, } } // Each endpoint advertises a max_idle_timeout, but the effective value at an endpoint // is computed as the minimum of the two advertised values (or the sole advertised value, // if only one endpoint advertises a non-zero value). // // Idle timeout is disabled when both endpoints omit this transport parameter or specify a value of 0. pub fn negotiate_max_idle_timeout(&mut self, max_idle_timeout: Duration) { match (self.max_idle_timeout, max_idle_timeout) { (_, Duration::ZERO) => (), (Duration::ZERO, remote) => self.max_idle_timeout = remote, (local, remote) => self.max_idle_timeout = local.min(remote), } self.heartbeat_interval = Self::suitable_heartbeat_interval(self.max_idle_timeout); } // Sets the interval for sending heartbeat packets. pub fn set_heartbeat_interval(&mut self, interval: Duration) { self.heartbeat_interval = interval; } } #[derive(Debug, Clone)] pub struct ArcIdleConfig(Arc>); impl ArcIdleConfig { // Creates a new `ArcIdleConfig` with the specified maximum idle timeout and defer idle timeout. pub fn new(max_idle_timeout: Duration, defer_idle_timeout: Duration) -> Self { ArcIdleConfig(Arc::new(RwLock::new(IdleConfig::new( max_idle_timeout, defer_idle_timeout, )))) } // Each endpoint advertises a max_idle_timeout, but the effective value at an endpoint // is computed as the minimum of the two advertised values (or the sole advertised value, // if only one endpoint advertises a non-zero value). // // Idle timeout is disabled when both endpoints omit this transport parameter or specify a value of 0. pub fn negotiate_max_idle_timeout(&self, max_idle_timeout: Duration) { self.0 .write() .unwrap() .negotiate_max_idle_timeout(max_idle_timeout); } // Sets the interval for sending heartbeat packets. pub fn set_heartbeat_interval(&self, interval: Duration) { self.0.write().unwrap().set_heartbeat_interval(interval); } pub fn timer(&self) -> ArcIdleTimer { ArcIdleTimer(Arc::new(Mutex::new(IdleTimer { idle_config: self.clone(), heartbeat_times: 0, last_effective_comm: None, idle_begin_at: None, }))) } fn defer_idle_timeout(&self) -> Duration { self.0.read().unwrap().defer_idle_timeout } fn heartbeat_interval(&self) -> Duration { self.0.read().unwrap().heartbeat_interval } fn timeout_after(&self, idle_at: Instant) -> bool { let max_idle_timeout = self.0.read().unwrap().max_idle_timeout; max_idle_timeout != Duration::ZERO && idle_at.elapsed() > max_idle_timeout } } // A timer for each path to determine when to send heartbeat packets // and when to delete the path due to idle timeout. #[derive(Debug)] pub struct IdleTimer { idle_config: ArcIdleConfig, heartbeat_times: u32, last_effective_comm: Option, idle_begin_at: Option, } impl IdleTimer { // Updates the timer when a packet is sent. pub fn on_sent(&mut self, packet_content: PacketContent) { if packet_content == PacketContent::EffectivePayload { self.last_effective_comm = Some(Instant::now()); self.heartbeat_times = 0; self.idle_begin_at = None; } } // Updates the timer when a packet is received. pub fn on_rcvd(&mut self, packet_content: PacketContent) { if packet_content == PacketContent::EffectivePayload { self.last_effective_comm = Some(Instant::now()); self.heartbeat_times = 0; self.idle_begin_at = None; } if self.idle_begin_at.is_some() { self.idle_begin_at = Some(Instant::now()); } } // Checks health of the path and // determines whether a heartbeat packet needs to be sent. pub fn health(&mut self) -> Result, TimeOut> { if let Some(t) = self.last_effective_comm { let elapsed = t.elapsed(); if elapsed > self.idle_config.defer_idle_timeout() { if self.idle_begin_at.is_none() { self.idle_begin_at = Some(Instant::now()); return Ok(Some(PingFrame)); // heartbeat for the last time } } else if elapsed > self.idle_config.heartbeat_interval() * (self.heartbeat_times + 1) { self.heartbeat_times += 1; return Ok(Some(PingFrame)); } } if self .idle_begin_at .is_some_and(|t| self.idle_config.timeout_after(t)) { return Err(TimeOut); } Ok(None) } } // A shared timer for each path to determine when to send heartbeat packets // and when to delete the path due to idle timeout. #[derive(Debug, Clone)] pub struct ArcIdleTimer(Arc>); impl ArcIdleTimer { // Updates the timer when a packet is sent. pub fn on_sent(&self, packet_content: PacketContent) { self.0.lock().unwrap().on_sent(packet_content); } // Updates the timer when a packet is received. pub fn on_rcvd(&self, packet_content: PacketContent) { self.0.lock().unwrap().on_rcvd(packet_content); } // Checks health of the path and // determines whether a heartbeat packet needs to be sent. pub fn health(&self) -> Result, TimeOut> { self.0.lock().unwrap().health() } } ================================================ FILE: qbase/src/token.rs ================================================ use std::{ops::Deref, sync::Arc}; use bytes::BufMut; use derive_more::Deref; use nom::{IResult, bytes::complete::take}; use rand::RngExt; use crate::{ error::{ErrorKind, QuicError}, frame::{GetFrameType, NewTokenFrame, io::ReceiveFrame}, }; pub const RESET_TOKEN_SIZE: usize = 16; #[derive(Deref, Debug, Copy, Clone, Default, PartialEq, Eq, Hash)] pub struct ResetToken([u8; RESET_TOKEN_SIZE]); impl ResetToken { pub fn new(bytes: &[u8]) -> Self { Self(bytes.try_into().unwrap()) } pub fn random_gen() -> Self { let mut bytes = [0; RESET_TOKEN_SIZE]; rand::rng().fill(&mut bytes); Self(bytes) } pub fn encoding_size(&self) -> usize { RESET_TOKEN_SIZE } } pub fn be_reset_token(input: &[u8]) -> IResult<&[u8], ResetToken> { let (input, bytes) = take(RESET_TOKEN_SIZE)(input)?; Ok((input, ResetToken::new(bytes))) } pub trait WriteResetToken { fn put_reset_token(&mut self, token: &ResetToken); } impl WriteResetToken for T { fn put_reset_token(&mut self, token: &ResetToken) { self.put_slice(token.as_slice()); } } pub trait TokenSink: Send + Sync { fn sink(&self, server_name: &str, token: Vec); fn fetch_token(&self, server_name: &str) -> Vec; } pub trait TokenProvider: Send + Sync { fn gen_new_token(&self, server_name: &str) -> Vec; fn gen_retry_token(&self, server_name: &str) -> Vec; // A token sent in a NEW_TOKEN frame or a Retry packet MUST be constructed in // a way that allows the server to identify how it was provided to a client fn verify_token(&self, server_name: &str, token: &[u8]) -> bool; } pub enum TokenRegistry { Client((String, Arc)), Server(Arc), } #[derive(Clone)] pub struct ArcTokenRegistry(Arc); impl ArcTokenRegistry { pub fn with_sink(server_name: String, sink: Arc) -> Self { Self(Arc::new(TokenRegistry::Client((server_name, sink)))) } pub fn with_provider(provider: Arc) -> Self { Self(Arc::new(TokenRegistry::Server(provider))) } } impl Deref for ArcTokenRegistry { type Target = TokenRegistry; fn deref(&self) -> &Self::Target { self.0.deref() } } impl ReceiveFrame for ArcTokenRegistry { type Output = (); fn recv_frame(&self, frame: NewTokenFrame) -> Result { match self.deref() { TokenRegistry::Client((server_name, client)) => { client.sink(server_name, frame.token().to_vec()); Ok(()) } TokenRegistry::Server(_) => Err(QuicError::new( ErrorKind::ProtocolViolation, frame.frame_type().into(), "Server received NewTokenFrame", ) .into()), } } } pub mod handy { pub struct NoopTokenRegistry; impl super::TokenSink for NoopTokenRegistry { fn sink(&self, _: &str, _: Vec) {} fn fetch_token(&self, _: &str) -> Vec { Vec::with_capacity(0) } } impl super::TokenProvider for NoopTokenRegistry { fn gen_new_token(&self, _: &str) -> Vec { Vec::new() } fn gen_retry_token(&self, _: &str) -> Vec { Vec::new() } fn verify_token(&self, _: &str, _: &[u8]) -> bool { false } } } #[cfg(test)] mod tests { #[test] fn test_create_token() { super::ResetToken::new(&[0; 16]); } #[test] #[should_panic] fn test_creat_token_with_less_size() { super::ResetToken::new(&[0; 15]); } #[test] #[should_panic] fn test_creat_token_with_more_size() { super::ResetToken::new(&[0; 17]); } #[test] fn test_read_reset_token() { use nom::error::{Error, ErrorKind}; let buf = vec![0; 16]; let (remain, token) = super::be_reset_token(&buf).unwrap(); assert_eq!(remain.len(), 0); assert_eq!(token, super::ResetToken::new(&[0; 16])); let buf = vec![0; 15]; assert_eq!( super::be_reset_token(&buf), Err(nom::Err::Error(Error::new(&buf[..], ErrorKind::Eof))) ); } #[test] fn test_write_reset_token() { use super::WriteResetToken; let mut buf = vec![]; let token = super::ResetToken::new(&[0; 16]); buf.put_reset_token(&token); assert_eq!(buf, &[0; 16]); } } ================================================ FILE: qbase/src/util/async_deque.rs ================================================ use std::{ collections::VecDeque, future::Future, pin::Pin, sync::{Arc, Mutex, MutexGuard}, task::{Context, Poll, Waker}, }; /// AsyncDeque is a deque that can be used in async context. /// /// It is a wrapper around VecDeque, with the ability to be popped in async context. /// That is, when calling pop on an empty queue, /// it will suspend the current task until a new element is pushed in. /// In a sense, it is a combination of the sender and receiver ends of an mpsc channel, /// and the sender can insert in both directions. #[derive(Debug)] struct AsyncDeque { queue: Option>, waker: Option, } impl AsyncDeque { /// Insert an element at the back of the queue, /// and wake up the `pop` task registered by [AsyncDeque::poll_pop] if necessary. fn push_back(&mut self, value: T) { if let Some(queue) = &mut self.queue { queue.push_back(value); if let Some(waker) = self.waker.take() { waker.wake(); } } } /// Insert an element at the front of the deque, /// and wake up the `pop` task registered by [AsyncDeque::poll_pop] if necessary. fn push_front(&mut self, value: T) { if let Some(queue) = &mut self.queue { queue.push_front(value); if let Some(waker) = self.waker.take() { waker.wake(); } } } /// Poll the next element in the queue. /// /// If the deque is empty, the current `pop` will be suspended until a new element is pushed in. /// /// If the deque is closed, the `pop` task will get the final `None` element, /// indicating that the queue has been closed, /// and the `pop` task should stop. fn poll_pop(&mut self, cx: &mut Context<'_>) -> Poll> { match &mut self.queue { Some(queue) => { if let Some(frame) = queue.pop_front() { Poll::Ready(Some(frame)) } else if let Some(ref waker) = self.waker { if !waker.will_wake(cx.waker()) { panic!( "Multiple tasks are attempting to wait on the same AsyncDeque. This is a bug, place report it." ); } self.waker = Some(cx.waker().clone()); // same waker, no need to update again Poll::Pending } else { // no waker, register the current waker self.waker = Some(cx.waker().clone()); Poll::Pending } } None => Poll::Ready(None), } } /// Return the number of elements in the queue. fn len(&self) -> usize { self.queue.as_ref().map(|v| v.len()).unwrap_or(0) } /// Return whether the queue is empty. fn is_empty(&self) -> bool { self.len() == 0 } /// Close the deque, and wake up the `pop` task registered by [AsyncDeque::poll_pop] nescessary. /// /// This will cause the `pop`` task get the final `None` element, /// indicating that the queue has been closed, /// and the `pop`` task should stop. /// /// # Examples pub fn close(&mut self) { self.queue = None; if let Some(waker) = self.waker.take() { waker.wake(); } } } impl Extend for AsyncDeque { fn extend>(&mut self, iter: I) { if let Some(queue) = &mut self.queue { queue.extend(iter); if let Some(waker) = self.waker.take() { waker.wake(); } } } } /// A shared deque that can be used in async context. /// /// It is a wrapper around VecDeque, with the ability to be popped in async context. /// That is, when calling pop on an empty queue, /// it will suspend the current task until a new element is pushed in. /// In a sense, it is a combination of the sender and receiver ends of an mpsc channel, /// and the sender can insert in both directions. #[derive(Debug)] pub struct ArcAsyncDeque(Arc>>); impl ArcAsyncDeque { /// Create a new [`ArcAsyncDeque`] with 8 as the default capacity. pub fn new() -> Self { Self(Arc::new(Mutex::new(AsyncDeque { queue: Some(VecDeque::with_capacity(8)), waker: None, }))) } /// Create a new [`ArcAsyncDeque`] with a given capacity. pub fn with_capacity(capacity: usize) -> Self { Self(Arc::new(Mutex::new(AsyncDeque { queue: Some(VecDeque::with_capacity(capacity)), waker: None, }))) } fn lock_guard(&self) -> MutexGuard<'_, AsyncDeque> { self.0.lock().unwrap() } /// Insert an element at the front of the queue, /// and wake up the `pop` task if registered by [ArcAsyncDeque::pop]. /// /// # Examples /// /// ``` /// use qbase::util::ArcAsyncDeque; /// /// let mut deque = ArcAsyncDeque::new(); /// deque.push_front(1); /// deque.push_front(2); /// assert_eq!(deque.len(), 2); /// ``` pub fn push_front(&self, value: T) { self.lock_guard().push_front(value); } /// Insert an element at the back of the queue, /// and wake up the `pop` task if registered by [ArcAsyncDeque::pop]. /// /// # Examples /// /// ``` /// use qbase::util::ArcAsyncDeque; /// /// let mut deque = ArcAsyncDeque::new(); /// deque.push_back(1); /// deque.push_back(2); /// assert_eq!(deque.len(), 2); /// ``` pub fn push_back(&self, value: T) { self.lock_guard().push_back(value); } /// Asynchronously pop the next element in the queue. /// /// If the deque is empty, the current `pop` will be suspended until a new element is pushed in. /// /// If the deque is closed, the `pop` task will get the final `None` element, /// indicating that the queue has been closed, /// and the `pop` task should stop. /// /// # Examples /// /// ``` /// use qbase::util::ArcAsyncDeque; /// /// #[tokio::test] /// async fn test() { /// let mut deque = ArcAsyncDeque::new(); /// /// tokio::spawn({ /// let deque = deque.clone(); /// async move { /// assert_eq!(deque.pop().await, Some(1)); /// } /// }); /// /// deque.push_back(1); /// } /// ``` pub fn pop(&self) -> Self { self.clone() } /// Poll pop the next element in the queue. /// /// If the deque is empty, the current `pop` will be suspended until a new element is pushed in. /// /// If the deque is closed, the `pop` task will get the final `None` element, /// indicating that the queue has been closed, /// and the `pop` task should stop. /// /// # Examples /// /// ``` /// use qbase::util::ArcAsyncDeque; /// use futures::task::{Poll, noop_waker}; /// /// let waker = noop_waker(); /// let mut cx = std::task::Context::from_waker(&waker); /// let mut deque = ArcAsyncDeque::new(); /// assert_eq!(deque.poll_pop(&mut cx), Poll::Pending); /// /// deque.push_back(1); /// assert_eq!(deque.poll_pop(&mut cx), Poll::Ready(Some(1))); /// assert_eq!(deque.poll_pop(&mut cx), Poll::Pending); /// ``` pub fn poll_pop(&self, cx: &mut Context<'_>) -> Poll> { self.lock_guard().poll_pop(cx) } /// Return the number of elements in the queue. pub fn len(&self) -> usize { self.lock_guard().len() } /// Return whether the queue is empty. /// /// # Examples /// /// ``` /// use qbase::util::ArcAsyncDeque; /// /// let mut deque = ArcAsyncDeque::new(); /// assert!(deque.is_empty()); /// /// deque.push_back(1); /// assert!(!deque.is_empty()); /// ``` pub fn is_empty(&self) -> bool { self.lock_guard().is_empty() } /// Close the deque, and wake up the `pop` task if registered by [ArcAsyncDeque::poll_pop]. /// /// This will cause the `pop` task get the final `None` element, /// indicating that the queue has been closed, /// and the `pop` task should stop. /// /// # Examples /// /// ``` /// use qbase::util::ArcAsyncDeque; /// /// #[tokio::test] /// async fn test() { /// let mut deque = ArcAsyncDeque::new(); /// /// tokio::spawn({ /// let deque = deque.clone(); /// async move { /// assert_eq!(deque.pop().await, Some(1)); /// assert_eq!(deque.pop().await, None); /// } /// }); /// /// deque.push_back(1); /// deque.close(); /// } /// ``` pub fn close(&self) { self.lock_guard().close(); } } impl Default for ArcAsyncDeque { fn default() -> Self { Self::new() } } impl Clone for ArcAsyncDeque { fn clone(&self) -> Self { Self(self.0.clone()) } } impl Future for ArcAsyncDeque { type Output = Option; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.poll_pop(cx) } } impl futures::Stream for ArcAsyncDeque { type Item = T; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.poll_pop(cx) } } impl Extend for &ArcAsyncDeque { fn extend>(&mut self, iter: I) { self.0.lock().unwrap().extend(iter); } } #[cfg(test)] mod tests { use futures::FutureExt; use super::*; #[tokio::test] async fn push_pop() { let deque = ArcAsyncDeque::new(); assert!(deque.is_empty()); deque.push_back(1); deque.push_back(2); assert_eq!(deque.len(), 2); assert_eq!(deque.pop().await, Some(1)); assert_eq!(deque.pop().await, Some(2)); let deque = ArcAsyncDeque::with_capacity(2); deque.push_back(1); deque.push_front(2); assert_eq!(deque.len(), 2); assert_eq!(deque.pop().await, Some(2)); assert_eq!(deque.pop().await, Some(1)); } #[tokio::test] async fn close() { let deque = ArcAsyncDeque::new(); assert!(deque.is_empty()); deque.push_back(1); deque.push_back(2); assert_eq!(deque.len(), 2); deque.close(); assert!(deque.is_empty()); assert_eq!(deque.pop().await, None); } #[tokio::test] async fn wake() { let deque = ArcAsyncDeque::new(); tokio::select! { item = deque.pop() => { assert_eq!(item, Some(1)); } _ = async { deque.push_back(1); std::future::pending::<()>().await; } => unreachable!() } let deque = ArcAsyncDeque::new(); tokio::select! { item = deque.pop() => { assert_eq!(item, Some(1)); } _ = async { deque.push_back(1); std::future::pending::<()>().await; } => unreachable!() } } #[tokio::test] async fn cancel() { let deque = ArcAsyncDeque::new(); // register Waker let poll = core::future::poll_fn(|cx| Poll::Ready(deque.pop().poll_unpin(cx))).await; assert_eq!(poll, Poll::Pending); // pop directly (&deque).extend([654]); let poll = core::future::poll_fn(|cx| Poll::Ready(deque.pop().poll_unpin(cx))).await; assert_eq!(poll, Poll::Ready(Some(654))); // register new Waker let poll = core::future::poll_fn(|cx| Poll::Ready(deque.pop().poll_unpin(cx))).await; assert_eq!(poll, Poll::Pending); // replace cancelled Waker: same task, so its ok let poll = core::future::poll_fn(|cx| Poll::Ready(deque.pop().poll_unpin(cx))).await; assert_eq!(poll, Poll::Pending); } #[tokio::test] async fn racing() { let deque: ArcAsyncDeque<()> = ArcAsyncDeque::new(); let consumer = tokio::spawn(deque.pop()); tokio::task::yield_now().await; let abuse = tokio::spawn(deque.pop()); tokio::task::yield_now().await; // willnot be waked up _ = consumer; // should panic assert!(abuse.await.is_err()); } } ================================================ FILE: qbase/src/util/bound_queue.rs ================================================ use std::{ self, future::poll_fn, sync::{Arc, Mutex}, }; use futures::{SinkExt, StreamExt, channel::mpsc}; #[derive(Debug)] struct BoundQueueInner { tx: mpsc::Sender, rx: Mutex>, } #[derive(Debug)] pub struct BoundQueue(Arc>); impl Clone for BoundQueue { fn clone(&self) -> Self { Self(self.0.clone()) } } impl BoundQueue { #[inline] pub fn new(size: usize) -> Self { let (tx, rx) = mpsc::channel(size); Self(Arc::new(BoundQueueInner { tx, rx: rx.into() })) } #[inline] pub fn try_send(&self, item: T) -> Result<(), mpsc::TrySendError> { self.0.tx.clone().try_send(item) } #[inline] pub async fn send(&self, item: T) -> Result<(), mpsc::SendError> { self.0.tx.clone().send(item).await } #[inline] pub async fn recv(&self) -> Option { poll_fn(|cx| self.0.rx.lock().unwrap().poll_next_unpin(cx)).await } #[inline] pub fn close(&self) { self.0.tx.clone().close_channel(); } #[inline] pub fn is_closed(&self) -> bool { self.0.tx.is_closed() } #[inline] pub fn same_queue(&self, other: &Self) -> bool { Arc::ptr_eq(&self.0, &other.0) } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_send_receive() { let queue = Arc::new(BoundQueue::new(2)); tokio::spawn({ let queue = queue.clone(); async move { assert!(queue.send(1).await.is_ok()); assert!(queue.send(2).await.is_ok()); } }); assert_eq!(queue.recv().await, Some(1)); assert_eq!(queue.recv().await, Some(2)); } } ================================================ FILE: qbase/src/util/data.rs ================================================ use bytes::{BufMut, Bytes, BytesMut}; pub trait ContinuousData { fn len(&self) -> usize; fn is_empty(&self) -> bool; fn to_bytes(&self) -> Bytes; } pub type DataPair<'a> = (&'a [u8], &'a [u8]); impl ContinuousData for DataPair<'_> { #[inline] fn len(&self) -> usize { self.0.len() + self.1.len() } #[inline] fn is_empty(&self) -> bool { self.0.is_empty() && self.1.is_empty() } #[inline] fn to_bytes(&self) -> Bytes { Bytes::from([self.0, self.1].concat()) } } impl ContinuousData for [u8] { #[inline] fn len(&self) -> usize { <[u8]>::len(self) } #[inline] fn is_empty(&self) -> bool { <[u8]>::is_empty(self) } #[inline] fn to_bytes(&self) -> Bytes { Bytes::copy_from_slice(self) } } impl ContinuousData for [u8; N] { #[inline] fn len(&self) -> usize { N } #[inline] fn is_empty(&self) -> bool { N == 0 } #[inline] fn to_bytes(&self) -> Bytes { Bytes::copy_from_slice(self) } } impl ContinuousData for Vec { #[inline] fn len(&self) -> usize { self.len() } #[inline] fn is_empty(&self) -> bool { self.is_empty() } #[inline] fn to_bytes(&self) -> Bytes { Bytes::copy_from_slice(self) } } impl ContinuousData for Bytes { #[inline] fn len(&self) -> usize { self.len() } #[inline] fn is_empty(&self) -> bool { self.is_empty() } #[inline] fn to_bytes(&self) -> Bytes { self.clone() } } pub type NonData = (); impl ContinuousData for NonData { #[inline] fn len(&self) -> usize { 0 } #[inline] fn is_empty(&self) -> bool { true } #[inline] fn to_bytes(&self) -> Bytes { Bytes::new() } } impl ContinuousData for &D { #[inline] fn len(&self) -> usize { D::len(*self) } #[inline] fn is_empty(&self) -> bool { D::is_empty(*self) } #[inline] fn to_bytes(&self) -> Bytes { D::to_bytes(*self) } } impl ContinuousData for [D] { #[inline] fn len(&self) -> usize { self.iter().map(|d| d.len()).sum() } #[inline] fn is_empty(&self) -> bool { self.iter().all(|d| d.is_empty()) } #[inline] fn to_bytes(&self) -> Bytes { self.iter() .fold(BytesMut::with_capacity(self.len()), |mut acc, d| { acc.extend(d.to_bytes()); acc }) .freeze() } } impl ContinuousData for [D; N] { #[inline] fn len(&self) -> usize { <[D]>::len(self) } #[inline] fn is_empty(&self) -> bool { <[D]>::is_empty(self) } #[inline] fn to_bytes(&self) -> Bytes { <[D]>::to_bytes(self) } } pub trait WriteData: BufMut { fn put_data(&mut self, data: &D); } impl WriteData> for T { #[inline] fn put_data(&mut self, data: &DataPair<'_>) { self.put_slice(data.0); self.put_slice(data.1); } } impl WriteData<[u8]> for T { #[inline] fn put_data(&mut self, data: &[u8]) { self.put_slice(data) } } impl WriteData<[u8; N]> for T { #[inline] fn put_data(&mut self, data: &[u8; N]) { self.put_slice(data) } } impl WriteData for T { #[inline] fn put_data(&mut self, data: &Bytes) { self.put_slice(data); } } impl WriteData for T { #[inline] fn put_data(&mut self, &(): &()) {} } impl WriteData<&D> for T where T: BufMut + WriteData, { #[inline] fn put_data(&mut self, data: &&D) { >::put_data(self, data); } } impl WriteData<[D]> for T where T: BufMut + WriteData, { #[inline] fn put_data(&mut self, data: &[D]) { for data in data { self.put_data(data); } } } impl WriteData<[D; N]> for T where T: BufMut + WriteData, { #[inline] fn put_data(&mut self, data: &[D; N]) { >::put_data(self, data); } } ================================================ FILE: qbase/src/util/index_deque.rs ================================================ use std::{ collections::VecDeque, ops::{Index, IndexMut}, }; use thiserror::Error; /// The index error type for [`IndexDeque`]. #[derive(Debug, Clone, Copy, PartialEq, Eq, Error)] pub enum IndexError { #[error("The index {0} exceed the limit {1}")] ExceedLimit(u64, u64), #[error("The index {0} is less than the offset {1}")] TooSmall(u64, u64), } /// A first-in-first-out queue indexed by the enqueue sequence number. /// /// For [`VecDeque`], the index of elements starts from 0 even after they are dequeued. /// However, for [`IndexDeque`], the index is the enqueue sequence number. /// Even if some elements have been dequeued, /// the enqueue index of other elements in IndexDeque remains unchanged. /// /// - `T` is the type of elements in the queue. /// - `LIMIT` is the maximum limit of the enqueue index. /// /// [`IndexDeque`] is useful in many places in QUIC implementation, /// such as recording packet sending history. #[derive(Debug)] pub struct IndexDeque { deque: VecDeque, offset: u64, } impl Default for IndexDeque { fn default() -> Self { Self { deque: VecDeque::default(), offset: 0, } } } impl IndexDeque { /// Create a new empty IndexDeque with the specified capacity. /// /// # Examples /// /// ``` /// use qbase::util::IndexDeque; /// /// let deque: IndexDeque = IndexDeque::with_capacity(10); /// ``` pub fn with_capacity(capacity: usize) -> Self { Self { deque: VecDeque::with_capacity(capacity), offset: 0, } } /// Returns true if the queue is empty. /// /// # Examples /// /// ``` /// use qbase::util::IndexDeque; /// /// let mut deque: IndexDeque = IndexDeque::default(); /// assert!(deque.is_empty()); /// deque.push_back(1).unwrap(); /// assert!(!deque.is_empty()); /// ``` pub fn is_empty(&self) -> bool { self.deque.is_empty() } /// Returns the number of elements in the queue. /// /// # Examples /// /// ``` /// use qbase::util::IndexDeque; /// /// let mut deque: IndexDeque = IndexDeque::default(); /// assert_eq!(deque.len(), 0); /// deque.push_back(1).unwrap(); /// assert_eq!(deque.len(), 1); /// ``` pub fn len(&self) -> usize { self.deque.len() } /// Returns the enqueue sequence number of the first element in the queue. /// /// # Examples /// /// ``` /// use qbase::util::IndexDeque; /// /// let mut deque: IndexDeque = IndexDeque::default(); /// assert_eq!(deque.offset(), 0); /// deque.push_back(1).unwrap(); /// assert_eq!(deque.offset(), 0); /// deque.pop_front(); /// assert_eq!(deque.offset(), 1); /// ``` pub fn offset(&self) -> u64 { self.offset } /// Returns the next enqueue sequence number of the queue. /// /// # Examples /// /// ``` /// use qbase::util::IndexDeque; /// /// let mut deque: IndexDeque = IndexDeque::default(); /// assert_eq!(deque.largest(), 0); /// deque.push_back(1).unwrap(); /// assert_eq!(deque.largest(), 1); /// ``` pub fn largest(&self) -> u64 { self.offset + self.deque.len() as u64 } /// Returns true if the queue contains the specified enqueue index. /// /// # Examples /// /// ``` /// use qbase::util::IndexDeque; /// /// let mut deque: IndexDeque = IndexDeque::default(); /// assert!(!deque.contain(0)); /// deque.push_back(1).unwrap(); /// assert!(deque.contain(0)); /// assert!(!deque.contain(1)); /// ``` pub fn contain(&self, idx: u64) -> bool { idx >= self.offset && idx < self.largest() } /// Provides a reference to an element at the specified enqueue index. /// /// # Examples /// /// ``` /// use qbase::util::IndexDeque; /// /// let mut deque: IndexDeque = IndexDeque::default(); /// deque.push_back(1).unwrap(); /// deque.push_back(2).unwrap(); /// deque.push_back(3).unwrap(); /// assert_eq!(deque.get(1), Some(&2)); /// assert_eq!(deque.get(3), None); /// ``` pub fn get(&self, idx: u64) -> Option<&T> { if self.contain(idx) { Some(&self.deque[(idx - self.offset) as usize]) } else { None } } /// Provides a mutable reference to an element at the specified enqueue index. /// /// # Examples /// /// ``` /// use qbase::util::IndexDeque; /// /// let mut deque: IndexDeque = IndexDeque::default(); /// deque.push_back(1).unwrap(); /// deque.push_back(2).unwrap(); /// deque.push_back(3).unwrap(); /// assert_eq!(deque[1], 2); /// if let Some(v) = deque.get_mut(1) { /// *v = 4; /// } /// assert_eq!(deque[1], 4); /// ``` pub fn get_mut(&mut self, idx: u64) -> Option<&mut T> { if self.contain(idx) { Some(&mut self.deque[(idx - self.offset) as usize]) } else { None } } /// Append an element to the end of the queue and return the enqueue index of the element. /// If it exceeds the maximum limit of the enqueue index, return [`IndexError`]. /// /// # Examples /// /// ``` /// use qbase::util::{IndexDeque, IndexError}; /// /// let mut deque: IndexDeque = IndexDeque::default(); /// assert_eq!(deque.push_back(1), Ok(0)); /// assert_eq!(deque.push_back(2), Ok(1)); /// assert_eq!(deque.push_back(3), Ok(2)); /// assert_eq!(deque.push_back(4), Err(IndexError::ExceedLimit(3, 2))); /// ``` pub fn push_back(&mut self, value: T) -> Result { let next_idx = self.offset.overflowing_add(self.deque.len() as u64); if next_idx.1 || next_idx.0 > LIMIT { Err(IndexError::ExceedLimit(next_idx.0, LIMIT)) } else { self.deque.push_back(value); Ok(self.deque.len() as u64 - 1 + self.offset) } } /// Returns None if the queue is empty; otherwise, returns /// the first element in the queue along with its enqueue index. /// /// # Examples /// /// ``` /// use qbase::util::IndexDeque; /// /// let mut deque: IndexDeque = IndexDeque::default(); /// assert_eq!(deque.pop_front(), None); /// /// deque.push_back(1).unwrap(); /// assert_eq!(deque.pop_front(), Some((0, 1))); /// assert!(deque.is_empty()); /// ``` pub fn pop_front(&mut self) -> Option<(u64, T)> { self.deque.pop_front().map(|v| { let offset = self.offset; self.offset += 1; (offset, v) }) } pub fn front(&self) -> Option<(u64, &T)> { self.deque.front().map(|v| (self.offset, v)) } pub fn back(&self) -> Option<(u64, &T)> { self.deque.back().map(|v| (self.largest() - 1, v)) } /// Returns a front-to-back iterator. /// /// # Examples /// /// ``` /// use qbase::util::IndexDeque; /// /// let mut deque: IndexDeque = IndexDeque::default(); /// deque.push_back(1).unwrap(); /// deque.push_back(2).unwrap(); /// deque.push_back(3).unwrap(); /// let b: &[_] = &[&1, &2, &3]; /// let c: Vec<&u64> = deque.iter().collect(); /// assert_eq!(b, c.as_slice()); /// ``` pub fn iter(&self) -> impl DoubleEndedIterator { self.deque.iter() } /// Returns a front-to-back iterator that returns mutable references. /// /// # Examples /// /// ``` /// use qbase::util::IndexDeque; /// /// let mut deque: IndexDeque = IndexDeque::default(); /// deque.push_back(1).unwrap(); /// deque.push_back(2).unwrap(); /// deque.push_back(3).unwrap(); /// for num in deque.iter_mut() { /// *num += 1; /// } /// let b: &[_] = &[&mut 2, &mut 3, &mut 4]; /// assert_eq!(deque.iter_mut().collect::>().as_slice(), b); /// ``` pub fn iter_mut(&mut self) -> impl DoubleEndedIterator { self.deque.iter_mut() } /// Returns a front-to-back iterator that returns the enqueue index along with the references. /// /// # Examples /// /// ``` /// use qbase::util::IndexDeque; /// /// let mut deque: IndexDeque = IndexDeque::default(); /// deque.push_back(1).unwrap(); /// deque.push_back(2).unwrap(); /// deque.push_back(3).unwrap(); /// for (idx, num) in deque.enumerate() { /// assert_eq!(idx + 1, *num); /// } /// ``` pub fn enumerate(&self) -> impl DoubleEndedIterator { self.deque .iter() .enumerate() .map(|(idx, item)| (self.offset + idx as u64, item)) } /// Returns a front-to-back iterator that returns /// the enqueue index along with the mutable references. /// /// # Examples /// /// ``` /// use qbase::util::IndexDeque; /// /// let mut deque: IndexDeque = IndexDeque::default(); /// deque.push_back(1).unwrap(); /// deque.push_back(2).unwrap(); /// deque.push_back(3).unwrap(); /// for (idx, num) in deque.enumerate_mut() { /// *num = *num + idx; /// } /// let b: &[_] = &[(0, &mut 1), (1, &mut 3), (2, &mut 5)]; /// assert_eq!(deque.enumerate_mut().collect::>().as_slice(), b); /// ``` pub fn enumerate_mut(&mut self) -> impl DoubleEndedIterator { self.deque .iter_mut() .enumerate() .map(|(idx, item)| (self.offset + idx as u64, item)) } /// Shortens the queue, dropping the first `n` elements. /// /// If `n` is greater or equal to the queue's length, this method will clear the queue. /// /// # Examples /// /// ``` /// use qbase::util::IndexDeque; /// /// let mut deque: IndexDeque = IndexDeque::default(); /// deque.push_back(1).unwrap(); /// deque.push_back(2).unwrap(); /// deque.push_back(3).unwrap(); /// deque.advance(2); /// assert_eq!(deque.len(), 1); /// assert_eq!(deque.offset(), 2); /// assert_eq!(deque[2], 3); /// ``` pub fn advance(&mut self, n: usize) { self.offset += n as u64; let _ = self.deque.drain(..n); } /// Removes the elements from the queue until the enqueue index is equal to `end`. /// Returns a front-to-back iterator over the removed elements. /// /// # Examples /// /// ``` /// use qbase::util::IndexDeque; /// /// let mut deque: IndexDeque = IndexDeque::default(); /// deque.push_back(1).unwrap(); /// deque.push_back(2).unwrap(); /// deque.push_back(3).unwrap(); /// let b: &[_] = &[1, 2]; /// assert_eq!(deque.drain_to(2).collect::>().as_slice(), b); /// assert_eq!(deque.offset(), 2); /// ``` pub fn drain_to(&mut self, end: u64) -> impl DoubleEndedIterator + '_ { #[cfg(not(test))] debug_assert!(end >= self.offset && end <= self.offset + self.deque.len() as u64); // avoid end < self.offset let end = std::cmp::max(end, self.offset); let offset = self.offset; // avoid end > self.offset + self.deque.len() self.offset = std::cmp::min(end, offset + self.deque.len() as u64); let end = (self.offset - offset) as usize; self.deque.drain(..end) } /// Force to reset the first enqueue index of the queue to `new_offset`. /// Then, it will affect the enqueue sequence numbers of all subsequent elements. /// /// Be careful to use this method, you must know what you are doing. /// /// # Examples /// /// ``` /// use qbase::util::IndexDeque; /// /// let mut deque: IndexDeque = IndexDeque::default(); /// deque.reset_offset(5); /// assert_eq!(deque.largest(), 5); /// deque.push_back(1).unwrap(); /// assert_eq!(deque[5], 1); /// ``` pub fn reset_offset(&mut self, new_offset: u64) { // assert!(self.is_empty() && new_offset >= self.offset); self.offset = new_offset; } } impl Extend for IndexDeque { fn extend>(&mut self, iter: I) { self.deque.extend(iter) } } impl IndexDeque { /// Inserts an element at the specified enqueue index `idx`, /// returns the origin element at the index if it exists. /// /// It will insert the default value in the gap /// between the current largest index and the `idx`, /// if the `idx` is greater than the current largest index. /// /// Returns [`IndexError`] if the enqueue index is less than the offset or exceeds the maximum limit. /// /// # Examples /// /// ``` /// use qbase::util::{IndexDeque, IndexError}; /// /// let mut deque: IndexDeque = IndexDeque::default(); /// let old_value = deque.insert(1, 2).unwrap(); /// assert_eq!(old_value, None); /// assert_eq!(deque[0], u64::default()); /// assert_eq!(deque[1], 2); /// /// let result = deque.insert(4, 5); /// assert_eq!(result, Err(IndexError::ExceedLimit(4, 3))); /// ``` pub fn insert(&mut self, idx: u64, value: T) -> Result, IndexError> { if idx > LIMIT { Err(IndexError::ExceedLimit(idx, LIMIT)) } else if idx < self.offset { Err(IndexError::TooSmall(idx, self.offset)) } else { let pos = (idx - self.offset) as usize; if pos < self.deque.len() { return Ok(Some(std::mem::replace(&mut self.deque[pos], value))); } if pos > self.deque.len() { self.deque.resize(pos, T::default()); } self.deque.push_back(value); Ok(None) } } /// Modifies the deque in-place so that offset() is equal to new_offset, either by /// removing excess elements from the back or by appending clones of value to the back. pub fn resize(&mut self, new_end: u64, value: T) -> Result<(), IndexError> { if new_end < self.offset { Err(IndexError::TooSmall(new_end, self.offset)) } else if new_end > LIMIT { Err(IndexError::ExceedLimit(new_end, LIMIT)) } else { let len = new_end.saturating_sub(self.offset); self.deque.resize(len as usize, value.clone()); Ok(()) } } } impl Index for IndexDeque { type Output = T; fn index(&self, index: u64) -> &Self::Output { &self.deque[(index - self.offset) as usize] } } impl IndexMut for IndexDeque { fn index_mut(&mut self, index: u64) -> &mut Self::Output { &mut self.deque[(index - self.offset) as usize] } } #[cfg(test)] mod tests { use super::*; #[test] fn test_index_queue() { let mut deque = IndexDeque::::default(); for i in 0..10 { assert_eq!(deque.push_back(i + 1), Ok(i)); } assert_eq!(deque.offset, 0); for i in 0..10 { assert_eq!(deque.pop_front(), Some((i, i + 1))); assert_eq!(deque.offset, i + 1); } assert_eq!(deque.pop_front(), None); assert_eq!(deque.offset, 10); for i in 10..20 { assert_eq!(deque.push_back(i + 1), Ok(i)); } assert_eq!(deque.push_back(21), Err(IndexError::ExceedLimit(20, 19))); assert_eq!(deque.offset, 10); assert!(!deque.contain(0)); assert!(!deque.contain(9)); assert!(deque.contain(10)); assert!(deque.contain(19)); assert!(!deque.contain(21)); assert_eq!(deque[10], 11); assert_eq!(deque[19], 20); assert_eq!(deque.drain_to(10).count(), 0); let mut i = 10; for item in deque.drain_to(15) { i += 1; assert_eq!(item, i); } assert_eq!(i, 15); assert!(deque.contain(15)); assert_eq!(deque.offset, 15); assert_eq!(deque.drain_to(30).count(), 5); assert_eq!(deque.offset, 20); assert!(deque.is_empty()); } #[test] fn test_insert() { let mut deque = IndexDeque::::default(); deque.insert(10, 11).unwrap(); assert_eq!(deque.offset, 0); assert_eq!(deque.len(), 11); for i in 0..10 { assert_eq!(deque[i], u64::default()); } assert_eq!(deque[10], 11); } #[test] fn test_skip() { let mut deque = IndexDeque::::default(); for i in 0..10 { assert_eq!(deque.push_back(i), Ok(i)); } assert_eq!(deque.offset, 0); deque.advance(5); assert_eq!(deque.offset, 5); deque.enumerate().for_each(|(idx, item)| { assert_eq!(idx, *item); }); } #[test] fn test_reset_offset() { let mut deque = IndexDeque::::default(); deque.reset_offset(5); assert_eq!(deque.offset, 5); for i in 0..10 { assert_eq!(deque.push_back(i), Ok(i + 5)); } for i in 0..10 { assert_eq!(deque.pop_front(), Some((i + 5, i))); } } #[test] fn test_reset_offset_with_content() { let mut deque = IndexDeque::::default(); for i in 0..5 { assert_eq!(deque.push_back(i), Ok(i)); } deque.reset_offset(10); deque.enumerate().for_each(|(idx, item)| { assert_eq!(idx, *item + 10); }); } #[test] fn test_reset_offset_panic2() { let mut deque = IndexDeque::::default(); for i in 0..10 { assert_eq!(deque.push_back(i), Ok(i)); } for i in 0..5 { assert_eq!(deque.pop_front(), Some((i, i))); } assert_eq!(deque.offset, 5); deque.reset_offset(3); deque.enumerate().for_each(|(idx, item)| { assert_eq!(idx + 2, *item); }); } #[test] fn test_resize() { let mut deque = IndexDeque::::default(); for i in 0..10 { assert_eq!(deque.push_back(i), Ok(i)); } assert_eq!(deque.offset, 0); deque.resize(15, 10).unwrap(); assert_eq!(deque.offset, 0); assert_eq!(deque.len(), 15); for i in 10..15 { assert_eq!(deque[i], 10); } deque.resize(5, 10).unwrap(); assert_eq!(deque.offset, 0); assert_eq!(deque.len(), 5); for i in 0..5 { assert_eq!(deque[i], i); } assert_eq!(deque.resize(20, 10), Err(IndexError::ExceedLimit(20, 19))); for i in 0..5 { assert_eq!(deque.pop_front(), Some((i, i))); } assert_eq!(deque.resize(0, 10), Err(IndexError::TooSmall(0, 5))); } } ================================================ FILE: qbase/src/util/unique_id.rs ================================================ use std::{ hash::Hash, sync::atomic::{AtomicUsize, Ordering}, }; use derive_more::Into; /// Opque, hashable, unique ID type. #[derive(Debug, Clone, Copy, Into, PartialEq, Eq, Hash)] pub struct UniqueId(usize); /// Thread safe, lock free unique ID generator. #[derive(Debug)] pub struct UniqueIdGenerator(AtomicUsize); impl Default for UniqueIdGenerator { fn default() -> Self { Self::new() } } impl UniqueIdGenerator { /// Create a new `UniqueIdGenerator` /// /// # Example /// /// ``` /// use qbase::util::UniqueIdGenerator; /// /// let generator = UniqueIdGenerator::new(); /// let id1 = generator.generate(); /// let id2 = generator.generate(); /// assert_ne!(id1, id2); /// ``` pub const fn new() -> Self { UniqueIdGenerator(AtomicUsize::new(1)) } /// Generated a new `UniqueId` starting from a specific value /// /// # Example /// /// ``` /// use qbase::util::UniqueIdGenerator; /// /// let generator = UniqueIdGenerator::new(); /// let id1 = generator.generate(); /// let id2 = generator.generate(); /// assert_ne!(id1, id2); /// ``` pub fn generate(&self) -> UniqueId { let id = self.0.fetch_add(1, Ordering::Relaxed); assert_ne!(id, 0, "UniqueId overflow"); UniqueId(id) } } #[cfg(test)] mod tests { use std::{collections::HashSet, sync::Arc, thread}; use super::*; #[test] fn test_unique_id_basic() { let generator = UniqueIdGenerator::new(); let id1 = generator.generate(); let id2 = generator.generate(); assert_ne!(id1, id2); assert_eq!(id1.0, 1); assert_eq!(id2.0, 2); } #[test] fn test_unique_id_hash() { let generator = UniqueIdGenerator::new(); let id1 = generator.generate(); let id2 = generator.generate(); let mut set = HashSet::new(); set.insert(id1); set.insert(id2); assert_eq!(set.len(), 2); } #[test] fn test_unique_id_clone_copy() { let generator = UniqueIdGenerator::new(); let id1 = generator.generate(); let id2 = id1; // Copy assert_eq!(id1, id2); } #[test] fn test_thread_safety() { let generator = Arc::new(UniqueIdGenerator::new()); let mut handles = vec![]; // 启动多个线程同时生成ID for _ in 0..10 { let generator = Arc::clone(&generator); let handle = thread::spawn(move || { let mut ids = Vec::new(); for _ in 0..100 { ids.push(generator.generate()); } ids }); handles.push(handle); } // 收集所有生成的ID let mut all_ids = HashSet::new(); for handle in handles { let ids = handle.join().unwrap(); for id in ids { assert!(all_ids.insert(id), "Duplicate ID found: {id:?}"); } } // 应该有1000个唯一的ID assert_eq!(all_ids.len(), 1000); } #[test] fn test_default_generator() { let gen1 = UniqueIdGenerator::new(); let gen2 = UniqueIdGenerator::new(); assert_eq!(gen1.generate(), gen2.generate()) } } ================================================ FILE: qbase/src/util/wakers.rs ================================================ use std::{ mem, sync::{Arc, Mutex, MutexGuard}, task::{Context, Poll, Wake, Waker}, usize, }; use smallvec::SmallVec; #[derive(Debug, Clone)] pub struct WakerVec { wakers: SmallVec<[Waker; N]>, } impl Default for WakerVec { fn default() -> Self { Self::new() } } impl WakerVec { pub const fn new() -> Self { Self { wakers: SmallVec::new_const(), } } pub fn register(&mut self, waker: &Waker) { if !self.wakers.iter().any(|w| w.will_wake(waker)) { self.wakers.push(waker.clone()); } } pub fn wake_all(&mut self) { for waker in self.wakers.drain(..) { waker.wake(); } } } impl Drop for WakerVec { fn drop(&mut self) { self.wake_all(); } } #[derive(Debug)] pub struct Wakers { wakers: Mutex>, } impl Wake for Wakers { fn wake(self: Arc) { self.wake_all(); } fn wake_by_ref(self: &Arc) { self.wake_all(); } } impl Default for Wakers { fn default() -> Self { Self::new() } } impl Wakers { pub const fn new() -> Self { Self { wakers: Mutex::new(WakerVec::new()), } } fn lock(&self) -> MutexGuard<'_, WakerVec> { self.wakers.lock().expect("Wakers mutex poisoned") } pub fn register(&self, waker: &Waker) { self.lock().register(waker) } pub fn wake_all(&self) { { mem::replace(&mut *self.lock(), WakerVec::new()) }.wake_all() } pub fn to_waker(self: &Arc) -> Waker { Waker::from(self.clone()) } pub fn combine_with( self: &Arc, cx: &mut Context<'_>, poll: impl FnOnce(&mut Context<'_>) -> Poll, ) -> Poll { self.register(cx.waker()); poll(&mut Context::from_waker(&self.to_waker())) } } ================================================ FILE: qbase/src/util.rs ================================================ mod async_deque; pub use async_deque::ArcAsyncDeque; mod bound_queue; pub use bound_queue::BoundQueue; mod data; pub use data::{ContinuousData, DataPair, NonData, WriteData}; mod index_deque; pub use index_deque::{IndexDeque, IndexError}; mod unique_id; pub use unique_id::{UniqueId, UniqueIdGenerator}; mod wakers; pub use wakers::{WakerVec, Wakers}; ================================================ FILE: qbase/src/varint.rs ================================================ use std::{cmp::Ordering, convert::TryFrom, fmt}; /// An integer less than 2^62 /// /// Values of this type are suitable for encoding as QUIC variable-length integer. /// It would be neat if we could express to Rust that the top two bits are available for use as enum /// discriminants /// /// See [variable-length integers](https://www.rfc-editor.org/rfc/rfc9000.html#name-variable-length-integer-enc) /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html) for more details. #[derive(Default, Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct VarInt(u64); /// The maximum value that can be represented by a QUIC variable-length integer. pub const VARINT_MAX: u64 = 0x3fff_ffff_ffff_ffff; /// The number of bytes that a QUIC variable-length integer can be encoded in. /// /// [`VarInt`] doesn't need to be encoded on the minimum number of bytes necessary, /// with the sole exception of the Frame Type field. pub enum EncodeBytes { One = 1, Two = 2, Four = 4, Eight = 8, } impl VarInt { /// The largest representable value pub const MAX: Self = Self(VARINT_MAX); /// The largest encoded value length pub const MAX_SIZE: usize = 8; /// Construct a `VarInt` from a [`u32`]. pub const fn from_u32(x: u32) -> Self { Self(x as u64) } /// Construct a `VarInt` from a [`u64`]. /// Succeeds if `x` < 2^62. pub const fn from_u64(x: u64) -> Result { if x < (1 << 62) { Ok(Self(x)) } else { Err(err::Overflow(x as _)) } } /// Create a VarInt from a [`u64`] without ensuring it's in range /// /// # Safety /// /// `x` must be less than 2^62. pub unsafe fn from_u64_unchecked(x: u64) -> Self { Self(x) } /// Construct a `VarInt` from a [`u128`]. /// Succeeds if `x` < 2^62. pub fn from_u128(x: u128) -> Result { if x < (1 << 62) { Ok(Self(x as _)) } else { Err(err::Overflow(x)) } } /// Extract the integer value pub fn into_u64(self) -> u64 { self.0 } /// Compute the number of bytes needed to encode this value pub fn encoding_size(self) -> usize { let x = self.0; if x < (1 << 6) { 1 } else if x < (1 << 14) { 2 } else if x < (1 << 30) { 4 } else if x < (1 << 62) { 8 } else { unreachable!("malformed VarInt"); } } } impl From for u64 { fn from(x: VarInt) -> Self { x.0 } } impl From for VarInt { fn from(x: u8) -> Self { Self(x.into()) } } impl From for VarInt { fn from(x: u16) -> Self { Self(x.into()) } } impl From for VarInt { fn from(x: u32) -> Self { Self(x.into()) } } impl TryFrom for VarInt { type Error = err::Overflow; fn try_from(x: u128) -> Result { Self::from_u128(x) } } impl TryFrom for VarInt { type Error = err::Overflow; /// Succeeds if `x` < 2^62 fn try_from(x: u64) -> Result { Self::from_u64(x) } } impl TryFrom for VarInt { type Error = err::Overflow; /// Succeeds if `x` < 2^62 fn try_from(x: usize) -> Result { Self::try_from(x as u64) } } impl nom::ToUsize for VarInt { fn to_usize(&self) -> usize { self.0 as usize } } impl PartialEq for VarInt { fn eq(&self, other: &u64) -> bool { self.0.eq(other) } } impl PartialOrd for VarInt { fn partial_cmp(&self, other: &u64) -> Option { self.0.partial_cmp(other) } } impl fmt::Display for VarInt { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) } } /// Error module for VarInt pub mod err { use std::fmt::Debug; use thiserror::Error; /// Overflow error indicating that a value exceeds 2^62 #[derive(Debug, Copy, Clone, Eq, PartialEq, Error)] #[error("Value({0}) too large for varint encoding")] pub struct Overflow(pub(super) u128); } use bytes::BufMut; use nom::{IResult, Parser, bits::streaming::take, combinator::flat_map, error::Error}; /// Parse a variable-length integer from the input buffer, /// [nom](https://docs.rs/nom/latest/nom/) parser style. /// /// ## Example /// ``` /// use qbase::varint::be_varint; /// /// let input = &[0b01000000, 0x01][..]; /// let result = be_varint(input); /// assert_eq!(result, Ok((&[][..], 1u32.into()))); /// ``` pub fn be_varint(input: &[u8]) -> IResult<&[u8], VarInt> { flat_map(take(2usize), |prefix: u8| { take::<&[u8], u64, usize, Error<(&[u8], usize)>>((8 << prefix) - 2) }) .parse((input, 0)) .map_err(|err| match err { nom::Err::Incomplete(needed) => { nom::Err::Incomplete(needed.map(|n| n.get().div_ceil(8) - input.len())) } _ => unreachable!(), }) .map(|((buf, _), value)| (buf, VarInt(value))) } /// A [`bytes::BufMut`] extension trait, makes buffer more friendly to write VarInt. pub trait WriteVarInt: BufMut { /// Write a variable-length integer. /// /// `put_varint` will write the smallest number of bytes needed to represent the value. /// `encode_varint` will write the specified number of bytes, and panic if the specified number of bytes /// is less than the smallest number of bytes needed to repressent the value. /// /// # Example /// ```rust /// use bytes::BufMut; /// use qbase::varint::{EncodeBytes, VarInt, WriteVarInt}; /// /// let val = VarInt::from_u32(1); /// let mut encode_buf = [0u8; 8]; /// /// let mut buf = &mut encode_buf[..]; /// buf.put_varint(&val); /// assert_eq!(buf.len(), 7); /// assert_eq!(encode_buf[0..1], [0x01]); /// /// let mut buf = &mut encode_buf[..]; /// buf.encode_varint(&val, EncodeBytes::Two); /// assert_eq!(buf.len(), 6); /// assert_eq!(encode_buf[0..2], [0x40, 0x01]); /// ``` fn put_varint(&mut self, value: &VarInt); /// Write a variable-length integer with specified number of bytes. fn encode_varint(&mut self, value: &VarInt, nbytes: EncodeBytes); } // 所有的BufMut都可以调用put_varint来写入VarInt了 impl WriteVarInt for T { fn put_varint(&mut self, value: &VarInt) { let x = value.0; if x < 1u64 << 6 { self.put_u8(x as u8); } else if x < 1u64 << 14 { self.put_u16((0b01 << 14) | x as u16); } else if x < 1u64 << 30 { self.put_u32((0b10 << 30) | x as u32); } else if x < 1u64 << 62 { self.put_u64((0b11 << 62) | x); } else { unreachable!("malformed VarInt") } } fn encode_varint(&mut self, value: &VarInt, nbytes: EncodeBytes) { match nbytes { EncodeBytes::One => { assert!(value.0 < 1u64 << 6); self.put_u8(value.0 as u8); } EncodeBytes::Two => { assert!(value.0 < 1u64 << 14); self.put_u16((0b01 << 14) | value.0 as u16); } EncodeBytes::Four => { assert!(value.0 < 1u64 << 30); self.put_u32((0b10 << 30) | value.0 as u32); } EncodeBytes::Eight => { assert!(value.0 < 1u64 << 62); self.put_u64((0b11 << 62) | value.0); } } } } #[cfg(test)] mod tests { use super::{EncodeBytes, VarInt, WriteVarInt}; #[test] fn test_equal() { let val = VarInt(0); assert_eq!(val, 0); assert!(val == 0); assert!(val != 1) } #[test] fn test_be_varint() { { let buf = &[0b00000001u8, 0x01][..]; let r = super::be_varint(buf); assert_eq!(r, Ok((&[0x01][..], VarInt(1)))); } { let buf = &[0b01000000u8, 0x06u8][..]; let r = super::be_varint(buf); assert_eq!(r, Ok((&[][..], VarInt(6)))); } { let buf = &[0b10000000u8, 1, 1, 1][..]; let r = super::be_varint(buf); assert_eq!(r, Ok((&[][..], VarInt(0x010101)))); } { let buf = &[0b11000000u8, 1, 1, 1, 1, 1, 1, 1][..]; let r = super::be_varint(buf); assert_eq!(r, Ok((&[][..], VarInt(0x01010101010101)))); } { let buf = &[0b11000000u8, 0x06u8][..]; let r = super::be_varint(buf); assert_eq!(r, Err(nom::Err::Incomplete(nom::Needed::new(6)))); } } fn assert_put_varint_eq(val: u64, expected: &[u8]) { let val = VarInt::from_u64(val).unwrap(); let mut buf = vec![]; buf.put_varint(&val); assert_eq!(buf, expected); } #[test] fn test_put_varint() { assert_put_varint_eq(0x0000_0000_0000_0000, &[0]); assert_put_varint_eq(0x0000_0000_0000_003F, &[0x3F]); assert_put_varint_eq(0x0000_0000_0000_0040, &[0x40, 0x40]); assert_put_varint_eq(0x0000_0000_0000_3FFF, &[0x7F, 0xFF]); assert_put_varint_eq(0x0000_0000_0000_4000, &[0x80, 0x00, 0x40, 0x00]); assert_put_varint_eq(0x0000_0000_3FFF_FFFF, &[0xBF, 0xFF, 0xFF, 0xFF]); assert_put_varint_eq( 0x0000_0000_4000_0000, &[0xC0, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00], ); assert_put_varint_eq( 0x3FFF_FFFF_FFFF_FFFF, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF], ); } #[test] fn test_encode_varint() { let val = VarInt::from_u32(1); let mut encode_buf = [0u8; 8]; let mut buf = &mut encode_buf[..]; buf.put_varint(&val); assert_eq!(buf.len(), 7); assert_eq!(encode_buf[0..1], [0x01]); let mut buf = &mut encode_buf[..]; buf.encode_varint(&val, EncodeBytes::Two); assert_eq!(buf.len(), 6); assert_eq!(encode_buf[0..2], [0x40, 0x01]); } } ================================================ FILE: qcongestion/Cargo.toml ================================================ [package] name = "qcongestion" version = "0.5.0" edition.workspace = true description = "Congestion control in QUIC, a part of dquic" readme.workspace = true repository.workspace = true license.workspace = true keywords.workspace = true categories.workspace = true rust-version.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] thiserror = { workspace = true } tracing = { workspace = true } qbase = { workspace = true } qevent = { workspace = true } rand = { workspace = true } tokio = { workspace = true, features = ["rt", "sync", "time", "macros"] } [dev-dependencies] tokio = { workspace = true, features = ["test-util"] } ================================================ FILE: qcongestion/src/algorithm/bbr/delivery_rate.rs ================================================ // https://tools.ietf.org/html/draft-cheng-iccrg-delivery-rate-estimation-01 use std::time::{Duration, Instant}; use crate::packets::{AckedPackets, SentPacket}; #[derive(Debug)] pub struct Rate { delivered: usize, delivered_time: Instant, first_sent_time: Instant, // Packet number of the last sent packet with app limited. end_of_app_limited: u64, // Packet number of the last sent packet. last_sent_packet: u64, // Packet number of the largest acked packet. largest_acked: u64, // Sample of rate estimation. rate_sample: RateSample, } impl Default for Rate { fn default() -> Self { let now = tokio::time::Instant::now(); Rate { delivered: 0, delivered_time: now, first_sent_time: now, end_of_app_limited: 0, last_sent_packet: 0, largest_acked: 0, rate_sample: RateSample::default(), } } } impl Rate { // 3.2. Transmitting or retransmitting a data packet pub fn on_packet_sent( &mut self, pkt: &mut SentPacket, bytes_in_flight: usize, bytes_lost: u64, ) { // No packets in flight. if bytes_in_flight == 0 { self.first_sent_time = pkt.time_sent; self.delivered_time = pkt.time_sent; } pkt.first_sent_time = self.first_sent_time; pkt.delivered_time = self.delivered_time; pkt.delivered = self.delivered; pkt.is_app_limited = self.app_limited(); pkt.tx_in_flight = bytes_in_flight; pkt.lost = bytes_lost; self.last_sent_packet = pkt.packet_number; } // Update the delivery rate sample when a packet is acked. pub fn update_rate_sample(&mut self, pkt: &AckedPackets, now: Instant) { self.delivered += pkt.size; self.delivered_time = now; if self.rate_sample.prior_time.is_none() || pkt.delivered > self.rate_sample.prior_delivered { self.rate_sample.prior_delivered = pkt.delivered; self.rate_sample.prior_time = Some(pkt.delivered_time); self.rate_sample.is_app_limited = pkt.is_app_limited; self.rate_sample.send_elapsed = pkt.time_sent.saturating_duration_since(pkt.first_sent_time); self.rate_sample.rtt = pkt.rtt; self.rate_sample.ack_elapsed = self .delivered_time .saturating_duration_since(pkt.delivered_time); self.first_sent_time = pkt.time_sent; } self.largest_acked = self.largest_acked.max(pkt.pn); } pub fn generate_rate_sample(&mut self) { // End app-limited phase if bubble is ACKed and gone. if self.app_limited() && self.largest_acked > self.end_of_app_limited { self.update_app_limited(false); } if self.rate_sample.prior_time.is_some() { let interval = self .rate_sample .send_elapsed .max(self.rate_sample.ack_elapsed); self.rate_sample.delivered = self .delivered .saturating_sub(self.rate_sample.prior_delivered); self.rate_sample.interval = interval; if !interval.is_zero() { // Fill in rate_sample with a rate sample. self.rate_sample.delivery_rate = (self.rate_sample.delivered as f64 / interval.as_secs_f64()) as u64; } } } pub fn update_app_limited(&mut self, v: bool) { self.end_of_app_limited = if v { self.last_sent_packet.max(1) } else { 0 } } pub fn app_limited(&mut self) -> bool { self.end_of_app_limited != 0 } pub fn delivered(&self) -> usize { self.delivered } pub fn sample_delivery_rate(&self) -> u64 { self.rate_sample.delivery_rate } pub fn sample_rtt(&self) -> Duration { self.rate_sample.rtt } pub fn sample_is_app_limited(&self) -> bool { self.rate_sample.is_app_limited } } #[derive(Default, Debug)] struct RateSample { delivery_rate: u64, is_app_limited: bool, interval: Duration, delivered: usize, prior_delivered: usize, prior_time: Option, send_elapsed: Duration, ack_elapsed: Duration, rtt: Duration, } #[cfg(test)] mod tests { use super::*; #[test] fn test_rate() { let mut rate = Rate::default(); let now = Instant::now(); let mut sents: Vec = (0..5) .map(|i| SentPacket { packet_number: i, sent_bytes: 100, time_sent: now, ..Default::default() }) .collect(); for sent in &mut sents { let pkt_num = sent.packet_number; rate.on_packet_sent(sent, (pkt_num * 100) as usize, 0); } let delay = Duration::from_millis(100); let recv_ack_time = now + delay; for _ in 0..3 { let sent = sents.pop().unwrap(); let mut acked: AckedPackets = sent.into(); acked.rtt = delay; rate.update_rate_sample(&acked, recv_ack_time); rate.generate_rate_sample(); } // 300 / 0.1 assert_eq!(rate.sample_delivery_rate(), 3000); assert_eq!(rate.sample_rtt(), delay); assert!(!rate.sample_is_app_limited()); } } ================================================ FILE: qcongestion/src/algorithm/bbr/min_max.rs ================================================ use std::fmt::Debug; #[derive(Copy, Clone, Debug)] pub(super) struct MinMax { /// round count, not a timestamp window: u64, samples: [MinMaxSample; 3], } impl MinMax { fn fill(&mut self, sample: MinMaxSample) { self.samples.fill(sample); } pub(super) fn update_max(&mut self, current_round: u64, measurement: u64) -> u64 { let sample = MinMaxSample { time: current_round, value: measurement, }; if self.samples[0].value == 0 /* uninitialised */ || /* found new max? */ sample.value >= self.samples[0].value || /* nothing left in window? */ sample.time - self.samples[2].time > self.window { self.fill(sample); /* forget earlier samples */ return self.samples[0].value; } if sample.value >= self.samples[1].value { self.samples[2] = sample; self.samples[1] = sample; } else if sample.value >= self.samples[2].value { self.samples[2] = sample; } self.subwin_update(sample); self.samples[0].value } /* As time advances, update the 1st, 2nd, and 3rd choices. */ fn subwin_update(&mut self, sample: MinMaxSample) { let dt = sample.time - self.samples[0].time; if dt > self.window { /* * Passed entire window without a new sample so make 2nd * choice the new sample & 3rd choice the new 2nd choice. * we may have to iterate this since our 2nd choice * may also be outside the window (we checked on entry * that the third choice was in the window). */ self.samples[0] = self.samples[1]; self.samples[1] = self.samples[2]; self.samples[2] = sample; if sample.time - self.samples[0].time > self.window { self.samples[0] = self.samples[1]; self.samples[1] = self.samples[2]; self.samples[2] = sample; } } else if self.samples[1].time == self.samples[0].time && dt > self.window / 4 { /* * We've passed a quarter of the window without a new sample * so take a 2nd choice from the 2nd quarter of the window. */ self.samples[2] = sample; self.samples[1] = sample; } else if self.samples[2].time == self.samples[1].time && dt > self.window / 2 { /* * We've passed half the window without finding a new sample * so take a 3rd choice from the last half of the window */ self.samples[2] = sample; } } } impl Default for MinMax { fn default() -> Self { Self { window: 10, samples: [Default::default(); 3], } } } #[derive(Debug, Copy, Clone, Default)] struct MinMaxSample { /// round number, not a timestamp time: u64, value: u64, } ================================================ FILE: qcongestion/src/algorithm/bbr/model.rs ================================================ use std::time::Instant; // 4.1. Maintaining the Network Path Model // This model includes two estimated parameters: self.BtlBw, and self.RTprop. use super::{Bbr, RTPROP_FILTER_LEN}; use crate::packets::AckedPackets; impl Bbr { // 4.1.1.3. Tracking Time for the self.BtlBw Max Filter // Upon connection initialization: pub(super) fn init_round_counting(&mut self) { self.next_round_delivered = 0; self.round_count = 0; self.is_round_start = false; } // Upon receiving an ACK for a given data packet: fn update_round(&mut self, packet: &AckedPackets) { if packet.delivered >= self.next_round_delivered { self.next_round_delivered = self.delivery_rate.delivered(); self.round_count += 1; self.is_round_start = true; self.packet_conservation = false; } else { self.is_round_start = false; } } // 4.1.1.5. Updating the BBR.BtlBw Max Filter pub(super) fn update_btlbw(&mut self, packet: &AckedPackets) { self.update_round(packet); if self.delivery_rate.sample_delivery_rate() >= self.btlbw || !self.delivery_rate.sample_is_app_limited() { self.btlbw = self .btlbwfilter .update_max(self.round_count, self.delivery_rate.sample_delivery_rate()); } } // 4.1.2.2. BBR.RTprop Min Filter pub(super) fn update_rtprop(&mut self) { let sample_rtt = self.delivery_rate.sample_rtt(); let now = tokio::time::Instant::now(); self.is_rtprop_expired = now.saturating_duration_since(self.rtprop_stamp) > RTPROP_FILTER_LEN; if !sample_rtt.is_zero() && (sample_rtt <= self.rtprop || self.is_rtprop_expired) { self.rtprop = sample_rtt; self.rtprop_stamp = now; } } } ================================================ FILE: qcongestion/src/algorithm/bbr/parameters.rs ================================================ // 4.2. BBR Control Parameters // BBR uses three distinct but interrelated control parameters: pacing rate, // send quantum, and congestion window (cwnd). use std::time::Duration; use super::{ Bbr, BbrStateMachine, INITIAL_CWND, MIN_PIPE_CWND_PKTS, MINIMUM_WINDOW_PACKETS, MSS, SEND_QUANTUM_THRESHOLD_PACING_RATE, }; use crate::rtt::INITIAL_RTT; impl Bbr { // 4.2.1. Pacing Rate pub(super) fn init_pacing_rate(&mut self) { let srtt = INITIAL_RTT; let nominal_bandwidth = INITIAL_CWND as f64 / srtt.as_secs_f64(); self.pacing_rate = (self.pacing_gain * nominal_bandwidth) as u64; } pub(super) fn set_pacing_rate(&mut self) { self.set_pacing_rate_with_gain(self.pacing_gain); } pub(super) fn set_pacing_rate_with_gain(&mut self, pacing_gain: f64) { let rate = (pacing_gain * self.btlbw as f64) as u64; if self.is_filled_pipe || rate > self.pacing_rate { self.pacing_rate = rate; } } // 4.2.2. Send Quantum pub(super) fn set_send_quantum(&mut self) { let floor = if self.pacing_rate < SEND_QUANTUM_THRESHOLD_PACING_RATE { MSS } else { 2 * MSS }; // BBR.send_quantum = min(BBR.pacing_rate * 1ms, 64KBytes) self.send_quantum = (self.pacing_rate / 1000).clamp(floor as u64, 64 * 1024); } // 4.2.3. Congestion Window // 4.2.3.2. Target cwnd pub fn inflight(&self, gain: f64) -> u64 { if self.rtprop == Duration::MAX { return INITIAL_CWND; } let quanta = 3 * self.send_quantum; let estimated_bdp = self.btlbw as f64 * self.rtprop.as_secs_f64(); (gain * estimated_bdp) as u64 + quanta } fn update_target_cwnd(&mut self) { self.target_cwnd = self.inflight(self.cwnd_gain); } // 4.2.3.4 Modulating cwnd in Loss Recovery pub(super) fn save_cwnd(&mut self) { self.prior_cwnd = if !self.in_recovery && self.state != BbrStateMachine::ProbeRTT { self.cwnd } else { self.cwnd.max(self.prior_cwnd) } } pub fn restore_cwnd(&mut self) { self.cwnd = self.cwnd.max(self.prior_cwnd) } fn modulate_cwnd_for_recovery(&mut self, bytes_in_flight: u64) { if self.newly_lost_bytes > 0 { self.cwnd = self .cwnd .saturating_sub(self.newly_lost_bytes) .max((MSS * MINIMUM_WINDOW_PACKETS) as u64); } if self.packet_conservation { self.cwnd = self.cwnd.max(bytes_in_flight + self.newly_acked_bytes); } } // 4.2.3.5 Modulating cwnd in ProbeRTT fn modulate_cwnd_for_probe_rtt(&mut self) { if self.state == BbrStateMachine::ProbeRTT { self.cwnd = self.cwnd.min(self.min_pipe_cwnd()); } } // 4.2.3.6. Core cwnd Adjustment Mechanism pub(super) fn set_cwnd(&mut self) { let bytes_in_flight = self.bytes_in_flight; self.update_target_cwnd(); self.modulate_cwnd_for_recovery(bytes_in_flight); if !self.packet_conservation { if self.is_filled_pipe { self.cwnd = self.target_cwnd.min(self.cwnd + self.newly_acked_bytes); } else if self.cwnd < self.target_cwnd || self.delivery_rate.delivered() < INITIAL_CWND as usize { self.cwnd += self.newly_acked_bytes; } self.cwnd = self.cwnd.max(self.min_pipe_cwnd()); } self.modulate_cwnd_for_probe_rtt(); } /// The minimal cwnd value BBR tries to target, in bytes pub(super) fn min_pipe_cwnd(&self) -> u64 { (MIN_PIPE_CWND_PKTS * MSS) as u64 } } #[cfg(test)] mod tests { use super::*; #[test] fn test_init_pacing_rate() { let mut bbr = Bbr::new(); bbr.init(); assert_eq!( bbr.pacing_rate, (bbr.pacing_gain * INITIAL_CWND as f64 / INITIAL_RTT.as_secs_f64()) as u64 ); } #[test] fn test_bbr_set_pacing_rate() { let mut bbr = Bbr::new(); bbr.btlbw = 1000; bbr.is_filled_pipe = true; bbr.set_pacing_rate(); assert_eq!(bbr.pacing_rate, (bbr.btlbw as f64 * bbr.pacing_gain) as u64); } #[test] fn test_bbr_set_send_quantum() { let mut bbr = Bbr::new(); bbr.pacing_rate = SEND_QUANTUM_THRESHOLD_PACING_RATE + 1; bbr.set_send_quantum(); assert_eq!(bbr.send_quantum, (2 * MSS) as u64); bbr.pacing_rate = SEND_QUANTUM_THRESHOLD_PACING_RATE - 1; bbr.set_send_quantum(); assert_eq!(bbr.send_quantum, MSS as u64); bbr.pacing_rate = 120_000_000; bbr.set_send_quantum(); assert_eq!(bbr.send_quantum, 64 * 1024); bbr.pacing_rate = 10_000_000; bbr.set_send_quantum(); assert_eq!(bbr.send_quantum, 10000); } #[test] fn test_bbr_inflight() { let mut bbr = Bbr::new(); bbr.btlbw = 10_000_000; bbr.rtprop = Duration::from_millis(100); let bdp = bbr.inflight(1.0); assert_eq!(bdp, 1_000_000); bbr.send_quantum = 64 * 1024; let bdp = bbr.inflight(1.0); assert_eq!(bdp, 1_000_000 + bbr.send_quantum * 3); } #[test] fn test_bbr_modulate_cwnd_for_recovery() { let mut bbr = Bbr::new(); bbr.cwnd = 10000; bbr.packet_conservation = false; bbr.newly_lost_bytes = 1000; // when packet lost cwnd sub lost_bytes bbr.modulate_cwnd_for_recovery(9000); assert_eq!(bbr.cwnd, 9000); bbr.packet_conservation = true; bbr.newly_lost_bytes = 0; bbr.cwnd = 10000; // when packet conservation cwnd add newly_acked_bytes bbr.modulate_cwnd_for_recovery(9000); bbr.newly_acked_bytes = 1000; assert_eq!(bbr.cwnd, 10000); } #[test] fn test_modulate_cwnd_for_probe_rtt() { let mut bbr = Bbr::new(); bbr.cwnd = 10000; // min(4 * MSS, cwnd) bbr.state = BbrStateMachine::ProbeRTT; bbr.modulate_cwnd_for_probe_rtt(); assert_eq!(bbr.cwnd, (4 * MSS) as u64); } #[test] fn test_bbr_set_cwnd() { let mut bbr = Bbr::new(); bbr.bytes_in_flight = 1000; bbr.packet_conservation = false; bbr.btlbw = 10_000_000; // 10Mbps bbr.rtprop = Duration::from_millis(100); bbr.newly_acked_bytes = 4000; // pacing_rate = btlbw * pacing_gain // target_cwnd = (btlbw * rtt) * cwnd_gain + quantum // init cwnd < tartget_cwnd // when receive ack, adjust cwnd bbr.set_cwnd(); assert_eq!(bbr.cwnd, 100000); } } ================================================ FILE: qcongestion/src/algorithm/bbr/state.rs ================================================ use std::time::Instant; use super::{Bbr, BbrStateMachine, HIGH_GAIN, PROBE_RTT_DURATION}; use crate::rtt::INITIAL_RTT; // BBRGainCycleLen: the number of phases in the BBR ProbeBW gain cycle: 8. const GAIN_CYCLE_LEN: usize = 8; // Pacing Gain Cycles. Each phase normally lasts for roughly BBR.RTprop. const PACING_GAIN_CYCLE: [f64; GAIN_CYCLE_LEN] = [1.25, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]; impl Bbr { pub(super) fn init(&mut self) { self.rtprop = INITIAL_RTT; self.rtprop_stamp = tokio::time::Instant::now(); self.probe_rtt_done_stamp = None; self.probe_rtt_round_done = false; self.packet_conservation = false; self.prior_cwnd = 0; self.is_idle_restart = false; self.init_round_counting(); self.init_full_pipe(); self.init_pacing_rate(); self.enter_startup(); } // 4.3.2.1. Startup Dynamics pub(crate) fn enter_startup(&mut self) { self.state = BbrStateMachine::Startup; self.pacing_gain = HIGH_GAIN; self.cwnd_gain = HIGH_GAIN; } // 4.3.2.2. Estimating When Startup has Filled the Pipe fn init_full_pipe(&mut self) { self.is_filled_pipe = false; self.full_bw = 0; self.full_bw_count = 0; } // 退出 startup 进入 drain 的条件是连续三回合没有带宽增长 pub(super) fn check_full_pipe(&mut self) { if self.is_filled_pipe || !self.is_round_start || self.delivery_rate.app_limited() { // no need to check for a full pipe now return; } // BBR.BtlBw still growing? if self.btlbw as f64 >= self.full_bw as f64 * 1.25 { // record new baseline level self.full_bw = self.btlbw; self.full_bw_count = 0; } self.full_bw_count += 1; if self.full_bw_count >= 3 { self.is_filled_pipe = true; } } // 4.3.3. Drain fn enter_drain(&mut self) { self.state = BbrStateMachine::Drain; self.pacing_gain = 1.0 / HIGH_GAIN; // pace slowly self.cwnd_gain = HIGH_GAIN; // maintain cwnd } pub(super) fn check_drain(&mut self) { if self.state == BbrStateMachine::Startup && self.is_filled_pipe { self.enter_drain() } if self.state == BbrStateMachine::Drain && self.bytes_in_flight <= self.inflight(1.0) { self.enter_probe_bw(); } } // 4.3.4. ProbeBW pub fn enter_probe_bw(&mut self) { self.state = BbrStateMachine::ProbeBW; self.pacing_gain = 1.0; self.cwnd_gain = 2.0; // 随机从一个阶段开始 self.cycle_index = GAIN_CYCLE_LEN - 1 - rand::rng().random_range(0..GAIN_CYCLE_LEN - 1); self.advance_cycle_phase() } // On each ACK BBR runs BBRCheckCyclePhase(), to see if it's time to // advance to the next gain cycle phase: pub(super) fn check_cycle_phase(&mut self) { if self.state == BbrStateMachine::ProbeBW && self.is_next_cycle_phase() { self.advance_cycle_phase(); } } fn advance_cycle_phase(&mut self) { self.cycle_stamp = tokio::time::Instant::now(); self.cycle_index = (self.cycle_index + 1) % GAIN_CYCLE_LEN; self.pacing_gain = PACING_GAIN_CYCLE[self.cycle_index]; } // 是否要进入下一阶段 fn is_next_cycle_phase(&mut self) -> bool { let now = tokio::time::Instant::now(); let is_full_length = now.saturating_duration_since(self.cycle_stamp) > self.rtprop; // pacing_gain == 1.0 持续 rtprop if (self.pacing_gain - 1.0).abs() < f64::EPSILON { return is_full_length; } // pacing_gain > 1 至少持续 rtprop 且 出现丢包或 inflight 达到 5/4 * estimated_BDP if self.pacing_gain > 1.0 { return is_full_length && (self.newly_lost_bytes > 0 || self.prior_bytes_in_flight >= self.inflight(self.pacing_gain)); } // pacing_gain < 1 至少持续 rtprop 且 inflight 达到 estimated_BDP is_full_length || self.prior_bytes_in_flight <= self.inflight(1.0) } // 4.3.4.4. Restarting From Idle pub(super) fn handle_restart_from_idle(&mut self) { if self.bytes_in_flight == 0 && self.delivery_rate.app_limited() { self.is_idle_restart = true; if self.state == BbrStateMachine::ProbeBW { self.set_pacing_rate_with_gain(1.0); } } } // 4.3.5. ProbeRTT pub(super) fn check_probe_rtt(&mut self) { if self.state != BbrStateMachine::ProbeRTT && self.is_rtprop_expired && !self.is_idle_restart { self.enter_probe_rtt(); self.save_cwnd(); self.probe_rtt_done_stamp = None; } if self.state == BbrStateMachine::ProbeRTT { self.handle_probe_rtt(); } self.is_idle_restart = false; } fn enter_probe_rtt(&mut self) { self.state = BbrStateMachine::ProbeRTT; self.pacing_gain = 1.0; self.cwnd_gain = 1.0; } fn handle_probe_rtt(&mut self) { // C.app_limited = (BW.delivered + packets_in_flight) ? : 1 self.delivery_rate.update_app_limited(true); let now = tokio::time::Instant::now(); if let Some(probe_rtt_done_stamp) = self.probe_rtt_done_stamp { if self.is_round_start { self.probe_rtt_round_done = true; } if self.probe_rtt_round_done && now >= probe_rtt_done_stamp { self.rtprop_stamp = now; self.restore_cwnd(); self.exit_probe_rtt(now); } } else if self.bytes_in_flight <= self.min_pipe_cwnd() { self.probe_rtt_done_stamp = Some(now + PROBE_RTT_DURATION); self.probe_rtt_round_done = false; self.next_round_delivered = self.delivery_rate.delivered(); } } fn exit_probe_rtt(&mut self, _: Instant) { if self.is_filled_pipe { self.enter_probe_bw(); } else { self.enter_startup(); } } } #[cfg(test)] mod tests { use std::time::{Duration, Instant}; use crate::algorithm::bbr::{ BbrStateMachine, HIGH_GAIN, INITIAL_CWND, MSS, tests::simulate_round_trip, }; #[test] fn test_bbr_init() { let mut bbr = super::Bbr::new(); bbr.init(); assert_eq!(bbr.state, BbrStateMachine::Startup); assert_eq!(bbr.pacing_gain, HIGH_GAIN); assert_eq!(bbr.cwnd_gain, HIGH_GAIN); assert_eq!(bbr.cwnd, INITIAL_CWND); } #[test] fn test_bbr_enter_startup() { let mut bbr = super::Bbr::new(); bbr.enter_startup(); assert_eq!(bbr.state, BbrStateMachine::Startup); assert_eq!(bbr.pacing_gain, HIGH_GAIN); assert_eq!(bbr.cwnd_gain, HIGH_GAIN); } #[test] fn test_bbr_check_full_pipe() { let mut bbr = super::Bbr::new(); let mut now = tokio::time::Instant::now(); let rtt = Duration::from_millis(100); simulate_round_trip(&mut bbr, now, rtt, 0, 10, MSS); now += Duration::from_secs(1); simulate_round_trip(&mut bbr, now, rtt, 0, 10, MSS); assert_eq!(bbr.btlbw, (10 * 10 * MSS) as u64); bbr.check_full_pipe(); assert!(!bbr.is_filled_pipe); now += Duration::from_secs(1); simulate_round_trip(&mut bbr, now, rtt, 0, 10, MSS); assert_eq!(bbr.btlbw, (10 * 10 * MSS) as u64); bbr.check_full_pipe(); assert!(!bbr.is_filled_pipe); now += Duration::from_secs(1); simulate_round_trip(&mut bbr, now, rtt, 0, 10, MSS); bbr.check_full_pipe(); assert!(bbr.is_filled_pipe); } #[test] fn test_bbr_check_drain() { let mut bbr = super::Bbr::new(); bbr.init(); bbr.is_filled_pipe = true; bbr.bytes_in_flight = 100; bbr.check_drain(); assert_eq!(bbr.state, BbrStateMachine::Drain); let mut bbr = super::Bbr::new(); bbr.init(); bbr.is_filled_pipe = true; bbr.check_drain(); assert_eq!(bbr.state, BbrStateMachine::ProbeBW); } #[test] fn test_bbr_enter_probe_bw() { let mut bbr = super::Bbr::new(); bbr.init(); bbr.enter_probe_bw(); assert_eq!(bbr.state, BbrStateMachine::ProbeBW); assert_eq!(bbr.cwnd_gain, 2.0); } #[test] fn test_bbr_advance_cycle_phase() { let mut bbr = super::Bbr::new(); bbr.init(); bbr.cycle_index = 0; bbr.advance_cycle_phase(); assert_eq!(bbr.pacing_gain, 0.75); bbr.cycle_index = 7; bbr.advance_cycle_phase(); assert_eq!(bbr.pacing_gain, 1.25) } #[test] fn test_bbr_is_next_cycle_phase() { let mut bbr = super::Bbr::new(); bbr.init(); bbr.enter_probe_bw(); let now = Instant::now(); bbr.pacing_gain = 1.0; bbr.cycle_stamp = now - Duration::from_secs(1); assert!(bbr.is_next_cycle_phase()); bbr.pacing_gain = 0.75; bbr.cycle_stamp = now - Duration::from_secs(1); bbr.prior_bytes_in_flight = 100; assert!(bbr.is_next_cycle_phase()); bbr.pacing_gain = 1.25; bbr.cycle_stamp = now - Duration::from_secs(1); assert!(bbr.is_next_cycle_phase()); } #[test] fn test_restart_from_idle() { let mut bbr = super::Bbr::new(); bbr.init(); bbr.bytes_in_flight = 0; bbr.handle_restart_from_idle(); assert!(!bbr.is_idle_restart); } } ================================================ FILE: qcongestion/src/algorithm/bbr.rs ================================================ use std::{ collections::VecDeque, time::{Duration, Instant}, }; use delivery_rate::Rate; use min_max::MinMax; use qevent::quic::recovery::RecoveryMetricsUpdated; use super::Control; use crate::packets::AckedPackets; mod delivery_rate; mod min_max; pub(crate) mod model; pub(crate) mod parameters; pub(crate) mod state; const MSS: usize = 1200; // RTpropFilterLen: A constant specifying the length of the RTProp min // filter window, RTpropFilterLen is `10` secs. const RTPROP_FILTER_LEN: Duration = Duration::from_secs(10); // BBRHighGain: A constant specifying the minimum gain value that will // allow the sending rate to double each round (`2/ln(2)` ~= `2.89`), used // in Startup mode for both BBR.pacing_gain and BBR.cwnd_gain. const HIGH_GAIN: f64 = 2.89; // ProbeRTTDuration: A constant specifying the minimum duration for // which ProbeRTT state holds inflight to BBRMinPipeCwnd or fewer // packets: 200 ms. const PROBE_RTT_DURATION: Duration = Duration::from_millis(200); // Pacing rate threshold for select different send quantum. Default `1.2Mbps`. const SEND_QUANTUM_THRESHOLD_PACING_RATE: u64 = 1_200_000 / 8; // Initial congestion window in bytes. pub(crate) const INITIAL_CWND: u64 = 80 * MSS as u64; // The minimal cwnd value BBR tries to target using: 4 packets, or 4 * SMSS const MIN_PIPE_CWND_PKTS: usize = 4; const MINIMUM_WINDOW_PACKETS: usize = 2; // BBR State // // https://datatracker.ietf.org/doc/html/draft-cardwell-iccrg-bbr-congestion-control-00#section-3.4 #[derive(Debug, PartialEq, Eq)] enum BbrStateMachine { Startup, Drain, ProbeBW, ProbeRTT, } pub(crate) struct Bbr { // StateMachine state: BbrStateMachine, // BBR.pacing_rate: The current pacing rate for a BBR flow, which // controls inter-packet spacing. pacing_rate: u64, // BBR.send_quantum: The maximum size of a data aggregate scheduled and // transmitted together. send_quantum: u64, // Cwnd: The transport sender's congestion window, which limits the // amount of data in flight. cwnd: u64, // BBR.BtlBw: BBR's estimated bottleneck bandwidth available to the transport // flow, estimated from the maximum delivery rate sample in a sliding window. btlbw: u64, // BBR.BtlBwFilter: The max filter used to estimate BBR.BtlBw. btlbwfilter: MinMax, // Delivery rate. delivery_rate: Rate, // BBR.RTprop: BBR's estimated two-way round-trip propagation delay of path, // estimated from the windowed minimum recent round-trip delay sample. rtprop: Duration, // BBR.rtprop_stamp: The wall clock time at which the current BBR.RTProp // sample was obtained. rtprop_stamp: Instant, // BBR.rtprop_expired: A boolean recording whether the BBR.RTprop has // expired and is due for a refresh with an application idle period or a // transition into ProbeRTT state. is_rtprop_expired: bool, // BBR.pacing_gain: The dynamic gain factor used to scale BBR.BtlBw to // produce BBR.pacing_rate. pacing_gain: f64, // BBR.cwnd_gain: The dynamic gain factor used to scale the estimated // BDP to produce a congestion window (cwnd). cwnd_gain: f64, // BBR.round_count: Count of packet-timed round trips. round_count: u64, // BBR.round_start: A boolean that BBR sets to true once per packet- // timed round trip, on ACKs that advance BBR.round_count. is_round_start: bool, // BBR.next_round_delivered: packet.delivered value denoting the end of // a packet-timed round trip. next_round_delivered: usize, // Estimator of full pipe. // BBR.filled_pipe: A boolean that records whether BBR estimates that it // has ever fully utilized its available bandwidth ("filled the pipe"). is_filled_pipe: bool, // Baseline level delivery rate for full pipe estimator. full_bw: u64, // The number of round for full pipe estimator without much growth. full_bw_count: u64, // Timestamp when ProbeRTT state ends. probe_rtt_done_stamp: Option, // Whether a roundtrip in ProbeRTT state ends. probe_rtt_round_done: bool, // Whether in packet sonservation mode. packet_conservation: bool, // Cwnd before loss recovery. prior_cwnd: u64, // Whether restarting from idle. is_idle_restart: bool, // Last time when cycle_index is updated. cycle_stamp: Instant, // Current index of pacing_gain_cycle[]. cycle_index: usize, // The upper bound on the volume of data BBR allows in flight. target_cwnd: u64, // Whether in the recovery mode. in_recovery: bool, // Time of the last recovery event starts. recovery_epoch_start: Option, // Ack time. ack_time: Instant, // Newly marked lost data size in bytes. newly_lost_bytes: u64, // lost data size in total bytes. bytes_lost_in_total: u64, // Newly acked data size in bytes. newly_acked_bytes: u64, // The last P.delivered in bytes. packet_delivered: u64, // The last P.sent_time to determine whether exit recovery. last_ack_packet_sent_time: Instant, // The amount of data that was in flight before processing this ACK. prior_bytes_in_flight: u64, // The sum of the size in bytes of all sent packets that contain at least // one ack-eliciting or PADDING frame and have not been acknowledged or // declared lost. The size does not include IP or UDP overhead. pub bytes_in_flight: u64, } impl From<&Bbr> for RecoveryMetricsUpdated { fn from(value: &Bbr) -> Self { qevent::build!(RecoveryMetricsUpdated { congestion_window: value.cwnd, bytes_in_flight: value.bytes_in_flight, pacing_rate: value.pacing_rate, custom_fields: Map { // AI补全 delivery_rate: value.delivery_rate.sample_delivery_rate(), packet_delivered: value.packet_delivered, newly_acked_bytes: value.newly_acked_bytes, newly_lost_bytes: value.newly_lost_bytes, bytes_lost_in_total: value.bytes_lost_in_total, } }) } } impl Bbr { pub fn new() -> Self { let now = Instant::now(); let mut bbr = Bbr { state: BbrStateMachine::Startup, pacing_rate: 0, send_quantum: 0, cwnd: INITIAL_CWND, btlbw: 0, btlbwfilter: MinMax::default(), delivery_rate: Rate::default(), rtprop: Duration::MAX, rtprop_stamp: now, is_rtprop_expired: false, pacing_gain: HIGH_GAIN, cwnd_gain: HIGH_GAIN, round_count: 0, is_round_start: false, next_round_delivered: 0, is_filled_pipe: false, full_bw: 0, full_bw_count: 0, probe_rtt_done_stamp: None, probe_rtt_round_done: false, packet_conservation: false, prior_cwnd: 0, is_idle_restart: false, cycle_stamp: now, cycle_index: 0, target_cwnd: 0, in_recovery: false, recovery_epoch_start: None, ack_time: now, newly_lost_bytes: 0, newly_acked_bytes: 0, last_ack_packet_sent_time: now, prior_bytes_in_flight: 0, packet_delivered: 0, bytes_in_flight: 0, bytes_lost_in_total: 0, }; bbr.on_connection_init(); bbr } } impl Control for Bbr { fn on_packet_sent_cc(&mut self, packet: &crate::packets::SentPacket) { todo!() } fn on_packet_acked(&mut self, acked_packet: &crate::packets::SentPacket) { todo!() } fn on_packets_lost(&mut self, lost_packets: &[crate::packets::SentPacket]) { todo!() } fn on_congestion_event(&mut self, sent_time: &Instant) { todo!() } fn congestion_window(&self) -> usize { todo!() } fn pacing_rate(&self) -> Option { todo!() } } impl Bbr { // 3.5.1. Initialization fn on_connection_init(&mut self) { self.init(); } // 3.5.2. Per-ACK Steps fn update_model_and_state(&mut self, ack: &mut AckedPackets) { self.update_btlbw(ack); self.check_cycle_phase(); self.check_full_pipe(); self.check_drain(); self.update_rtprop(); self.check_probe_rtt(); } fn update_control_parameters(&mut self) { self.set_pacing_rate(); self.set_send_quantum(); self.set_cwnd(); } // 3.5.3. Per-Transmit Steps fn on_transmit(&mut self) { self.handle_restart_from_idle(); } } #[cfg(test)] mod tests { use std::{ collections::VecDeque, time::{Duration, Instant}, }; use crate::{ algorithm::bbr::{BbrStateMachine, HIGH_GAIN, INITIAL_CWND, MSS}, packets::{AckedPackets, SentPacket}, rtt::INITIAL_RTT, }; #[test] fn test_bbr_init() { let mut bbr = super::Bbr::new(); bbr.init(); assert_eq!(bbr.state, BbrStateMachine::Startup); assert_eq!(bbr.pacing_gain, HIGH_GAIN); assert_eq!(bbr.cwnd_gain, HIGH_GAIN); assert_eq!(bbr.cycle_index, 0); assert_eq!(bbr.cwnd, INITIAL_CWND); assert_eq!(bbr.bytes_in_flight, 0); assert_eq!( bbr.pacing_rate, (bbr.pacing_gain * INITIAL_CWND as f64 / INITIAL_RTT.as_secs_f64()) as u64 ); } #[test] fn test_bbr_sent() { let mut bbr = super::Bbr::new(); for _ in 0..10 { let mut sent = SentPacket { sent_bytes: MSS, ..Default::default() }; bbr.on_sent(&mut sent, MSS); } assert_eq!(bbr.bytes_in_flight, 10 * MSS as u64); } #[test] fn test_bbr_ack() { let mut bbr = super::Bbr::new(); let mut now = Instant::now(); let rtt = Duration::from_millis(100); simulate_round_trip(&mut bbr, now, rtt, 0, 10, MSS); assert_eq!(bbr.bytes_in_flight, 0); assert_eq!(bbr.delivery_rate.delivered(), 10 * MSS); assert_eq!( bbr.delivery_rate.sample_delivery_rate(), (10 * 10 * MSS) as u64 ); now += Duration::from_secs(1); // next roud // generate btlbw simulate_round_trip(&mut bbr, now, rtt, 10, 40, MSS); assert_eq!(bbr.delivery_rate.delivered(), 40 * MSS); assert_eq!( bbr.delivery_rate.sample_delivery_rate(), (30 * 10 * MSS) as u64 ); assert_eq!(bbr.btlbw, (10 * 10 * MSS) as u64); assert_eq!( bbr.pacing_rate, (bbr.pacing_gain * INITIAL_CWND as f64 / INITIAL_RTT.as_secs_f64()) as u64 ); now += Duration::from_secs(1); // update btlbw simulate_round_trip(&mut bbr, now, rtt, 40, 60, MSS); assert_eq!( bbr.delivery_rate.sample_delivery_rate(), (20 * 10 * MSS) as u64 ); assert_eq!(bbr.btlbw, (3 * 10 * 10 * MSS) as u64); assert_eq!(bbr.pacing_rate, (bbr.btlbw as f64 * bbr.pacing_gain) as u64); } pub(super) fn simulate_round_trip( bbr: &mut super::Bbr, start_time: Instant, rtt: Duration, start: usize, end: usize, packet_size: usize, ) { let mut acks = VecDeque::with_capacity(end - start); for i in start..end { let mut sent: SentPackets = SentPackets { packet_number: i as u64, sent_bytes: packet_size, time_sent: start_time, ..Default::default() }; bbr.on_sent(&mut sent, 0); let mut ack: AckedPackets = sent.into(); ack.rtt = rtt; acks.push_back(ack); } // let ack_time = start_time + rtt; bbr.on_ack(acks); } } ================================================ FILE: qcongestion/src/algorithm/new_reno.rs ================================================ use std::sync::{ Arc, atomic::{AtomicU16, Ordering}, }; use qbase::{Epoch, frame::AckFrame}; use qevent::quic::recovery::RecoveryMetricsUpdated; use tokio::time::Instant; use crate::{ algorithm::Control, packets::{SentPacket, State}, }; const INFINITRE_SSTHRESH: usize = usize::MAX; pub(crate) struct NewReno { max_datagram_size: Arc, ecn_ce_counters: [u64; Epoch::count()], bytes_in_flight: usize, congestion_window: usize, congestion_recovery_start_time: Option, ssthresh: usize, } impl From<&NewReno> for RecoveryMetricsUpdated { fn from(reno: &NewReno) -> Self { qevent::build!(RecoveryMetricsUpdated { congestion_window: reno.congestion_window as u64, ssthresh: reno.ssthresh as u64, }) } } impl NewReno { /// B.3. Initialization pub(crate) fn new(max_datagram_size: Arc) -> Self { // The upper bound for the initial window will be // min (10*MSS, max (2*MSS, 14600)) // See https://datatracker.ietf.org/doc/html/rfc6928#autoid-3 let mtu = max_datagram_size.load(Ordering::Relaxed); let initial_window = (mtu * 10).min((mtu * 2).max(14600)); NewReno { max_datagram_size, ecn_ce_counters: [0, 0, 0], congestion_window: initial_window as usize, bytes_in_flight: 0, congestion_recovery_start_time: None, ssthresh: INFINITRE_SSTHRESH, } } /// B.4. On Packet Sent /// OnPacketSentCC(sent_bytes): /// . bytes_in_flight += sent_bytes fn on_packet_sent_cc(&mut self, sent_bytes: usize) { self.bytes_in_flight += sent_bytes; } /// B.5. On Packet Acknowledgment /// InCongestionRecovery(sent_time): /// return sent_time <= congestion_recovery_start_time fn in_congestion_recovery(&self, sent_time: &Instant) -> bool { self.congestion_recovery_start_time .map(|recovery_start_time| *sent_time <= recovery_start_time) .unwrap_or(false) } /// OnPacketAcked(acked_packet): /// if (!acked_packet.in_flight): /// return; /// // Remove from bytes_in_flight. /// bytes_in_flight -= acked_packet.sent_bytes /// // Do not increase congestion_window if application /// // limited or flow control limited. /// if (IsAppOrFlowControlLimited()) /// return /// // Do not increase congestion window in recovery period. /// if (InCongestionRecovery(acked_packet.time_sent)): /// return /// if (congestion_window < ssthresh): /// // Slow start. /// congestion_window += acked_packet.sent_bytes /// else: /// // Congestion avoidance. /// congestion_window += /// max_datagram_size * acked_packet.sent_bytes /// / congestion_window fn on_packet_acked(&mut self, acked_packet: &SentPacket) { if !acked_packet.count_for_cc { return; } // 如果不是 inflight 状态,说明已经丢包重传了 if acked_packet.state == State::Inflight { self.bytes_in_flight = self.bytes_in_flight.saturating_sub(acked_packet.sent_bytes); } // 如果是 Retranmit 状态,又被 ack, 把拥塞窗口加回来 if self.in_congestion_recovery(&acked_packet.time_sent) { qevent::event!({ RecoveryMetricsUpdated::from(&*self) }); return; } if self.congestion_window < self.ssthresh { self.congestion_window += acked_packet.sent_bytes; } else { self.congestion_window += self.max_datagram_size() * acked_packet.sent_bytes / self.congestion_window; } qevent::event!({ RecoveryMetricsUpdated::from(&*self) }); } /// B.6. On New Congestion Event /// OnCongestionEvent(sent_time): /// // No reaction if already in a recovery period. /// if (InCongestionRecovery(sent_time)): /// return /// // Enter recovery period. /// congestion_recovery_start_time = now() /// ssthresh = congestion_window * kLossReductionFactor /// congestion_window = max(ssthresh, kMinimumWindow) /// // A packet can be sent to speed up loss recovery. /// MaybeSendOnePacket() fn on_congestion_event(&mut self, sent_time: &Instant) { if self.in_congestion_recovery(sent_time) { return; } let now = tokio::time::Instant::now(); self.congestion_recovery_start_time = Some(now); // WARN: will be zero self.ssthresh = self.congestion_window - self.max_datagram_size(); // The RECOMMENDED value is 2 * max_datagram_size. // See https://datatracker.ietf.org/doc/html/rfc9002#name-initial-and-minimum-congest self.congestion_window = self.ssthresh.max(2 * self.max_datagram_size()); // A packet can be sent to speed up loss recovery. // self.maybe_send_packet(1); qevent::event!({ RecoveryMetricsUpdated::from(&*self) }); } /// B.7. Process ECN Information /// ProcessECN(ack, pn_space): /// // If the ECN-CE counter reported by the peer has increased, /// // this could be a new congestion event. /// if (ack.ce_counter > ecn_ce_counters[pn_space]): /// ecn_ce_counters[pn_space] = ack.ce_counter /// sent_time = sent_packets[ack.largest_acked].time_sent /// OnCongestionEvent(sent_time) fn process_ecn(&mut self, ack: &AckFrame, sent_time: &Instant, epoch: Epoch) { if let Some(ecn) = ack.ecn() && ecn.ce() > self.ecn_ce_counters[epoch] { self.ecn_ce_counters[epoch] = ecn.ce(); self.on_congestion_event(sent_time); } } /// B.8. On Packets Lost /// OnPacketsLost(lost_packets): /// sent_time_of_last_loss = 0 /// // Remove lost packets from bytes_in_flight. /// for lost_packet in lost_packets: /// if lost_packet.in_flight: /// bytes_in_flight -= lost_packet.sent_bytes /// sent_time_of_last_loss = /// max(sent_time_of_last_loss, lost_packet.time_sent) /// // Congestion event if in-flight packets were lost /// if (sent_time_of_last_loss != 0): /// OnCongestionEvent(sent_time_of_last_loss) /// // Reset the congestion window if the loss of these /// // packets indicates persistent congestion. /// // Only consider packets sent after getting an RTT sample. /// if (first_rtt_sample == 0): /// return /// pc_lost = [] /// for lost in lost_packets: /// if lost.time_sent > first_rtt_sample: /// pc_lost.insert(lost) /// if (InPersistentCongestion(pc_lost)): /// congestion_window = kMinimumWindow /// congestion_recovery_start_time = 0 fn on_packets_lost( &mut self, lost_packets: &mut dyn Iterator, persistent_lost: bool, ) { let mut sent_time_last_loss: Option = None; for lost_packet in lost_packets { if lost_packet.count_for_cc { self.bytes_in_flight = self.bytes_in_flight.saturating_sub(lost_packet.sent_bytes); sent_time_last_loss = sent_time_last_loss .map(|t| t.max(lost_packet.time_sent)) .or(Some(lost_packet.time_sent)); } } if let Some(time) = sent_time_last_loss { self.on_congestion_event(&time); } if persistent_lost { // WARN: will be zero self.ssthresh = self.congestion_window >> 1; self.congestion_window = self.ssthresh.max(2 * self.max_datagram_size()); self.congestion_recovery_start_time = None; } } /// RemoveFromBytesInFlight(discarded_packets): /// // Remove any unacknowledged packets from flight. /// foreach packet in discarded_packets: /// if packet.in_flight /// bytes_in_flight -= size fn remove_from_bytes_in_flight( &mut self, discard_packets: &mut dyn Iterator, ) { for packet in discard_packets { if packet.count_for_cc && packet.state != State::Retransmitted { self.bytes_in_flight -= packet.sent_bytes; } } } fn max_datagram_size(&self) -> usize { self.max_datagram_size.load(Ordering::Relaxed) as usize } } impl Control for NewReno { fn on_packet_sent_cc(&mut self, packet: &SentPacket) { self.on_packet_sent_cc(packet.sent_bytes); } fn on_packet_acked(&mut self, acked_packet: &SentPacket) { self.on_packet_acked(acked_packet); } fn on_packets_lost( &mut self, lost_packets: &mut dyn Iterator, persistent_lost: bool, ) { self.on_packets_lost(lost_packets, persistent_lost); } fn congestion_window(&self) -> usize { self.congestion_window } fn pacing_rate(&self) -> Option { None } fn remove_from_bytes_in_flight(&mut self, packets: &mut dyn Iterator) { self.remove_from_bytes_in_flight(packets); } fn process_ecn(&mut self, ack: &AckFrame, sent_time: &Instant, epoch: Epoch) { self.process_ecn(ack, sent_time, epoch); } } /* #[cfg(test)] mod tests { use super::*; use crate::packets::SentPacket; #[test] fn test_reno_init() { let reno = NewReno::new(); assert_eq!(reno.cwnd, INIT_CWND); assert_eq!(reno.ssthresh, super::INFINITRE_SSTHRESH); assert_eq!(reno.recovery_start_time, None); } #[test] fn test_reno_slow_start() { let mut reno = NewReno::new(); let acks = generate_acks(0, 10); // first roud trip reno.on_ack(acks); assert_eq!(reno.cwnd, 20 * MSS as u64); // second roud trip let acks = generate_acks(10, 30); reno.on_ack(acks); assert_eq!(reno.cwnd, 40 * MSS as u64); } #[test] fn test_reno_congestion_avoidance() { let mut reno = NewReno::new(); reno.ssthresh = 30 * MSS as u64; let acks = generate_acks(0, 20); let pre_cwnd = reno.cwnd(); // slow start reno.on_ack(acks); assert_eq!(reno.cwnd, pre_cwnd + 20 * MSS as u64); let pre_cwnd = reno.cwnd(); let acks = generate_acks(20, 60); // congestion avoidance // increase a MSS when bytes_acked is greater than cwnd reno.on_ack(acks); assert_eq!(reno.cwnd, pre_cwnd + MSS as u64); } #[test] fn test_reno_congestion_event() { let mut reno = NewReno::new(); let now = Instant::now(); reno.ssthresh = 20 * MSS as u64; let acks = generate_acks(0, 10); reno.on_ack(acks); assert_eq!(reno.cwnd, 20 * MSS as u64); assert_eq!(reno.recovery_start_time, None); let time_lost = now + std::time::Duration::from_millis(100); let lost = SentPacket { packet_number: 11, sent_bytes: MSS, time_sent: now, ..Default::default() }; reno.on_congestion_event(&lost); assert_eq!(reno.cwnd, 10 * MSS as u64); assert_eq!(reno.ssthresh, 10 * MSS as u64); assert_eq!(reno.recovery_start_time, Some(time_lost)); } fn generate_acks(start: usize, end: usize) -> VecDeque { let mut acks = VecDeque::with_capacity(end - start); for i in start..end { let sent = SentPacket { packet_number: i as u64, sent_bytes: MSS, time_sent: Instant::now(), ..Default::default() }; let ack: AckedPackets = sent.into(); acks.push_back(ack); } acks } } */ ================================================ FILE: qcongestion/src/algorithm.rs ================================================ use qbase::{Epoch, frame::AckFrame}; use tokio::time::Instant; use crate::packets::SentPacket; // pub(crate) mod bbr; pub(crate) mod new_reno; /// The [`Algorithm`] enum represents different congestion control algorithms that can be used. pub enum Algorithm { Bbr, NewReno, } pub trait Control: Send { fn on_packet_sent_cc(&mut self, packet: &SentPacket); fn on_packet_acked(&mut self, acked_packet: &SentPacket); fn on_packets_lost( &mut self, lost_packets: &mut dyn Iterator, persistent_lost: bool, ); fn process_ecn(&mut self, ack: &AckFrame, sent_time: &Instant, epoch: Epoch); fn congestion_window(&self) -> usize; fn pacing_rate(&self) -> Option; fn remove_from_bytes_in_flight(&mut self, packets: &mut dyn Iterator); } ================================================ FILE: qcongestion/src/congestion.rs ================================================ use std::sync::{Arc, Mutex}; use qbase::{ Epoch, frame::AckFrame, net::tx::{ArcSendWaker, Signals}, }; use qevent::quic::recovery::PacketLostTrigger; use tokio::time::{Duration, Instant}; use crate::{ Algorithm, Feedback, MSS, TooManyPtos, algorithm::{Control, new_reno::NewReno}, pacing::{self, Pacer}, packets::{PacketSpace, SentPacket}, rtt::{ArcRtt, INITIAL_RTT}, status::PathStatus, }; const INIT_CWND: usize = MSS * 10; const PACKET_THRESHOLD: usize = 3; /// Imple RFC 9002 Appendix A. Loss Recovery /// See [Appendix A](https://datatracker.ietf.org/doc/html/rfc9002#name-loss-recovery-pseudocode) pub struct CongestionController { algorithm: Box, // The Round-Trip Time (RTT) estimator. rtt: ArcRtt, loss_detection_timer: Option, // The number of times a PTO has been sent without receiving an acknowledgment. // Use to pto backoff pto_count: u32, max_ack_delay: Duration, packet_spaces: [PacketSpace; Epoch::count()], // pacer is used to control the burst rate pacer: pacing::Pacer, // The waker to notify when the controller is ready to send. pending_burst: bool, // epoch packet trackers trackers: [Arc; 3], need_send_ack_eliciting_packets: [usize; Epoch::count()], path_status: PathStatus, tx_waker: ArcSendWaker, } impl CongestionController { /// A.4. Initialization fn init( algorithm: Algorithm, max_ack_delay: Duration, trackers: [Arc; 3], path_status: PathStatus, tx_waker: ArcSendWaker, ) -> Self { let algorithm: Box = match algorithm { Algorithm::Bbr => todo!("implement BBR"), Algorithm::NewReno => Box::new(NewReno::new(path_status.pmtu())), }; let now = Instant::now(); CongestionController { algorithm, rtt: ArcRtt::new(), loss_detection_timer: None, pto_count: 0, max_ack_delay, packet_spaces: [ PacketSpace::with_epoch(Epoch::Initial, Duration::ZERO), PacketSpace::with_epoch(Epoch::Handshake, Duration::ZERO), PacketSpace::with_epoch(Epoch::Data, max_ack_delay), ], pacer: Pacer::new(INITIAL_RTT, INIT_CWND, path_status.mtu(), now, None), pending_burst: false, trackers, need_send_ack_eliciting_packets: [0; Epoch::count()], path_status, tx_waker, } } /// A.5. On Sending a Packet /// OnPacketSent(packet_number, pn_space, ack_eliciting, /// in_flight, sent_bytes): /// sent_packets[pn_space][packet_number].packet_number = /// packet_number /// sent_packets[pn_space][packet_number].time_sent = now() /// sent_packets[pn_space][packet_number].ack_eliciting = /// ack_eliciting /// sent_packets[pn_space][packet_number].in_flight = in_flight /// sent_packets[pn_space][packet_number].sent_bytes = sent_bytes /// if (in_flight): /// if (ack_eliciting): /// time_of_last_ack_eliciting_packet[pn_space] = now() /// OnPacketSentCC(sent_bytes) /// SetLossDetectionTimer() pub fn on_packet_sent( &mut self, packet_number: u64, epoch: Epoch, ack_eliciting: bool, in_flight: bool, sent_bytes: usize, ) { let now = Instant::now(); let sent = SentPacket::new(packet_number, now, ack_eliciting, in_flight, sent_bytes); if in_flight { if ack_eliciting { self.packet_spaces[epoch].time_of_last_ack_eliciting_packet = Some(now); self.need_send_ack_eliciting_packets[epoch] = self.need_send_ack_eliciting_packets[epoch].saturating_sub(1); } self.algorithm.on_packet_sent_cc(&sent); self.packet_spaces[epoch] .loss_time .get_or_insert_with(|| now + self.rtt.loss_delay()); self.set_loss_detection_timer(); } self.packet_spaces[epoch].sent_packets.push_back(sent); self.pacer.on_sent(sent_bytes); } /// A.6. On Receiving a Datagram /// OnDatagramReceived(datagram): /// // If this datagram unblocks the server, arm the /// // PTO timer to avoid deadlock. /// if (server was at anti-amplification limit): /// SetLossDetectionTimer() /// if loss_detection_timer.timeout < now(): /// // Execute PTO if it would have expired /// // while the amplification limit applied. /// OnLossDetectionTimeout() pub fn on_datagram_rcvd(&mut self) { // If this datagram unblocks the server, arm the PTO timer to avoid deadlock. if self.path_status.is_at_anti_amplification_limit() { let now = Instant::now(); self.set_loss_detection_timer(); if self.loss_detection_timer.is_some_and(|t| t < now) { // Execute PTO if it would have expired while the amplification limit applied. self.on_loss_detection_timeout(); } } } /// A.7. On Receiving an Acknowledgment /// OnAckReceived(ack, pn_space): /// if (largest_acked_packet[pn_space] == infinite): /// largest_acked_packet[pn_space] = ack.largest_acked /// else: /// largest_acked_packet[pn_space] = /// max(largest_acked_packet[pn_space], ack.largest_acked) /// /// // DetectAndRemoveAckedPackets finds packets that are newly /// // acknowledged and removes them from sent_packets. /// newly_acked_packets = /// DetectAndRemoveAckedPackets(ack, pn_space) /// // Nothing to do if there are no newly acked packets. /// if (newly_acked_packets.empty()): /// return /// /// // Update the RTT if the largest acknowledged is newly acked /// // and at least one ack-eliciting was newly acked. /// if (newly_acked_packets.largest().packet_number == /// ack.largest_acked && /// IncludesAckEliciting(newly_acked_packets)): /// latest_rtt = /// now() - newly_acked_packets.largest().time_sent /// UpdateRtt(ack.ack_delay) /// /// // Process ECN information if present. /// if (ACK frame contains ECN information): /// ProcessECN(ack, pn_space) /// /// lost_packets = DetectAndRemoveLostPackets(pn_space) /// if (!lost_packets.empty()): /// OnPacketsLost(lost_packets) /// OnPacketsAcked(newly_acked_packets) /// /// // Reset pto_count unless the client is unsure if /// // the server has validated the client's address. /// if (PeerCompletedAddressValidation()): /// pto_count = 0 /// SetLossDetectionTimer() pub fn on_ack_rcvd(&mut self, epoch: Epoch, ack_frame: &AckFrame, now: Instant) { self.packet_spaces[epoch].update_largest_acked_packet(ack_frame.largest()); match self.packet_spaces[epoch].on_ack_rcvd(ack_frame, &mut self.algorithm) { None => return, Some(newly_acked_packets) => { let (largest_pn, largest_time_sent) = newly_acked_packets.largest; if largest_pn == ack_frame.largest() && newly_acked_packets.include_ack_eliciting { self.rtt.update( now - largest_time_sent, Duration::from_micros(ack_frame.delay()), self.path_status.is_handshake_confirmed(), ); } // Process ECN information if present. if ack_frame.ecn().is_some() { self.process_ecn(ack_frame, &largest_time_sent, epoch) } } } let mut loss_pns = self.packet_spaces[epoch] .detect_lost_packets(self.rtt.loss_delay(), PACKET_THRESHOLD, &mut self.algorithm) .peekable(); if loss_pns.peek().is_some() { self.rtt.try_backoff_rtt(); self.trackers[epoch].may_loss(PacketLostTrigger::TimeThreshold, &mut loss_pns); } if self.peer_completed_address_validation() { self.pto_count = 0; } self.set_loss_detection_timer(); } /// A.8. Setting the Loss Detection Timer /// SetLossDetectionTimer(): /// earliest_loss_time, _ = GetLossTimeAndSpace() /// if (earliest_loss_time != 0): /// // Time threshold loss detection. /// loss_detection_timer.update(earliest_loss_time) /// return /// /// if (server is at anti-amplification limit): /// // The server's timer is not set if nothing can be sent. /// loss_detection_timer.cancel() /// return /// /// if (no ack-eliciting packets in flight && /// PeerCompletedAddressValidation()): /// // There is nothing to detect lost, so no timer is set. /// // However, the client needs to arm the timer if the /// // server might be blocked by the anti-amplification limit. /// loss_detection_timer.cancel() /// return /// /// timeout, _ = GetPtoTimeAndSpace() /// loss_detection_timer.update(timeout) fn set_loss_detection_timer(&mut self) { if let Some((earliest_loss_time, _)) = self.get_loss_time_and_epoch() { self.loss_detection_timer = Some(earliest_loss_time); return; } if self.path_status.is_at_anti_amplification_limit() { self.loss_detection_timer = None; return; } if self.no_ack_eliciting_in_flight() && self.peer_completed_address_validation() { self.loss_detection_timer = None; return; } self.loss_detection_timer = self.get_pto_time_and_epoch().map(|(timeout, _)| timeout); } // A.9. On Timeout /// OnLossDetectionTimeout(): /// earliest_loss_time, pn_space = GetLossTimeAndSpace() /// if (earliest_loss_time != 0): /// // Time threshold loss Detection /// lost_packets = DetectAndRemoveLostPackets(pn_space) /// assert(!lost_packets.empty()) /// OnPacketsLost(lost_packets) /// SetLossDetectionTimer() /// return /// /// if (no ack-eliciting packets in flight): /// assert(!PeerCompletedAddressValidation()) /// // Client sends an anti-deadlock packet: Initial is padded /// // to earn more anti-amplification credit, /// // a Handshake packet proves address ownership. /// if (has Handshake keys): /// SendOneAckElicitingHandshakePacket() /// else: /// SendOneAckElicitingPaddedInitialPacket() /// else: /// // PTO. Send new data if available, else retransmit old data. /// // If neither is available, send a single PING frame. /// _, pn_space = GetPtoTimeAndSpace() /// SendOneOrTwoAckElicitingPackets(pn_space) /// /// pto_count++ /// SetLossDetectionTimer() fn on_loss_detection_timeout(&mut self) -> u32 { if let Some((_, epoch)) = self.get_loss_time_and_epoch() { let mut loss_pns = self.packet_spaces[epoch] .detect_lost_packets(self.rtt.loss_delay(), PACKET_THRESHOLD, &mut self.algorithm) .peekable(); if loss_pns.peek().is_some() { self.rtt.try_backoff_rtt(); self.trackers[epoch].may_loss(PacketLostTrigger::TimeThreshold, &mut loss_pns); } self.set_loss_detection_timer(); return self.pto_count; } if self.no_ack_eliciting_in_flight() { // assert!(!self.peer_completed_address_validation()); if self.path_status.has_handshake_key() { // Send an anti-deadlock packet: Initial is padded // to earn more anti-amplification credit, // a Handshake packet proves address ownership. self.send_ack_eliciting_packet(Epoch::Handshake, 1); } else { self.send_ack_eliciting_packet(Epoch::Initial, 1); } } else { // PTO. Send new data if available, else retransmit old data. // If neither is available, send a single PING frame. if let Some((_, epoch)) = self.get_pto_time_and_epoch() { self.send_ack_eliciting_packet(epoch, 1); } } self.pto_count += 1; self.set_loss_detection_timer(); self.pto_count } /// GetLossTimeAndSpace(): /// time = loss_time[Initial] /// space = Initial /// for pn_space in [ Handshake, ApplicationData ]: /// if (time == 0 || loss_time[pn_space] < time): /// time = loss_time[pn_space]; /// space = pn_space /// return time, space fn get_loss_time_and_epoch(&self) -> Option<(Instant, Epoch)> { self.packet_spaces .iter() .zip(Epoch::iter()) .filter(|(space, _)| space.loss_time.is_some()) .map(|(space, epoch)| (space.loss_time.unwrap(), *epoch)) .min_by_key(|(loss_time, _)| *loss_time) } // GetPtoTimeAndSpace(): // duration = (smoothed_rtt + max(4 * rttvar, kGranularity)) // * (2 ^ pto_count) // // Anti-deadlock PTO starts from the current time // if (no ack-eliciting packets in flight): // assert(!PeerCompletedAddressValidation()) // if (has handshake keys): // return (now() + duration), Handshake // else: // return (now() + duration), Initial // pto_timeout = infinite // pto_space = Initial // for space in [ Initial, Handshake, ApplicationData ]: // if (no ack-eliciting packets in flight in space): // continue; // if (space == ApplicationData): // // Skip Application Data until handshake confirmed. // if (handshake is not confirmed): // return pto_timeout, pto_space // // Include max_ack_delay and backoff for Application Data. // duration += max_ack_delay * (2 ^ pto_count) // // t = time_of_last_ack_eliciting_packet[space] + duration // if (t < pto_timeout): // pto_timeout = t // pto_space = space // return pto_timeout, pto_space fn get_pto_time_and_epoch(&self) -> Option<(Instant, Epoch)> { let mut duration = self.rtt.base_pto(self.pto_count); let now = Instant::now(); if self.no_ack_eliciting_in_flight() { // assert!(!self.peer_completed_address_validation()); if self.path_status.has_handshake_key() { return Some((now + duration, Epoch::Handshake)); } else { return Some((now + duration, Epoch::Initial)); } } let mut pto_time = None; for &epoch in Epoch::iter() { if self.packet_spaces[epoch].no_ack_eliciting_in_flight() { continue; } if epoch == Epoch::Data { // An endpoint MUST NOT set its PTO timer for the Application Data // packet number epoch until the handshake is confirmed if !self.path_status.is_handshake_confirmed() { return pto_time; } duration += self.max_ack_delay * (1 << self.pto_count); } let t = self.packet_spaces[epoch] .time_of_last_ack_eliciting_packet .unwrap() + duration; if pto_time.is_none() || pto_time.is_some_and(|(pto_time, _)| t < pto_time) { pto_time = Some((t, epoch)); } } pto_time } fn no_ack_eliciting_in_flight(&self) -> bool { Epoch::iter().all(|epoch| self.packet_spaces[*epoch].no_ack_eliciting_in_flight()) } /// PeerCompletedAddressValidation(): /// // Assume clients validate the server's address implicitly. /// if (endpoint is server): /// return true /// // Servers complete address validation when a /// // protected packet is received. /// return has received Handshake ACK || /// handshake confirmed fn peer_completed_address_validation(&self) -> bool { self.path_status.is_server() || self.path_status.has_received_handshake_ack() || self.path_status.is_handshake_confirmed() } fn process_ecn(&mut self, ack: &AckFrame, sent_time: &Instant, epoch: Epoch) { self.algorithm.process_ecn(ack, sent_time, epoch); } fn send_ack_eliciting_packet(&mut self, epoch: Epoch, count: usize) { self.need_send_ack_eliciting_packets[epoch] += count; self.tx_waker.wake_by(Signals::PING); } #[inline] fn need_ack(&self) -> bool { Epoch::iter().any(|&epoch| self.packet_spaces[epoch].rcvd_packets.need_ack().is_some()) } #[inline] fn send_quota(&mut self) -> usize { let now = Instant::now(); self.pacer.schedule( self.rtt.smoothed_rtt(), self.algorithm.congestion_window(), self.path_status.mtu(), now, self.algorithm.pacing_rate(), ) } //OnPacketNumberSpaceDiscarded(pn_space): // assert(pn_space != ApplicationData) // RemoveFromBytesInFlight(sent_packets[pn_space]) // sent_packets[pn_space].clear() // // Reset the loss detection and PTO timer // time_of_last_ack_eliciting_packet[pn_space] = 0 // loss_time[pn_space] = 0 // pto_count = 0 // SetLossDetectionTimer() fn discard_epoch(&mut self, epoch: Epoch) { assert!(epoch != Epoch::Data); self.packet_spaces[epoch].discard(&mut self.algorithm); self.loss_detection_timer = None; self.pto_count = 0; self.set_loss_detection_timer(); } fn get_pto(&self, epoch: Epoch) -> Duration { let mut pto_time = self.rtt.base_pto(self.pto_count); if epoch == Epoch::Data { pto_time += self.max_ack_delay * (1 << self.pto_count); } pto_time } } #[derive(Clone)] pub struct ArcCC(Arc>); impl ArcCC { pub fn new( algorithm: Algorithm, max_ack_delay: Duration, trackers: [Arc; 3], path_status: PathStatus, tx_waker: ArcSendWaker, ) -> Self { ArcCC(Arc::new(Mutex::new(CongestionController::init( algorithm, max_ack_delay, trackers, path_status, tx_waker, )))) } } impl super::Transport for ArcCC { fn do_tick(&self) -> Result<(), TooManyPtos> { let now = Instant::now(); let mut guard = self.0.lock().unwrap(); if guard.loss_detection_timer.is_some_and(|t| t <= now) { let pto_count = guard.on_loss_detection_timeout(); if pto_count > 6 { return Err(TooManyPtos(pto_count)); } } if guard.pending_burst && guard.send_quota() >= guard.path_status.mtu() { guard.pending_burst = false; guard.tx_waker.wake_by(Signals::CONGESTION); } if guard.need_ack() { guard.tx_waker.wake_by(Signals::TRANSPORT); } Ok(()) } fn send_quota(&self) -> Result { let mut guard = self.0.lock().unwrap(); let send_quota = guard.send_quota(); if send_quota >= guard.path_status.mtu() { Ok(send_quota) } else { guard.pending_burst = true; Err(Signals::CONGESTION) } } fn retransmit_and_expire_time(&self, epoch: Epoch) -> (Duration, Duration) { let guard = self.0.lock().unwrap(); ( // 尽量让路径先发起重传 guard.rtt.loss_delay() + guard.rtt.rttvar(), guard.get_pto(epoch), ) } fn need_ack(&self, epoch: Epoch) -> Option<(u64, Instant)> { let guard = self.0.lock().unwrap(); guard.packet_spaces[epoch].rcvd_packets.need_ack() } fn on_pkt_sent( &self, epoch: Epoch, pn: u64, is_ack_eliciting: bool, sent_bytes: usize, in_flight: bool, ack: Option, ) { let mut guard = self.0.lock().unwrap(); guard.on_packet_sent(pn, epoch, is_ack_eliciting, in_flight, sent_bytes); if let Some(largest_acked) = ack { guard.packet_spaces[epoch] .rcvd_packets .on_ack_sent(pn, largest_acked); } // See [Section 17.2.2.1](https://www.rfc-editor.org/rfc/rfc9000#name-abandoning-initial-packets) if epoch == Epoch::Handshake && !guard.path_status.is_server() { guard.discard_epoch(Epoch::Initial); } } fn on_ack_rcvd(&self, epoch: Epoch, ack_frame: &AckFrame) { let mut guard = self.0.lock().unwrap(); let now = Instant::now(); guard.on_ack_rcvd(epoch, ack_frame, now); // See [Section 17.2.2.1](https://www.rfc-editor.org/rfc/rfc9000#name-abandoning-initial-packets) if epoch == Epoch::Handshake && guard.path_status.is_server() { guard.discard_epoch(Epoch::Initial); } } fn on_pkt_rcvd(&self, epoch: Epoch, pn: u64, is_ack_eliciting: bool) { if !is_ack_eliciting { return; } let mut guard = self.0.lock().unwrap(); guard.packet_spaces[epoch].rcvd_packets.on_pkt_rcvd(pn); guard.on_datagram_rcvd(); } fn get_pto(&self, epoch: Epoch) -> Duration { let guard = self.0.lock().unwrap(); guard.get_pto(epoch) } fn discard_epoch(&self, epoch: Epoch) { let mut guard = self.0.lock().unwrap(); guard.discard_epoch(epoch); } fn need_send_ack_eliciting(&self, epoch: Epoch) -> usize { let guard = self.0.lock().unwrap(); guard.need_send_ack_eliciting_packets[epoch] } fn grant_anti_amplification(&self) { let guard = self.0.lock().unwrap(); guard.path_status.release_anti_amplification_limit(); } } #[cfg(test)] mod tests {} ================================================ FILE: qcongestion/src/lib.rs ================================================ use qbase::{Epoch, frame::AckFrame, net::tx::Signals}; use qevent::quic::recovery::PacketLostTrigger; use thiserror::Error; use tokio::time::{Duration, Instant}; mod algorithm; pub use algorithm::Algorithm; mod congestion; pub use congestion::ArcCC; mod pacing; mod packets; mod rtt; mod status; pub use status::{HandshakeStatus, PathStatus}; /// default datagram size in bytes. pub const MSS: usize = 1200; #[derive(Debug, Clone, Copy, Error)] #[error("Too many PTOs: {0}")] pub struct TooManyPtos(u32); /// The [`Transport`] trait defines the interface for congestion control algorithms. pub trait Transport { /// Performs a periodic tick to drive the congestion control algorithm. fn do_tick(&self) -> Result<(), TooManyPtos>; /// Returns how many bytes can be sent at the moment. /// If the congestion controller is not ready, returns an signal that should be waited for. fn send_quota(&self) -> Result; /// Gets the retransmission and expiration time for the given epoch. fn retransmit_and_expire_time(&self, epoch: Epoch) -> (Duration, Duration); /// Records the sending of a packet, which may affect congestion control state. /// # Parameters /// - `pn`: The packet number of the sent packet. /// - `is_ack_eliciting`: A boolean indicating whether the packet is ack-eliciting. /// - `sent_bytes`: The number of bytes sent in this packet. /// - `in_flight`: A boolean indicating whether the packet is considered in-flight. /// - `ack`: An optional `u64` representing the largest acknowledged packet number if an AckFrame was included. fn on_pkt_sent( &self, epoch: Epoch, pn: u64, is_ack_eliciting: bool, sent_bytes: usize, in_flight: bool, ack: Option, ); /// Records the receipt of a packet, which may influence future packet transmissions. /// # Parameters /// - `pn`: The packet number of the received packet. /// - `is_ack_elicition`: A boolean indicating whether the received packet is ack-eliciting. fn on_pkt_rcvd(&self, space: Epoch, pn: u64, is_ack_elicition: bool); /// Checks if an AckFrame should be sent in the next packet for the given epoch. /// # Returns /// An [`Option`] containing the largest packet ID and the time it was received if an AckFrame is needed. fn need_ack(&self, space: Epoch) -> Option<(u64, Instant)>; /// Checks if an ack-eliciting packet should be sent for the given epoch. fn need_send_ack_eliciting(&self, space: Epoch) -> usize; /// Updates the congestion control state upon receiving an AckFrame. fn on_ack_rcvd(&self, space: Epoch, ack_frame: &AckFrame); /// Retrieves the current path's PTO duration. /// # Returns /// The current PTO duration for the given epoch. fn get_pto(&self, epoch: Epoch) -> Duration; /// Discards the congestion control state for the specified epoch. fn discard_epoch(&self, epoch: Epoch); /// Releases the anti-amplification limit for this path. fn grant_anti_amplification(&self); } /// The [`Feedback`] trait defines the interface for packet tracking pub trait Feedback: Send + Sync { /// Indicates that a packet with the specified packet number may have been lost. /// # Parameters /// - `pn`: The packet number of the potentially lost packet. fn may_loss(&self, trigger: PacketLostTrigger, pns: &mut dyn Iterator); } ================================================ FILE: qcongestion/src/pacing.rs ================================================ use tokio::time::{Duration, Instant}; // The burst interval in milliseconds const BURST_INTERVAL: Duration = Duration::from_millis(10); const MIN_BURST_SIZE: usize = 10; const MAX_BURST_SIZE: usize = 1280; // Using a value for N that is small, but at least 1 (for example, 1.25) // ensures that variations in RTT do not result in underutilization of the congestion window. const N: f64 = 1.25; pub(super) struct Pacer { capacity: usize, cwnd: usize, tokens: usize, last_burst_time: Instant, rate: Option, } impl Pacer { pub(super) fn new( smoothed_rtt: Duration, cwnd: usize, mtu: usize, now: Instant, rate: Option, ) -> Self { let capacity = Pacer::calculate_capacity(smoothed_rtt, cwnd, mtu, rate); Pacer { capacity, cwnd, tokens: capacity, last_burst_time: now, rate, } } pub(super) fn on_sent(&mut self, packet_size: usize) { self.tokens = self.tokens.saturating_sub(packet_size); } // Schedule and return the packet size to send, max size is mtu pub(super) fn schedule( &mut self, srtt: Duration, cwnd: usize, mtu: usize, now: Instant, rate: Option, ) -> usize { // Update capacity if cwnd or rate has changed if self.cwnd != cwnd || rate != self.rate { self.capacity = Pacer::calculate_capacity(srtt, cwnd, mtu, rate); self.tokens = self.tokens.min(self.capacity); } self.cwnd = cwnd; self.rate = rate; let rate = match rate { Some(r) => r, // RFC 9002 7.7. Pacing // rate = N * congestion_window / smoothed_rtt None => (N * cwnd as f64 / srtt.as_secs_f64()) as usize, }; // Update the last_burst_time and tokens let elapsed = now.duration_since(self.last_burst_time); // TODO: 时间间隔有上限 // elapsed.max(srtt.as_secs_f64() * 2); let new_token = elapsed.as_secs_f64() * rate as f64; self.tokens = self .tokens .saturating_add(new_token as usize) .min(self.capacity); self.last_burst_time = now; self.tokens } fn calculate_capacity( smoothed_rtt: Duration, cwnd: usize, mtu: usize, rate: Option, ) -> usize { let rtt = smoothed_rtt.as_nanos().max(1); let capacity = match rate { // Use the provided rate to calculate the capacity Some(r) => (r as f64 * BURST_INTERVAL.as_secs_f64()) as usize, // Use cwnd and smoothed_rtt to calculate the capacity None => ((cwnd as u128 * BURST_INTERVAL.as_nanos()) / rtt) as usize, }; capacity.clamp(MIN_BURST_SIZE * mtu, MAX_BURST_SIZE * mtu) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_pacer_initialization() { let now = Instant::now(); let pacer = Pacer::new( Duration::from_millis(100), 10, 1500, // MTU now, Some(1_000_000), ); // min capacity is 15KB assert_eq!(pacer.capacity, 15_000); assert_eq!(pacer.tokens, pacer.capacity); assert_eq!(pacer.last_burst_time, now); // if rate is None capacity = cwnd * brust_interval / rtt let pacer = Pacer::new(Duration::from_millis(100), 2_000_000, 1500, now, None); assert_eq!(pacer.capacity, 200_000); let pacer = Pacer::new( Duration::from_millis(100), 2_000_000, 1500, now, Some(18_000_000), // 18_000 kB/s ); // 180KB assert_eq!(pacer.capacity, 180_000); } #[test] fn test_on_sent() { let mut pacer = Pacer::new( Duration::from_millis(100), 10, 1500, Instant::now(), Some(1_000_000), ); // token 15_000 assert_eq!(pacer.tokens, 15_000); pacer.on_sent(1500); // 发送一个 MTU 大小的数据包 assert_eq!(pacer.tokens, 15_000 - 1500); pacer.on_sent(20_000); assert_eq!(pacer.tokens, 0); } #[test] fn test_schedule_no_rate() { let srtt = Duration::from_millis(100); let mut cwnd = 2_000_000; // 2MB let mtu: usize = 1500; let mut update_time = Instant::now(); let mut pacer = Pacer::new(srtt, cwnd, mtu, update_time, None); // token = 200_000 pacer.on_sent(20_000); assert_eq!(pacer.tokens, 180_000); // rate = 1.25 * cwnd / srtt // after 20 ms update_time += BURST_INTERVAL * 2; let packet_size = pacer.schedule(srtt, cwnd, mtu, update_time, None); assert_eq!(pacer.tokens, 200_000); assert_eq!(packet_size, 200_000); pacer.on_sent(1500 * 13); assert_eq!(pacer.tokens, 180_500); // add token update_time += BURST_INTERVAL; let packet_size = pacer.schedule(srtt, cwnd, mtu, update_time, None); assert_eq!(pacer.capacity, 200_000); assert_eq!(pacer.tokens, 200_000); assert_eq!(packet_size, 200_000); // change cwnd, change capacity cwnd = 1_500_000; // 1.5 MB let packet_size = pacer.schedule(srtt, cwnd, mtu, update_time, None); assert_eq!(pacer.capacity, 150_000); assert_eq!(pacer.tokens, 150_000); assert_eq!(packet_size, 150_000); } #[test] fn test_schedule_with_rate() { let srtt = Duration::from_millis(100); let cwnd = 2_000_000; // 2MB let mtu: usize = 1500; let mut update_time = Instant::now(); // 16MB/s let mut rate = Some(16_000_000); let mut pacer = Pacer::new(srtt, cwnd, mtu, update_time, rate); assert_eq!(pacer.capacity, 160_000); let size = pacer.schedule(srtt, cwnd, mtu, update_time, rate); assert_eq!(size, 160_000); pacer.on_sent(150_000); let size = pacer.schedule(srtt, cwnd, mtu, update_time, rate); assert_eq!(size, 10_000); // update rate to update capacity // 1 MB rate = Some(1_000_000); let size = pacer.schedule(srtt, cwnd, mtu, update_time, rate); assert_eq!(size, 10_000); assert_eq!(pacer.capacity, 15_000); update_time += BURST_INTERVAL; let size = pacer.schedule(srtt, cwnd, mtu, update_time, rate); assert_eq!(pacer.tokens, 15_000); assert_eq!(size, 15_000); } } ================================================ FILE: qcongestion/src/packets.rs ================================================ use std::{cmp::Ordering, collections::VecDeque, time::Duration}; use qbase::{Epoch, frame::AckFrame}; use tokio::time::Instant; use crate::algorithm::Control; #[derive(Default, PartialEq, Eq, Clone, Debug)] pub(crate) enum State { #[default] Inflight, Acked, Retransmitted, } #[derive(Eq, Clone, Debug)] pub struct SentPacket { pub(crate) packet_number: u64, pub(crate) time_sent: Instant, pub(crate) ack_eliciting: bool, pub(crate) sent_bytes: usize, pub(crate) state: State, pub(crate) count_for_cc: bool, } impl SentPacket { pub(crate) fn new( packet_number: u64, time_sent: Instant, ack_eliciting: bool, count_for_cc: bool, sent_bytes: usize, ) -> Self { SentPacket { packet_number, time_sent, ack_eliciting, count_for_cc, sent_bytes, state: State::Inflight, } } } impl PartialOrd for SentPacket { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl PartialEq for SentPacket { fn eq(&self, other: &Self) -> bool { self.packet_number == other.packet_number } } impl Ord for SentPacket { fn cmp(&self, other: &Self) -> Ordering { self.packet_number.cmp(&other.packet_number) } } /// The [`RcvdRecords`] struct is used to maintain records of received packets for each epoch. /// It tracks acknowledged packets and determines when an ACK frame should be sent. /// It also retires packets that have been acknowledged by an ACK frame that has already sent and which has been confirmed by the peer. #[derive(Debug)] pub(crate) struct RcvdRecords { epoch: Epoch, ack_immedietly: bool, latest_rcvd_time: Option, largest_rcvd_packet: Option<(u64, Instant)>, max_ack_delay: Duration, } impl RcvdRecords { pub(crate) fn new(epoch: Epoch, max_ack_delay: Duration) -> Self { Self { epoch, ack_immedietly: false, latest_rcvd_time: None, largest_rcvd_packet: None, max_ack_delay, } } pub(crate) fn on_pkt_rcvd(&mut self, pn: u64) { // An endpoint MUST acknowledge all ack-eliciting Initial and Handshake packets immediately if self.epoch == Epoch::Initial || self.epoch == Epoch::Handshake { self.ack_immedietly = true; } // See [Section 13.2.1](https://www.rfc-editor.org/rfc/rfc9000.html#name-sending-ack-frames) // An endpoint SHOULD generate and send an ACK frame without delay when it receives an ack-eliciting packet either: // 1. When the received packet has a packet number less than another ack-eliciting packet that has been received // 2. when the packet has a packet number larger than the highest-numbered ack-eliciting packet that has been // received and there are missing packets between that packet and this packet. let now = Instant::now(); if self.latest_rcvd_time.is_none() { self.latest_rcvd_time = Some(now); } self.ack_immedietly |= self .largest_rcvd_packet .is_some_and(|(largest_pn, _)| pn < largest_pn); self.largest_rcvd_packet = self.largest_rcvd_packet .map_or(Some((pn, now)), |(largest_pn, time)| { if pn > largest_pn { Some((pn, now)) } else { Some((largest_pn, time)) } }); } /// Checks whether an ACK frame needs to be sent. /// Returns [`Some`] if it's time to send an ACK based on the maximum delay. pub(crate) fn need_ack(&self) -> Option<(u64, Instant)> { let now = Instant::now(); if self.ack_immedietly { return self.largest_rcvd_packet; } if self .latest_rcvd_time .is_some_and(|t| t + self.max_ack_delay < now) { return self.largest_rcvd_packet; } None } /// Called when an ACK is sent. /// Updates the last ACK sent information and resets the `need_ack` flag. pub(crate) fn on_ack_sent(&mut self, _pn: u64, _largest_acked: u64) { self.largest_rcvd_packet = None; self.latest_rcvd_time = None; self.ack_immedietly = false; } } // bbr_packet: VecDeque pub(crate) struct PacketSpace { pub(crate) largest_acked_packet: Option, pub(crate) time_of_last_ack_eliciting_packet: Option, pub(crate) loss_time: Option, pub(crate) sent_packets: VecDeque, pub(crate) rcvd_packets: RcvdRecords, pub(crate) max_ack_delay: Duration, } pub(crate) struct NewlyAckedPackets { pub(crate) include_ack_eliciting: bool, pub(crate) largest: (u64, Instant), } impl PacketSpace { pub(crate) fn with_epoch(epoch: Epoch, max_ack_delay: Duration) -> Self { Self { largest_acked_packet: None, time_of_last_ack_eliciting_packet: None, loss_time: None, sent_packets: VecDeque::with_capacity(4), rcvd_packets: RcvdRecords::new(epoch, max_ack_delay), max_ack_delay, } } pub(crate) fn update_largest_acked_packet(&mut self, pn: u64) { self.largest_acked_packet = self.largest_acked_packet.map(|n| n.max(pn)).or(Some(pn)); } pub(crate) fn on_ack_rcvd( &mut self, ack_frame: &AckFrame, algorithm: &mut Box, ) -> Option { if self.sent_packets.is_empty() { return None; } let mut include_ack_eliciting = false; let mut largest_acked = None; let mut index = self .sent_packets .binary_search_by(|p| p.packet_number.cmp(&ack_frame.largest())) .unwrap_or_else(|i| i.saturating_sub(1)); for range in ack_frame.iter() { for pn in range.rev() { while index > 0 && self.sent_packets[index].packet_number > pn { index = index.saturating_sub(1); } if self.sent_packets[index].packet_number == pn && self.sent_packets[index].state != State::Acked { algorithm.on_packet_acked(&self.sent_packets[index]); self.sent_packets[index].state = State::Acked; include_ack_eliciting |= self.sent_packets[index].ack_eliciting; largest_acked = largest_acked .map(|(n, t)| { if n < pn { (pn, self.sent_packets[index].time_sent) } else { (n, t) } }) .or(Some((pn, self.sent_packets[index].time_sent))); } } } while self .sent_packets .front() .is_some_and(|sent| sent.state == State::Acked || sent.state == State::Retransmitted) { self.sent_packets.pop_front(); } Some(NewlyAckedPackets { include_ack_eliciting, largest: largest_acked?, }) } pub(crate) fn no_ack_eliciting_in_flight(&self) -> bool { self.sent_packets .iter() .all(|sent| !sent.ack_eliciting || sent.state != State::Inflight) } pub(crate) fn detect_lost_packets( &mut self, loss_delay: Duration, packet_threshold: usize, algorithm: &mut Box, ) -> impl Iterator + use<> { // assert!(self.largest_acked_packet.is_some()); self.loss_time = None; let now = Instant::now(); let lost_sent_time = now - loss_delay - self.max_ack_delay; let largest_acked = self.largest_acked_packet.unwrap_or(0); let largest_index = self .sent_packets .binary_search_by(|p| p.packet_number.cmp(&largest_acked)) .unwrap_or_else(|i| i.saturating_sub(1)); let loss: Vec<_> = self .sent_packets .iter_mut() .enumerate() .filter(|(_, pkt)| pkt.state == State::Inflight) .map(move |(idx, unacked)| { if unacked.time_sent < lost_sent_time || largest_index >= idx + packet_threshold { unacked.state = State::Retransmitted; Ok((idx, &*unacked)) } else { Err(unacked.time_sent + loss_delay) } }) .filter_map(|result| match result { Ok(t) => Some(t), Err(time) => { self.loss_time = self.loss_time.map_or(Some(time), |t| Some(t.min(time))); None } }) .collect(); const PERSISTENT_LOSS_THRESHOLD: usize = 3; let persistent_lost = loss .iter() .map(|(idx, _)| idx) .try_fold((None, 0), |(prev, count), &idx| { let lost_count = prev.map_or(0, |p| (idx - p == 1) as usize * (count + 1)); if lost_count + 1 >= PERSISTENT_LOSS_THRESHOLD { Err(()) } else { Ok((Some(idx), lost_count)) } }) .is_err(); let (packet_numbers, loss_packet): (Vec<_>, Vec<_>) = loss .into_iter() .map(|(_, pkt)| (pkt.packet_number, pkt)) .unzip(); if !loss_packet.is_empty() { algorithm.on_packets_lost(&mut loss_packet.into_iter(), persistent_lost); } packet_numbers.into_iter() } pub(crate) fn discard(&mut self, algorithm: &mut Box) { let mut remove_from_inflight = self .sent_packets .iter() .filter(|sent| sent.state == State::Inflight); algorithm.remove_from_bytes_in_flight(&mut remove_from_inflight); self.sent_packets.clear(); self.time_of_last_ack_eliciting_packet = None; self.loss_time = None; } } #[cfg(test)] mod tests { use std::{ sync::{Arc, atomic::AtomicU16}, vec, }; use super::*; use crate::algorithm::new_reno::NewReno; #[test] fn test_packet_space() { let mut packet_space = PacketSpace::with_epoch(Epoch::Initial, Duration::from_millis(100)); // let now = Instant::now(); for i in 0..10 { packet_space.sent_packets.push_back(SentPacket::new( i, Instant::now(), true, true, 1200, )); } // ack 9 ~ 4, 1 ~ 0 loss 2,3 let ack_frame = AckFrame::new( 9_u32.into(), 100_u32.into(), 5_u32.into(), vec![(1_u32.into(), 1_u32.into())], None, ); let mut reno: Box = Box::new(NewReno::new(Arc::new(AtomicU16::new(1200)))); packet_space.on_ack_rcvd(&ack_frame, &mut reno); // init 12000, ack 8 packet 12000 + 8 * MSS = 21600 assert_eq!(reno.congestion_window(), 21600); packet_space.largest_acked_packet = Some(ack_frame.largest()); let loss = packet_space.detect_lost_packets(Duration::from_millis(100), 3, &mut reno); assert_eq!(loss.collect::>(), vec![2, 3]); // loss 2, 3 cwnd = 21600 - MSS assert_eq!(reno.congestion_window(), 20400); for i in 10..15 { packet_space.sent_packets.push_back(SentPacket::new( i, Instant::now(), true, true, 1200, )); } for i in 20..25 { packet_space.sent_packets.push_back(SentPacket::new( i, Instant::now(), false, true, 1200, )); } // ack 24 ~ 20 13 // loss 10, 11,12,14 let ack_frame = AckFrame::new( 24_u32.into(), 100_u32.into(), 5_u32.into(), vec![(4_u32.into(), 0_u32.into())], None, ); packet_space.on_ack_rcvd(&ack_frame, &mut reno); packet_space.largest_acked_packet = Some(ack_frame.largest()); assert_eq!(reno.congestion_window(), 20817); packet_space.largest_acked_packet = Some(ack_frame.largest()); let loss = packet_space.detect_lost_packets(Duration::from_millis(100), 3, &mut reno); assert_eq!(loss.collect::>(), vec![10, 11, 12, 14]); assert_eq!(reno.congestion_window(), (20817 - 1200) / 2); } #[tokio::test(flavor = "current_thread")] async fn test_rcvd_records() { let mut rcvd_records = RcvdRecords::new(Epoch::Data, Duration::from_millis(100)); for i in 0..10 { rcvd_records.on_pkt_rcvd(i); } tokio::time::pause(); tokio::time::advance(Duration::from_millis(100)).await; assert_eq!(rcvd_records.need_ack().unwrap().0, 9); rcvd_records.on_ack_sent(9, 9); assert_eq!(rcvd_records.need_ack(), None); tokio::time::resume(); rcvd_records.on_pkt_rcvd(10); assert_eq!(rcvd_records.need_ack(), None); rcvd_records.on_pkt_rcvd(15); assert_eq!(rcvd_records.need_ack(), None); rcvd_records.on_pkt_rcvd(11); assert_eq!(rcvd_records.need_ack().unwrap().0, 15); } } ================================================ FILE: qcongestion/src/rtt.rs ================================================ use std::sync::{Arc, Mutex}; use qevent::quic::recovery::RecoveryMetricsUpdated; use tokio::time::{Duration, Instant}; pub const INITIAL_RTT: Duration = Duration::from_millis(33); pub const MAX_INITIAL_RTT: Duration = Duration::from_millis(333); const GRANULARITY: Duration = Duration::from_millis(1); const TIME_THRESHOLD: f32 = 1.125; #[derive(Debug, Clone)] pub struct Rtt { max_ack_delay: Duration, first_rtt_sample: Option, latest_rtt: Duration, smoothed_rtt: Duration, rttvar: Duration, min_rtt: Duration, } impl From<&Rtt> for RecoveryMetricsUpdated { fn from(rtt: &Rtt) -> Self { qevent::build!(RecoveryMetricsUpdated { smoothed_rtt: rtt.smoothed_rtt.as_secs_f32() * 1000.0, min_rtt: rtt.min_rtt.as_secs_f32() * 1000.0, latest_rtt: rtt.latest_rtt.as_secs_f32() * 1000.0, rtt_variance: rtt.rttvar.as_secs_f32() * 1000.0, }) } } impl Default for Rtt { fn default() -> Self { Self { max_ack_delay: Duration::from_millis(0), first_rtt_sample: None, latest_rtt: Duration::from_millis(0), smoothed_rtt: INITIAL_RTT, rttvar: INITIAL_RTT / 2, min_rtt: Duration::from_millis(0), } } } impl Rtt { fn update( &mut self, latest_rtt: Duration, mut ack_delay: Duration, is_handshake_confirmed: bool, ) { self.latest_rtt = latest_rtt; if self.first_rtt_sample.is_none() { self.min_rtt = latest_rtt; self.smoothed_rtt = latest_rtt; self.rttvar = latest_rtt / 2; self.first_rtt_sample = Some(tokio::time::Instant::now()); } else { // min_rtt ignores acknowledgment delay. self.min_rtt = std::cmp::min(self.min_rtt, latest_rtt); // Limit ack_delay by max_ack_delay after handshake confirmation. if is_handshake_confirmed { ack_delay = std::cmp::min(ack_delay, self.max_ack_delay); } // Adjust for acknowledgment delay if plausible. let mut adjusted_rtt = latest_rtt; if latest_rtt >= self.min_rtt + ack_delay { adjusted_rtt = latest_rtt - ack_delay; } let abs_diff = self.smoothed_rtt.abs_diff(adjusted_rtt); self.rttvar = self.rttvar.mul_f32(0.75) + abs_diff.mul_f32(0.25); self.smoothed_rtt = self.smoothed_rtt.mul_f32(0.875) + adjusted_rtt.mul_f32(0.125); } let event = RecoveryMetricsUpdated::from(&*self); qevent::event!(event); } fn loss_delay(&self) -> Duration { std::cmp::max( std::cmp::max(self.latest_rtt, self.smoothed_rtt).mul_f32(TIME_THRESHOLD), GRANULARITY, ) } /// duration = (smoothed_rtt + max(4 * rttvar, kGranularity)) /// * (2 ^ pto_count) fn base_pto(&self, pto_count: u32) -> Duration { self.smoothed_rtt + std::cmp::max(4 * self.rttvar, GRANULARITY) * (1 << pto_count) } fn try_backoff_rtt(&mut self) { if self.first_rtt_sample.is_some() { return; } self.smoothed_rtt = self .smoothed_rtt .mul_f32(TIME_THRESHOLD) .min(MAX_INITIAL_RTT); self.rttvar = self.smoothed_rtt / 2; tracing::trace!(target: "quic", "Back off initial RTT {}ms", self.smoothed_rtt.as_millis()); } } #[derive(Debug, Clone, Default)] pub struct ArcRtt(Arc>); /// 对外只需暴露ArcRtt,Rtt成为内部实现 impl ArcRtt { pub fn new() -> Self { Self(Arc::new(Mutex::new(Rtt::default()))) } pub fn update(&self, latest_rtt: Duration, ack_delay: Duration, is_handshake_confirmed: bool) { self.0 .lock() .unwrap() .update(latest_rtt, ack_delay, is_handshake_confirmed); } pub fn loss_delay(&self) -> Duration { self.0.lock().unwrap().loss_delay() } pub fn smoothed_rtt(&self) -> Duration { self.0.lock().unwrap().smoothed_rtt } pub fn rttvar(&self) -> Duration { self.0.lock().unwrap().rttvar } pub fn base_pto(&self, pto_count: u32) -> Duration { self.0.lock().unwrap().base_pto(pto_count) } /// Backs off initial RTT on loss before first RTT sample pub fn try_backoff_rtt(&self) { self.0.lock().unwrap().try_backoff_rtt(); } } #[cfg(test)] mod tests {} ================================================ FILE: qcongestion/src/status.rs ================================================ use std::sync::{ Arc, atomic::{AtomicBool, AtomicU16, Ordering}, }; #[derive(Debug)] pub struct HandshakeStatus { is_server: AtomicBool, has_handshake_key: AtomicBool, has_received_handshake_ack: AtomicBool, is_handshake_confirmed: AtomicBool, } impl HandshakeStatus { pub fn new(is_server: bool) -> Self { Self { is_server: AtomicBool::new(is_server), has_handshake_key: AtomicBool::new(false), has_received_handshake_ack: AtomicBool::new(false), is_handshake_confirmed: AtomicBool::new(false), } } } impl HandshakeStatus { pub fn got_handshake_key(&self) { self.has_handshake_key.store(true, Ordering::Relaxed); } pub fn received_handshake_ack(&self) { self.has_received_handshake_ack .store(true, Ordering::Relaxed); } pub fn handshake_confirmed(&self) { self.is_handshake_confirmed.store(true, Ordering::Relaxed); } } #[derive(Clone)] pub struct PathStatus { handshake: Arc, is_at_anti_amplification_limit: Arc, pmtu: Arc, } impl PathStatus { pub fn new(handshake: Arc, pmut: Arc) -> Self { Self { handshake, is_at_anti_amplification_limit: Arc::new(AtomicBool::new(true)), pmtu: pmut, } } pub(crate) fn is_server(&self) -> bool { self.handshake.is_server.load(Ordering::Relaxed) } pub(crate) fn has_handshake_key(&self) -> bool { self.handshake.has_handshake_key.load(Ordering::Relaxed) } pub(crate) fn has_received_handshake_ack(&self) -> bool { self.handshake .has_received_handshake_ack .load(Ordering::Relaxed) } pub(crate) fn is_handshake_confirmed(&self) -> bool { self.handshake .is_handshake_confirmed .load(Ordering::Relaxed) } pub(crate) fn is_at_anti_amplification_limit(&self) -> bool { self.is_at_anti_amplification_limit.load(Ordering::Relaxed) } pub fn release_anti_amplification_limit(&self) { self.is_at_anti_amplification_limit .store(false, Ordering::Release); } pub fn enter_anti_amplification_limit(&self) { self.is_at_anti_amplification_limit .store(true, Ordering::Release); } pub(super) fn pmtu(&self) -> Arc { self.pmtu.clone() } pub(crate) fn mtu(&self) -> usize { self.pmtu.load(Ordering::Relaxed) as usize } } ================================================ FILE: qconnection/Cargo.toml ================================================ [package] name = "qconnection" version = "0.5.0" edition.workspace = true description = "Encapsulation of QUIC connections, a part of dquic" readme.workspace = true repository.workspace = true license.workspace = true keywords.workspace = true categories.workspace = true rust-version.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] bytes = { workspace = true } dashmap = { workspace = true } derive_more = { workspace = true, features = [ "as_ref", "deref", "display", "from", ] } enum_dispatch = { workspace = true } futures = { workspace = true } qbase = { workspace = true } qcongestion = { workspace = true } qresolve = { workspace = true } qevent = { workspace = true } qrecovery = { workspace = true } ring = { workspace = true } qtraversal = { workspace = true } rand = { workspace = true } rustls = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["rt", "sync", "time", "macros"] } tokio-util = { workspace = true, features = ["rt"] } tracing = { workspace = true } qinterface = { workspace = true } x509-parser = { workspace = true } [target.'cfg(any(unix, windows))'.dependencies] qinterface = { workspace = true, features = ["qudp"] } # features: datagram qdatagram = { workspace = true, optional = true } [features] default = ["datagram"] datagram = ["dep:qdatagram"] telemetry = ["qevent/telemetry"] ================================================ FILE: qconnection/src/builder.rs ================================================ use std::{ net::SocketAddr, sync::{Arc, atomic::AtomicBool}, time::Duration, }; use qbase::{ cid::{ConnectionId, GenUniqueCid}, error::Error, net::tx::{ArcSendWakers, Signals}, packet::keys::ArcZeroRttKeys, param::{ArcParameters, ClientParameters, ParameterId, Parameters, ServerParameters}, role::{IntoRole, Role}, sid::{ ControlStreamsConcurrency, ProductStreamsConcurrencyController, handy::DemandConcurrency, }, time::ArcIdleConfig, token::{ArcTokenRegistry, TokenProvider, TokenSink}, }; use qcongestion::HandshakeStatus; use qdatagram::DatagramFlow; use qevent::{ GroupID, quic::{ Owner, transport::{ParametersRestored, ParametersSet}, }, telemetry::{Instrument, QLog, handy::NoopLogger}, }; use qinterface::{ component::{ location::Locations, route::{QuicRouter, RcvdPacketQueue}, }, io::{ProductIO, handy::DEFAULT_IO_FACTORY}, manager::InterfaceManager, }; use qrecovery::crypto::CryptoStream; use qtraversal::punch::puncher::ArcPuncher; use rustls::{ ClientConfig as TlsClientConfig, ServerConfig as TlsServerConfig, crypto::CryptoProvider, }; use tokio::sync::mpsc; use tracing::Instrument as _; use crate::{ ArcLocalCids, ArcReliableFrameDeque, ArcRemoteCids, CidRegistry, Components, Connection, ConnectionState, DataJournal, DataStreams, FlowController, Handshake, QuicRouterRegistry, RawHandshake, SpecificComponents, events::{ArcEventBroker, EmitEvent, Event}, path::ArcPathContexts, space::{ Spaces, data::DataSpace, handshake::HandshakeSpace, initial::InitialSpace, spawn_deliver_and_parse, }, state::ArcConnState, tls::{ AcceptAllClientAuther, ArcSendLock, ArcTlsHandshake, AuthClient, ClientTlsSession, ServerTlsSession, TlsHandshakeInfo, TlsSession, }, traversal::PunchTransaction, }; impl Connection { pub fn new_client(server_name: String, token_sink: Arc) -> ClientFoundation { ClientFoundation { server_name: server_name.clone(), token_registry: ArcTokenRegistry::with_sink(server_name, token_sink), client_params: ClientParameters::default(), } } pub fn new_server(token_provider: Arc) -> ServerFoundation { ServerFoundation { token_registry: ArcTokenRegistry::with_provider(token_provider), server_params: ServerParameters::default(), client_auther: Box::new(AcceptAllClientAuther), } } } pub struct ClientFoundation { server_name: String, token_registry: ArcTokenRegistry, client_params: ClientParameters, } impl ClientFoundation { pub fn with_parameters(mut self, params: ClientParameters) -> Self { self.client_params = params; self } } pub struct ServerFoundation { token_registry: ArcTokenRegistry, server_params: ServerParameters, client_auther: Box, } impl ServerFoundation { pub fn with_parameters(mut self, params: ServerParameters) -> Self { self.server_params = params; self } pub fn with_client_auther(mut self, authers: Box) -> Self { self.client_auther = authers; self } } pub struct ConnectionFoundation { foundation: Foundation, tls_config: TlsConfig, ifaces: Arc, iface_factory: Arc, quic_router: Arc, locations: Arc, stun_servers: Arc<[SocketAddr]>, streams_ctrl: Box, defer_idle_timeout: Duration, } pub type ClientConnectionFoundation = ConnectionFoundation; pub type ServerConnectionFoundation = ConnectionFoundation; impl ClientFoundation { pub fn with_tls_config( self, tls_config: TlsClientConfig, ) -> ConnectionFoundation { ConnectionFoundation { foundation: self, tls_config, ifaces: InterfaceManager::global().clone(), iface_factory: Arc::new(DEFAULT_IO_FACTORY), quic_router: QuicRouter::global().clone(), locations: Arc::new(Locations::new()), stun_servers: Arc::new([]), streams_ctrl: Box::new(DemandConcurrency), // ZST cause no alloc defer_idle_timeout: Duration::ZERO, } } } impl ConnectionFoundation { pub fn with_streams_concurrency_strategy(self, strategy_factory: &F) -> Self where F: ProductStreamsConcurrencyController + ?Sized, { let client_params = &self.foundation.client_params; let init_max_bidi_streams = client_params .get(ParameterId::InitialMaxStreamsBidi) .expect("unreachable: default value will be got if the value unset"); let init_max_uni_streams = client_params .get(ParameterId::InitialMaxStreamsUni) .expect("unreachable: default value will be got if the value unset"); ConnectionFoundation { streams_ctrl: strategy_factory.init(init_max_bidi_streams, init_max_uni_streams), ..self } } pub fn with_zero_rtt(mut self, enabled: bool) -> Self { self.tls_config.enable_early_data = enabled; self } } impl ServerFoundation { pub fn with_tls_config( self, tls_config: TlsServerConfig, ) -> ConnectionFoundation { ConnectionFoundation { foundation: self, tls_config, ifaces: InterfaceManager::global().clone(), iface_factory: Arc::new(DEFAULT_IO_FACTORY), quic_router: QuicRouter::global().clone(), locations: Arc::new(Locations::new()), stun_servers: Arc::new([]), streams_ctrl: Box::new(DemandConcurrency), // ZST cause no alloc defer_idle_timeout: Duration::ZERO, } } } impl ConnectionFoundation { pub fn with_streams_concurrency_strategy(self, strategy_factory: &F) -> Self where F: ProductStreamsConcurrencyController + ?Sized, { let server_params = &self.foundation.server_params; let init_max_bidi_streams = server_params .get(ParameterId::InitialMaxStreamsBidi) .expect("unreachable: default value will be got if the value unset"); let init_max_uni_streams = server_params .get(ParameterId::InitialMaxStreamsUni) .expect("unreachable: default value will be got if the value unset"); ConnectionFoundation { streams_ctrl: strategy_factory.init(init_max_bidi_streams, init_max_uni_streams), ..self } } pub fn with_zero_rtt(mut self, enabled: bool) -> Self { match enabled { true => self.tls_config.max_early_data_size = 0xffffffff, false => self.tls_config.max_early_data_size = 0, } self } } impl ConnectionFoundation { pub fn with_iface_factory(mut self, factory: Arc) -> Self { self.iface_factory = factory; self } pub fn with_iface_manager(mut self, ifaces: Arc) -> Self { self.ifaces = ifaces; self } pub fn with_quic_router(mut self, quic_router: Arc) -> Self { self.quic_router = quic_router; self } pub fn with_locations(mut self, locations: Arc) -> Self { self.locations = locations; self } pub fn with_stun_servers(mut self, stun_servers: Arc<[SocketAddr]>) -> Self { self.stun_servers = stun_servers; self } pub fn with_defer_idle_timeout(mut self, timeout: Duration) -> Self { self.defer_idle_timeout = timeout; self } } fn initial_keys_with( crypto_provider: &Arc, origin_dcid: &ConnectionId, side: rustls::Side, version: rustls::quic::Version, ) -> rustls::quic::Keys { crypto_provider .cipher_suites .iter() .find_map(|cs| match (cs.suite(), cs.tls13()) { (rustls::CipherSuite::TLS13_AES_128_GCM_SHA256, Some(suite)) => { Some(suite.quic_suite()) } _ => None, }) .flatten() .expect("crypto provider does not provide supported cipher suite") .keys(origin_dcid, side, version) } impl ConnectionFoundation { pub fn with_cids(self, origin_dcid: ConnectionId) -> PendingConnection { let initial_keys = initial_keys_with( self.tls_config.crypto_provider(), &origin_dcid, rustls::Side::Client, crate::tls::QUIC_VERSION, ); let rcvd_pkt_q = Arc::new(RcvdPacketQueue::new()); let tx_wakers = ArcSendWakers::default(); let reliable_frames = ArcReliableFrameDeque::with_capacity_and_wakers(8, tx_wakers.clone()); let quic_router_registry = self .quic_router .registry_on_issuing_scid(rcvd_pkt_q.clone(), reliable_frames.clone()); let initial_scid = quic_router_registry.gen_unique_cid(); let mut client_params = self.foundation.client_params; _ = client_params.set(ParameterId::InitialSourceConnectionId, initial_scid); let host = self .foundation .server_name .split_once(':') .map(|(h, _)| h) .unwrap_or(&self.foundation.server_name) .to_string(); let tls_session = ClientTlsSession::init(host, Arc::new(self.tls_config), &client_params) .expect("Failed to initialize TLS handshake"); let zero_rtt_keys = ArcZeroRttKeys::new_pending(Role::Client); // if zero rtt enabled && loadede remembered parameters && zero rtt keys is available let parameters = match tls_session.load_zero_rtt() { Some((remembered_parameters, avaliable_zero_rtt_keys)) => { qevent::event!(ParametersRestored { client_parameters: &remembered_parameters, }); zero_rtt_keys.set_keys(avaliable_zero_rtt_keys); Parameters::new_client(client_params, Some(remembered_parameters), origin_dcid) } None => Parameters::new_client(client_params, None, origin_dcid), }; PendingConnection { interfaces: self.ifaces, iface_factory: self.iface_factory, quic_router: self.quic_router, locations: self.locations, stun_servers: self.stun_servers, rcvd_pkt_q, defer_idle_timeout: self.defer_idle_timeout, role: Role::Client, origin_dcid, initial_scid, tx_wakers, send_lock: ArcSendLock::unrestricted(), reliable_frames, quicrouter_registry: quic_router_registry, parameters, token_registry: self.foundation.token_registry, tls_session: TlsSession::Client(tls_session), initial_keys, zero_rtt_keys, streams_ctrl: self.streams_ctrl, specific: SpecificComponents::Client {}, qlogger: Arc::new(NoopLogger), } } } impl ConnectionFoundation { pub fn with_cids(self, origin_dcid: ConnectionId) -> PendingConnection { let initial_keys = initial_keys_with( self.tls_config.crypto_provider(), &origin_dcid, rustls::Side::Server, crate::tls::QUIC_VERSION, ); let rcvd_pkt_q = Arc::new(RcvdPacketQueue::new()); let tx_wakers = ArcSendWakers::default(); let reliable_frames = ArcReliableFrameDeque::with_capacity_and_wakers(8, tx_wakers.clone()); let quic_router_registry = self .quic_router .registry_on_issuing_scid(rcvd_pkt_q.clone(), reliable_frames.clone()); let initial_scid = quic_router_registry.gen_unique_cid(); let odcid_router_entry = self .quic_router .insert(origin_dcid.into(), rcvd_pkt_q.clone()); let mut server_params = self.foundation.server_params; _ = server_params.set(ParameterId::InitialSourceConnectionId, initial_scid); _ = server_params.set(ParameterId::OriginalDestinationConnectionId, origin_dcid); let tls_session = ServerTlsSession::init( Arc::new(self.tls_config), &server_params, self.foundation.client_auther, ) .expect("Failed to initialize TLS handshake"); // TODO: tls创建的错误处理 PendingConnection { interfaces: self.ifaces, iface_factory: self.iface_factory, quic_router: self.quic_router, locations: self.locations, stun_servers: self.stun_servers, rcvd_pkt_q, defer_idle_timeout: self.defer_idle_timeout, role: Role::Server, origin_dcid, initial_scid, tx_wakers, send_lock: tls_session.send_lock().clone(), reliable_frames, quicrouter_registry: quic_router_registry, parameters: Parameters::new_server(server_params), token_registry: self.foundation.token_registry, tls_session: TlsSession::Server(tls_session), initial_keys, zero_rtt_keys: ArcZeroRttKeys::new_pending(Role::Server), streams_ctrl: self.streams_ctrl, specific: SpecificComponents::Server { odcid_router_entry: Arc::new(odcid_router_entry), using_odcid: Arc::new(AtomicBool::new(true)), }, qlogger: Arc::new(NoopLogger), } } } pub struct PendingConnection { interfaces: Arc, iface_factory: Arc, quic_router: Arc, locations: Arc, stun_servers: Arc<[SocketAddr]>, rcvd_pkt_q: Arc, defer_idle_timeout: Duration, role: Role, origin_dcid: ConnectionId, initial_scid: ConnectionId, send_lock: ArcSendLock, tx_wakers: ArcSendWakers, reliable_frames: ArcReliableFrameDeque, quicrouter_registry: QuicRouterRegistry, parameters: Parameters, token_registry: ArcTokenRegistry, tls_session: TlsSession, initial_keys: rustls::quic::Keys, zero_rtt_keys: ArcZeroRttKeys, streams_ctrl: Box, specific: SpecificComponents, qlogger: Arc, } fn init_stream_and_datagram( local_params: &qbase::param::core::Parameters, remote_params: &qbase::param::core::Parameters, reliable_frames: ArcReliableFrameDeque, streams_ctrl: Box, tx_wakers: ArcSendWakers, metrics: qbase::metric::ArcConnectionMetrics, ) -> (DataStreams, FlowController, DatagramFlow) { assert_ne!(LR::into_role(), RR::into_role()); let flow_ctrl = FlowController::new( remote_params .get(ParameterId::InitialMaxData) .expect("unreachable: default value will be got if the value unset"), local_params .get(ParameterId::InitialMaxData) .expect("unreachable: default value will be got if the value unset"), reliable_frames.clone(), tx_wakers.clone(), ); let data_streams = DataStreams::new( LR::into_role(), local_params, remote_params, streams_ctrl, reliable_frames.clone(), tx_wakers.clone(), Some(metrics), ); let datagram_flow = DatagramFlow::new( local_params .get(ParameterId::MaxDatagramFrameSize) .expect("unreachable: default value will be got if the value unset"), tx_wakers.clone(), ); (data_streams, flow_ctrl, datagram_flow) } impl PendingConnection { pub fn with_qlog(mut self, qlogger: Arc) -> Self { self.qlogger = qlogger; self } pub fn run(self) -> Connection { let (event_broker, events) = mpsc::unbounded_channel(); let group_id = GroupID::from(self.origin_dcid); let qlog_span = self.qlogger.new_trace(self.role.into(), group_id.clone()); let tracing_span = tracing::debug_span!(parent: None, "connection", role = %self.role, odcid = %group_id); let _span = (qlog_span.enter(), tracing_span.clone().entered()); tracing::trace!(parameters=?self.parameters, "starting new connection"); let conn_state = ArcConnState::new(); let event_broker = ArcEventBroker::new(conn_state.clone(), event_broker); let quic_handshake = Handshake::new( RawHandshake::new(self.role, self.reliable_frames.clone()), Arc::new(HandshakeStatus::new(self.role == Role::Server)), event_broker.clone(), ); let local_cids = ArcLocalCids::new(self.initial_scid, self.quicrouter_registry); let remote_cids = ArcRemoteCids::new( self.parameters .get_local(ParameterId::ActiveConnectionIdLimit) .expect("unreachable: default value will be got if the value unset"), self.reliable_frames.clone(), ); let cid_registry = CidRegistry::new(self.role, self.origin_dcid, local_cids, remote_cids); let spaces = Spaces::new( InitialSpace::new(self.initial_keys.into()), HandshakeSpace::new(), DataSpace::new(self.zero_rtt_keys), ); let crypto_streams = [ CryptoStream::new(self.tx_wakers.clone()), CryptoStream::new(self.tx_wakers.clone()), CryptoStream::new(self.tx_wakers.clone()), ]; let metrics = Arc::new(qbase::metric::ConnectionMetrics::default()); let (data_streams, flow_ctrl, datagram_flow) = match self.role { Role::Client => init_stream_and_datagram( self.parameters.client().unwrap(), self.parameters .remembered() .map(|p| p.as_ref()) .unwrap_or(&ServerParameters::default()), self.reliable_frames.clone(), self.streams_ctrl, self.tx_wakers.clone(), metrics.clone(), ), Role::Server => init_stream_and_datagram( self.parameters.server().unwrap(), &ClientParameters::default(), self.reliable_frames.clone(), self.streams_ctrl, self.tx_wakers.clone(), metrics.clone(), ), }; let puncher = ArcPuncher::new( self.reliable_frames.clone(), PunchTransaction::new(cid_registry.clone()), spaces.data().clone(), self.interfaces.clone(), self.iface_factory, self.quic_router.clone(), self.stun_servers.clone(), ); let max_idle_timeout = self .parameters .get_local(ParameterId::MaxIdleTimeout) .expect("Duration::ZERO if not specified"); let components = Components { interfaces: self.interfaces, locations: self.locations, rcvd_pkt_q: self.rcvd_pkt_q, conn_state, idle_config: ArcIdleConfig::new(max_idle_timeout, self.defer_idle_timeout), paths: ArcPathContexts::new(self.tx_wakers.clone(), event_broker.clone()), send_lock: self.send_lock, tls_handshake: ArcTlsHandshake::new(self.tls_session), quic_handshake, parameters: ArcParameters::from(self.parameters), token_registry: self.token_registry, cid_registry, spaces, crypto_streams, reliable_frames: self.reliable_frames, data_streams, flow_ctrl, datagram_flow, event_broker, metrics, specific: self.specific, puncher, }; spawn_tls_handshake(&components, self.tx_wakers.clone()); spawn_deliver_and_parse(&components); let connection_state = Arc::new(ConnectionState { state: Ok(components).into(), qlog_span, tracing_span, }); spawn_drive_connection(events, connection_state.clone()); Connection(connection_state) } } fn spawn_tls_handshake(components: &Components, tx_wakers: ArcSendWakers) { let task = components.tls_handshake.clone().start( components.parameters.clone(), components.quic_handshake.clone(), components.crypto_streams.clone(), ( components.spaces.handshake().keys(), components.spaces.data().zero_rtt_keys(), components.spaces.data().one_rtt_keys(), ), tls_fin_handler( components.parameters.clone(), components.data_streams.clone(), components.flow_ctrl.clone(), components.spaces.data().journal().clone(), components.cid_registry.local.clone(), components.idle_config.clone(), tx_wakers, ), ); let event_broker = components.event_broker.clone(); let task = async move { if let Err(Error::Quic(e)) = task.await { event_broker.emit(Event::Failed(e)); } }; // Terminates when the QUIC connection closes and the event broker shuts down. tokio::spawn(task.instrument_in_current().in_current_span()); } fn tls_fin_handler( parameters: ArcParameters, data_streams: DataStreams, flow_ctrl: FlowController, data_journal: DataJournal, local_cids: ArcLocalCids, idle_config: ArcIdleConfig, tx_wakers: ArcSendWakers, ) -> impl FnOnce(&TlsHandshakeInfo) -> Result<(), Error> + Send { fn apply_parameters( data_streams: &DataStreams, flow_ctrl: &FlowController, // datagram_flow data_journal: &DataJournal, local_cids: &ArcLocalCids, idle_config: &ArcIdleConfig, zero_rtt_rejected: bool, remote_parameters: Arc>, ) -> Result<(), Error> { // accept InitialMaxStreamsBidi, InitialMaxStreamUni, // InitialMaxStreamDataBidiLocal, InitialMaxStreamDataBidiRemote, InitialMaxStreamDataUni, data_streams.revise_params(zero_rtt_rejected, remote_parameters.as_ref()); // accept InitialMaxData: flow_ctrl.sender.revise_max_data( zero_rtt_rejected, remote_parameters .get(ParameterId::InitialMaxData) .expect("unreachable: default value will be got if the value unset"), ); // accept ActiveConnectionIdLimit local_cids.set_limit( remote_parameters .get(ParameterId::ActiveConnectionIdLimit) .expect("unreachable: default value will be got if the value unset"), )?; data_journal.of_rcvd_packets().revise_max_ack_delay( remote_parameters .get(ParameterId::MaxAckDelay) .expect("unreachable: default value will be got if the value unset"), ); idle_config.negotiate_max_idle_timeout( remote_parameters .get(ParameterId::MaxIdleTimeout) .expect("Duration::ZERO if not specified"), ); Ok(()) } move |info| { let zero_rtt_rejected = info .zero_rtt_accepted() .map(|accepted| !accepted) .unwrap_or(false); let parameters = parameters.lock_guard()?; if parameters.role() == Role::Client { if zero_rtt_rejected { debug_assert_eq!(parameters.role(), Role::Client); tracing::trace!(target: "quic", "0-RTT is not enabled, or not accepted by the server."); } else { tracing::trace!(target: "quic", "0-RTT is enabled and accepted by the server."); } } match parameters.role() { Role::Client => { let remote_parameters = parameters .server() .expect("client and server parameters has been ready") .clone(); drop(parameters); qevent::event!(ParametersSet { owner: Owner::Remote, server_parameters: &remote_parameters, }); apply_parameters( &data_streams, &flow_ctrl, &data_journal, &local_cids, &idle_config, zero_rtt_rejected, remote_parameters, )?; } Role::Server => { let remote_parameters = parameters .client() .expect("client and server parameters has been ready") .clone(); drop(parameters); qevent::event!(ParametersSet { owner: Owner::Remote, client_parameters: &remote_parameters, }); apply_parameters( &data_streams, &flow_ctrl, &data_journal, &local_cids, &idle_config, zero_rtt_rejected, remote_parameters, )?; } } tx_wakers.wake_all_by(Signals::TLS_FIN); Result::<_, Error>::Ok(()) } } fn spawn_drive_connection(mut events: mpsc::UnboundedReceiver, state: Arc) { tokio::spawn( async move { while let Some(event) = events.recv().await { match event { Event::Handshaked => {} Event::Failed(quic_error) => _ = state.enter_closing(quic_error), Event::ApplicationClose(_app_error) => {} Event::Closed(ccf) => _ = state.enter_draining(ccf), Event::StatelessReset => {} Event::Terminated => {} } } } .instrument_in_current() .in_current_span(), ); } ================================================ FILE: qconnection/src/events.rs ================================================ use std::sync::Arc; use qbase::{ self, error::{AppError, QuicError}, frame::ConnectionCloseFrame, }; use qevent::quic::connectivity::BaseConnectionStates; use tokio::sync::mpsc; use crate::state::ArcConnState; /// The events that can be emitted by a quic connection #[derive(Debug, Clone, PartialEq, Eq)] pub enum Event { // The connection is handshaked Handshaked, // An Error occurred during the connection, will enter the closing state Failed(QuicError), // The connection is closed by application, just a notification ApplicationClose(AppError), // Received a connection close frame, will enter the draining state Closed(ConnectionCloseFrame), // Received a stateless reset, will enter the draining state StatelessReset, // The connection is terminated completely Terminated, } pub trait EmitEvent: Send + Sync { fn emit(&self, event: Event); } #[derive(Clone)] pub struct ArcEventBroker { conn_state: ArcConnState, raw_broker: Arc, } impl ArcEventBroker { pub fn new(conn_state: ArcConnState, event_broker: E) -> Self { Self { conn_state, raw_broker: Arc::new(event_broker), } } } impl EmitEvent for ArcEventBroker { fn emit(&self, event: Event) { match &event { Event::Handshaked => { if self.conn_state.enter_handshaked().is_none() { return; } } Event::Failed(error) => { if self.conn_state.enter_closing(error).is_none() { return; } } Event::ApplicationClose(error) => { if self.conn_state.enter_closing(error).is_none() { return; } } Event::Closed(ccf) => { if self.conn_state.enter_draining(ccf).is_none() { return; } } Event::Terminated => { let terminated_state = BaseConnectionStates::Closed; self.conn_state.update(terminated_state.into()); } Event::StatelessReset => todo!("unsupported"), }; tracing::debug!(target: "quic", new_state = ?event, "connection state changed"); self.raw_broker.emit(event); } } impl EmitEvent for mpsc::UnboundedSender { fn emit(&self, event: Event) { _ = self.send(event); } } #[cfg(test)] mod tests { use tokio::sync::mpsc; use super::*; #[test] fn test_emit_event() { let (tx, mut rx) = mpsc::unbounded_channel(); tx.emit(Event::Handshaked); assert_eq!(rx.try_recv().unwrap(), Event::Handshaked); } } ================================================ FILE: qconnection/src/handshake.rs ================================================ use std::{ops::Deref, sync::Arc}; use qbase::{ error::Error, frame::{ HandshakeDoneFrame, io::{ReceiveFrame, SendFrame}, }, role::Role, }; use qcongestion::HandshakeStatus; use crate::{ events::{ArcEventBroker, EmitEvent, Event}, path::ArcPathContexts, }; pub type RawHandshake = qbase::handshake::Handshake; /// A wrapper of [`qbase::handshake::Handshake`] that will emit [`Event::Handshaked`] when the handshake is done. /// /// Read the documentation of [`qbase::handshake::Handshake`] for more information. #[derive(Clone)] pub struct Handshake where T: SendFrame + Clone, { inner: RawHandshake, inform_cc: Arc, broker: ArcEventBroker, } impl Handshake where T: SendFrame + Clone, { pub fn new( raw: RawHandshake, inform_cc: Arc, broker: ArcEventBroker, ) -> Self { Self { inner: raw, inform_cc, broker, } } pub fn discard_spaces_on_server_handshake_done(&self, paths: &ArcPathContexts) -> bool { let is_server_done = self.inner.done(); if is_server_done { self.inform_cc.handshake_confirmed(); paths.discard_initial_and_handshake_space(); self.broker.emit(Event::Handshaked); } is_server_done } pub fn role(&self) -> Role { self.inner.role() } pub fn status(&self) -> Arc { self.inform_cc.clone() } pub fn discard_spaces_on_client_handshake_done( &self, paths: ArcPathContexts, ) -> HandshakeDoneReceiver { HandshakeDoneReceiver { handshake: self.clone(), paths, } } } pub struct HandshakeDoneReceiver where T: SendFrame + Clone, { handshake: Handshake, paths: ArcPathContexts, } impl ReceiveFrame for HandshakeDoneReceiver where T: SendFrame + Clone, { type Output = (); fn recv_frame(&self, frame: HandshakeDoneFrame) -> Result<(), Error> { if self.handshake.inner.recv_frame(frame)? { self.handshake.inform_cc.handshake_confirmed(); self.paths.discard_initial_and_handshake_space(); self.handshake.broker.emit(Event::Handshaked); } Ok(()) } } impl Deref for Handshake where T: SendFrame + Clone, { type Target = HandshakeStatus; fn deref(&self) -> &Self::Target { &self.inform_cc } } ================================================ FILE: qconnection/src/lib.rs ================================================ pub mod builder; pub mod events; pub mod handshake; pub mod path; pub mod space; pub mod state; pub mod termination; pub mod tls; mod traversal; pub mod tx; pub mod prelude { pub use qbase::{ cid::ConnectionId, error::{AppError, Error, ErrorKind, QuicError}, frame::ConnectionCloseFrame, net::{addr::*, route::*}, param::ParameterId, role::{Client, IntoRole, Role, Server}, sid::{ControlStreamsConcurrency, ProductStreamsConcurrencyController, StreamId}, varint::VarInt, }; #[cfg(feature = "datagram")] pub use qdatagram::{DatagramReader, DatagramWriter}; pub use qinterface::{ bind_uri::BindUri, io::{IO, IoExt}, }; pub use qrecovery::{recv::StopSending, send::CancelStream, streams::error::StreamError}; pub mod handy { pub use qbase::{param::handy::*, sid::handy::*, token::handy::*}; pub use qevent::telemetry::handy::*; pub use qinterface::io::handy::*; } pub use crate::{ Connection, StreamReader, StreamWriter, tls::{ AuthClient, ClientAgentVerifyResult, ClientNameVerifyResult, LocalAgent, RemoteAgent, SignError, VerifyError, }, }; } // Re-export dependencies use std::{ borrow::Cow, fmt::Debug, future::Future, io, net::SocketAddr, sync::{Arc, RwLock, atomic::AtomicBool}, }; pub use ::{qbase, qdatagram, qevent, qinterface, qrecovery, qtraversal}; use derive_more::From; use enum_dispatch::enum_dispatch; use events::{ArcEventBroker, EmitEvent, Event}; use futures::{FutureExt, TryFutureExt}; use path::ArcPathContexts; use qbase::{ cid, error::{AppError, Error, ErrorKind, QuicError}, flow, frame::{ConnectionCloseFrame, CryptoFrame, Frame, ReliableFrame, StreamFrame}, net::{ addr::EndpointAddr, route::{Link, Pathway}, }, param::{ArcParameters, ParameterId}, role::Role, sid::StreamId, time::ArcIdleConfig, token::ArcTokenRegistry, }; use qdatagram::DatagramFlow; #[cfg(feature = "datagram")] use qdatagram::{DatagramReader, DatagramWriter}; use qevent::{ quic::{Owner, connectivity::ConnectionClosed}, telemetry::Instrument, }; use qinterface::{ bind_uri::BindUri, component::{ location::Locations, route::{self, QuicRouterEntry, RcvdPacketQueue}, }, manager::InterfaceManager, }; use qrecovery::{ crypto::CryptoStream, journal, recv, reliable, send, streams::{self, Ext}, }; use space::Spaces; use state::ArcConnState; use termination::Termination; use tls::ArcSendLock; use tracing::Instrument as _; use crate::{ path::{CreatePathFailure, PathDeactivated}, space::data::DataSpace, termination::Terminator, tls::{ArcTlsHandshake, LocalAgent, RemoteAgent}, traversal::PunchTransaction, }; /// The kind of frame which guaratend to be received by peer. /// /// The bundle of [`StreamFrame`], [`CryptoFrame`], and [`ReliableFrame`]. #[derive(Debug, Clone, From, Eq, PartialEq)] #[enum_dispatch(EncodeSize, FrameFeture)] pub enum GuaranteedFrame { Stream(StreamFrame), Crypto(CryptoFrame), Reliable(ReliableFrame), } impl<'f, D> TryFrom<&'f Frame> for GuaranteedFrame { type Error = &'f Frame; fn try_from(frame: &'f Frame) -> Result { Ok(match ReliableFrame::try_from(frame) { Ok(reliable) => Self::Reliable(reliable), Err(Frame::Crypto(crypto, _data)) => Self::Crypto(*crypto), Err(Frame::Stream(stream, _data)) => Self::Stream(*stream), Err(frame) => return Err(frame), }) } } /// For initial space, only reliable transmission of crypto frames is required. pub type InitialJournal = journal::Journal; /// For handshake space, only reliable transmission of crypto frames is required. pub type HandshakeJournal = journal::Journal; /// For data space, reliable transmission of [`GuaranteedFrame`] (crypto frames, stream frames and reliable frames) is required. pub type DataJournal = journal::Journal; pub type ArcReliableFrameDeque = reliable::ArcReliableFrameDeque; pub type QuicRouterRegistry = route::QuicRouterRegistry; pub type ArcLocalCids = cid::ArcLocalCids; pub type ArcRemoteCids = cid::ArcRemoteCids; pub type CidRegistry = cid::Registry; pub type ArcDcidCell = cid::ArcCidCell; pub type FlowController = flow::FlowController; pub type Credit<'a> = flow::Credit<'a, ArcReliableFrameDeque>; pub type Handshake = handshake::Handshake; pub type RawHandshake = handshake::RawHandshake; pub type DataStreams = streams::DataStreams; pub type StreamReader = recv::Reader>; pub type StreamWriter = send::Writer>; pub type ArcPuncher = qtraversal::punch::puncher::ArcPuncher; #[derive(Clone)] pub struct Components { // TODO: delete this interfaces: Arc, locations: Arc, rcvd_pkt_q: Arc, conn_state: ArcConnState, idle_config: ArcIdleConfig, paths: ArcPathContexts, send_lock: ArcSendLock, tls_handshake: ArcTlsHandshake, quic_handshake: Handshake, parameters: ArcParameters, token_registry: ArcTokenRegistry, cid_registry: CidRegistry, spaces: Spaces, crypto_streams: [CryptoStream; 3], reliable_frames: ArcReliableFrameDeque, data_streams: DataStreams, flow_ctrl: FlowController, datagram_flow: DatagramFlow, event_broker: ArcEventBroker, metrics: qbase::metric::ArcConnectionMetrics, specific: SpecificComponents, puncher: ArcPuncher, } #[derive(Clone)] pub enum SpecificComponents { Client {}, Server { using_odcid: Arc, odcid_router_entry: Arc, }, } /// expand Impl_Future![Type] to `impl Future + Send + use<>` macro_rules! Impl_Future { [$ty:ty] => { impl Future + Send + use<> }; } impl Components { pub fn role(&self) -> Role { match self.specific { SpecificComponents::Client { .. } => Role::Client, SpecificComponents::Server { .. } => Role::Server, } } /// Gets the connection metrics for tracking data volumes. pub fn metrics(&self) -> &qbase::metric::ArcConnectionMetrics { &self.metrics } #[allow(clippy::type_complexity)] pub fn open_bi_stream( &self, ) -> Impl_Future![Result, Error>] { let zero_rtt_avaliable = self.spaces.data().is_zero_rtt_avaliable(); let tls_handshake = self.tls_handshake.clone(); let data_streams = self.data_streams.clone(); let parameters = self.parameters.clone(); async move { if !zero_rtt_avaliable { tls_handshake.info().await?; } data_streams.open_bi(¶meters).await } .instrument_in_current() .in_current_span() } pub fn open_uni_stream(&self) -> Impl_Future![Result, Error>] { let zero_rtt_avaliable = self.spaces.data().is_zero_rtt_avaliable(); let tls_handshake = self.tls_handshake.clone(); let data_streams = self.data_streams.clone(); let parameters = self.parameters.clone(); async move { if !zero_rtt_avaliable { tls_handshake.info().await?; } data_streams.open_uni(¶meters).await } .instrument_in_current() .in_current_span() } #[allow(clippy::type_complexity)] pub fn accept_bi_stream( &self, ) -> Impl_Future![Result<(StreamId, (StreamReader, StreamWriter)), Error>] { let data_streams = self.data_streams.clone(); let parameters = self.parameters.clone(); async move { data_streams.accept_bi(¶meters).await } .instrument_in_current() .in_current_span() } pub fn accept_uni_stream(&self) -> Impl_Future![Result<(StreamId, StreamReader), Error>] { let data_streams = self.data_streams.clone(); async move { data_streams.accept_uni().await } .instrument_in_current() .in_current_span() } #[cfg(feature = "datagram")] #[deprecated] pub fn datagram_reader(&self) -> io::Result { self.datagram_flow.reader() } #[cfg(feature = "datagram")] #[deprecated] pub fn datagram_writer(&self) -> Impl_Future![io::Result] { let params = self.parameters.clone(); let datagram_flow = self.datagram_flow.clone(); async move { let max_datagram_frame_size = params .remote_ready() .await? .get_remote(ParameterId::MaxDatagramFrameSize) .expect("unreachable: default value will be got if the value unset"); datagram_flow.writer(max_datagram_frame_size) } .instrument_in_current() .in_current_span() } pub fn add_path( &self, bind_uri: BindUri, link: Link, pathway: Pathway, ) -> Result<(), CreatePathFailure> { self.get_or_try_create_path(bind_uri, link, pathway, false) .map(|_| ()) } pub fn del_path(&self, pathway: &Pathway) { self.paths.remove(pathway, &PathDeactivated::App); } pub fn local_agent(&self) -> Impl_Future![Result, Error>] { let tls_handshake = self.tls_handshake.clone(); async move { match tls_handshake.info().await?.as_ref() { tls::TlsHandshakeInfo::Client { local_agent, .. } => Ok(local_agent.clone()), tls::TlsHandshakeInfo::Server { local_agent, .. } => Ok(Some(local_agent.clone())), } } .instrument_in_current() .in_current_span() } pub fn remote_agent(&self) -> Impl_Future![Result, Error>] { let tls_handshake = self.tls_handshake.clone(); async move { match tls_handshake.info().await?.as_ref() { tls::TlsHandshakeInfo::Client { remote_agent, .. } => { Ok(Some(remote_agent.clone())) } tls::TlsHandshakeInfo::Server { remote_agent, .. } => Ok(remote_agent.clone()), } } .instrument_in_current() .in_current_span() } } impl Components { pub fn enter_closing(self, error: Error) -> Termination { qevent::event!(ConnectionClosed { owner: Owner::Local, error: &error, // TODO: trigger }); self.data_streams.on_conn_error(&error); self.datagram_flow.on_conn_error(&error); self.tls_handshake.on_conn_error(&error); self.parameters.on_conn_error(&error); tokio::spawn( { let pto_duration = self.paths.max_pto_duration().unwrap_or_default(); let event_broker = self.event_broker.clone(); async move { tokio::time::sleep(pto_duration).await; event_broker.emit(Event::Terminated); } } .instrument_in_current() .in_current_span(), ); match self.send_lock.is_permitted() { // If permitted, we can send ccf packets. true => { let terminator = Arc::new(Terminator::new(error.clone().into(), &self)); tokio::spawn( async move { self.spaces.send_ccf_packets(terminator.as_ref()).await } .instrument_in_current() .in_current_span(), ); } // No need to send packets, just clear the paths. false => { // TODO: check the remote of close spaces self.paths.clear(); } } Termination::closing(error, self.cid_registry.local, self.rcvd_pkt_q) } pub fn enter_draining(self, ccf: ConnectionCloseFrame) -> Termination { qevent::event!(ConnectionClosed { owner: Owner::Local, ccf: &ccf // TODO: trigger }); let error = ccf.clone().into(); self.data_streams.on_conn_error(&error); self.datagram_flow.on_conn_error(&error); self.tls_handshake.on_conn_error(&error); self.parameters.on_conn_error(&error); tokio::spawn( { let pto_duration = self.paths.max_pto_duration().unwrap_or_default(); let event_broker = self.event_broker.clone(); async move { tokio::time::sleep(pto_duration).await; event_broker.emit(Event::Terminated); } } .instrument_in_current() .in_current_span(), ); match self.send_lock.is_permitted() { // If permitted, we can send ccf packets. true => { let terminator = Arc::new(Terminator::new(ccf, &self)); tokio::spawn( async move { self.spaces.send_ccf_packets(terminator.as_ref()).await } .instrument_in_current() .in_current_span(), ); } // No need to send packets, just clear the paths. false => { self.paths.clear(); } } // No need to receive packets, just close all queues. self.rcvd_pkt_q.close_all(); Termination::draining(error, self.cid_registry.local) } } struct ConnectionState { state: RwLock>, qlog_span: qevent::telemetry::Span, tracing_span: tracing::Span, } impl ConnectionState { // called by event pub fn enter_closing(&self, error: QuicError) -> Result<(), Error> { let _span = (self.qlog_span.enter(), self.tracing_span.enter()); let mut conn = self.state.write().unwrap(); let core_conn = conn.as_ref().map_err(|t| t.error())?; *conn = Err(core_conn.clone().enter_closing(error.into())); Ok(()) } pub fn application_close( &self, reason: impl Into>, code: u64, ) -> Result<(), Error> { let _span = (self.qlog_span.enter(), self.tracing_span.enter()); let mut conn = self.state.write().unwrap(); let core_conn = conn.as_ref().map_err(|t| t.error())?; let error_code = code.try_into().expect("application error code overflow"); let error = AppError::new(error_code, reason); let event = Event::ApplicationClose(error.clone()); core_conn.event_broker.emit(event); *conn = Err(core_conn.clone().enter_closing(error.into())); Ok(()) } pub fn enter_draining(&self, ccf: ConnectionCloseFrame) -> bool { let _span = (self.qlog_span.enter(), self.tracing_span.enter()); let mut conn = self.state.write().unwrap(); match conn.as_mut() { Ok(core_conn) => { *conn = Err(core_conn.clone().enter_draining(ccf)); true } Err(termination) => termination.enter_draining(), } } fn try_map_components(&self, op: impl FnOnce(&Components) -> T) -> Result { let _span = (self.qlog_span.enter(), self.tracing_span.enter()); self.state .read() .unwrap() .as_ref() .map(op) .map_err(|termination| termination.error()) } fn try_map_components_future( &self, op: M, ) -> impl Future> + Send + use where F: Future + Send, M: FnOnce(&Components) -> F, { match self.try_map_components(op) { Ok(future) => future.map(Ok).left_future(), Err(error) => std::future::ready(error).map(Err).right_future(), } } fn validate(&self) -> Result<(), Error> { let _span = (self.qlog_span.enter(), self.tracing_span.enter()); let mut conn = self.state.write().unwrap(); let core_conn = conn.as_ref().map_err(|e| e.error())?; let validate = 'validate: { if core_conn.paths.is_empty() { let error = QuicError::with_default_fty(ErrorKind::NoViablePath, "No viable path exist"); break 'validate Err(error); } Ok(()) }; if let Err(error) = validate { core_conn.event_broker.emit(Event::Failed(error.clone())); let termination = core_conn.clone().enter_closing(error.into()); let error = termination.error(); *conn = Err(termination); return Err(error); } Ok(()) } } impl Drop for ConnectionState { fn drop(&mut self) { let _span = self.tracing_span.enter(); if self.validate().is_ok() && self.application_close("", 0).is_ok() { #[cfg(debug_assertions)] tracing::warn!(target: "quic", "connection is still active when dropped, close it automatically."); #[cfg(not(debug_assertions))] tracing::debug!(target: "quic", "connection is still active when dropped, close it automatically."); } } } #[derive(Clone)] pub struct Connection(Arc); impl Connection { pub fn role(&self) -> Result { self.0.try_map_components(|core_conn| core_conn.role()) } /// Close the connection with application close frame. /// /// Return error if the connection is already closed. pub fn close(&self, reason: impl Into>, code: u64) -> Result<(), Error> { self.0.application_close(reason, code) } /// Gets the connection metrics for tracking data volumes. /// /// Returns the metrics that track: /// - pending_send_bytes: Data written by application but not yet sent /// - sent_unacked_bytes: Data sent but not yet acknowledged /// - sent_acked_bytes: Data sent and acknowledged pub fn metrics(&self) -> Result { self.0 .try_map_components(|core_conn| core_conn.metrics().clone()) } #[allow(clippy::type_complexity)] pub fn open_bi_stream( &self, ) -> Impl_Future![Result, Error>] { self.0 .try_map_components_future(|core_conn| core_conn.open_bi_stream()) .map(|result| result?) } pub fn open_uni_stream(&self) -> Impl_Future![Result, Error>] { self.0 .try_map_components_future(|core_conn| core_conn.open_uni_stream()) .map(|result| result?) } #[allow(clippy::type_complexity)] pub fn accept_bi_stream( &self, ) -> Impl_Future![Result<(StreamId, (StreamReader, StreamWriter)), Error>] { self.0 .try_map_components_future(|core_conn| core_conn.accept_bi_stream()) .map(|result| result?) } pub fn accept_uni_stream(&self) -> Impl_Future![Result<(StreamId, StreamReader), Error>] { self.0 .try_map_components_future(|core_conn| core_conn.accept_uni_stream()) .map(|result| result?) } #[cfg(feature = "datagram")] #[deprecated] #[allow(deprecated)] pub fn datagram_reader(&self) -> Result, Error> { self.0 .try_map_components(|core_conn| core_conn.datagram_reader()) } #[cfg(feature = "datagram")] #[deprecated] #[allow(deprecated)] pub async fn datagram_writer(&self) -> Result, Error> { Ok(self .0 .try_map_components(|core_conn| core_conn.datagram_writer())? .await) } pub fn add_path( &self, bind_uri: BindUri, link: Link, pathway: Pathway, ) -> Result<(), CreatePathFailure> { self.0 .try_map_components(|core_conn| core_conn.add_path(bind_uri, link, pathway)) .unwrap_or_else(|cc| Err(CreatePathFailure::ConnectionClosed(cc))) } pub fn del_path(&self, pathway: &Pathway) -> Result<(), Error> { self.0 .try_map_components(|core_conn| core_conn.del_path(pathway)) } pub fn origin_dcid(&self) -> Result { self.0 .try_map_components(|core_conn| core_conn.cid_registry.origin_dcid()) } pub fn handshaked(&self) -> Impl_Future![Result<(), Error>] { self.0 .try_map_components_future(|core_conn| core_conn.conn_state.handshaked()) .map(|result| result?) } pub fn terminated(&self) -> Impl_Future![Error] { self.0 .try_map_components_future(|core_conn| core_conn.conn_state.terminated()) .map(|(Ok(error) | Err(error))| error) } pub fn local_agent(&self) -> Impl_Future![Result, Error>] { self.0 .try_map_components_future(|core_conn| core_conn.local_agent()) .map(|result| result?) } pub fn remote_agent(&self) -> Impl_Future![Result, Error>] { self.0 .try_map_components_future(|core_conn| core_conn.remote_agent()) .map(|result| result?) } pub fn server_name(&self) -> Impl_Future![Result] { self.0 .try_map_components_future(|core_conn| match core_conn.role() { Role::Client => core_conn .remote_agent() .map_ok(|agent| agent.unwrap().name().to_owned()) .left_future(), Role::Server => core_conn .local_agent() .map_ok(|agent| agent.unwrap().name().to_owned()) .right_future(), }) .map(|result| result?) } pub fn add_local_endpoint(&self, bind: BindUri, addr: EndpointAddr) -> Result<(), Error> { self.0 .try_map_components(|core_conn| core_conn.add_local_endpoint(bind, addr)) } pub fn add_peer_endpoint( &self, addr: EndpointAddr, source: qresolve::Source, ) -> Result<(), Error> { self.0 .try_map_components(|core_conn| core_conn.add_peer_endpoint(addr, source)) } pub fn remove_address(&self, addr: SocketAddr) -> Result<(), Error> { self.0 .try_map_components(|core_conn| core_conn.remove_address(addr)) } pub fn subscribe_local_address(&self) -> Result<(), Error> { self.0 .try_map_components(|core_conn| core_conn.subscribe_local_address()) } pub fn path_context(&self) -> Result { self.0 .try_map_components(|core_conn| core_conn.paths.clone()) } /// Check if the connection is still valid. /// /// Return error if no viable path exists, or the connection is closed. pub fn validate(&self) -> Result<(), Error> { self.0.validate() } } ================================================ FILE: qconnection/src/path/aa.rs ================================================ use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering}; use qbase::net::tx::{ArcSendWaker, Signals}; pub const DEFAULT_ANTI_FACTOR: usize = 3; /// Therefore, after receiving packets from an address that is not yet validated, /// an endpoint MUST limit the amount of data it sends to the unvalidated address /// to N(three) times the amount of data received from that address. #[derive(Debug)] pub struct AntiAmplifier { // Each time data is received, credit is increased; // each time data is sent, credit is consumed. credit: AtomicUsize, // If the credit is exhausted, it needs to wait until // new data is received before it can continue to send. tx_waker: ArcSendWaker, state: AtomicU8, } impl AntiAmplifier { const NORMAL: u8 = 0; const GRANTED: u8 = 1; const ABORTED: u8 = 2; pub fn new(tx_waker: ArcSendWaker) -> Self { Self { credit: AtomicUsize::new(0), tx_waker, state: AtomicU8::new(0), } } /// Store N * amount of credit pub fn on_rcvd(&self, amount: usize) { if self.state.load(Ordering::Acquire) != Self::NORMAL { return; } self.credit.fetch_add(amount * N, Ordering::AcqRel); self.tx_waker.wake_by(Signals::CREDIT); } /// This function must only be called by one at a time, and the amount of data sent /// must be feed back to the anti-amplifier before poll_apply can be called again. pub fn balance(&self) -> Result, Signals> { match self.state.load(Ordering::Acquire) { Self::GRANTED => Ok(Some(usize::MAX)), Self::ABORTED => Ok(None), Self::NORMAL => { let credit = self.credit.load(Ordering::Acquire); if credit == 0 { // 再次检查,以防grant、abort在self.waker赋值前被调用,导致任务死掉 let state = self.state.load(Ordering::Acquire); if state == Self::NORMAL { Err(Signals::CREDIT) } else { self.tx_waker.wake_by(Signals::CREDIT); if state == Self::GRANTED { Ok(Some(usize::MAX)) } else { Ok(None) } } } else { Ok(Some(credit)) } } _ => unreachable!(), } } pub fn on_sent(&self, amount: usize) { if self.state.load(Ordering::Acquire) == Self::NORMAL { self.credit.fetch_sub(amount, Ordering::AcqRel); } } pub fn grant(&self) { if self .state .compare_exchange( Self::NORMAL, Self::GRANTED, Ordering::AcqRel, Ordering::Acquire, ) .is_ok() { self.tx_waker.wake_by(Signals::CREDIT); } } pub fn abort(&self) { if self .state .compare_exchange( Self::NORMAL, Self::ABORTED, Ordering::AcqRel, Ordering::Acquire, ) .is_ok() { self.tx_waker.wake_by(Signals::CREDIT); } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_deposit_and_poll_apply() { let waker = ArcSendWaker::new(); let anti_amplifier = AntiAmplifier::<3>::new(waker); // Initially, no credit assert_eq!(anti_amplifier.balance(), Err(Signals::CREDIT)); // Deposit 1 unit of data, should add 3 units of credit anti_amplifier.on_rcvd(1); assert_eq!(anti_amplifier.credit.load(Ordering::Acquire), 3); // Apply for 2 units of data, should return 2 units assert_eq!(anti_amplifier.balance(), Ok(Some(3))); assert_eq!(anti_amplifier.credit.load(Ordering::Acquire), 3); anti_amplifier.on_sent(3); // No credit left, should return Pending assert_eq!(anti_amplifier.balance(), Err(Signals::CREDIT)); } #[test] fn test_multiple_deposits() { let waker = ArcSendWaker::new(); let anti_amplifier = AntiAmplifier::<3>::new(waker); // Deposit 1 unit of data, should add 3 units of credit anti_amplifier.on_rcvd(1); assert_eq!(anti_amplifier.credit.load(Ordering::Acquire), 3); // Deposit another 1 unit of data, should add another 3 units of credit anti_amplifier.on_rcvd(1); assert_eq!(anti_amplifier.credit.load(Ordering::Acquire), 6); // Apply for 5 units of data, should return 5 units assert_eq!(anti_amplifier.balance(), Ok(Some(6))); assert_eq!(anti_amplifier.credit.load(Ordering::Acquire), 6); // Post sent 5 units, should reduce credit by 5 anti_amplifier.on_sent(5); assert_eq!(anti_amplifier.credit.load(Ordering::Acquire), 1); } } ================================================ FILE: qconnection/src/path/burst.rs ================================================ use std::{ io, ops::Deref, sync::{Arc, atomic::Ordering::Acquire}, }; use bytes::BufMut; use derive_more::From; use qbase::{ Epoch, GetEpoch, cid::{BorrowedCid, ConnectionId}, frame::PingFrame, net::tx::{ArcSendWaker, Signals}, packet::{ AssemblePacket, Package, PacketContent, PacketInfo, ProductHeader, header::{ long::{HandshakeHeader, InitialHeader, ZeroRttHeader, io::LongHeaderBuilder}, short::OneRttHeader, }, io::{Packages, PadProbe, PadTo20, PadToFull, Repeat}, signal::SpinBit, }, role::Role, token::TokenRegistry, }; use qcongestion::{ArcCC, Transport}; use qinterface::io::IO; use qrecovery::journal::{AckPackege, ArcRcvdJournal, Journal}; use qtraversal::packet::{ForwardHeader, WriteForwardHeader}; use crate::{ ArcDcidCell, ArcReliableFrameDeque, CidRegistry, Components, path::{AntiAmplifier, Constraints}, space::{Spaces, data::DataSpace, handshake::HandshakeSpace, initial::InitialSpace}, tls::ArcTlsHandshake, tx::PacketWriter, }; // /// Trait alias // pub trait PackageIntoSpacePacketWriter>: // for<'b, 's> Package> // { // } // impl, P> PackageIntoSpacePacketWriter for P where // P: for<'b, 's> Package> // { // } // pn space? pub trait PacketSpace { type JournalFrame; fn new_packet<'b, 's>( &'s self, header: H, cc: &ArcCC, buffer: &'b mut [u8], ) -> Result, Signals>; } pub struct Burst { path: Arc, initial_token: Vec, cid_registry: CidRegistry, spin: bool, spaces: Spaces, tls_handshake: ArcTlsHandshake, } impl super::Path { pub fn new_burst(self: &Arc, components: &Components) -> Burst { Burst { path: self.clone(), initial_token: match components.token_registry.deref() { TokenRegistry::Client((server_name, token_sink)) => { token_sink.fetch_token(server_name) } TokenRegistry::Server(..) => vec![], }, cid_registry: components.cid_registry.clone(), spin: false, // TODO spaces: components.spaces.clone(), tls_handshake: components.tls_handshake.clone(), } } } // 用双层Result #[derive(From)] pub enum BurstError { Signals(Signals), PathDeactived, } pub struct PacketsAssembler<'a> { cc: &'a ArcCC, constraints: Constraints, cid_registry: &'a CidRegistry, borrowed_dcid: Result, Signals>, initial_token: &'a [u8], spin: SpinBit, } impl<'a> PacketsAssembler<'a> { fn new( cid_registry: &'a CidRegistry, dcid_cell: &'a ArcDcidCell, anti_amplifier: &AntiAmplifier, cc: &'a ArcCC, tx_waker: ArcSendWaker, initial_token: &'a [u8], spin: impl Into, ) -> Result, BurstError> { let send_quota = cc.send_quota()?; let Some(credit_limit) = anti_amplifier.balance()? else { return Err(BurstError::PathDeactived); }; let Some(borrowed_dcid) = dcid_cell.borrow_cid(tx_waker).transpose() else { return Err(BurstError::PathDeactived); }; let constraints = Constraints::new(credit_limit, send_quota); Ok(Self { cid_registry, borrowed_dcid, cc, constraints, initial_token, spin: spin.into(), }) } fn initial_scid(&self) -> Result { self.cid_registry .local .initial_scid() .ok_or(Signals::empty()) } fn applied_dcid(&self) -> Result { self.borrowed_dcid.as_deref().copied().map_err(|e| *e) } /// Return the connection ID that used to send the initial and zero rtt packets. /// /// dquic implements multi-path handshake feature, the client creates many paths and sends initial packets. /// /// Client will only use origin_dcid to send initial and zero rtt packets. /// /// The client and server must negotiate a handshake path and assign the initial dcid to this path /// to prevent the unique connection ID from being obtained by an invalid path, causing the connection to fail. /// /// The client and server choose the path where they receive the first initial packet as the handshake path. /// The server will only return the initial packet on the handshake path to negotiate the handshake path. /// /// Therefore, for the server, it can only send the initial packet with the connection ID assigned to the path. /// This manifests itself during the handshake as sending the initial packet only on the first path. fn initial_dcid(&self) -> Result { match self.cid_registry.role() { Role::Client => Ok(self.cid_registry.origin_dcid()), Role::Server => self.applied_dcid(), } } pub fn commit(&mut self, sent_bytes: usize, pkt_info: PacketInfo) { self.constraints.commit(sent_bytes, pkt_info.in_flight()); self.cc.on_pkt_sent( pkt_info.epoch().expect("todo"), pkt_info.packet_number(), pkt_info.ack_eliciting(), sent_bytes, pkt_info.in_flight(), pkt_info.largest_ack(), ); } } impl ProductHeader for PacketsAssembler<'_> { fn new_header(&self) -> Result { Ok( LongHeaderBuilder::with_cid(self.initial_dcid()?, self.initial_scid()?) .initial(self.initial_token.to_vec()), ) } } impl ProductHeader for PacketsAssembler<'_> { fn new_header(&self) -> Result { Ok(LongHeaderBuilder::with_cid(self.initial_dcid()?, self.initial_scid()?).zero_rtt()) } } impl ProductHeader for PacketsAssembler<'_> { fn new_header(&self) -> Result { Ok(LongHeaderBuilder::with_cid(self.applied_dcid()?, self.initial_scid()?).handshake()) } } impl ProductHeader for PacketsAssembler<'_> { fn new_header(&self) -> Result { Ok(OneRttHeader::new(self.spin, self.applied_dcid()?)) } } impl<'a> PacketsAssembler<'a> { pub fn assemble<'s, 'b, H, Space, P>( &mut self, space: &'s Space, data_sources: P, buffer: &'b mut [u8], packet_content: &mut PacketContent, ) -> Result where Self: ProductHeader, Space: PacketSpace + GetEpoch, Space::JournalFrame: 's, P: Package>, { let buffer = self.constraints.constrain(buffer); let mut packet = space.new_packet(self.new_header()?, self.cc, buffer)?; *packet_content += packet.assemble_packet(&mut Packages((data_sources, PadTo20)))?; let (sent_bytes, props) = packet.encrypt_and_protect_packet(); self.commit(sent_bytes, props); Result::<_, Signals>::Ok(sent_bytes) } } pub type PackageIntoSpace = dyn for<'b, 's> Package>::JournalFrame>> + Send; pub struct DataSources { initial: Box>, zero_rtt: Box>, handshake: Box>, one_rtt: Box>, } impl Components { pub(super) fn packages(&self) -> DataSources { let initial_packages = self.crypto_streams[Epoch::Initial] .outgoing() .package(Epoch::Initial); let zero_rtt_packages = Packages(( // repeat to send multi reliable frames in one packet Repeat(self.reliable_frames.clone()), // repeat to send multi stream frames in one packet Repeat( self.data_streams .package(self.flow_ctrl.sender.clone(), true), ), // TODO: datagram )); let handshake_packages = self.crypto_streams[Epoch::Handshake] .outgoing() .package(Epoch::Handshake); let one_rtt_packages = Packages(( self.crypto_streams[Epoch::Data] .outgoing() .package(Epoch::Data), // repeat to send multi reliable frames in one packet Repeat(self.reliable_frames.clone()), // repeat to send multi stream frames in one packet Repeat( self.data_streams .package(self.flow_ctrl.sender.clone(), false), ), // TODO: datagram )); DataSources { initial: Box::new(initial_packages), zero_rtt: Box::new(zero_rtt_packages), handshake: Box::new(handshake_packages), one_rtt: Box::new(one_rtt_packages), } } } fn ack_package<'s, S, F>(space: &'s S, cc: &ArcCC) -> AckPackege<'s> where S: GetEpoch + AsRef>, F: 's, { // (1) may_loss被调用时cc已经被锁定,may_loss会尝试锁定sent_journal // (2) PacketMemory会持有sent_journal的guard,而need_ack会尝试锁定cc // 在PacketMemory存在时尝试锁定cc,可能会和 (1) 冲突: // (1)持有cc,要锁定sent_journal;(2)持有sent_journal要锁定cc // 在多线程的情况下,可能会发生死锁。所以提前调用need_ack,避免交叉导致死锁 ArcRcvdJournal::ack_package(space.as_ref().as_ref(), cc.need_ack(space.epoch())) } impl Burst { fn assembler<'a>(&'a self) -> Result, BurstError> { PacketsAssembler::new( &self.cid_registry, &self.path.dcid_cell, &self.path.anti_amplifier, &self.path.cc, self.path.tx_waker.clone(), &self.initial_token, self.spin, ) } fn load_spaces( &self, DataSources { initial: initial_data_sources, zero_rtt: zero_rtt_data_sources, handshake: handshake_data_sources, one_rtt: one_rtt_data_sources, }: &mut DataSources, mut buffer: &mut [u8], ) -> Result<(usize, PacketContent), BurstError> { let Self { path, spaces, tls_handshake, .. } = self; let initial_space = spaces.initial().as_ref(); let handshake_space = spaces.handshake().as_ref(); let data_space = spaces.data().as_ref(); let origin = buffer.remaining_mut(); let mut packet_content = PacketContent::default(); let mut assembler = self.assembler()?; let mut signals = Signals::empty(); let Ok(tls_fin) = tls_handshake.is_finished() else { return Err(BurstError::PathDeactived); }; match assembler.assemble( initial_space, &mut Packages((ack_package(initial_space, &path.cc), initial_data_sources)), buffer, &mut packet_content, ) { Ok(bytes_sent) => buffer = buffer[bytes_sent..].as_mut(), Err(s) => signals |= s, }; let loaded_initial = buffer.remaining_mut() != origin; if !tls_fin { match assembler.assemble::( data_space, zero_rtt_data_sources, buffer, &mut packet_content, ) { Ok(bytes_sent) => buffer = buffer[bytes_sent..].as_mut(), Err(s) => signals |= s, } } match assembler.assemble( handshake_space, &mut Packages(( ack_package(handshake_space, &path.cc), handshake_data_sources, )), buffer, &mut packet_content, ) { Ok(bytes_sent) => buffer = buffer[bytes_sent..].as_mut(), Err(s) => signals |= s, } if tls_fin { let result = if path.validated.load(Acquire) { assembler.assemble::( data_space, &mut Packages(( ack_package(data_space, &path.cc), &path.challenge_sndbuf, &path.response_sndbuf, one_rtt_data_sources, loaded_initial.then_some(PadToFull), PadProbe, )), buffer, &mut packet_content, ) } else { assembler.assemble::( data_space, &mut Packages(( ack_package(data_space, &path.cc), &path.challenge_sndbuf, &path.response_sndbuf, loaded_initial.then_some(PadToFull), PadProbe, )), buffer, &mut packet_content, ) }; match result { Ok(bytes_sent) => buffer = buffer[bytes_sent..].as_mut(), Err(s) => signals |= s, } } if loaded_initial { assert!(buffer.remaining_mut() != origin); buffer.put_bytes(0, buffer.remaining_mut()); return Ok((origin, packet_content)); } let sent_bytes = origin - buffer.remaining_mut(); (sent_bytes > 0) .then_some((sent_bytes, packet_content)) .ok_or(BurstError::Signals(signals)) } } struct PingSource { need_send_ack_eliciting: usize, } impl Package for PingSource where PingFrame: Package, { fn dump(&mut self, target: &mut Target) -> Result { if self.need_send_ack_eliciting > 0 { return PingFrame.dump(target); } // TODO: refactor signal names Err(Signals::PING) } } fn ping_package(cc: &ArcCC, epoch: Epoch) -> PingSource { // avoid deadlock, same as ack_package PingSource { need_send_ack_eliciting: cc.need_send_ack_eliciting(epoch), } } impl Burst { fn load_ping(&self, buffer: &mut [u8]) -> Result<(usize, PacketContent), BurstError> { let Self { spaces, path, .. } = self; let mut assembler = self.assembler()?; let mut signals = Signals::empty(); let mut packet_content = PacketContent::default(); for &epoch in Epoch::iter().rev() { let result = match epoch { Epoch::Data => { let ack_package = ack_package(spaces.data().as_ref(), &path.cc); let ping_package = ping_package(&path.cc, epoch); assembler.assemble::( spaces.data().as_ref(), &mut Packages((ack_package, ping_package, PadToFull)), buffer, &mut packet_content, ) } Epoch::Handshake => { let ack_package = ack_package(spaces.handshake().as_ref(), &path.cc); let ping_package = ping_package(&path.cc, epoch); assembler.assemble( spaces.handshake().as_ref(), &mut Packages((ack_package, ping_package, PadToFull)), buffer, &mut packet_content, ) } Epoch::Initial => { let ack_package = ack_package(spaces.initial().as_ref(), &path.cc); let ping_package = ping_package(&path.cc, epoch); assembler.assemble( spaces.initial().as_ref(), &mut Packages((ack_package, ping_package, PadToFull)), buffer, &mut packet_content, ) } }; match result { Ok(sent_bytes) => return Ok((sent_bytes, packet_content)), Err(s) => signals |= s, } } Err(BurstError::Signals(signals)) } fn load_heartbeat(&self, buffer: &mut [u8]) -> Result<(usize, PacketContent), BurstError> { let Self { spaces, path, .. } = self; let mut assembler = self.assembler()?; let mut packet_content = PacketContent::default(); match assembler.assemble::( spaces.data().as_ref(), &path.heartbeat_sndbuf, buffer, &mut packet_content, ) { Ok(sent_bytes) => Ok((sent_bytes, packet_content)), Err(s) => Err(BurstError::Signals(s)), } } pub async fn burst<'b>( &self, data_sources: &mut DataSources, buffers: &'b mut Vec>, ) -> Result>, BurstError> { let Ok(max_segments) = self.path.interface.max_segments() else { return Err(BurstError::PathDeactived); }; let Ok(max_segment_size) = self.path.interface.max_segment_size() else { return Err(BurstError::PathDeactived); }; if buffers.len() < max_segments { buffers.resize_with(max_segments, || vec![0; max_segment_size]); } use core::ops::ControlFlow::*; let reversed_size = ForwardHeader::encoding_size(&self.path.pathway); let (Break(result) | Continue(result)) = buffers .iter_mut() .map(move |buffer| { if buffer.len() < max_segment_size { buffer.resize(max_segment_size, 0); } &mut buffer[..max_segment_size] }) .map(move |segment| { let buffer_size = segment.len().min(self.path.mtu() as _); let buffer = &mut segment[..buffer_size][reversed_size..]; self.load_spaces(data_sources, buffer) .inspect(|(_, packet_content)| { self.path.idle_timer.on_sent(*packet_content); }) .or_else(|error| match error { BurstError::Signals(signals) => { self.load_ping(buffer).map_err(|e| match e { BurstError::Signals(s) => BurstError::Signals(signals | s), e @ BurstError::PathDeactived => e, }) } e @ BurstError::PathDeactived => Err(e), }) .or_else(|error| match error { BurstError::Signals(signals) => { self.load_heartbeat(buffer).map_err(|e| match e { BurstError::Signals(s) => BurstError::Signals(signals | s), e @ BurstError::PathDeactived => e, }) } e @ BurstError::PathDeactived => Err(e), }) .map(|(packet_size, _)| { if reversed_size > 0 { let (mut header, payload) = segment.split_at_mut(reversed_size); let forward_hdr = ForwardHeader::new( 0, // FIXME: unwrap &self.path.pathway, payload, ); tracing::trace!(?forward_hdr, link=%self.path.link(),"put forward header"); header.put_forward_header(&forward_hdr); } io::IoSlice::new(&segment[..reversed_size + packet_size]) }) }) .try_fold( Ok(Vec::with_capacity(max_segments)), |segments, load_result| match (segments, load_result) { (Ok(segments), Err(signals)) if segments.is_empty() => Break(Err(signals)), (Ok(segments), Err(_signals)) => Break(Ok(segments)), (Ok(mut segments), Ok(segment)) if segment.len() < segments.last().copied().unwrap_or_default() => { segments.push(segment.len()); Break(Ok(segments)) } (Ok(mut segments), Ok(segment)) => { segments.push(segment.len()); Continue(Ok(segments)) } (Err(_), _) => unreachable!("segments should not be Err in this context"), }, ); Ok(result? .iter() .zip(buffers) .map(|(&len, buffer)| io::IoSlice::new(&buffer[..len])) .collect()) } } ================================================ FILE: qconnection/src/path/drive.rs ================================================ use qcongestion::Transport; use tokio::time::Duration; use crate::{path::PathDeactivated, tls::ArcTlsHandshake}; impl super::Path { pub async fn drive(&self, _tls_handshake: ArcTlsHandshake) -> Result<(), PathDeactivated> { loop { tokio::time::sleep(Duration::from_millis(10)).await; if let Some(frame) = self.idle_timer.health()? { self.heartbeat_sndbuf.write(frame); } self.cc.do_tick()?; } } } ================================================ FILE: qconnection/src/path/error.rs ================================================ use derive_more::From; use qbase::{error::Error as QuicError, time::TimeOut}; use qcongestion::TooManyPtos; use qinterface::bind_uri::BindUri; use thiserror::Error; use crate::path::validate::ValidateFailure; #[derive(Debug, From, Error)] pub enum CreatePathFailure { #[error("Network interface not found for bind URI: {0}")] NoInterface(BindUri), #[error("Connection is closed")] ConnectionClosed(QuicError), } #[derive(Debug, From, Error)] pub enum PathDeactivated { #[error("Path validation failed")] Invalid(#[source] ValidateFailure), #[error(transparent)] Idle(TimeOut), #[error("Lost path state")] Lost(#[source] TooManyPtos), #[error("Failed to send packets on path")] Io(#[source] std::io::Error), #[error("Manually removed by application")] App, } ================================================ FILE: qconnection/src/path/paths.rs ================================================ use std::{ future::Future, sync::{Arc, Mutex, Weak}, time::Duration, }; use dashmap::DashMap; use derive_more::Deref; use qbase::{ Epoch, cid::ConnectionId, error::{ErrorKind, QuicError}, net::{addr::EndpointAddr, route::Pathway, tx::ArcSendWakers}, }; use qcongestion::Transport; use qevent::telemetry::Instrument; use tokio_util::task::AbortOnDropHandle; use tracing::Instrument as _; use super::Path; use crate::{ ArcRemoteCids, events::{ArcEventBroker, EmitEvent, Event}, path::{CreatePathFailure, PathDeactivated}, }; #[derive(Deref)] pub struct PathContext { #[deref] path: Arc, _task: AbortOnDropHandle<()>, } #[derive(Clone)] pub struct ArcPathContexts { paths: Arc>, tx_wakers: ArcSendWakers, broker: ArcEventBroker, initial_path: Arc>>>, } impl ArcPathContexts { pub fn new(tx_wakers: ArcSendWakers, broker: ArcEventBroker) -> Self { Self { paths: Default::default(), tx_wakers, broker, initial_path: Arc::default(), } } pub fn assign_handshake_path( &self, path: &Arc, remote_cids: &ArcRemoteCids, initial_dcid: ConnectionId, ) -> bool { let mut handshake_path = self.initial_path.lock().unwrap(); if handshake_path.is_some() { return false; } remote_cids.apply_initial_dcid(initial_dcid, &path.dcid_cell); *handshake_path = Some(Arc::downgrade(path)); true } pub fn handshake_path(&self) -> Option> { self.initial_path .lock() .unwrap() .clone() .expect("unreachable: Handshake packet received before first initial packet processed") .upgrade() } pub fn get_or_try_create_with( &self, pathway: Pathway, try_create: impl FnOnce() -> Result<(Arc, T), CreatePathFailure>, ) -> Result, CreatePathFailure> where T: Future> + Send + 'static, { match self.paths.entry(pathway) { dashmap::Entry::Occupied(occupied_entry) => Ok(occupied_entry.get().path.clone()), dashmap::Entry::Vacant(vacant_entry) => { let (path, task) = try_create()?; self.tx_wakers.insert(pathway, &path.tx_waker); let paths = self.clone(); let task = AbortOnDropHandle::new(tokio::spawn( async move { let reason = task.await.unwrap_err(); paths.remove(&pathway, &reason); } .instrument_in_current() .in_current_span(), )); Ok(vacant_entry .insert(PathContext { path, _task: task }) .clone()) } } } pub fn get(&self, pathway: &Pathway) -> Option> { self.paths.get(pathway).map(|p| p.path.clone()) } pub fn remove(&self, pathway: &Pathway, reason: &PathDeactivated) { if self.paths.remove(pathway).is_some() { self.tx_wakers.remove(pathway); tracing::debug!(target: "quic", %pathway, %reason, "path deactivated"); if self.is_empty() { let error = QuicError::with_default_fty( ErrorKind::NoViablePath, format!("No viable path exist, last path removed because: {reason}"), ); self.broker.emit(Event::Failed(error)); } } } pub fn is_empty(&self) -> bool { self.paths.is_empty() } pub fn max_pto_duration(&self) -> Option { self.paths.iter().map(|p| p.cc().get_pto(Epoch::Data)).max() } pub fn paths)>>(&self) -> C { self.paths .iter() .map(|p| (*p.key(), p.path.clone())) .collect() } pub fn discard_initial_and_handshake_space(&self) { self.paths.iter().for_each(|p| { p.cc().discard_epoch(Epoch::Initial); p.cc().discard_epoch(Epoch::Handshake); }); } pub fn clear(&self) { self.paths.clear(); } pub fn on_path_validated(&self, pathway: Pathway) { if matches!(pathway.remote(), EndpointAddr::Direct { .. }) { self.paths.iter().for_each(|p| { if matches!(p.pathway.remote(), EndpointAddr::Direct { .. }) { p.path.deactivate(); } }); } } } ================================================ FILE: qconnection/src/path/util.rs ================================================ use std::{ pin::Pin, sync::Mutex, task::{Context, Poll}, }; use bytes::BufMut; use futures::StreamExt; use qbase::{ net::tx::{ArcSendWaker, Signals}, packet::{Package, PacketContent}, util::ArcAsyncDeque, }; /// A buffer that contains a single frame to be sent. /// /// This struct impl [`Default`], and the `new` method is not provided. pub struct SendBuffer { item: Mutex>, tx_waker: ArcSendWaker, } impl SendBuffer { pub fn new(tx_waker: ArcSendWaker) -> Self { Self { item: Default::default(), tx_waker, } } /// Write a frame to the buffer. /// /// [`SendBuffer`] can only buffer one frame at a time. If you write a new frame to the buffer before the previous /// frame is sent, the previous frame will be overwritten. pub fn write(&self, frame: T) { self.tx_waker.wake_by(Signals::TRANSPORT); *self.item.lock().unwrap() = Some(frame); } } impl SendBuffer { /// Try load the frame to be sent into the `packet`. pub fn try_load_frames_into(&self, packet: &mut P) -> Result<(), Signals> where for<'a> &'a F: Package

, { let mut guard = self.item.lock().unwrap(); match guard.as_ref() { Some(mut frame) => { frame.dump(packet)?; guard.take().unwrap(); Ok(()) } None => Err(Signals::TRANSPORT), } } } impl Package

for &SendBuffer where for<'a> &'a F: Package

, { #[inline] fn dump(&mut self, into: &mut P) -> Result { self.try_load_frames_into(into)?; Ok(PacketContent::EffectivePayload) } } /// A buffer to cache received frames. /// /// /// [`Stream`] is implemented for this struct, you can use it as a stream to receive frames. /// /// You can also use the [`RecvBuffer::receive`] method to wait for a frame to be received. /// /// # Example /// ```rust /// use qconnection::path::RecvBuffer; /// use futures::StreamExt; /// # async fn demo() { /// let rcv_buf = RecvBuffer::default(); /// /// tokio::spawn({ /// let rcv_buf = rcv_buf.clone(); /// async move { /// let value = rcv_buf.receive().await; /// assert_eq!(value, Some(42u32)); /// } /// }); /// /// rcv_buf.write(42u32); /// # } /// ``` /// /// [`Stream`]: futures::Stream /// [`Future`]: core::future::Future #[derive(Clone, Debug, Default)] pub struct RecvBuffer(ArcAsyncDeque); impl RecvBuffer { /// Create a new empty [`RecvBuffer`]. pub fn new() -> Self { Self(ArcAsyncDeque::with_capacity(2)) } /// Write a frame to the buffer. pub fn write(&self, value: T) { self.0.push_back(value); } /// Waiting for a frame to be received. pub async fn receive(&self) -> Option { let mut this = self; this.next().await } /// Dismiss the buffer /// /// Append received frames will be Ignored, existing frames will be dropped, the future will return `None`. pub fn dismiss(&self) { self.0.close(); } } impl futures::Stream for RecvBuffer { type Item = T; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.0.poll_pop(cx) } } impl futures::Stream for &RecvBuffer { type Item = T; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.0.poll_pop(cx) } } /// The constraints for sending data, appllied to the data buffer. #[derive(Debug, Clone, Copy)] pub struct Constraints { /// Credit limit which is from anti-amplification attack limit, appllied to all data, including the packet header. /// /// When the verification is passed, the limit will be removed, and the value is `usize::MAX`. // 信用额度,源于抗放大攻击;当验证通过后,将不再设限,表现为usize::MAX // 作用于所有数据,包括包头 credit_limit: usize, /// Send quota, which is from the congestion control algorithm. As time goes by, the amount of data that should be /// sent. /// /// It is applied to ack-eliciting data packets, unless the packet only sends Padding/Ack/Ccf frames. // 发送配额,源于拥塞控制算法,随着时间的流逝,得到的本次Burst应当发送的数据量 // 作用于ack-eliciting数据包,除非该包只发送Padding/Ack/Ccf帧 send_quota: usize, } impl Constraints { /// Create a new [`Constraints`] with the given credit limit and send quota. pub fn new(credit_limit: usize, send_quota: usize) -> Self { Self { credit_limit, send_quota, } } /// Return whether the constraints are available(More frames can be send). /// /// The conditions for ending is the credit limit is used up. Even if the send quota is not used up, packets that /// only contain Padding/Ack/Ccf can still be sent. /// // 结束条件 // - 抗放大攻击额度用完 // - 抗放大攻击额度没用完,但发送配额用完 // + 此时,仍可以仅发送Ack帧 pub fn is_available(&self) -> bool { self.credit_limit > 0 } /// Constrain the buffer, make it smaller than the limit and quota. pub fn constrain<'b>(&self, buf: &'b mut [u8]) -> &'b mut [u8] { let min_len = buf .remaining_mut() .min(self.credit_limit) .min(self.send_quota); &mut buf[..min_len] } pub fn available(&self) -> usize { self.credit_limit.min(self.send_quota) } /// Commit consumption of credit limit and send quota. /// /// The `len` is how much data was written to the constrained buffer, `is_just_ack` instruct whether the send quota /// should be consume. /// /// See [section-12.4-14.4.1](https//rfc-editor.org/rfc/rfc9000.html#section-12.4-14.4.1) /// and [table 3](https//rfc-editor.org/rfc/rfc9000.html#table-3) /// of [RFC9000](https//rfc-editor.org/rfc/rfc9000.html) for more details. pub fn commit(&mut self, len: usize, in_flight: bool) { self.credit_limit = self.credit_limit.saturating_sub(len); if in_flight { self.send_quota = self.send_quota.saturating_sub(len); } } } /// The struct that can be constrained by the [`Constraints`], usually a buffer. pub trait ApplyConstraints { /// Apply the [`Constraints`] on the struct. fn apply(self, constraints: &Constraints) -> Self; } impl ApplyConstraints for &mut [u8] { fn apply(self, constraints: &Constraints) -> Self { constraints.constrain(self) } } ================================================ FILE: qconnection/src/path/validate.rs ================================================ use std::{sync::atomic::Ordering, time::Duration}; use qbase::{frame::PathChallengeFrame, net::tx::Signals}; use qcongestion::Transport; use thiserror::Error; use tokio::time::Instant; #[derive(Debug, Error, Clone, Copy)] pub enum ValidateFailure { #[error( "Path validation abort due to path inactivity by other reasons(usually connection closed)" )] PathInactive, #[error("Path validation failed after {0} ms", elapsed.as_millis())] Timeout { elapsed: Duration }, } impl super::Path { pub fn validated(&self) { self.validated.store(true, Ordering::Release); self.tx_waker.wake_by(Signals::PATH_VALIDATE); } pub async fn validate(&self) -> Result<(), ValidateFailure> { let challenge = PathChallengeFrame::random(); let start = Instant::now(); for _ in 0..30 { let timeout_duration = self.cc().get_pto(qbase::Epoch::Data); self.challenge_sndbuf.write(challenge); match tokio::time::timeout(timeout_duration, self.response_rcvbuf.receive()).await { Ok(Some(response)) if *response == *challenge => { self.validated(); self.anti_amplifier.grant(); tracing::debug!(target: "quic", pathway=%self.pathway, "path validated successfully"); return Ok(()); } // 外部发生变化,导致路径验证任务作废 Ok(None) => return Err(ValidateFailure::PathInactive), // 超时或者收到不对的response,按"停-等协议",继续再发一次Challenge,最多3次 _ => continue, } } Err(ValidateFailure::Timeout { elapsed: start.elapsed(), }) } } ================================================ FILE: qconnection/src/path.rs ================================================ use std::{ io, sync::{ Arc, atomic::{AtomicBool, AtomicU16, Ordering}, }, }; use qbase::{ Epoch, error::Error, frame::{PathChallengeFrame, PathResponseFrame, PingFrame, io::ReceiveFrame}, net::{ route::{Line, Link, Pathway, Route}, tx::ArcSendWaker, }, packet::PacketContent, param::ParameterId, time::ArcIdleTimer, }; use qcongestion::{Algorithm, ArcCC, Feedback, HandshakeStatus, MSS, PathStatus, Transport}; use qevent::{quic::connectivity::PathAssigned, telemetry::Instrument}; use qinterface::{ Interface, bind_uri::BindUri, io::{IO, IoExt}, }; use tokio::time::Duration; mod aa; mod burst; mod drive; pub mod error; pub mod paths; pub mod util; mod validate; pub use aa::*; pub use burst::PacketSpace; pub use error::*; pub use paths::*; use tokio_util::task::AbortOnDropHandle; use tracing::Instrument as _; pub use util::*; use crate::{ArcDcidCell, Components, path::burst::BurstError}; // pub mod burst; pub struct Path { interface: Interface, validated: AtomicBool, active: AtomicBool, link: Link, pathway: Pathway, cc: ArcCC, dcid_cell: ArcDcidCell, anti_amplifier: AntiAmplifier, idle_timer: ArcIdleTimer, heartbeat_sndbuf: SendBuffer, challenge_sndbuf: SendBuffer, response_sndbuf: SendBuffer, response_rcvbuf: RecvBuffer, tx_waker: ArcSendWaker, pmtu: Arc, status: PathStatus, } impl Components { pub fn get_or_try_create_path( &self, bind_uri: BindUri, link: Link, pathway: Pathway, is_probed: bool, ) -> Result, CreatePathFailure> { let try_create = || { let interface = self .interfaces .borrow(&bind_uri) .ok_or(CreatePathFailure::NoInterface(bind_uri))?; let dcid_cell = self.cid_registry.remote.apply_dcid(); let max_ack_delay = self .parameters .lock_guard()? .get_local(ParameterId::MaxAckDelay) .expect("unreachable: default value will be got if the value unset"); let is_initial_path = self.conn_state.try_entry_attempted(self, link)?; qevent::event!(PathAssigned { path_id: pathway.to_string(), path_local: link.src, path_remote: link.dst, }); let path = Arc::new(Path::new( interface, link, pathway, dcid_cell, max_ack_delay, self.idle_config.timer(), [ Arc::new( self.spaces .initial() .tracker(self.crypto_streams[Epoch::Initial].clone()), ), Arc::new( self.spaces .handshake() .tracker(self.crypto_streams[Epoch::Handshake].clone()), ), Arc::new(self.spaces.data().tracker( self.crypto_streams[Epoch::Data].clone(), self.data_streams.clone(), self.reliable_frames.clone(), )), ], self.quic_handshake.status(), )); let validate = { let path = path.clone(); let paths = self.paths.clone(); let tls_handshake = self.tls_handshake.clone(); let conn_state = self.conn_state.clone(); async move { if !is_probed { path.grant_anti_amplification(); } if tls_handshake.info().await.is_err() { return Ok(()); } match paths.handshake_path() { Some(handshake_path) if Arc::ptr_eq(&handshake_path, &path) => { path.validated(); Ok(()) } _ => { if conn_state.handshaked().await.is_err() { return Ok(()); } path.validate().await } } } }; let drive = { let path = path.clone(); let tls_handshake = self.tls_handshake.clone(); async move { path.drive(tls_handshake).await } }; let burst = { let path = path.clone(); let mut packages = self.packages(); let burst = path.new_burst(self); async move { let mut buffers = vec![]; loop { match burst.burst(&mut packages, &mut buffers).await { Ok(segments) => path.send_packets(&segments).await?, Err(BurstError::Signals(s)) => path.tx_waker.wait_for(s).await, Err(BurstError::PathDeactived) => return io::Result::Ok(()), } } } }; let task = async move { Err(tokio::select! { Ok(Err(e)) = AbortOnDropHandle::new(tokio::spawn(validate.instrument_in_current().in_current_span())) => PathDeactivated::from(e), Ok(Err(e)) = AbortOnDropHandle::new(tokio::spawn(drive.instrument_in_current().in_current_span())) => e, Ok(Err(e)) = AbortOnDropHandle::new(tokio::spawn(burst.instrument_in_current().in_current_span())) => PathDeactivated::from(e), }) }; let task = Instrument::instrument(task, qevent::span!(@current, path=pathway.to_string())) .in_current_span(); tracing::trace!(target: "quic", %pathway, %link, is_probed, is_initial_path, "add new path"); Ok((path, task)) }; self.paths.get_or_try_create_with(pathway, try_create) } } impl Path { #[allow(clippy::too_many_arguments)] pub fn new( interface: Interface, link: Link, pathway: Pathway, dcid_cell: ArcDcidCell, max_ack_delay: Duration, idle_timer: ArcIdleTimer, feedbacks: [Arc; 3], handshake_status: Arc, ) -> Self { let pmtu = Arc::new(AtomicU16::new(MSS as u16)); let path_status = PathStatus::new(handshake_status, pmtu.clone()); let tx_waker = ArcSendWaker::new(); let cc = ArcCC::new( Algorithm::NewReno, max_ack_delay, feedbacks, path_status.clone(), tx_waker.clone(), ); Self { interface, link, pathway, cc, dcid_cell, validated: AtomicBool::new(false), active: AtomicBool::new(true), anti_amplifier: AntiAmplifier::new(tx_waker.clone()), idle_timer, heartbeat_sndbuf: SendBuffer::new(tx_waker.clone()), challenge_sndbuf: SendBuffer::new(tx_waker.clone()), response_sndbuf: SendBuffer::new(tx_waker.clone()), response_rcvbuf: Default::default(), tx_waker, pmtu, status: path_status, } } pub fn cc(&self) -> &ArcCC { &self.cc } pub fn on_packet_rcvd( &self, epoch: Epoch, pn: u64, size: usize, packet_content: PacketContent, ) { self.anti_amplifier.on_rcvd(size); if size > 0 { self.status.release_anti_amplification_limit(); } self.idle_timer.on_rcvd(packet_content); self.cc() .on_pkt_rcvd(epoch, pn, packet_content.is_ack_eliciting()); } pub fn grant_anti_amplification(&self) { self.anti_amplifier.grant(); self.cc().grant_anti_amplification(); } pub fn mtu(&self) -> u16 { self.pmtu.load(Ordering::Acquire) } pub async fn send_packets(&self, bufs: &[io::IoSlice<'_>]) -> io::Result<()> { self.anti_amplifier .on_sent(bufs.iter().map(|s| s.len()).sum()); if self.anti_amplifier.balance().is_err() { self.status.enter_anti_amplification_limit(); } let line = Line::new(self.link, 64, None, self.mtu()); let route = Route::new(self.pathway, line); self.interface.sendmmsg(bufs, route).await } pub fn deactivate(&self) { self.active.store(false, Ordering::Release); } pub fn active(&self) { self.active.store(true, Ordering::Release); } pub fn link(&self) -> &Link { &self.link } pub fn pathway(&self) -> &Pathway { &self.pathway } pub fn bind_uri(&self) -> BindUri { self.interface.bind_uri() } } impl Drop for Path { fn drop(&mut self) { self.response_rcvbuf.dismiss(); } } impl ReceiveFrame for Path { type Output = (); fn recv_frame(&self, frame: PathChallengeFrame) -> Result { self.response_sndbuf.write(frame.into()); Ok(()) } } impl ReceiveFrame for Path { type Output = (); fn recv_frame(&self, frame: PathResponseFrame) -> Result { self.response_rcvbuf.write(frame); Ok(()) } } ================================================ FILE: qconnection/src/space/data.rs ================================================ use std::sync::Arc; use qbase::{ Epoch, GetEpoch, error::{Error, QuicError}, frame::{ ConnectionCloseFrame, Frame, ReliableFrame, io::{ReceiveFrame, SendFrame}, }, net::{ route::{Link, Pathway}, tx::Signals, }, packet::{ self, header::{GetType, OneRttHeader, long::ZeroRttHeader}, io::PacketSpace, keys::{ArcOneRttKeys, ArcZeroRttKeys, DirectionalKeys}, r#type::Type, }, util::BoundQueue, }; use qcongestion::{ArcCC, Feedback, Transport}; use qevent::{ quic::{ PacketHeader, PacketType, QuicFramesCollector, recovery::{PacketLost, PacketLostTrigger}, }, telemetry::Instrument, }; use qinterface::{ bind_uri::BindUri, component::route::{CipherPacket, PlainPacket}, }; use qrecovery::crypto::CryptoStream; use tokio::sync::mpsc; use crate::{ ArcReliableFrameDeque, Components, DataJournal, DataStreams, GuaranteedFrame, events::{ArcEventBroker, EmitEvent, Event}, path::{self, Path, error::CreatePathFailure}, space::{ AckDataSpace, FlowControlledDataStreams, assemble_closing_packet, filter_odcid_packet, pipe, read_plain_packet, }, state, termination::Terminator, tx::{PacketWriter, TrivialPacketWriter}, }; pub type CipherZeroRttPacket = CipherPacket; pub type PlainZeroRttPacket = PlainPacket; pub type ReceivedZeroRttFrom = (CipherZeroRttPacket, (BindUri, Pathway, Link)); pub type CipherOneRttPacket = CipherPacket; pub type PlainOneRttPacket = PlainPacket; pub type ReceivedOneRttFrom = (CipherOneRttPacket, (BindUri, Pathway, Link)); pub struct DataSpace { zero_rtt_keys: ArcZeroRttKeys, one_rtt_keys: ArcOneRttKeys, journal: DataJournal, } impl AsRef for DataSpace { fn as_ref(&self) -> &DataJournal { &self.journal } } impl DataSpace { pub fn new(zero_rtt_keys: ArcZeroRttKeys) -> Self { Self { zero_rtt_keys, one_rtt_keys: ArcOneRttKeys::new_pending(), journal: DataJournal::with_capacity(16, None), } } pub async fn decrypt_0rtt_packet( &self, packet: CipherZeroRttPacket, ) -> Option> { // TODO: client should never received 0rtt packet... match self.zero_rtt_keys.get_decrypt_keys()?.await { Some(keys) => { packet.decrypt_long_packet(keys.header.as_ref(), keys.packet.as_ref(), |pn| { self.journal.of_rcvd_packets().decode_pn(pn) }) } None => { packet.drop_on_key_unavailable(); None } } } pub async fn decrypt_1rtt_packet( &self, packet: CipherOneRttPacket, ) -> Option> { match self.one_rtt_keys.get_remote_keys().await { Some((hpk, pk)) => packet.decrypt_short_packet(hpk.as_ref(), &pk, |pn| { self.journal.of_rcvd_packets().decode_pn(pn) }), None => { packet.drop_on_key_unavailable(); None } } } pub fn is_one_rtt_keys_ready(&self) -> bool { self.one_rtt_keys.get_local_keys().is_some() } pub fn is_zero_rtt_avaliable(&self) -> bool { self.zero_rtt_keys.get_encrypt_keys().is_some() } pub fn one_rtt_keys(&self) -> ArcOneRttKeys { self.one_rtt_keys.clone() } pub fn zero_rtt_keys(&self) -> ArcZeroRttKeys { self.zero_rtt_keys.clone() } pub(crate) fn journal(&self) -> &DataJournal { &self.journal } pub fn tracker( &self, crypto_stream: CryptoStream, streams: DataStreams, reliable_frames: ArcReliableFrameDeque, ) -> DataTracker { DataTracker { journal: self.journal.clone(), crypto_stream, streams, reliable_frames, } } } impl GetEpoch for DataSpace { fn epoch(&self) -> Epoch { Epoch::Data } } impl path::PacketSpace for DataSpace { type JournalFrame = GuaranteedFrame; fn new_packet<'b, 's>( &'s self, header: ZeroRttHeader, cc: &ArcCC, buffer: &'b mut [u8], ) -> Result, Signals> { if self.one_rtt_keys.get_local_keys().is_some() { return Err(Signals::TLS_FIN); // should 1rtt } let Some(keys) = self.zero_rtt_keys.get_encrypt_keys() else { return Err(Signals::empty()); // no 0rtt keys, just skip 0rtt }; let (retran_timeout, expire_timeout) = cc.retransmit_and_expire_time(Epoch::Data); PacketWriter::new_long( header, buffer, keys, self.journal.as_ref(), retran_timeout, expire_timeout, ) } } impl PacketSpace for DataSpace { type PacketAssembler<'a> = TrivialPacketWriter<'a, 'a, GuaranteedFrame>; #[inline] fn new_packet<'a>( &'a self, header: ZeroRttHeader, buffer: &'a mut [u8], ) -> Result, Signals> { if self.one_rtt_keys.get_local_keys().is_some() { return Err(Signals::TLS_FIN); // should 1rtt } let Some(keys) = self.zero_rtt_keys.get_encrypt_keys() else { return Err(Signals::empty()); // no 0rtt keys, just skip 0rtt }; TrivialPacketWriter::new_long(header, buffer, keys, self.journal.as_ref()) } } impl path::PacketSpace for DataSpace { type JournalFrame = GuaranteedFrame; fn new_packet<'b, 's>( &'s self, header: OneRttHeader, cc: &ArcCC, buffer: &'b mut [u8], ) -> Result, Signals> { let (hpk, pk) = self.one_rtt_keys.get_local_keys().ok_or(Signals::KEYS)?; let (key_phase, pk) = pk.lock_guard().get_local(); let (retran_timeout, expire_timeout) = cc.retransmit_and_expire_time(Epoch::Data); PacketWriter::new_short( header, buffer, DirectionalKeys { header: hpk, packet: pk, }, key_phase, self.journal.as_ref(), retran_timeout, expire_timeout, ) } } impl PacketSpace for DataSpace { type PacketAssembler<'a> = TrivialPacketWriter<'a, 'a, GuaranteedFrame>; #[inline] fn new_packet<'a>( &'a self, header: OneRttHeader, buffer: &'a mut [u8], ) -> Result, Signals> { let (hpk, pk) = self.one_rtt_keys.get_local_keys().ok_or(Signals::KEYS)?; let (key_phase, pk) = pk.lock_guard().get_local(); TrivialPacketWriter::new_short( header, buffer, DirectionalKeys { header: hpk, packet: pk, }, key_phase, self.journal.as_ref(), ) } } fn frame_dispathcer( space: &DataSpace, components: &Components, event_broker: &ArcEventBroker, ) -> impl for<'p> Fn(Frame, Type, &'p Path) + use<> { let (ack_frames_entry, rcvd_ack_frames) = mpsc::unbounded_channel(); // 连接级的 let (max_data_frames_entry, rcvd_max_data_frames) = mpsc::unbounded_channel(); let (data_blocked_frames_entry, rcvd_data_blocked_frames) = mpsc::unbounded_channel(); let (new_cid_frames_entry, rcvd_new_cid_frames) = mpsc::unbounded_channel(); let (retire_cid_frames_entry, rcvd_retire_cid_frames) = mpsc::unbounded_channel(); let (handshake_done_frames_entry, rcvd_handshake_done_frames) = mpsc::unbounded_channel(); let (new_token_frames_entry, rcvd_new_token_frames) = mpsc::unbounded_channel(); // 数据级的 let (crypto_frames_entry, rcvd_crypto_frames) = mpsc::unbounded_channel(); let (stream_ctrl_frames_entry, rcvd_stream_ctrl_frames) = mpsc::unbounded_channel(); let (stream_frames_entry, rcvd_stream_frames) = mpsc::unbounded_channel(); #[cfg(feature = "datagram")] let (datagram_frames_entry, rcvd_datagram_frames) = mpsc::unbounded_channel(); let (punch_frames_entry, rcvd_punch_frames) = mpsc::unbounded_channel(); let (punch_hello_frames_entry, rcvd_punch_hello_frames) = mpsc::unbounded_channel(); let flow_controlled_data_streams = FlowControlledDataStreams::new( components.data_streams.clone(), components.flow_ctrl.clone(), ); // Assemble the pipelines of frame processing pipe( rcvd_retire_cid_frames, components.cid_registry.local.clone(), event_broker.clone(), ); pipe( rcvd_new_cid_frames, components.cid_registry.remote.clone(), event_broker.clone(), ); pipe( rcvd_max_data_frames, components.flow_ctrl.sender.clone(), event_broker.clone(), ); pipe( rcvd_data_blocked_frames, components.flow_ctrl.recver.clone(), event_broker.clone(), ); pipe( rcvd_handshake_done_frames, components .quic_handshake .discard_spaces_on_client_handshake_done(components.paths.clone()), event_broker.clone(), ); pipe( rcvd_crypto_frames, components.crypto_streams[space.epoch()].incoming(), event_broker.clone(), ); pipe( rcvd_stream_ctrl_frames, flow_controlled_data_streams.clone(), event_broker.clone(), ); pipe( rcvd_stream_frames, flow_controlled_data_streams, event_broker.clone(), ); #[cfg(feature = "datagram")] pipe( rcvd_datagram_frames, components.datagram_flow.clone(), event_broker.clone(), ); pipe( rcvd_ack_frames, AckDataSpace::new( &space.journal, components.data_streams.clone(), &components.crypto_streams[space.epoch()], ), event_broker.clone(), ); pipe( rcvd_new_token_frames, components.token_registry.clone(), event_broker.clone(), ); pipe(rcvd_punch_frames, components.clone(), event_broker.clone()); pipe( rcvd_punch_hello_frames, components.clone(), event_broker.clone(), ); let event_broker = event_broker.clone(); let rcvd_joural = space.journal.of_rcvd_packets(); move |frame: Frame, pty: packet::Type, path: &Path| match frame { Frame::Ack(f) => { path.cc().on_ack_rcvd(Epoch::Data, &f); rcvd_joural.on_rcvd_ack(&f); _ = ack_frames_entry.send(f) } Frame::NewToken(f) => _ = new_token_frames_entry.send(f), Frame::MaxData(f) => _ = max_data_frames_entry.send(f), Frame::NewConnectionId(f) => _ = new_cid_frames_entry.send(f), Frame::RetireConnectionId(f) => _ = retire_cid_frames_entry.send(f), Frame::HandshakeDone(f) => { // See [Section 4.1.2](https://datatracker.ietf.org/doc/html/rfc9001#handshake-confirmed) _ = handshake_done_frames_entry.send(f) } Frame::DataBlocked(f) => _ = data_blocked_frames_entry.send(f), Frame::PathChallenge(f) => _ = path.recv_frame(f), Frame::PathResponse(f) => _ = path.recv_frame(f), Frame::StreamCtl(f) => _ = stream_ctrl_frames_entry.send(f), Frame::Stream(f, data) => _ = stream_frames_entry.send((f, data)), Frame::Crypto(f, bytes) => _ = crypto_frames_entry.send((f, bytes)), #[cfg(feature = "datagram")] Frame::Datagram(f, data) => _ = datagram_frames_entry.send((f, data)), Frame::Close(f) if matches!(pty, Type::Short(_)) => event_broker.emit(Event::Closed(f)), Frame::AddAddress(frame) => { _ = punch_frames_entry.send(( path.bind_uri().clone(), *path.pathway(), *path.link(), ReliableFrame::AddAddress(frame), )) } Frame::RemoveAddress(frame) => { _ = punch_frames_entry.send(( path.bind_uri().clone(), *path.pathway(), *path.link(), ReliableFrame::RemoveAddress(frame), )) } Frame::PunchMeNow(frame) => { _ = punch_frames_entry.send(( path.bind_uri().clone(), *path.pathway(), *path.link(), ReliableFrame::PunchMeNow(frame), )) } Frame::PunchHello(frame) => { _ = punch_hello_frames_entry.send(( path.bind_uri().clone(), *path.pathway(), *path.link(), frame, )) } Frame::PunchDone(frame) => { _ = punch_frames_entry.send(( path.bind_uri().clone(), *path.pathway(), *path.link(), ReliableFrame::PunchDone(frame), )) } _ => {} } } async fn parse_normal_zero_rtt_packet( (packet, (bind_uri, pathway, link)): ReceivedZeroRttFrom, space: &DataSpace, components: &Components, dispatch_frame: impl Fn(Frame, Type, &Path), ) -> Result<(), Error> { let Some(packet) = space.decrypt_0rtt_packet(packet).await.transpose()? else { return Ok(()); }; let path = match components.get_or_try_create_path(bind_uri, link, pathway, true) { Ok(path) => path, Err(CreatePathFailure::ConnectionClosed(..)) => { packet.drop_on_conenction_closed(); return Ok(()); } Err(CreatePathFailure::NoInterface(..)) => { packet.drop_on_interface_not_found(); return Ok(()); } }; let Some(packet) = filter_odcid_packet(packet, &components.specific) else { return Ok(()); }; let packet_content = read_plain_packet(&packet, |frame| { dispatch_frame(frame, packet.get_type(), &path); })?; space.journal.of_rcvd_packets().on_rcvd_pn( packet.pn(), packet_content.is_ack_eliciting(), path.cc().get_pto(Epoch::Data), ); path.on_packet_rcvd(Epoch::Data, packet.pn(), packet.size(), packet_content); Result::<(), Error>::Ok(()) } async fn parse_normal_one_rtt_packet( (packet, (bind_uri, pathway, link)): ReceivedOneRttFrom, space: &DataSpace, components: &Components, dispatch_frame: impl Fn(Frame, Type, &Path), ) -> Result<(), Error> { let Some(packet) = space.decrypt_1rtt_packet(packet).await.transpose()? else { return Ok(()); }; let path = match components.get_or_try_create_path(bind_uri, link, pathway, true) { Ok(path) => path, Err(CreatePathFailure::ConnectionClosed(..)) => { packet.drop_on_conenction_closed(); return Ok(()); } Err(CreatePathFailure::NoInterface(..)) => { packet.drop_on_interface_not_found(); return Ok(()); } }; let Some(packet) = filter_odcid_packet(packet, &components.specific) else { return Ok(()); }; components .quic_handshake .discard_spaces_on_server_handshake_done(&components.paths); let packet_content = read_plain_packet(&packet, |frame| { dispatch_frame(frame, packet.get_type(), &path); })?; space.journal.of_rcvd_packets().on_rcvd_pn( packet.pn(), packet_content.is_ack_eliciting(), path.cc().get_pto(Epoch::Data), ); path.on_packet_rcvd(Epoch::Data, packet.pn(), packet.size(), packet_content); Result::<(), Error>::Ok(()) } fn parse_closing_one_rtt_packet( space: &DataSpace, packet: CipherOneRttPacket, ) -> Option { let (hpk, pk) = space.one_rtt_keys.remote_keys()?; let packet = packet .decrypt_short_packet(hpk.as_ref(), &pk, |pn| { space.journal.of_rcvd_packets().decode_pn(pn) }) .and_then(Result::ok)?; let mut ccf = None; _ = read_plain_packet(&packet, |frame| { ccf = ccf.take().or(match frame { Frame::Close(ccf) => Some(ccf), _ => None, }); }); ccf } pub async fn deliver_and_parse_packets( zeor_rtt_packets: BoundQueue, one_rtt_packets: BoundQueue, space: Arc, components: Components, event_broker: ArcEventBroker, ) { let conn_state = &components.conn_state; let dispatch_frame = frame_dispathcer(&space, &components, &event_broker); let normal_deliver_and_parse_zero_rtt_loop = async { while let Some(form) = zeor_rtt_packets.recv().await { let span = qevent::span!(@current, path=form.1.2.to_string()); let parse = parse_normal_zero_rtt_packet(form, &space, &components, &dispatch_frame); if let Err(Error::Quic(error)) = Instrument::instrument(parse, span).await { event_broker.emit(Event::Failed(error)); }; } }; let normal_deliver_and_parse_one_rtt_loop = async { while let Some(form) = one_rtt_packets.recv().await { let span = qevent::span!(@current, path=form.1.2.to_string()); let parse = parse_normal_one_rtt_packet(form, &space, &components, &dispatch_frame); if let Err(Error::Quic(error)) = Instrument::instrument(parse, span).await { event_broker.emit(Event::Failed(error)); }; } }; let normal_deliver_and_parse_loops = async { if components.tls_handshake.info().await.is_err() { return; } tokio::join!( normal_deliver_and_parse_zero_rtt_loop, normal_deliver_and_parse_one_rtt_loop, ); }; let ccf = tokio::select! { // deliver and parse packets. complete when packet queue closed _ = normal_deliver_and_parse_loops => return, // connection terminated(enter closing/draining state) error = conn_state.terminated() => match conn_state.current() { // entered closing_state, keep receiving packets, and send ccf state if state == Some(state::CLOSING) => ConnectionCloseFrame::from(error), // entered other state, do nothing _ => return } }; let terminator = Terminator::new(ccf, &components); // Release the primary connection state drop(components); zeor_rtt_packets.close(); while let Some((packet, (_bind_uri, pathway, _link))) = one_rtt_packets.recv().await { if let Some(ccf) = parse_closing_one_rtt_packet(&space, packet) { event_broker.emit(Event::Closed(ccf)); } if terminator.should_send() { terminator .try_send_on(pathway, |buffer, ccf| { assemble_closing_packet::( space.as_ref(), &terminator, buffer, ccf, ) }) .await } } } pub struct DataTracker { journal: DataJournal, crypto_stream: CryptoStream, streams: DataStreams, reliable_frames: ArcReliableFrameDeque, } impl Feedback for DataTracker { fn may_loss(&self, trigger: PacketLostTrigger, pns: &mut dyn Iterator) { let sent_jornal = self.journal.of_sent_packets(); let crypto_outgoing = self.crypto_stream.outgoing(); let mut sent_packets = sent_jornal.rotate(); for pn in pns { let mut may_lost_frames = QuicFramesCollector::::new(); for frame in sent_packets.may_loss_packet(pn) { match frame { GuaranteedFrame::Crypto(frame) => { may_lost_frames.extend([&frame]); crypto_outgoing.may_loss_data(&frame); } GuaranteedFrame::Stream(frame) => { may_lost_frames.extend([&frame]); self.streams.may_loss_data(&frame); } GuaranteedFrame::Reliable(frame) => { may_lost_frames.extend([&frame]); self.reliable_frames.send_frame([frame]); } }; } qevent::event!(PacketLost { header: PacketHeader { // TOOD: 如果只有支持0rtt,这里就不一定是1rtt了 packet_type: PacketType::OneRTT, packet_number: pn }, frames: may_lost_frames, trigger }); } } } ================================================ FILE: qconnection/src/space/handshake.rs ================================================ use std::sync::Arc; use qbase::{ Epoch, GetEpoch, error::{Error, QuicError}, frame::{ConnectionCloseFrame, CryptoFrame, Frame}, net::tx::Signals, packet::{header::long::HandshakeHeader, io::PacketSpace, keys::ArcKeys}, util::BoundQueue, }; use qcongestion::{Feedback, Transport}; use qevent::{ quic::{ PacketHeader, PacketType, QuicFramesCollector, recovery::{PacketLost, PacketLostTrigger}, }, telemetry::Instrument, }; use qinterface::component::route::{CipherPacket, PlainPacket, Way}; use qrecovery::crypto::CryptoStream; use tokio::sync::mpsc; use crate::{ Components, HandshakeJournal, events::{ArcEventBroker, EmitEvent, Event}, path::{self, Path, error::CreatePathFailure}, space::{ AckHandshakeSpace, assemble_closing_packet, filter_odcid_packet, pipe, read_plain_packet, }, state, termination::Terminator, tx::{PacketWriter, TrivialPacketWriter}, }; pub type CipherHanshakePacket = CipherPacket; pub type PlainHandshakePacket = PlainPacket; pub type ReceivedFrom = (CipherHanshakePacket, Way); pub struct HandshakeSpace { keys: ArcKeys, journal: HandshakeJournal, } impl AsRef for HandshakeSpace { fn as_ref(&self) -> &HandshakeJournal { &self.journal } } impl HandshakeSpace { pub fn new() -> Self { Self { keys: ArcKeys::new_pending(), journal: HandshakeJournal::with_capacity(16, None), } } pub fn keys(&self) -> ArcKeys { self.keys.clone() } pub async fn decrypt_packet( &self, packet: CipherHanshakePacket, ) -> Option> { match self.keys.get_remote_keys().await { Some(keys) => packet.decrypt_long_packet( keys.remote.header.as_ref(), keys.remote.packet.as_ref(), |pn| self.journal.of_rcvd_packets().decode_pn(pn), ), None => { packet.drop_on_key_unavailable(); None } } } pub fn tracker(&self, crypto_stream: CryptoStream) -> HandshakeTracker { HandshakeTracker { journal: self.journal.clone(), crypto_stream, } } } impl Default for HandshakeSpace { fn default() -> Self { Self::new() } } impl GetEpoch for HandshakeSpace { fn epoch(&self) -> Epoch { Epoch::Handshake } } impl path::PacketSpace for HandshakeSpace { type JournalFrame = CryptoFrame; fn new_packet<'b, 's>( &'s self, header: HandshakeHeader, cc: &qcongestion::ArcCC, buffer: &'b mut [u8], ) -> Result, Signals> { let keys = self.keys.get_local_keys().ok_or(Signals::KEYS)?; let (retran_timeout, expire_timeout) = cc.retransmit_and_expire_time(Epoch::Handshake); PacketWriter::new_long( header, buffer, keys.local.clone(), self.journal.as_ref(), retran_timeout, expire_timeout, ) } } impl PacketSpace for HandshakeSpace { type PacketAssembler<'a> = TrivialPacketWriter<'a, 'a, CryptoFrame>; #[inline] fn new_packet<'a>( &'a self, header: HandshakeHeader, buffer: &'a mut [u8], ) -> Result, Signals> { let keys = self.keys.get_local_keys().ok_or(Signals::KEYS)?; TrivialPacketWriter::new_long(header, buffer, keys.local, self.journal.as_ref()) } } fn frame_dispathcer( space: &HandshakeSpace, components: &Components, event_broker: &ArcEventBroker, ) -> impl for<'p> Fn(Frame, &'p Path) + use<> { let (crypto_frames_entry, rcvd_crypto_frames) = mpsc::unbounded_channel(); let (ack_frames_entry, rcvd_ack_frames) = mpsc::unbounded_channel(); pipe( rcvd_crypto_frames, components.crypto_streams[space.epoch()].incoming(), event_broker.clone(), ); pipe( rcvd_ack_frames, AckHandshakeSpace::new(&space.journal, &components.crypto_streams[space.epoch()]), event_broker.clone(), ); let inform_cc = components.quic_handshake.status(); let event_broker = event_broker.clone(); let rcvd_joural = space.journal.of_rcvd_packets(); move |frame: Frame, path: &Path| match frame { Frame::Ack(f) => { path.cc().on_ack_rcvd(Epoch::Handshake, &f); rcvd_joural.on_rcvd_ack(&f); _ = ack_frames_entry.send(f); inform_cc.received_handshake_ack(); } Frame::Close(f) => event_broker.emit(Event::Closed(f)), Frame::Crypto(f, bytes) => _ = crypto_frames_entry.send((f, bytes)), Frame::Padding(_) | Frame::Ping(_) => {} _ => unreachable!("unexpected frame: {:?} in handshake packet", frame), } } async fn parse_normal_packet( (packet, (bind_uri, pathway, link)): ReceivedFrom, space: &HandshakeSpace, components: &Components, dispatch_frame: impl Fn(Frame, &Path), ) -> Result<(), Error> { let Some(packet) = space.decrypt_packet(packet).await.transpose()? else { return Ok(()); }; let path = match components.get_or_try_create_path(bind_uri, link, pathway, true) { Ok(path) => path, Err(CreatePathFailure::ConnectionClosed(..)) => { packet.drop_on_conenction_closed(); return Ok(()); } Err(CreatePathFailure::NoInterface(..)) => { packet.drop_on_interface_not_found(); return Ok(()); } }; let Some(packet) = filter_odcid_packet(packet, &components.specific) else { return Ok(()); }; // See [RFC 9000 section 8.1](https://www.rfc-editor.org/rfc/rfc9000.html#name-address-validation-during-c) // Once an endpoint has successfully processed a Handshake packet from the peer, it can consider the peer // address to have been validated. // It may have already been verified using tokens in the Handshake space path.grant_anti_amplification(); let packet_content = read_plain_packet(&packet, |frame| dispatch_frame(frame, &path))?; space.journal.of_rcvd_packets().on_rcvd_pn( packet.pn(), packet_content.is_ack_eliciting(), path.cc().get_pto(Epoch::Handshake), ); path.on_packet_rcvd(Epoch::Handshake, packet.pn(), packet.size(), packet_content); Result::<(), Error>::Ok(()) } fn parse_closing_packet( space: &HandshakeSpace, packet: CipherHanshakePacket, ) -> Option { // TOOD: improve Keys let remote_keys = space.keys.get_local_keys()?.remote; let packet = packet .decrypt_long_packet( remote_keys.header.as_ref(), remote_keys.packet.as_ref(), |pn| space.journal.of_rcvd_packets().decode_pn(pn), ) .and_then(Result::ok)?; let mut ccf = None; _ = read_plain_packet(&packet, |frame| { ccf = ccf.take().or(match frame { Frame::Close(ccf) => Some(ccf), _ => None, }); }); ccf } pub async fn deliver_and_parse_packets( packets: BoundQueue, space: Arc, components: Components, event_broker: ArcEventBroker, ) { let conn_state = &components.conn_state; let dispatch_frame = frame_dispathcer(&space, &components, &event_broker); let normal_deliver_and_parse_loop = async { while let Some(form) = packets.recv().await { let span = qevent::span!(@current, path=form.1.2.to_string()); let parse = parse_normal_packet(form, &space, &components, &dispatch_frame); if let Err(Error::Quic(error)) = Instrument::instrument(parse, span).await { event_broker.emit(Event::Failed(error)); }; } }; let ccf = tokio::select! { // deliver and parse packets. complete when packet queue closed _ = normal_deliver_and_parse_loop => return, // connection terminated(enter closing/draining state) error = conn_state.terminated() => match conn_state.current() { // entered closing_state, keep receiving packets, and send ccf state if state == Some(state::CLOSING) => ConnectionCloseFrame::from(error), // entered other state, do nothing _ => return } }; let terminator = Terminator::new(ccf, &components); // Release the primary connection state drop(components); while let Some((packet, (_bind_uri, pathway, _link))) = packets.recv().await { if let Some(ccf) = parse_closing_packet(&space, packet) { event_broker.emit(Event::Closed(ccf)); } if terminator.should_send() { terminator .try_send_on(pathway, |buffer, ccf| { assemble_closing_packet(space.as_ref(), &terminator, buffer, ccf) }) .await } } } pub struct HandshakeTracker { journal: HandshakeJournal, crypto_stream: CryptoStream, } impl Feedback for HandshakeTracker { fn may_loss(&self, trigger: PacketLostTrigger, pns: &mut dyn Iterator) { let sent_jornal = self.journal.of_sent_packets(); let outgoing = self.crypto_stream.outgoing(); let mut sent_packets = sent_jornal.rotate(); for pn in pns { let mut may_lost_frames = QuicFramesCollector::::new(); for frame in sent_packets.may_loss_packet(pn) { may_lost_frames.extend([&frame]); outgoing.may_loss_data(&frame); } qevent::event!(PacketLost { header: PacketHeader { packet_type: PacketType::Handshake, packet_number: pn }, frames: may_lost_frames, trigger }); } } } ================================================ FILE: qconnection/src/space/initial.rs ================================================ use std::{ops::Deref, sync::Arc}; use qbase::{ Epoch, GetEpoch, error::{Error, QuicError}, frame::{ConnectionCloseFrame, CryptoFrame, Frame}, net::tx::Signals, packet::{ header::{GetScid, long::InitialHeader}, io::PacketSpace, keys::{ArcKeys, Keys}, }, token::TokenRegistry, util::BoundQueue, }; use qcongestion::{Feedback, Transport}; use qevent::{ quic::{ PacketHeader, PacketType, QuicFramesCollector, recovery::{PacketLost, PacketLostTrigger}, }, telemetry::Instrument, }; use qinterface::component::route::{CipherPacket, PlainPacket, Way}; use qrecovery::crypto::CryptoStream; use tokio::sync::mpsc; use crate::{ Components, InitialJournal, events::{ArcEventBroker, EmitEvent, Event}, path::{self, Path, error::CreatePathFailure}, space::{ AckInitialSpace, assemble_closing_packet, filter_odcid_packet, pipe, read_plain_packet, }, state, termination::Terminator, tx::{PacketWriter, TrivialPacketWriter}, }; pub type CipherInitialPacket = CipherPacket; pub type PlainInitialPacket = PlainPacket; pub type ReceivedFrom = (CipherInitialPacket, Way); pub struct InitialSpace { keys: ArcKeys, journal: InitialJournal, } impl AsRef for InitialSpace { fn as_ref(&self) -> &InitialJournal { &self.journal } } impl InitialSpace { // Initial keys应该是预先知道的,或者传入dcid,可以构造出来 pub fn new(keys: Keys) -> Self { let journal = InitialJournal::with_capacity(16, None); Self { keys: ArcKeys::with_keys(keys), journal, } } pub fn keys(&self) -> ArcKeys { self.keys.clone() } pub async fn decrypt_packet( &self, packet: CipherInitialPacket, ) -> Option> { match self.keys.get_remote_keys().await { Some(keys) => packet.decrypt_long_packet( keys.remote.header.as_ref(), keys.remote.packet.as_ref(), |pn| self.journal.of_rcvd_packets().decode_pn(pn), ), None => { packet.drop_on_key_unavailable(); None } } } pub fn tracker(&self, crypto_stream: CryptoStream) -> InitialTracker { InitialTracker { journal: self.journal.clone(), crypto_stream, } } } impl GetEpoch for InitialSpace { fn epoch(&self) -> Epoch { Epoch::Initial } } impl path::PacketSpace for InitialSpace { type JournalFrame = CryptoFrame; fn new_packet<'b, 's>( &'s self, header: InitialHeader, cc: &qcongestion::ArcCC, buffer: &'b mut [u8], ) -> Result, Signals> { let keys = self.keys.get_local_keys().ok_or(Signals::KEYS)?; let (retran_timeout, expire_timeout) = cc.retransmit_and_expire_time(Epoch::Handshake); PacketWriter::new_long( header, buffer, keys.local, self.journal.as_ref(), retran_timeout, expire_timeout, ) } } impl PacketSpace for InitialSpace { type PacketAssembler<'a> = TrivialPacketWriter<'a, 'a, CryptoFrame>; #[inline] fn new_packet<'a>( &'a self, header: InitialHeader, buffer: &'a mut [u8], ) -> Result, Signals> { let keys = self.keys.get_local_keys().ok_or(Signals::KEYS)?; TrivialPacketWriter::new_long(header, buffer, keys.local, self.journal.as_ref()) } } fn frame_dispathcer( space: &InitialSpace, components: &Components, event_broker: &ArcEventBroker, ) -> impl for<'p> Fn(Frame, &'p Path) + use<> { let (crypto_frames_entry, rcvd_crypto_frames) = mpsc::unbounded_channel(); let (ack_frames_entry, rcvd_ack_frames) = mpsc::unbounded_channel(); pipe( rcvd_crypto_frames, components.crypto_streams[space.epoch()].incoming(), event_broker.clone(), ); pipe( rcvd_ack_frames, AckInitialSpace::new(&space.journal, &components.crypto_streams[space.epoch()]), event_broker.clone(), ); let event_broker = event_broker.clone(); let rcvd_joural = space.journal.of_rcvd_packets(); move |frame: Frame, path: &Path| match frame { Frame::Ack(f) => { path.cc().on_ack_rcvd(Epoch::Initial, &f); rcvd_joural.on_rcvd_ack(&f); _ = ack_frames_entry.send(f); } Frame::Close(f) => event_broker.emit(Event::Closed(f)), Frame::Crypto(f, bytes) => _ = crypto_frames_entry.send((f, bytes)), Frame::Padding(_) | Frame::Ping(_) => {} _ => unreachable!("unexpected frame: {:?} in initial packet", frame), } } async fn parse_normal_packet( (packet, (bind_uri, pathway, link)): ReceivedFrom, space: &InitialSpace, components: &Components, dispatch_frame: impl Fn(Frame, &Path), ) -> Result<(), Error> { let parameters = &components.parameters; let paths = &components.paths; let remote_cids = &components.cid_registry.remote; let validate_token = { let token_registry = &components.token_registry; let tls_handshake = &components.tls_handshake; |initial_token: &[u8], path: &Path| { if let TokenRegistry::Server(provider) = token_registry.deref() && let Ok(Some(server_name)) = tls_handshake.server_name() && provider.verify_token(server_name.as_ref(), initial_token) { path.grant_anti_amplification(); } } }; // rfc9000 7.2: // if subsequent Initial packets include a different Source Connection ID, they MUST be discarded. This avoids // unpredictable outcomes that might otherwise result from stateless processing of multiple Initial packets // with different Source Connection IDs. if matches!(parameters.lock_guard()?.initial_scid_from_peer(), Some(scid) if scid != *packet.scid()) { packet.drop_on_scid_unmatch(); return Ok(()); } let Some(packet) = space.decrypt_packet(packet).await.transpose()? else { return Ok(()); }; let path = match components.get_or_try_create_path(bind_uri, link, pathway, true) { Ok(path) => path, Err(CreatePathFailure::ConnectionClosed(..)) => { packet.drop_on_conenction_closed(); return Ok(()); } Err(CreatePathFailure::NoInterface(..)) => { packet.drop_on_interface_not_found(); return Ok(()); } }; let Some(packet) = filter_odcid_packet(packet, &components.specific) else { return Ok(()); }; let packet_content = read_plain_packet(&packet, |frame| dispatch_frame(frame, &path))?; space.journal.of_rcvd_packets().on_rcvd_pn( packet.pn(), packet_content.is_ack_eliciting(), path.cc().get_pto(Epoch::Initial), ); path.on_packet_rcvd(Epoch::Initial, packet.pn(), packet.size(), packet_content); // Negotiate handshake path if paths.assign_handshake_path(&path, remote_cids, *packet.scid()) { parameters .lock_guard()? .initial_scid_from_peer_need_equal(*packet.scid())?; } // See [RFC 9000 section 8.1](https://www.rfc-editor.org/rfc/rfc9000.html#name-address-validation-during-c) // A server might wish to validate the client address before starting the cryptographic handshake. // QUIC uses a token in the Initial packet to provide address validation prior to completing the handshake. // This token is delivered to the client during connection establishment with a Retry packet (see Section 8.1.2) // or in a previous connection using the NEW_TOKEN frame (see Section 8.1.3). if !packet.token().is_empty() { validate_token(packet.token(), &path); } Result::<(), Error>::Ok(()) } fn parse_closing_packet( space: &InitialSpace, packet: CipherInitialPacket, ) -> Option { // TOOD: improve Keys let remote_keys = space.keys.get_local_keys()?.remote; let packet = packet .decrypt_long_packet( remote_keys.header.as_ref(), remote_keys.packet.as_ref(), |pn| space.journal.of_rcvd_packets().decode_pn(pn), ) .and_then(Result::ok)?; let mut ccf = None; _ = read_plain_packet(&packet, |frame| { ccf = ccf.take().or(match frame { Frame::Close(ccf) => Some(ccf), _ => None, }); }); ccf } pub async fn deliver_and_parse_packets( packets: BoundQueue, space: Arc, components: Components, event_broker: ArcEventBroker, ) { let conn_state = &components.conn_state; let dispatch_frame = frame_dispathcer(&space, &components, &event_broker); let normal_deliver_and_parse_loop = async { while let Some(form) = packets.recv().await { let span = qevent::span!(@current, path=form.1.2.to_string()); let parse = parse_normal_packet(form, &space, &components, &dispatch_frame); if let Err(Error::Quic(error)) = Instrument::instrument(parse, span).await { event_broker.emit(Event::Failed(error)); }; } }; let ccf = tokio::select! { // deliver and parse packets. complete when packet queue closed _ = normal_deliver_and_parse_loop => return, // connection terminated(enter closing/draining state) error = conn_state.terminated() => match conn_state.current() { // entered closing_state, keep receiving packets, and send ccf state if state == Some(state::CLOSING) => ConnectionCloseFrame::from(error), // entered other state, do nothing _ => return } }; let terminator = Terminator::new(ccf, &components); // Release the primary connection state drop(components); while let Some((packet, (_bind_uri, pathway, _link))) = packets.recv().await { if let Some(ccf) = parse_closing_packet(&space, packet) { event_broker.emit(Event::Closed(ccf)); } // TODO:尝试解决计数分离的问题?将收包统计转为连接和路径级?发送数据包交给路径? if terminator.should_send() { terminator .try_send_on(pathway, |buffer, ccf| { assemble_closing_packet(space.as_ref(), &terminator, buffer, ccf) }) .await } } } pub struct InitialTracker { journal: InitialJournal, crypto_stream: CryptoStream, } impl Feedback for InitialTracker { fn may_loss(&self, trigger: PacketLostTrigger, pns: &mut dyn Iterator) { let sent_jornal = self.journal.of_sent_packets(); let outgoing = self.crypto_stream.outgoing(); let mut sent_packets = sent_jornal.rotate(); for pn in pns { let mut may_lost_frames = QuicFramesCollector::::new(); for frame in sent_packets.may_loss_packet(pn) { may_lost_frames.extend([&frame]); outgoing.may_loss_data(&frame); } qevent::event!(PacketLost { header: PacketHeader { packet_type: PacketType::Initial, packet_number: pn }, frames: may_lost_frames, trigger }); } } } ================================================ FILE: qconnection/src/space.rs ================================================ pub mod data; pub mod handshake; pub mod initial; use std::{borrow::Cow, fmt::Debug, sync::Arc}; use bytes::Bytes; use qbase::{ error::{Error, QuicError}, frame::{ AckFrame, ConnectionCloseFrame, CryptoFrame, FrameFeature, FrameReader, GetFrameType, ReliableFrame, StreamCtlFrame, StreamFrame, io::ReceiveFrame, }, packet::{ AssemblePacket, Package, PacketContent, PacketSpace, PacketWriter, ProductHeader, header::{GetDcid, GetType, short::OneRttHeader}, io::{Packages, PadTo20}, }, }; use qevent::{ quic::{ PacketHeaderBuilder, QuicFramesCollector, transport::{PacketReceived, PacketsAcked}, }, telemetry::Instrument, }; use qinterface::component::route::PlainPacket; use qrecovery::{ crypto::{CryptoStream, CryptoStreamOutgoing}, journal::{ArcSentJournal, Journal}, }; use tokio::sync::mpsc::UnboundedReceiver; use tracing::Instrument as _; use crate::{ Components, DataStreams, FlowController, GuaranteedFrame, SpecificComponents, events::{ArcEventBroker, EmitEvent, Event}, termination::Terminator, }; #[derive(Clone)] pub struct Spaces { initial: Arc, handshake: Arc, data: Arc, } impl Spaces { pub fn new( initial: initial::InitialSpace, handshake: handshake::HandshakeSpace, data: data::DataSpace, ) -> Self { Self { initial: Arc::new(initial), handshake: Arc::new(handshake), data: Arc::new(data), } } pub fn initial(&self) -> &Arc { &self.initial } pub fn handshake(&self) -> &Arc { &self.handshake } pub fn data(&self) -> &Arc { &self.data } } fn assemble_closing_packet<'s, 'b: 's, H, S>( space: &'s S, product_header: &impl ProductHeader, buffer: &'b mut [u8], ccf: &ConnectionCloseFrame, ) -> Option where S: PacketSpace, S::PacketAssembler<'s>: AsRef>, for<'f> &'f ConnectionCloseFrame: Package>, { let header = product_header.new_header().ok()?; let mut packet = S::new_packet(space, header, buffer).ok()?; let ccf = match ccf.belongs_to(packet.as_ref().packet_type()) { true => Cow::Borrowed(ccf), false => Cow::Owned(ConnectionCloseFrame::from(match ccf { ConnectionCloseFrame::App(app_close_frame) => app_close_frame.conceal(), ConnectionCloseFrame::Quic(..) => unreachable!(), })), }; packet .assemble_packet(&mut Packages((ccf.as_ref(), PadTo20))) .ok()?; Some(packet.encrypt_and_protect_packet().0) } impl Spaces { pub async fn send_ccf_packets(&self, t: &Terminator) { t.try_send(|mut buf, ccf| { let original_size = buf.len(); let initial_size = assemble_closing_packet(self.initial().as_ref(), t, buf, ccf); buf = &mut buf[initial_size.unwrap_or(0)..]; let handshake_size = assemble_closing_packet(self.handshake().as_ref(), t, buf, ccf); buf = &mut buf[handshake_size.unwrap_or(0)..]; let one_rtt_size = assemble_closing_packet::(self.data().as_ref(), t, buf, ccf); buf = &mut buf[one_rtt_size.unwrap_or(0)..]; if initial_size.is_some() { buf.fill(0); Some(original_size) } else { (original_size != buf.len()).then_some(original_size - buf.len()) } }) .await; } } fn pipe( mut source: UnboundedReceiver, destination: impl ReceiveFrame + Send + 'static, broker: ArcEventBroker, ) { tokio::spawn( async move { while let Some(f) = source.recv().await { if let Err(Error::Quic(e)) = destination.recv_frame(f) { broker.emit(Event::Failed(e)); break; } } } .instrument_in_current() .in_current_span(), ); } /// When receiving a [`StreamFrame`] or [`StreamCtlFrame`], /// flow control must be updated accordingly #[derive(Clone)] struct FlowControlledDataStreams { streams: DataStreams, flow_ctrl: FlowController, } impl FlowControlledDataStreams { fn new(streams: DataStreams, flow_ctrl: FlowController) -> Self { Self { streams, flow_ctrl } } } impl ReceiveFrame<(StreamFrame, Bytes)> for FlowControlledDataStreams { type Output = (); fn recv_frame(&self, data_frame: (StreamFrame, Bytes)) -> Result { let frame_type = data_frame.0.frame_type(); let new_data_size = self.streams.recv_data(data_frame)?; self.flow_ctrl.on_new_rcvd(frame_type, new_data_size)?; Ok(()) } } impl ReceiveFrame for FlowControlledDataStreams { type Output = (); fn recv_frame(&self, frame: StreamCtlFrame) -> Result { let new_data_size = self.streams.recv_stream_control(frame)?; self.flow_ctrl .on_new_rcvd(frame.frame_type(), new_data_size)?; Ok(()) } } struct AckInitialSpace { sent_journal: ArcSentJournal, crypto_stream_outgoing: CryptoStreamOutgoing, } impl AckInitialSpace { fn new(journal: &Journal, crypto_stream: &CryptoStream) -> Self { Self { sent_journal: journal.of_sent_packets(), crypto_stream_outgoing: crypto_stream.outgoing(), } } } impl ReceiveFrame for AckInitialSpace { type Output = (); fn recv_frame(&self, ack_frame: AckFrame) -> Result { let mut rotate_guard = self.sent_journal.rotate(); rotate_guard.update_largest(&ack_frame)?; let acked = ack_frame.iter().flat_map(|r| r.rev()).collect::>(); qevent::event!(PacketsAcked { packet_number_space: qbase::Epoch::Initial, packet_nubers: acked.clone(), }); for pn in acked { for frame in rotate_guard.on_packet_acked(pn) { self.crypto_stream_outgoing.on_data_acked(&frame); } } Ok(()) } } struct AckHandshakeSpace { sent_journal: ArcSentJournal, crypto_stream_outgoing: CryptoStreamOutgoing, } impl AckHandshakeSpace { fn new(journal: &Journal, crypto_stream: &CryptoStream) -> Self { Self { sent_journal: journal.of_sent_packets(), crypto_stream_outgoing: crypto_stream.outgoing(), } } } impl ReceiveFrame for AckHandshakeSpace { type Output = (); fn recv_frame(&self, ack_frame: AckFrame) -> Result { let mut rotate_guard = self.sent_journal.rotate(); rotate_guard.update_largest(&ack_frame)?; let acked = ack_frame.iter().flat_map(|r| r.rev()).collect::>(); qevent::event!(PacketsAcked { packet_number_space: qbase::Epoch::Handshake, packet_nubers: acked.clone(), }); for pn in acked { for frame in rotate_guard.on_packet_acked(pn) { self.crypto_stream_outgoing.on_data_acked(&frame); } } Ok(()) } } struct AckDataSpace { send_journal: ArcSentJournal, data_streams: DataStreams, crypto_stream_outgoing: CryptoStreamOutgoing, } impl AckDataSpace { fn new( journal: &Journal, data_streams: DataStreams, crypto_stream: &CryptoStream, ) -> Self { Self { send_journal: journal.of_sent_packets(), data_streams, crypto_stream_outgoing: crypto_stream.outgoing(), } } } impl ReceiveFrame for AckDataSpace { type Output = (); fn recv_frame(&self, ack_frame: AckFrame) -> Result { let mut rotate_guard = self.send_journal.rotate(); rotate_guard.update_largest(&ack_frame)?; let acked = ack_frame.iter().flat_map(|r| r.rev()).collect::>(); qevent::event!(PacketsAcked { packet_number_space: qbase::Epoch::Data, packet_nubers: acked.clone(), }); for pn in acked { for frame in rotate_guard.on_packet_acked(pn) { match frame { GuaranteedFrame::Stream(stream_frame) => { self.data_streams.on_data_acked(stream_frame) } GuaranteedFrame::Crypto(crypto_frame) => { self.crypto_stream_outgoing.on_data_acked(&crypto_frame) } GuaranteedFrame::Reliable(ReliableFrame::StreamCtl( StreamCtlFrame::ResetStream(reset_frame), )) => self.data_streams.on_reset_acked(reset_frame), _ => { /* nothing to do */ } } } } Ok(()) } } pub fn spawn_deliver_and_parse(components: &Components) { let received_packets_queue = &components.rcvd_pkt_q; let initial = initial::deliver_and_parse_packets( received_packets_queue.initial().clone(), components.spaces.initial.clone(), components.clone(), components.event_broker.clone(), ); let handshake = handshake::deliver_and_parse_packets( received_packets_queue.handshake().clone(), components.spaces.handshake.clone(), components.clone(), components.event_broker.clone(), ); let data = data::deliver_and_parse_packets( received_packets_queue.zero_rtt().clone(), received_packets_queue.one_rtt().clone(), components.spaces.data.clone(), components.clone(), components.event_broker.clone(), ); tokio::spawn( async move { tokio::join!(biased; data, handshake, initial) } .instrument_in_current() .in_current_span(), ); } /// For server connection, the origin dcid doesnot own a sequences number, once we received a packet which dcid != odcid, /// we should stop using the odcid, and drop the subsequent packets with odcid. /// /// We do not remove the route to odcid, otherwise the server may establish multiple connections for packets with same odcid. /// /// https://www.rfc-editor.org/rfc/rfc9000.html#name-negotiating-connection-ids fn filter_odcid_packet( packet: PlainPacket, specific: &SpecificComponents, ) -> Option> { use std::sync::atomic::Ordering::SeqCst; if let SpecificComponents::Server { odcid_router_entry, using_odcid, } = &specific { let dcid = (*packet.dcid()).into(); if odcid_router_entry.signpost() == dcid && !using_odcid.load(SeqCst) { drop(packet); // just drop the packet, It's like we never received this packet. return None; } if odcid_router_entry.signpost() != dcid { using_odcid.store(false, SeqCst); } } Some(packet) } fn read_plain_packet( packet: &PlainPacket, mut dispatch_frame: impl FnMut(qbase::frame::Frame), ) -> Result where H: GetType, PacketHeaderBuilder: for<'a> From<&'a H>, { let mut frames_collector = QuicFramesCollector::::new(); let mut packet_content = PacketContent::default(); let frame_reader = FrameReader::new(packet.body(), packet.get_type()); for frame_result in frame_reader { let (frame, r#type) = frame_result.map_err(QuicError::from)?; frames_collector.extend([&frame]); packet_content += r#type; dispatch_frame(frame); } packet.log_received(frames_collector); Ok(packet_content) } ================================================ FILE: qconnection/src/state.rs ================================================ use std::{ future::Future, sync::{ Arc, atomic::{AtomicU8, Ordering}, }, }; use qbase::{error::Error, frame::ConnectionCloseFrame, net::route::Link, role::Role}; use qevent::{ quic::{ Owner, connectivity::{ BaseConnectionStates, ConnectionStarted, ConnectionState as QlogConnectionState, ConnectionStateUpdated, GranularConnectionStates, }, transport::ParametersSet, }, telemetry::Instrument, }; use tokio::sync::SetOnce; use tracing::Instrument as _; use crate::Components; #[derive(Clone)] pub struct ArcConnState { state: Arc, handshaked: Arc>, terminated: Arc>, } impl Default for ArcConnState { fn default() -> Self { Self { state: Default::default(), handshaked: Arc::new(SetOnce::new()), terminated: Arc::new(SetOnce::new()), } } } impl ArcConnState { pub fn new() -> Self { Self::default() } /// Attempt to set the connection state from None to `BaseConnectionStates::Attempted`. /// /// Returns true if the state was successfully set to `BaseConnectionStates::Attempted`. /// /// Called when creating paths. If it returns true, it means that the path is the first path to connect. pub fn try_entry_attempted(&self, components: &Components, link: Link) -> Result { let attempted = encode(BaseConnectionStates::Attempted.into()); let success = self .state .compare_exchange(0, attempted, Ordering::AcqRel, Ordering::Acquire) .is_ok(); if success { // same as Self::update qevent::event!(ConnectionStateUpdated { new: BaseConnectionStates::Attempted, }); qevent::event!(ConnectionStarted { socket: { (link.src, link.dst) } // cid不在这一层,未知 }); match components.role() { Role::Client => { let lock_guard = components.parameters.lock_guard(); if let Some(local_parameters) = lock_guard.as_ref().ok().and_then(|p| p.client()) { qevent::event!(ParametersSet { owner: Owner::Local, client_parameters: local_parameters.as_ref(), }) } } Role::Server => { let lock_guard = components.parameters.lock_guard(); if let Some(local_parameters) = lock_guard.as_ref().ok().and_then(|p| p.server()) { qevent::event!(ParametersSet { owner: Owner::Local, server_parameters: local_parameters.as_ref(), }) } } }; } Ok(success) } /// Try to update the connection state, return the old state if successful. pub fn update(&self, state: QlogConnectionState) -> Option { let new_state_code = encode(state); let mut old_state_code = self.state.load(Ordering::Acquire); loop { if new_state_code <= old_state_code { return None; } match self.state.compare_exchange( old_state_code, new_state_code, Ordering::AcqRel, Ordering::Acquire, ) { Ok(_old_state_code) => { // when server received a initial packet but failed to decrypt it, connection state will // enter Closing directly without enter Attempted. let old_state = decode(old_state_code).unwrap_or(BaseConnectionStates::Attempted.into()); qevent::event!(ConnectionStateUpdated { new: state, old: old_state }); return Some(old_state); } Err(current_state_code) => old_state_code = current_state_code, } } } pub fn enter_handshaked(&self) -> Option { if let Some(old_state) = self.update(GranularConnectionStates::HandshakeConfirmed.into()) { self.handshaked.set(()).expect("Handshaked already set"); return Some(old_state); } None } pub fn enter_closing(&self, error: &(impl Into + Clone)) -> Option { if let Some(old_state) = self.update(GranularConnectionStates::Closing.into()) { self.terminated .set(error.clone().into()) .expect("Terminated error already set"); return Some(old_state); } None } pub fn enter_draining(&self, ccf: &ConnectionCloseFrame) -> Option { if let Some(old_state) = self.update(GranularConnectionStates::Draining.into()) { if old_state != QlogConnectionState::Granular(GranularConnectionStates::Closing) { self.terminated .set(ccf.clone().into()) .expect("Terminated error already set"); } return Some(old_state); } None } pub fn handshaked(&self) -> impl Future> + Send + use<> { let handshaked = self.handshaked.clone(); let terminated = self.terminated.clone(); async move { tokio::select! { _ = handshaked.wait() => Ok(()), error = terminated.wait() => Err(error.clone()), } } .instrument_in_current() .in_current_span() } pub fn terminated(&self) -> impl Future + Send + use<> { let terminated = self.terminated.clone(); async move { terminated.wait().await.clone() } .instrument_in_current() .in_current_span() } pub fn current(&self) -> Option { decode(self.state.load(Ordering::Acquire)) } } macro_rules! mapping { ($( $a:ident ::$ b:ident ( $c:ident :: $d:ident ) => $number:literal, )*) => { pub fn decode(code: u8) -> Option { match code { $( $number => Some($a::$b($c::$d)), )* _ => None, } } pub fn encode(state: QlogConnectionState) -> u8 { match state { $( $a::$b($c::$d) => $number, )* _ => unreachable!("base closed and granular closed are repeated, use the base one"), } } }; } mapping! { QlogConnectionState::Base(BaseConnectionStates::Attempted) => 1, QlogConnectionState::Base(BaseConnectionStates::HandshakeStarted) => 2, // miss QlogConnectionState::Granular(GranularConnectionStates::PeerValidated) => 3, // miss QlogConnectionState::Granular(GranularConnectionStates::EarlyWrite) => 4, // miss QlogConnectionState::Base(BaseConnectionStates::HandshakeComplete) => 5, // miss QlogConnectionState::Granular(GranularConnectionStates::HandshakeConfirmed) => 6, QlogConnectionState::Granular(GranularConnectionStates::Closing) => 7, QlogConnectionState::Granular(GranularConnectionStates::Draining) => 8, // QlogConnectionState::Granular(GranularConnectionStates::Closed) => 9, QlogConnectionState::Base(BaseConnectionStates::Closed) => 9, } pub const HANDSHAKE_CONFIRMED: QlogConnectionState = QlogConnectionState::Granular(GranularConnectionStates::HandshakeConfirmed); pub const CLOSING: QlogConnectionState = QlogConnectionState::Granular(GranularConnectionStates::Closing); pub const DRAINING: QlogConnectionState = QlogConnectionState::Granular(GranularConnectionStates::Draining); pub const CLOSED: QlogConnectionState = QlogConnectionState::Granular(GranularConnectionStates::Closed); ================================================ FILE: qconnection/src/termination.rs ================================================ use std::{ io, mem, sync::{ Arc, Mutex, atomic::{AtomicUsize, Ordering}, }, time::Duration, }; use qbase::{ cid::ConnectionId, error::Error, frame::ConnectionCloseFrame, net::{route::Pathway, tx::Signals}, packet::{ header::{ long::{HandshakeHeader, InitialHeader, io::LongHeaderBuilder}, short::OneRttHeader, }, io::ProductHeader, }, }; use qinterface::component::route::RcvdPacketQueue; use tokio::time::Instant; use crate::{ArcLocalCids, Components, path::ArcPathContexts}; /// Keep a few states to support sending packets with ccf. /// /// when it is dropped all paths will be destroyed pub struct Terminator { last_recv_time: Mutex, rcvd_packets: AtomicUsize, scid: Option, dcid: Option, ccf: ConnectionCloseFrame, paths: ArcPathContexts, } impl Drop for Terminator { fn drop(&mut self) { self.paths.clear(); } } impl ProductHeader for Terminator { fn new_header(&self) -> Result { let (Some(dcid), Some(scid)) = (self.dcid, self.scid) else { return Err(Signals::empty()); }; // TODO: initial token Ok(LongHeaderBuilder::with_cid(dcid, scid).initial(vec![])) } } impl ProductHeader for Terminator { fn new_header(&self) -> Result { let (Some(dcid), Some(scid)) = (self.dcid, self.scid) else { return Err(Signals::empty()); }; Ok(LongHeaderBuilder::with_cid(dcid, scid).handshake()) } } impl ProductHeader for Terminator { fn new_header(&self) -> Result { let Some(dcid) = self.dcid else { return Err(Signals::empty()); }; // TODO: spin bit Ok(OneRttHeader::new(false.into(), dcid)) } } impl Terminator { pub fn new(ccf: ConnectionCloseFrame, components: &Components) -> Self { Self { last_recv_time: Mutex::new(Instant::now()), rcvd_packets: AtomicUsize::new(0), scid: components.cid_registry.local.initial_scid(), dcid: components.cid_registry.remote.latest_dcid(), ccf, paths: components.paths.clone(), } } pub fn should_send(&self) -> bool { let mut last_recv_time_guard = self.last_recv_time.lock().unwrap(); self.rcvd_packets.fetch_add(1, Ordering::AcqRel); if self.rcvd_packets.load(Ordering::Acquire) >= 3 || last_recv_time_guard.elapsed() > Duration::from_secs(1) { *last_recv_time_guard = tokio::time::Instant::now(); self.rcvd_packets.store(0, Ordering::Release); true } else { false } } pub async fn try_send(&self, mut write: W) where W: FnMut(&mut [u8], &ConnectionCloseFrame) -> Option, { for (_pathway, path) in self.paths.paths::>() { let mut datagram = vec![0; path.mtu() as _]; if let Some(written) = write(&mut datagram, &self.ccf) && written > 0 { _ = path .send_packets(&[io::IoSlice::new(&datagram[..written])]) .await; } } } pub async fn try_send_on(&self, pathway: Pathway, write: W) where W: FnOnce(&mut [u8], &ConnectionCloseFrame) -> Option, { let Some(path) = self.paths.get(&pathway) else { return; }; let mut datagram = vec![0; path.mtu() as _]; match write(&mut datagram, &self.ccf) { Some(written) if written > 0 => { _ = path .send_packets(&[io::IoSlice::new(&datagram[..written])]) .await; } _ => {} }; } } #[derive(Clone)] enum State { Closing(Arc), Draining, } #[derive(Clone)] pub struct Termination { // for generate io::Error error: Error, // keep this to keep the routing _local_cids: ArcLocalCids, state: State, } impl Termination { pub fn closing(error: Error, local_cids: ArcLocalCids, state: Arc) -> Self { Self { error, _local_cids: local_cids, state: State::Closing(state), } } pub fn draining(error: Error, local_cids: ArcLocalCids) -> Self { Self { error, _local_cids: local_cids, state: State::Draining, } } pub fn error(&self) -> Error { self.error.clone() } // Close packets queues, dont send and receive any more packets. pub fn enter_draining(&mut self) -> bool { match mem::replace(&mut self.state, State::Draining) { State::Closing(rcvd_pkt_q) => { rcvd_pkt_q.close_all(); true } _ => false, } } } ================================================ FILE: qconnection/src/tls/agent.rs ================================================ use std::sync::Arc; use derive_more::AsRef; use rustls::{ SignatureScheme, pki_types::{CertificateDer, SubjectPublicKeyInfoDer}, sign::{CertifiedKey, SigningKey}, }; use thiserror::Error; use x509_parser::prelude::FromDer; #[derive(Debug, Clone, AsRef)] pub struct LocalAgent { name: Arc, certified_key: Arc, } #[derive(Debug, Error)] pub enum SignError { #[error("Unsupported signature scheme {scheme:?}")] UnsupportedScheme { scheme: SignatureScheme }, #[error(transparent)] Crypto { #[from] source: rustls::Error, }, } impl LocalAgent { pub fn new(name: Arc, certified_key: Arc) -> Self { Self { name, certified_key, } } pub fn name(&self) -> &str { &self.name } pub fn cert_chain(&self) -> &[CertificateDer<'static>] { &self.certified_key.cert } pub fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { public_key(self.cert_chain()) } pub fn sign_algorithm(&self) -> rustls::SignatureAlgorithm { self.certified_key.key.algorithm() } pub fn sign(&self, scheme: SignatureScheme, data: &[u8]) -> Result, SignError> { sign(self.certified_key.key.as_ref(), scheme, data) } pub fn verify( &self, scheme: SignatureScheme, data: &[u8], signature: &[u8], ) -> Result { verify(self.public_key(), scheme, data, signature) } } #[derive(Debug, Clone, AsRef)] pub struct RemoteAgent { name: Arc, cert: Arc<[CertificateDer<'static>]>, } #[derive(Debug, Error)] pub enum VerifyError { #[error("Unsupported signature scheme {scheme:?}")] UnsupportedScheme { scheme: SignatureScheme }, } impl RemoteAgent { pub fn new(name: Arc, cert: Arc<[CertificateDer<'static>]>) -> Self { Self { name, cert } } pub fn name(&self) -> &str { &self.name } pub fn cert_chain(&self) -> &[CertificateDer<'static>] { &self.cert } pub fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { public_key(self.cert_chain()) } pub fn verify( &self, scheme: SignatureScheme, data: &[u8], signature: &[u8], ) -> Result { verify(self.public_key(), scheme, data, signature) } } fn public_key<'d>(cert_chain: &'d [CertificateDer<'d>]) -> SubjectPublicKeyInfoDer<'d> { use x509_parser::prelude::*; match x509_parser::certificate::X509Certificate::from_der(&cert_chain[0]) { Ok((_remain, certificate)) => { let spki = certificate.public_key().raw; spki.to_owned().into() } Err(_error) if cert_chain.len() == 1 => cert_chain[0].as_ref().into(), Err(_error) => unreachable!("rustls returned an invalid peer_certificates."), } } fn sign( key: &(impl SigningKey + ?Sized), scheme: SignatureScheme, data: &[u8], ) -> Result, SignError> { // FIXME: same as load spki then sign with ring? let signer = key .choose_scheme(&[scheme]) .ok_or(SignError::UnsupportedScheme { scheme })?; Ok(signer.sign(data)?) } fn verify( spki: SubjectPublicKeyInfoDer, scheme: SignatureScheme, data: &[u8], signature: &[u8], ) -> Result { let algorithm: &'static dyn ring::signature::VerificationAlgorithm = match scheme { SignatureScheme::ECDSA_NISTP384_SHA384 => &ring::signature::ECDSA_P384_SHA384_ASN1, SignatureScheme::ECDSA_NISTP256_SHA256 => &ring::signature::ECDSA_P256_SHA256_ASN1, SignatureScheme::ED25519 => &ring::signature::ED25519, SignatureScheme::RSA_PKCS1_SHA256 => &ring::signature::RSA_PKCS1_2048_8192_SHA256, SignatureScheme::RSA_PKCS1_SHA384 => &ring::signature::RSA_PKCS1_2048_8192_SHA384, SignatureScheme::RSA_PKCS1_SHA512 => &ring::signature::RSA_PKCS1_2048_8192_SHA512, SignatureScheme::RSA_PSS_SHA256 => &ring::signature::RSA_PSS_2048_8192_SHA256, SignatureScheme::RSA_PSS_SHA384 => &ring::signature::RSA_PSS_2048_8192_SHA384, SignatureScheme::RSA_PSS_SHA512 => &ring::signature::RSA_PSS_2048_8192_SHA512, _ => return Err(VerifyError::UnsupportedScheme { scheme }), }; let public_key = match x509_parser::x509::SubjectPublicKeyInfo::from_der(&spki) { Ok((_remain, spki)) => spki.subject_public_key, Err(_error) => unreachable!("rustls returned an invalid peer_certificates."), }; Ok( ring::signature::UnparsedPublicKey::new(algorithm, public_key) .verify(data, signature) .is_ok(), ) } ================================================ FILE: qconnection/src/tls/client_auth.rs ================================================ use std::{ ops::{BitAnd, Deref}, sync::Arc, }; use tokio::sync::SetOnce; use crate::prelude::{LocalAgent, RemoteAgent}; #[derive(Default, Clone, Debug, PartialEq, Eq)] pub enum ClientNameVerifyResult { #[default] Accept, /// Refuse the connection with a reason that will be sent to the client. Refuse(String), /// Refuse the connection silently without sending any reason to the client. /// /// Left a reason for logging purpose only. SilentRefuse(String), } impl BitAnd for ClientNameVerifyResult { type Output = Self; fn bitand(self, rhs: Self) -> Self::Output { use ClientNameVerifyResult::*; match (self, rhs) { (Accept, Accept) => Accept, (SilentRefuse(reason), ..) | (.., SilentRefuse(reason)) => SilentRefuse(reason), (Refuse(reason), ..) | (.., Refuse(reason)) => Refuse(reason), } } } #[derive(Default, Clone, Debug, PartialEq, Eq)] pub enum ClientAgentVerifyResult { #[default] Accept, Refuse(String), } impl BitAnd for ClientAgentVerifyResult { type Output = Self; fn bitand(self, rhs: Self) -> Self::Output { use ClientAgentVerifyResult::*; match (self, rhs) { (Accept, Accept) => Accept, (Refuse(reason), ..) | (.., Refuse(reason)) => Refuse(reason), } } } pub trait AuthClient: Send + Sync { fn verify_client_name( &self, server_agent: &LocalAgent, client_name: Option<&str>, ) -> ClientNameVerifyResult; fn verify_client_agent( &self, server_agent: &LocalAgent, client_agent: &RemoteAgent, ) -> ClientAgentVerifyResult; } #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct AcceptAllClientAuther; impl AuthClient for AcceptAllClientAuther { fn verify_client_name(&self, _: &LocalAgent, _: Option<&str>) -> ClientNameVerifyResult { ClientNameVerifyResult::Accept } fn verify_client_agent(&self, _: &LocalAgent, _: &RemoteAgent) -> ClientAgentVerifyResult { ClientAgentVerifyResult::Accept } } #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct ClientNameAuther; impl AuthClient for ClientNameAuther { fn verify_client_name(&self, _: &LocalAgent, _: Option<&str>) -> ClientNameVerifyResult { ClientNameVerifyResult::Accept } fn verify_client_agent( &self, _: &LocalAgent, client_agent: &RemoteAgent, ) -> ClientAgentVerifyResult { use x509_parser::prelude::*; macro_rules! refuse { ($($tt:tt)*) => { return ClientAgentVerifyResult::Refuse(format!($($tt)*)) }; } let cert = match x509_parser::parse_x509_certificate(&client_agent.cert_chain()[0]) { Ok((_remain, cert)) => cert, Err(error) => refuse!("Invalid certificate: {error}"), }; let san = match cert.subject_alternative_name() { Ok(Some(san)) => san, Ok(None) => refuse!("Missing SAN in certificate"), Err(error) => refuse!("Invalid SAN in certificate: {error}"), }; if san.value.general_names.iter().any(|name| match name { GeneralName::DNSName(name) => *name == client_agent.name(), _ => false, }) { return ClientAgentVerifyResult::Accept; } refuse!("Client name not verified by client certificate") } } impl AuthClient for &A { fn verify_client_name( &self, server_agent: &LocalAgent, client_name: Option<&str>, ) -> ClientNameVerifyResult { A::verify_client_name(self, server_agent, client_name) } fn verify_client_agent( &self, server_agent: &LocalAgent, client_agent: &RemoteAgent, ) -> ClientAgentVerifyResult { A::verify_client_agent(self, server_agent, client_agent) } } impl AuthClient for Box { fn verify_client_name( &self, server_agent: &LocalAgent, client_name: Option<&str>, ) -> ClientNameVerifyResult { self.deref().verify_client_name(server_agent, client_name) } fn verify_client_agent( &self, server_agent: &LocalAgent, client_agent: &RemoteAgent, ) -> ClientAgentVerifyResult { self.deref().verify_client_agent(server_agent, client_agent) } } impl AuthClient for Arc { fn verify_client_name( &self, server_agent: &LocalAgent, client_name: Option<&str>, ) -> ClientNameVerifyResult { self.deref().verify_client_name(server_agent, client_name) } fn verify_client_agent( &self, server_agent: &LocalAgent, client_agent: &RemoteAgent, ) -> ClientAgentVerifyResult { self.deref().verify_client_agent(server_agent, client_agent) } } macro_rules! impl_auth_client_for_tuple { ($head:ident $($tail:ident)*) => { impl_auth_client_for_tuple!(@impl $head $($tail)*); impl_auth_client_for_tuple!($($tail)*); }; (@impl $($t:ident)*) => { impl<$($t,)*> AuthClient for ($($t,)*) where $($t: AuthClient,)* { fn verify_client_name( &self, server_agent: &LocalAgent, client_name: Option<&str> ) -> ClientNameVerifyResult { #[allow(non_snake_case)] let ($($t,)*) = self; $($t.verify_client_name(server_agent, client_name) &)* Default::default() } fn verify_client_agent( &self, server_agent: &LocalAgent, client_agent: &RemoteAgent ) -> ClientAgentVerifyResult { #[allow(non_snake_case)] let ($($t,)*) = self; $($t.verify_client_agent(server_agent, client_agent) &)* Default::default() } } }; () => {} } impl_auth_client_for_tuple! { Z Y X W V U T S R Q P O N M L K J I H G F E D C B A } /// A gate that controls server transmission permissions during parameter verification. /// /// `SendLock` is used by the server to restrict data transmission until transport /// parameter validation and server name verification are completed. It provides operations to: /// - `request_permit()`: Request permission to send (public method) /// - `grant_permit()`: Grant permission to send (internal method, pub(super) visibility) /// /// This mechanism ensures that the server sends no data until it has properly validated /// the client's transport parameters and verified the requested server name (SNI), /// enhancing security by preventing premature data transmission before proper validation. #[derive(Default, Debug, Clone)] pub struct ArcSendLock(Arc>); impl ArcSendLock { /// Create a new `SendLock` in the restricted state. /// /// Transmission will be blocked until client parameters and server /// verification are completed, or when silent rejection is not enabled. /// /// Usually for server, which needs to do extra verify client name and certs. pub fn new() -> Self { Self::default() } /// Create a new `SendLock` in the unrestricted state. /// /// Transmission is immediately permitted, used when silent rejection /// is disabled or verification has already been completed. /// /// Usually for client, which does not need to do extra verify server name and certs. pub fn unrestricted() -> Self { Self(Arc::new(SetOnce::new_with(Some(())))) } /// Request permission to send data. /// /// This method will block until client parameters and server verification /// are completed, or connection error occured. /// /// This method will not block when silent rejection is not enabled pub async fn request_permit(&self) { _ = self.0.wait().await } /// Check if transmission is currently permitted. pub fn is_permitted(&self) -> bool { self.0.get().is_some() } /// Grant permission for transmission. /// /// Called after client parameters and server verification are completed /// successfully. Unblocks all pending transmission requests. pub fn grant_permit(&self) { _ = self.0.set(()); } } ================================================ FILE: qconnection/src/tls.rs ================================================ mod agent; mod client_auth; use std::{ future::Future, sync::{Arc, Mutex, MutexGuard}, task::{Context, Poll, Waker}, }; pub use agent::{LocalAgent, RemoteAgent, SignError, VerifyError}; pub use client_auth::{ AcceptAllClientAuther, ArcSendLock, AuthClient, ClientAgentVerifyResult, ClientNameVerifyResult, }; use futures::{future::poll_fn, never::Never}; use qbase::{ Epoch, error::{Error, ErrorKind, QuicError}, packet::keys::{ArcKeys, ArcOneRttKeys, ArcZeroRttKeys, DirectionalKeys}, param::{ArcParameters, ClientParameters, ParameterId, ServerParameters, WriteParameters}, }; use qrecovery::crypto::CryptoStream; use rustls::{ ClientConfig, HandshakeKind, ServerConfig, SignatureScheme, client::ResolvesClientCert, quic::{ClientConnection, KeyChange, ServerConnection}, server::{ClientHello, ResolvesServerCert}, sign::CertifiedKey, }; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use crate::{Handshake, tls::client_auth::ClientNameAuther}; pub enum TlsSession { Client(ClientTlsSession), Server(ServerTlsSession), } pub const QUIC_VERSION: rustls::quic::Version = rustls::quic::Version::V1; impl TlsSession { fn poll_read_hs(&mut self, cx: &mut Context, buf: &mut Vec) -> Poll> { match match self { TlsSession::Client(session) => session.tls_conn.write_hs(buf), TlsSession::Server(session) => session.tls_conn.write_hs(buf), } { None if buf.is_empty() => { match self { TlsSession::Client(session) => session.read_waker = Some(cx.waker().clone()), TlsSession::Server(session) => session.read_waker = Some(cx.waker().clone()), } Poll::Pending } key_change => Poll::Ready(key_change), } } fn write_hs(&mut self, buf: &[u8]) -> Result<(), rustls::Error> { match self { TlsSession::Client(ClientTlsSession { tls_conn, .. }) => tls_conn.read_hs(buf)?, TlsSession::Server(ServerTlsSession { tls_conn, .. }) => tls_conn.read_hs(buf)?, } if let Some(waker) = match self { TlsSession::Client(ClientTlsSession { read_waker, .. }) => read_waker.take(), TlsSession::Server(ServerTlsSession { read_waker, .. }) => read_waker.take(), } { waker.wake(); } Ok(()) } fn alert(&self) -> Option { match self { TlsSession::Client(session) => session.tls_conn.alert(), TlsSession::Server(session) => session.tls_conn.alert(), } } fn is_handshaking(&self) -> bool { match self { TlsSession::Client(session) => session.tls_conn.is_handshaking(), TlsSession::Server(session) => session.tls_conn.is_handshaking(), } } fn handshake_kind(&self) -> Option { match self { TlsSession::Client(session) => session.tls_conn.handshake_kind(), TlsSession::Server(session) => session.tls_conn.handshake_kind(), } } fn is_finished(&self) -> bool { !self.is_handshaking() && self.handshake_kind().is_some() } fn r#yield(&self) -> TlsHandshakeInfo { const INCOMPLETE: &str = ""; match self { TlsSession::Client(tls_session) => TlsHandshakeInfo::Client { zero_rtt_accepted: tls_session.zero_rtt_accepted.expect(INCOMPLETE), local_agent: tls_session.local_agent().clone(), remote_agent: tls_session.remote_agent.clone().expect(INCOMPLETE), }, TlsSession::Server(tls_session) => TlsHandshakeInfo::Server { local_agent: tls_session.local_agent().clone().expect(INCOMPLETE), remote_agent: tls_session.remote_agent.clone(), }, } } } pub struct ClientTlsSession { server_name: String, tls_conn: ClientConnection, read_waker: Option, // shared with ClientCertResolver local_agent: Arc>>, zero_rtt_accepted: Option, remote_agent: Option, } #[derive(Debug, Clone)] struct ClientCertResolver { client_name: Arc, inner: Arc, client_agent: Arc>>, } impl ResolvesClientCert for ClientCertResolver { fn resolve( &self, root_hint_subjects: &[&[u8]], sigschemes: &[SignatureScheme], ) -> Option> { self.inner .resolve(root_hint_subjects, sigschemes) .inspect(|resolved_cert| { let client_agent = LocalAgent::new(self.client_name.clone(), resolved_cert.clone()); let old = self.client_agent.lock().unwrap().replace(client_agent); assert!( old.is_none(), "unreachable: qconnection::tls::ClientCertResolver resolve only once" ) }) } fn only_raw_public_keys(&self) -> bool { self.inner.only_raw_public_keys() } fn has_certs(&self) -> bool { self.inner.has_certs() } } impl ClientTlsSession { pub fn init( server_name: String, mut tls_config: Arc, client_params: &ClientParameters, ) -> Result { let mut params_buf = Vec::with_capacity(1024); params_buf.put_parameters(client_params); let local_agent = Arc::new(Mutex::new(None)); // 通过注入ServerCertResolver实现CertifiedKey向上传递 if let Some(client_name) = client_params.get::(ParameterId::ClientName) { let tls_config = Arc::make_mut(&mut tls_config); tls_config.client_auth_cert_resolver = Arc::new(ClientCertResolver { client_name: client_name.into(), inner: tls_config.client_auth_cert_resolver.clone(), client_agent: local_agent.clone(), }); }; let name = rustls::pki_types::ServerName::try_from(server_name.clone()) .map_err(|e| rustls::Error::Other(rustls::OtherError(Arc::new(e))))?; let tls_conn = ClientConnection::new(tls_config, QUIC_VERSION, name, params_buf)?; let tls_session = Self { local_agent, server_name, tls_conn, read_waker: None, zero_rtt_accepted: None, remote_agent: None, }; Ok(tls_session) } fn local_agent(&self) -> MutexGuard<'_, Option> { self.local_agent.lock().expect("Poison") } #[must_use] pub fn load_zero_rtt(&self) -> Option<(ServerParameters, DirectionalKeys)> { match ( self.tls_conn.quic_transport_parameters(), self.tls_conn.zero_rtt_keys(), ) { (Some(raw_params), Some(keys)) => { let params = ServerParameters::parse_from_bytes(raw_params).ok()?; Some((params, keys.into())) } _ => None, } } fn try_process_sh(&mut self) { self.remote_agent = (self.tls_conn.peer_certificates()) .map(|cert| RemoteAgent::new(self.server_name.as_str().into(), Arc::from(cert))) } fn try_process_ee(&mut self, parameters: &ArcParameters) -> Result<(), Error> { let Some(handshake_kind) = self.tls_conn.handshake_kind() else { return Ok(()); }; let raw_params = self .tls_conn .quic_transport_parameters() .expect("Parameters must be known at this point"); let mut parameters = parameters.lock_guard()?; let remebered = parameters.remembered().cloned(); let params = ServerParameters::parse_from_bytes(raw_params)?; self.zero_rtt_accepted = Some( matches!(remebered, Some(remembered) if remembered.is_0rtt_accepted(¶ms)) && matches!(handshake_kind, rustls::HandshakeKind::Resumed), ); parameters.recv_remote_params(params)?; Ok(()) } } impl Drop for ClientTlsSession { fn drop(&mut self) { if let Some(read_waker) = self.read_waker.take() { read_waker.wake(); } } } pub struct ServerTlsSession { client_auther: Box, tls_conn: ServerConnection, read_waker: Option, // shared with ServerCertResolver local_agent: Arc>>, client_name: Option>, send_lock: ArcSendLock, remote_agent: Option, } #[derive(Debug, Clone)] struct ServerCertResolver { inner: Arc, server_agent: Arc>>, } impl ResolvesServerCert for ServerCertResolver { fn resolve(&self, client_hello: ClientHello<'_>) -> Option> { let server_name = client_hello.server_name()?.into(); self.inner.resolve(client_hello).inspect(|resolved_cert| { let sever_agent = LocalAgent::new(server_name, resolved_cert.clone()); let old = self.server_agent.lock().unwrap().replace(sever_agent); assert!( old.is_none(), "unreachable: qconnection::tls::ServerCertResolver resolve only once" ) }) } fn only_raw_public_keys(&self) -> bool { self.inner.only_raw_public_keys() } } impl ServerTlsSession { pub fn init( mut tls_config: Arc, server_params: &ServerParameters, client_auther: Box, ) -> Result { let mut params_buf = Vec::with_capacity(1024); params_buf.put_parameters(server_params); let local_agent = Arc::new(Mutex::new(None)); // 通过注入ServerCertResolver实现CertifiedKey向上传递 { let tls_config = Arc::make_mut(&mut tls_config); tls_config.cert_resolver = Arc::new(ServerCertResolver { inner: tls_config.cert_resolver.clone(), server_agent: local_agent.clone(), }); }; let tls_conn = ServerConnection::new(tls_config, QUIC_VERSION, params_buf)?; let tls_session = Self { client_auther, tls_conn, read_waker: None, local_agent, client_name: None, send_lock: ArcSendLock::new(), remote_agent: None, }; Ok(tls_session) } pub fn send_lock(&self) -> &ArcSendLock { &self.send_lock } fn local_agent(&self) -> MutexGuard<'_, Option> { self.local_agent.lock().expect("Poison") } pub fn server_name(&self) -> Option { Some(self.local_agent().as_ref()?.name().to_owned()) } fn try_process_ch( &mut self, parameters: &ArcParameters, zero_rtt_keys: &ArcZeroRttKeys, ) -> Result<(), Error> { let client_params = ClientParameters::parse_from_bytes( self.tls_conn .quic_transport_parameters() .expect("Client parameters must be present in ClientHello"), )?; let client_name = client_params.get::(ParameterId::ClientName); let server_agent = self.local_agent().clone().ok_or_else(|| { QuicError::with_default_fty(ErrorKind::ConnectionRefused, "Missing SNI in client hello") })?; match self .client_auther .verify_client_name(&server_agent, client_name.as_deref()) { ClientNameVerifyResult::Accept => { self.send_lock.grant_permit(); tracing::debug!(?client_name); self.client_name = client_name.map(Arc::from); parameters.lock_guard()?.recv_remote_params(client_params)?; match self.tls_conn.zero_rtt_keys() { Some(keys) => zero_rtt_keys.set_keys(keys.into()), None => _ = zero_rtt_keys.invalid(), } Ok(()) } ClientNameVerifyResult::Refuse(reason) => { self.send_lock.grant_permit(); tracing::debug!( target: "quic", server_name = %server_agent.name(), client_name = ?self.client_name.as_deref(), ?reason, "Client name verification failed, refusing connection." ); Err(Error::Quic(QuicError::with_default_fty( ErrorKind::ConnectionRefused, reason, ))) } ClientNameVerifyResult::SilentRefuse(reason) => { tracing::debug!( target: "quic", server_name = %server_agent.name(), client_name = ?self.client_name.as_deref(), ?reason, "Client name verification failed, refusing connection silently." ); Err(Error::Quic(QuicError::with_default_fty( ErrorKind::ConnectionRefused, "", ))) } } } fn try_process_cert(&mut self) -> Result<(), Error> { let Some(client_name) = self.client_name.as_ref() else { return Ok(()); }; let Some(client_cert) = self.tls_conn.peer_certificates().map(Arc::from) else { return Ok(()); }; let client_agent = RemoteAgent::new(client_name.clone(), client_cert); let server_agent = self .local_agent() .clone() .expect("Server name must be known at this point"); match (ClientNameAuther, &self.client_auther) .verify_client_agent(&server_agent, &client_agent) { ClientAgentVerifyResult::Accept => { self.remote_agent = Some(client_agent); Ok(()) } ClientAgentVerifyResult::Refuse(reason) => { tracing::debug!( target: "quic", server_name = %server_agent.name(), ?self.client_name, ?reason, "Client certificate verification failed, refusing connection." ); Err(Error::Quic(QuicError::with_default_fty( ErrorKind::ConnectionRefused, reason, ))) } } } } impl Drop for ServerTlsSession { fn drop(&mut self) { if let Some(read_waker) = self.read_waker.take() { read_waker.wake(); } } } #[derive(Debug, Clone)] pub enum TlsHandshakeInfo { Client { local_agent: Option, remote_agent: RemoteAgent, zero_rtt_accepted: bool, }, Server { local_agent: LocalAgent, remote_agent: Option, }, } impl TlsHandshakeInfo { pub fn zero_rtt_accepted(&self) -> Option { match self { TlsHandshakeInfo::Client { zero_rtt_accepted, .. } => Some(*zero_rtt_accepted), TlsHandshakeInfo::Server { .. } => None, } } } enum InfoState { Demand(Vec), Ready(Arc), } impl InfoState { fn set(&mut self, info: Arc) { // wakers woken in drop *self = Self::Ready(info); } fn poll_get(&mut self, cx: &mut Context) -> Poll> { match self { InfoState::Demand(wakers) => { wakers.push(cx.waker().clone()); Poll::Pending } InfoState::Ready(tls_handshake_info) => Poll::Ready(tls_handshake_info.clone()), } } fn get(&self) -> Option<&Arc> { match self { InfoState::Demand(..) => None, InfoState::Ready(tls_handshake_info) => Some(tls_handshake_info), } } } impl Default for InfoState { fn default() -> Self { Self::Demand(vec![]) } } impl Drop for InfoState { fn drop(&mut self) { if let Self::Demand(wakers) = self { for waker in wakers.drain(..) { waker.wake(); } } } } pub struct TlsHandshake { session: TlsSession, info: InfoState, } #[derive(Clone)] pub struct ArcTlsHandshake(Arc>>); impl ArcTlsHandshake { pub fn new(session: TlsSession) -> ArcTlsHandshake { Self(Arc::new(Mutex::new(Ok(TlsHandshake { session, info: Default::default(), })))) } fn state(&self) -> MutexGuard<'_, Result> { self.0.lock().unwrap() } async fn read_hs(&self, buf: &mut Vec) -> Result, Error> { poll_fn(|cx| { let mut tls_handshake = self.state(); match tls_handshake.as_mut() { Ok(state) => state.session.poll_read_hs(cx, buf).map(Ok), Err(e) => Poll::Ready(Err(e.clone())), } }) .await } fn write_hs(&self, buf: &[u8]) -> Result<(), Error> { let mut tls_handshake = self.state(); let tls_handshake = tls_handshake.as_mut().map_err(|e| e.clone())?; match tls_handshake.session.write_hs(buf) { Ok(_) => Ok(()), Err(error) => { let error_kind = match tls_handshake.session.alert() { Some(alert) => ErrorKind::Crypto(alert.into()), None => ErrorKind::ProtocolViolation, }; Err(Error::Quic(QuicError::with_default_fty( error_kind, format!("TLS error: {error}"), ))) } } } pub fn info( &self, ) -> impl Future, Error>> + Unpin + use<'_> { poll_fn(|cx| { let mut tls_handshake = self.state(); match tls_handshake.as_mut() { Ok(state) => state.info.poll_get(cx).map(Ok), Err(e) => Poll::Ready(Err(e.clone())), } }) } pub fn is_finished(&self) -> Result { let tls_handshake = self.state(); match tls_handshake.as_ref() { Ok(state) => Ok(state.session.is_finished()), Err(e) => Err(e.clone()), } } pub fn server_name(&self) -> Result, Error> { let tls_handshake = self.state(); let tls_handshake = tls_handshake.as_ref().map_err(|error| error.clone())?; Ok(match &tls_handshake.session { TlsSession::Client(session) => Some(session.server_name.clone()), TlsSession::Server(session) => session.server_name(), }) } pub fn on_conn_error(&self, error: &Error) { *self.state() = Err(error.clone()) } fn try_process_tls_message( &self, parameters: &ArcParameters, zero_rtt_keys: &ArcZeroRttKeys, ) -> Result>, Error> { let mut state = self.state(); let tls_handshake = state.as_mut().map_err(|e| e.clone())?; match &mut tls_handshake.session { TlsSession::Client(session) => { if session.remote_agent.is_none() { session.try_process_sh(); } if !parameters.lock_guard()?.is_remote_params_received() { session.try_process_ee(parameters)?; } } TlsSession::Server(session) => { if !parameters.lock_guard()?.is_remote_params_received() { session.try_process_ch(parameters, zero_rtt_keys)?; } if session.remote_agent.is_none() { session.try_process_cert()?; } } } if tls_handshake.session.is_finished() && tls_handshake.info.get().is_none() { let info = Arc::new(tls_handshake.session.r#yield()); tracing::debug!(target: "quic", "TLS handshake finished"); tls_handshake.info.set(info.clone()); return Ok(Some(info)); } Ok(None) } pub fn start( self, parameters: ArcParameters, quic_handshake: Handshake, crypto_streams: [CryptoStream; 3], (handshake_keys, zero_rtt_keys, one_rtt_keys): (ArcKeys, ArcZeroRttKeys, ArcOneRttKeys), on_handshake_conmplete: impl FnOnce(&TlsHandshakeInfo) -> Result<(), Error> + Send + 'static, ) -> impl futures::Future> + Send + 'static { let mut on_handshake_conmplete = Some(on_handshake_conmplete); let crypto_read_task = |epoch: Epoch| { let tls_handshake = self.clone(); let mut stream_reader = crypto_streams[epoch].reader(); async move { let mut buf = [0; 2048]; while let Ok(read) = stream_reader.read(&mut buf).await { tls_handshake.write_hs(&buf[..read])?; } Result::<_, Error>::Ok(()) } }; let [initial_read_task, handshake_read_task, data_read_task] = Epoch::EPOCHS.map(|epoch: Epoch| crypto_read_task(epoch)); let mut crypto_writers = Epoch::EPOCHS.map(|epoch: Epoch| crypto_streams[epoch].writer().clone()); let crypto_write_task = async move { let mut buf = Vec::with_capacity(2048); let mut cur_epoch = Epoch::Initial; loop { let key_change = self.read_hs(&mut buf).await?; if !buf.is_empty() { // error: crypto buffer offset overflow (crypto_writers[cur_epoch].write_all(&buf).await).map_err(|e| { QuicError::with_default_fty(ErrorKind::Internal, format!("{e:?}")) })?; buf.clear(); } match key_change { Some(KeyChange::Handshake { keys }) => { handshake_keys.set_keys(keys.into()); quic_handshake.got_handshake_key(); cur_epoch = Epoch::Handshake; } Some(KeyChange::OneRtt { keys, next }) => { one_rtt_keys.set_keys(keys, next); cur_epoch = Epoch::Data; } None => {} }; if let Some(info) = self.try_process_tls_message(¶meters, &zero_rtt_keys)? { (on_handshake_conmplete.take().expect("TLS complete twice"))(&info)?; } } }; // rustc: error[E0282]: type annotations needed let crypto_write_task = async move { let result: Result = crypto_write_task.await; result }; async move { tokio::try_join!( initial_read_task, handshake_read_task, data_read_task, crypto_write_task, )?; Ok(()) } } } ================================================ FILE: qconnection/src/traversal.rs ================================================ use std::{io, net::SocketAddr}; use futures::{StreamExt, stream::FuturesUnordered}; use qbase::{ frame::{PunchHelloFrame, ReliableFrame, io::ReceiveFrame}, net::{ addr::EndpointAddr, route::{Link, Pathway}, tx::Signals, }, packet::{ProductHeader, header::short::OneRttHeader}, }; use qevent::telemetry::Instrument; use qinterface::{bind_uri::BindUri, component::location::AddressEvent}; use qtraversal::nat::client::{ClientLocationData, StunClientsComponent}; use tracing::Instrument as _; use super::Components; use crate::CidRegistry; impl ReceiveFrame<(BindUri, Pathway, Link, ReliableFrame)> for Components { type Output = (); fn recv_frame( &self, frame: (BindUri, Pathway, Link, ReliableFrame), ) -> Result { self.puncher.recv_frame(frame) } } impl ReceiveFrame<(BindUri, Pathway, Link, PunchHelloFrame)> for Components { type Output = (); fn recv_frame( &self, frame: (BindUri, Pathway, Link, PunchHelloFrame), ) -> Result { self.puncher.recv_frame(frame) } } impl Components { pub fn subscribe_local_address(&self) { let mut observer = self.locations.subscribe(); let conn = self.clone(); let future = async move { let handle_address_event = |(bind_uri, event): (BindUri, AddressEvent)| { let event = match event.downcast::>() { Ok(AddressEvent::Upsert(data)) => { // on error: delect from address book // THINK: Err和remove的异同? let Ok(bound_addr) = data.as_ref() else { return; }; let endpoint_addr = EndpointAddr::direct(*bound_addr); conn.add_local_endpoint(bind_uri, endpoint_addr); return; } Ok(AddressEvent::Remove(_type_id)) => return, Ok(AddressEvent::Closed) => return, Err(event) => event, }; let _event = match event.downcast::() { Ok(AddressEvent::Upsert(data)) => { let Ok(endpoint_addr) = data.as_ref() else { return; }; conn.add_local_endpoint(bind_uri.clone(), *endpoint_addr); if matches!(*endpoint_addr, EndpointAddr::Agent { .. }) { _ = conn.add_local_punch_address(bind_uri.clone(), *endpoint_addr); } return; } Ok(AddressEvent::Remove(_type_id)) => return, Ok(AddressEvent::Closed) => return, Err(_event) => return, }; }; loop { tokio::select! { _ = conn.conn_state.terminated() => break, address_event = observer.recv() => { match address_event { Some(event) => handle_address_event(event), None => break, } } } } }; // Terminates when the connection is closed or the observer channel drops. tokio::spawn(future.instrument_in_current().in_current_span()); } // 添加本地直通地址 可以直接新建 path pub fn add_local_endpoint(&self, bind: BindUri, addr: EndpointAddr) { tracing::trace!(target: "quic", bind_uri = %bind, %addr,"add local endpoint"); match self.puncher.add_local_endpoint(bind, addr) { Ok(ways) => { let ways: Vec<(BindUri, Link, qtraversal::PathWay)> = ways; ways.into_iter().for_each(|way| { let _ = self.add_path(way.0, way.1, way.2); }); } Err(error) => { tracing::debug!(target: "quic", ?error, "Add local endpoint failed"); } } } // 添加对端直通地址,可以直接新建 path pub fn add_peer_endpoint(&self, addr: EndpointAddr, source: qresolve::Source) { tracing::trace!(target: "quic", %addr, ?source, "add peer endpoint"); match self.puncher.add_peer_endpoint(addr, source) { Ok(ways) => { ways.into_iter().for_each(|way| { let _ = self.add_path(way.0, way.1, way.2); }); } Err(error) => { tracing::warn!(target: "quic", ?error, "Add peer endpoint failed"); } } } // 添加本地直连地址,用于打洞,不能直接新建路径 pub fn add_local_punch_address( &self, bind_uri: BindUri, endpoint_addr: EndpointAddr, ) -> io::Result<()> { let iface = self .interfaces .borrow(&bind_uri) .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "interface not found"))?; let local_addr = endpoint_addr.addr(); let conn = self.clone(); let tasks = iface.with_component(|clinets: &StunClientsComponent| { clinets.with_clients(|map| { // workaround. clippy issue: https://github.com/rust-lang/rust-clippy/issues/16428 #[allow(clippy::redundant_iter_cloned)] map.values() .cloned() .map(|client| async move { client.nat_type().await }) .collect::>() }) })?; let Some(mut tasks) = tasks else { return Ok(()); }; tokio::spawn( async move { while let Some(result) = tasks.next().await { if let Ok(nat_type) = result { _ = conn.puncher.add_local_address( bind_uri.clone(), local_addr, nat_type, 0, ); } } } .instrument_in_current() .in_current_span(), ); Ok(()) } pub fn remove_address(&self, addr: SocketAddr) { let _ = self.puncher.remove_local_address(addr); } } #[derive(Clone)] pub struct PunchTransaction { cid_registry: CidRegistry, } impl PunchTransaction { pub(crate) fn new(cid_registry: CidRegistry) -> Self { Self { cid_registry } } } impl ProductHeader for PunchTransaction { fn new_header(&self) -> Result { Ok(OneRttHeader::new( false.into(), self.cid_registry .remote .latest_dcid() .ok_or(Signals::CONNECTION_ID)?, )) } } ================================================ FILE: qconnection/src/tx.rs ================================================ use bytes::BufMut; use derive_more::Deref; use qbase::{ frame::{ContainSpec, FrameFeature, Spec}, net::tx::Signals, packet::{ AssemblePacket, PacketInfo, PacketWriter as BasePacketWriter, RecordFrame, header::{EncodeHeader, GetType, io::WriteHeader, long::LongHeader, short::OneRttHeader}, keys::DirectionalKeys, signal::KeyPhaseBit, }, util::ContinuousData, }; use qevent::packet::PacketWriter as QEventPacketWriter; use qrecovery::journal::{ArcSentJournal, NewPacketGuard}; use tokio::time::Duration; #[derive(Deref)] pub struct PacketWriter<'b, 's, F> { #[deref] writer: QEventPacketWriter<'b>, // 不同空间的send guard类型不一样 clerk: NewPacketGuard<'s, F>, retran_timeout: Duration, expire_timeout: Duration, } impl<'b, F> AsRef> for PacketWriter<'b, '_, F> { #[inline] fn as_ref(&self) -> &BasePacketWriter<'b> { &self.writer } } impl<'b, F> AsRef> for PacketWriter<'b, '_, F> { #[inline] fn as_ref(&self) -> &QEventPacketWriter<'b> { &self.writer } } impl<'b, 's, F> PacketWriter<'b, 's, F> { pub fn new_long( header: LongHeader, buffer: &'b mut [u8], keys: DirectionalKeys, journal: &'s ArcSentJournal, retran_timeout: Duration, expire_timeout: Duration, ) -> Result where S: EncodeHeader + 'static, LongHeader: GetType, for<'a> &'a mut [u8]: WriteHeader>, { let clerk = journal.new_packet(); let pn = clerk.pn(); Ok(Self { clerk, writer: QEventPacketWriter::new_long(&header, buffer, pn, keys)?, expire_timeout, retran_timeout, }) } pub fn new_short( header: OneRttHeader, buffer: &'b mut [u8], keys: DirectionalKeys, key_phase: KeyPhaseBit, journal: &'s ArcSentJournal, retran_timeout: Duration, expire_timeout: Duration, ) -> Result { let clerk = journal.new_packet(); let pn = clerk.pn(); Ok(Self { clerk, writer: QEventPacketWriter::new_short(&header, buffer, pn, keys, key_phase)?, expire_timeout, retran_timeout, }) } } unsafe impl<'b, 's, F> BufMut for PacketWriter<'b, 's, F> { #[inline] fn remaining_mut(&self) -> usize { self.writer.remaining_mut() } #[inline] unsafe fn advance_mut(&mut self, cnt: usize) { unsafe { self.writer.advance_mut(cnt) }; } #[inline] fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice { self.writer.chunk_mut() } // steam/datagram可能会手动padding,padding也要被记录,所以这里不能用默认实现 #[inline] fn put_bytes(&mut self, val: u8, cnt: usize) { self.writer.put_bytes(val, cnt); } } impl AssemblePacket for PacketWriter<'_, '_, F> { #[inline] fn encrypt_and_protect_packet(self) -> (usize, PacketInfo) { self.clerk .build_with_time(self.retran_timeout, self.expire_timeout); self.writer.encrypt_and_protect_packet() } } impl<'b, GF, F, D: ContinuousData> RecordFrame for PacketWriter<'b, '_, GF> where QEventPacketWriter<'b>: RecordFrame, for<'f> &'f F: TryInto, { #[inline] fn record_frame(&mut self, frame: &F) { if let Ok(frame) = frame.try_into() { self.clerk.record_frame(frame); } else { self.clerk.record_trivial(); } self.writer.record_frame(frame); } } #[derive(Deref)] pub struct TrivialPacketWriter<'b, 's, F> { #[deref] writer: QEventPacketWriter<'b>, // 不同空间的send guard类型不一样 clerk: NewPacketGuard<'s, F>, } impl<'b, F> AsRef> for TrivialPacketWriter<'b, '_, F> { #[inline] fn as_ref(&self) -> &BasePacketWriter<'b> { &self.writer } } impl<'b, F> AsRef> for TrivialPacketWriter<'b, '_, F> { #[inline] fn as_ref(&self) -> &QEventPacketWriter<'b> { &self.writer } } impl<'b, 's, F> TrivialPacketWriter<'b, 's, F> { #[inline] pub fn new_long( header: LongHeader, buffer: &'b mut [u8], keys: DirectionalKeys, journal: &'s ArcSentJournal, ) -> Result where S: EncodeHeader + 'static, LongHeader: GetType, for<'a> &'a mut [u8]: WriteHeader>, { let clerk = journal.new_packet(); let pn = clerk.pn(); Ok(Self { clerk, writer: QEventPacketWriter::new_long(&header, buffer, pn, keys)?, }) } #[inline] pub fn new_short( header: OneRttHeader, buffer: &'b mut [u8], keys: DirectionalKeys, key_phase: KeyPhaseBit, journal: &'s ArcSentJournal, ) -> Result { let clerk = journal.new_packet(); let pn = clerk.pn(); Ok(Self { clerk, writer: QEventPacketWriter::new_short(&header, buffer, pn, keys, key_phase)?, }) } } unsafe impl<'b, 's, F> BufMut for TrivialPacketWriter<'b, 's, F> { #[inline] fn remaining_mut(&self) -> usize { self.writer.remaining_mut() } #[inline] unsafe fn advance_mut(&mut self, cnt: usize) { unsafe { self.writer.advance_mut(cnt) }; } #[inline] fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice { self.writer.chunk_mut() } // steam/datagram可能会手动padding,padding也要被记录,所以这里不能用默认实现 #[inline] fn put_bytes(&mut self, val: u8, cnt: usize) { self.writer.put_bytes(val, cnt); } } impl AssemblePacket for TrivialPacketWriter<'_, '_, F> { #[inline] fn encrypt_and_protect_packet(self) -> (usize, PacketInfo) { self.clerk.build_trivial(); self.writer.encrypt_and_protect_packet() } } impl<'b, GF, F, D: ContinuousData> RecordFrame for TrivialPacketWriter<'b, '_, GF> where F: FrameFeature, QEventPacketWriter<'b>: RecordFrame, { #[inline] fn record_frame(&mut self, frame: &F) { // however, this will be checked again in NewPacketGuard::build_trivial debug_assert!( frame.specs().contain(Spec::NonAckEliciting), "Frame is not non-ack eliciting {}", std::any::type_name::() ); self.clerk.record_trivial(); self.writer.record_frame(frame); } } ================================================ FILE: qdatagram/Cargo.toml ================================================ [package] name = "qdatagram" version = "0.5.0" edition.workspace = true description = "Datagram transmission of dquic" readme.workspace = true repository.workspace = true license.workspace = true keywords.workspace = true categories.workspace = true rust-version.workspace = true [dependencies] bytes = { workspace = true } futures = { workspace = true } qbase = { workspace = true } tokio = { workspace = true } tracing = { workspace = true } [dev-dependencies] tokio = { workspace = true, features = ["test-util", "macros"] } ================================================ FILE: qdatagram/src/lib.rs ================================================ mod reader; use bytes::Bytes; pub use reader::*; mod writer; use std::io; use qbase::{ error::Error, frame::{DatagramFrame, io::ReceiveFrame}, net::tx::{ArcSendWakers, Signals}, packet::Package, }; pub use writer::*; /// Combination of [`DatagramIncoming`] and [`DatagramOutgoing`] #[derive(Debug, Clone)] pub struct DatagramFlow { /// The incoming datagram frame, see type's doc for more details. incoming: DatagramIncoming, /// The outgoing datagram frame, see type's doc for more details. outgoing: DatagramOutgoing, } impl DatagramFlow { /// Creates a new instance of [`DatagramFlow`]. /// /// This method takes local protocol parameter [`max_datagram_frame_size`], /// the local's transport parameter [`max_datagram_frame_size`] limits the size of the datagram frames that peer /// can send. /// /// [`max_datagram_frame_size`]: https://www.rfc-editor.org/rfc/rfc9221.html#name-transport-parameter #[inline] pub fn new(local_max_datagram_frame_size: u64, tx_wakers: ArcSendWakers) -> Self { Self { incoming: DatagramIncoming::new(local_max_datagram_frame_size as _), outgoing: DatagramOutgoing::new(tx_wakers), } } pub fn try_load_data_into

(&self, packet: &mut P) -> Result<(), Signals> where P: bytes::BufMut + ?Sized, (DatagramFrame, Bytes): Package

, { self.outgoing.try_load_data_into(packet) } /// Create a new **unique** instance of [`DatagramReader`]. /// /// Return an error if the connection is closing or already closed, /// or datagram is disenabled by local. /// /// See [`DatagramIncoming::new_reader`] for more details. #[inline] pub fn reader(&self) -> io::Result { self.incoming.new_reader() } /// Create a new instance of [`DatagramWriter`]. /// /// Return an error if the connection is closing or already closed, /// or datagram is disenabled by peer(`max_datagram_frame_size` is `0`) /// /// See [`DatagramOutgoing::new_writer`] for more details. #[inline] pub fn writer(&self, max_datagram_frame_size: u64) -> io::Result { self.outgoing.new_writer(max_datagram_frame_size) } /// See [`DatagramOutgoing::on_conn_error`] and [`DatagramIncoming::on_conn_error`] for more details. #[inline] pub fn on_conn_error(&self, error: &Error) { self.incoming.on_conn_error(error); self.outgoing.on_conn_error(error); } } /// See [`DatagramIncoming::recv_datagram`] for more details. impl ReceiveFrame<(DatagramFrame, Bytes)> for DatagramFlow { type Output = (); #[inline] fn recv_frame(&self, (frame, body): (DatagramFrame, Bytes)) -> Result { self.incoming.recv_datagram(frame, body) } } ================================================ FILE: qdatagram/src/reader.rs ================================================ use std::{ collections::VecDeque, future::Future, io, pin::Pin, sync::{Arc, Mutex}, task::{Context, Poll, Waker, ready}, }; use bytes::{BufMut, Bytes}; use qbase::{ error::{Error, ErrorKind, QuicError}, frame::{DatagramFrame, EncodeSize, GetFrameType}, }; #[derive(Debug)] struct RawDatagarmReader { local_max_size: usize, rcvd_datagrams: VecDeque, read_waker: Option, } impl RawDatagarmReader { fn new(local_max_size: usize) -> Self { Self { local_max_size, rcvd_datagrams: VecDeque::new(), read_waker: None, } } } #[derive(Debug, Clone)] pub struct DatagramIncoming(Arc>>); impl DatagramIncoming { /// Create a new [`DatagramIncoming`] to receive datagram frames. pub fn new(local_max_size: usize) -> Self { Self(Arc::new(Mutex::new(Ok(RawDatagarmReader::new( local_max_size, ))))) } /// Try to create a new [`DatagramReader`] for the application to read the received datagram frames. /// /// Returns an error when the Unreliable Datagram Extension was disenabled by local parameters, /// see for more delails. pub fn new_reader(&self) -> io::Result { let mut guard = self.0.lock().unwrap(); let reader = guard.as_mut().map_err(|e| e.clone())?; if reader.local_max_size == 0 { tracing::error!(" Cause by: DatagramIncoming::new_reader local_max_size is 0"); return Err(io::Error::new( io::ErrorKind::Unsupported, "Unreliable Datagram Extension was disenabled by local parameters", )); } Ok(DatagramReader(self.0.clone())) } /// Receives a datagram frame for the application to read. /// /// If the size of the received datagram exceeds the maximum size set by the local protocol parameters `max_datagram_frame_size`, /// a connection error occurs. /// /// If the connection is closing or closed, the new datagram will be ignored. /// /// If the application is waiting for the data to be read, the task will be woken up when the datagram is received. pub fn recv_datagram(&self, frame: DatagramFrame, data: bytes::Bytes) -> Result<(), Error> { let mut guard = self.0.lock().unwrap(); let reader = guard.as_mut().map_err(|e| e.clone())?; if (frame.encoding_size() + data.len()) > reader.local_max_size { tracing::error!(" Cause by: DatagramIncoming::recv_datagram"); return Err(QuicError::new( ErrorKind::ProtocolViolation, frame.frame_type().into(), format!( "datagram size {} exceeds the maximum size {}", frame.encoding_size() + data.len(), reader.local_max_size ), ) .into()); } reader.rcvd_datagrams.push_back(data); if let Some(waker) = reader.read_waker.take() { waker.wake(); } Ok(()) } /// When a connection error occurs, the error will be set to the reader. /// /// Any subsequent calls to [`DatagramIncoming::new_reader`], [`DatagramReader::poll_recv`], [`DatagramReader::read`] /// and [`DatagramReader::read_buf`] will return an error. /// /// If there is a task waiting for the data to be read, the task will be woken up and return an error immediately. /// /// All the received datagrams will be discarded, and subsequent calls to [`DatagramIncoming::recv_datagram`] will be ignored. pub fn on_conn_error(&self, error: &Error) { let guard = &mut self.0.lock().unwrap(); if let Ok(reader) = guard.as_mut() { if let Some(waker) = reader.read_waker.take() { waker.wake(); } **guard = Err(error.clone()); } } } // The reader for the application to read the received [datagram frames]. /// /// [datagram frames]: https://www.rfc-editor.org/rfc/rfc9221.html #[derive(Debug, Clone)] pub struct DatagramReader(Arc>>); impl DatagramReader { // Poll to receive a [datagram frame] from peer. /// /// This is the internal implementation of the [`DatagramReader::recv`] method. /// /// If the datagram is not ready, and the connection is active, /// the method will return [`Poll::Pending`] and set the waker for waking up the task when the datagram is received. /// /// Note that only the waker set by the last call may be awakened /// /// While there has a datagram frame received but unread, /// this method will return [`Poll::Ready`] with the received datagram frame as [`Ok`]. /// /// If the connection is closing or already closed, /// this method will return [`Poll::Ready`] with an error as [`Err`]. /// /// [datagram frame]: https://www.rfc-editor.org/rfc/rfc9221.html pub fn poll_recv(&self, cx: &mut Context<'_>) -> Poll> { let mut reader = self.0.lock().unwrap(); match reader.as_mut() { Ok(reader) => match reader.rcvd_datagrams.pop_front() { Some(bytes) => Poll::Ready(Ok(bytes)), None => { reader.read_waker = Some(cx.waker().clone()); Poll::Pending } }, Err(e) => Poll::Ready(Err(io::Error::from(e.clone()))), } } /// Receive a [datagram frame] from peer. /// /// This method is asynchronous and returns a future that resolves to the received datagram. /// /// ``` rust, ignore /// pub async fn recv(&self) -> io::Result /// ``` /// /// The future will yield the received datagram as [`Ok`]. /// /// If the connection is closing or already closed, the future will yield an error as [`Err`]. /// /// The future is *Cancel Safe*. /// /// [datagram frame]: https://www.rfc-editor.org/rfc/rfc9221.html pub fn recv(&mut self) -> RecvDatagram<'_> { RecvDatagram { reader: self } } /// Reads the received [datagram frame] into a mutable slice. /// /// This method is asynchronous and returns a future that resolves to the number of bytes read. /// /// ``` rust, ignore /// pub async fn read(&self, buf: & mut [u8]) -> io::Result /// ``` /// /// The future will yield the size of bytes read from the received datagram as [`Ok`]. /// /// If the buffer is not large enough to hold the received data, the received data will be truncated. /// /// If the connection is closing or already closed, the future will yield an error as [`Err`]. /// /// [datagram frame]: https://www.rfc-editor.org/rfc/rfc9221.html pub fn read<'b>(&'b mut self, buf: &'b mut [u8]) -> ReadIntoSlice<'b> { ReadIntoSlice { reader: self, buf } } /// Reads the received [datagram frame] into a mutable reference to [`bytes::BufMut`]. /// /// This method is asynchronous and returns a future that resolves to the number of bytes read. /// /// ``` rust, ignore /// pub async fn read_buf(&self, buf: & mut [u8]) -> io::Result /// ``` /// /// The future will yield the size of bytes read from the received datagram as [`Ok`]. /// /// If the buffer is not large enough to hold the received data, the behavior is defined by the [`bytes::BufMut::put`] implementation. /// /// If the connection is closing or already closed, the future will yield an error as [`Err`]. /// /// [datagram frame]: https://www.rfc-editor.org/rfc/rfc9221.html pub fn read_buf<'b, B: BufMut>(&'b mut self, buf: &'b mut B) -> ReadIntoBuf<'b, B> { ReadIntoBuf { reader: self, buf } } } /// The [`Future`] created by [`DatagramReader::recv`], see [`DatagramReader::recv`] for more. pub struct RecvDatagram<'a> { reader: &'a mut DatagramReader, } impl Future for RecvDatagram<'_> { type Output = io::Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.reader.poll_recv(cx) } } /// the [`Future`] created by [`DatagramReader::read`], see [`DatagramReader::read`] for more. pub struct ReadIntoSlice<'a> { reader: &'a mut DatagramReader, buf: &'a mut [u8], } impl Future for ReadIntoSlice<'_> { type Output = io::Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let s = self.get_mut(); let bytes = ready!(s.reader.poll_recv(cx)?); let len = bytes.len().min(s.buf.len()); s.buf[..len].copy_from_slice(&bytes[..len]); Poll::Ready(Ok(len)) } } /// the [`Future`] created by [`DatagramReader::read_buf`], see [`DatagramReader::read_buf`] for more. pub struct ReadIntoBuf<'a, B> { reader: &'a mut DatagramReader, buf: &'a mut B, } impl Future for ReadIntoBuf<'_, B> where B: BufMut, { type Output = io::Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let s = self.get_mut(); let bytes = ready!(s.reader.poll_recv(cx)?); let len = bytes.len(); s.buf.put(bytes); Poll::Ready(Ok(len)) } } #[cfg(test)] mod tests { use qbase::{frame::FrameType, varint::VarInt}; use super::*; #[tokio::test] async fn test_datagram_reader_recv_buf() { let incoming = DatagramIncoming::new(1024); let recv = tokio::spawn({ let mut reader = incoming.new_reader().unwrap(); async move { let n = reader.read(&mut [0u8; 1024]).await.unwrap(); assert_eq!(n, 11); } }); incoming .recv_datagram( DatagramFrame::new(false, VarInt::from_u32(11)), Bytes::from_static(b"hello world"), ) .unwrap(); recv.await.unwrap(); } #[tokio::test] async fn test_datagram_reader_on_conn_error() { let incoming = DatagramIncoming::new(1024); let error = QuicError::new( ErrorKind::ProtocolViolation, FrameType::Datagram(0).into(), "protocol violation", ) .into(); incoming.on_conn_error(&error); let new_reader = incoming.new_reader(); assert!(new_reader.is_err()); assert_eq!(new_reader.unwrap_err().kind(), io::ErrorKind::BrokenPipe); } } ================================================ FILE: qdatagram/src/writer.rs ================================================ use std::{ collections::VecDeque, io, ops::DerefMut, sync::{Arc, Mutex}, }; use bytes::{BufMut, Bytes}; use qbase::{ error::Error, frame::{DatagramFrame, EncodeSize}, net::tx::{ArcSendWakers, Signals}, packet::Package, varint::VarInt, }; #[derive(Debug)] struct RawDatagramWriter { /// The queue that stores the datagram frame to send. datagrams: VecDeque, tx_wakers: ArcSendWakers, } impl RawDatagramWriter { fn new(tx_wakers: ArcSendWakers) -> Self { Self { datagrams: VecDeque::new(), tx_wakers, } } } /// The struct for protocol layer to mange the outgoing side of the datagram flow. #[derive(Debug, Clone)] pub struct DatagramOutgoing(Arc>>); impl DatagramOutgoing { pub fn new(tx_wakers: ArcSendWakers) -> DatagramOutgoing { DatagramOutgoing(Arc::new(Mutex::new(Ok(RawDatagramWriter::new(tx_wakers))))) } /// Try to reate a new instance of [`DatagramWriter`]. /// /// This method takes the remote transport parameters `max_datagram_frame_size`. /// /// Return an error if the connection is closing or already closed, /// or datagram is disenabled by peer(`max_datagram_frame_size` is `0`) pub fn new_writer(&self, max_datagram_frame_size: u64) -> io::Result { let mut guard = self.0.lock().unwrap(); let _writer = guard.as_mut().map_err(|e| e.clone())?; if max_datagram_frame_size == 0 { tracing::error!(" Cause by: DatagramOutgoing::new_writer"); return Err(io::Error::new( io::ErrorKind::Unsupported, "Unreliable Datagram Extension was disenabled by peer's parameters", )); } Ok(DatagramWriter { writer: self.0.clone(), max_datagram_frame_size: max_datagram_frame_size as _, }) } // Same logic with `try_load_data_into`, only used for test purpose. #[cfg(test)] fn try_read_datagram(&self, mut buf: &mut [u8]) -> Option<(DatagramFrame, usize)> { use qbase::frame::io::WriteDataFrame; let mut guard = self.0.lock().unwrap(); let Ok(writer) = guard.as_mut() else { return None; }; let datagram = writer.datagrams.front()?; let available = buf.remaining_mut(); let max_encoding_size = available.saturating_sub(datagram.len()); if max_encoding_size == 0 { return None; } let data = writer.datagrams.pop_front().expect("unreachable"); let data_len = VarInt::try_from(data.len()).unwrap(); let frame_without_len = DatagramFrame::new(false, data_len); let frame_with_len = DatagramFrame::new(true, data_len); let frame = match max_encoding_size { // Encode length n if n >= frame_with_len.encoding_size() => { buf.put_data_frame(&frame_with_len, &data); frame_with_len } // Do not encode length, may need padding n => { buf.put_bytes(0, n - frame_without_len.encoding_size()); buf.put_data_frame(&frame_without_len, &data); frame_without_len } }; Some((frame, available - buf.remaining_mut())) } /// Attempts to load the datagram frame into the packet. /// /// # Encoding /// /// [`DatagramFrame`] has two types: /// - frame type `0x30`: The datagram frame without the data's length. /// /// The size of this form of frame is `1 byte` + `the size of the data`. /// /// - frame type `0x31`: The datagram frame with the data's length. /// /// The size of this form of frame is `1 byte` + `the size of the data's length` + `the size of the data`. /// /// The datagram won't be split into multiple frames. If the remaining space of packet is not enough to encode the datagram frame, /// the datagram will not be loaded. /// /// This method tries to encode the [`DatagramFrame`] with the data's length first (frame type `0x31`). /// /// If remaining space of the packet is not enough to encode the length, /// it will encode the [`DatagramFrame`] without the data's length (frame type `0x30`). /// Because no frame can be put after the datagram frame without length, /// padding frames will be put before the datagram frame. /// In this case, the packet will be filled. pub fn try_load_data_into

(&self, packet: &mut P) -> Result<(), Signals> where P: BufMut + ?Sized, (DatagramFrame, Bytes): Package

, { let mut guard = self.0.lock().unwrap(); let Ok(writer) = guard.as_mut() else { return Err(Signals::empty()); // connection closed }; let Some(datagram) = writer.datagrams.front() else { return Err(Signals::TRANSPORT); }; let available = packet.remaining_mut(); let max_encoding_size = available.saturating_sub(datagram.len()); if max_encoding_size == 0 { return Err(Signals::CONGESTION); } let data = writer.datagrams.pop_front().expect("unreachable"); let data_len = VarInt::try_from(data.len()).unwrap(); let frame_without_len = DatagramFrame::new(false, data_len); let frame_with_len = DatagramFrame::new(true, data_len); match max_encoding_size { // Encode length n if n >= frame_with_len.encoding_size() => { (frame_with_len, data).dump(packet).unwrap(); } // Do not encode length, may need padding n => { packet.put_bytes(0, n - frame_without_len.encoding_size()); (frame_without_len, data).dump(packet).unwrap(); } } Ok(()) } /// When a connection error occurs, set the internal state to an error state. /// /// Any subsequent calls to [`DatagramWriter::send`] or [`DatagramWriter::send_bytes`] will return an error. /// All datagrams in the internal queue will be dropped and not sent to the peer. pub fn on_conn_error(&self, error: &Error) { let writer = &mut self.0.lock().unwrap(); if writer.is_ok() { **writer = Err(error.clone()); } } } /// The writer for application to send the [datagram frames] to the peer. /// /// You can clone the writer or wrapper it in an [`Arc`] to send the datagram frames in many tasks. /// /// [datagram frames]: https://www.rfc-editor.org/rfc/rfc9221.html #[derive(Debug, Clone)] pub struct DatagramWriter { writer: Arc>>, /// The maximum size of the datagram frame that can be sent to the peer. /// /// The value is set by the remote peer, and the protocol layer will use this value to limit the size of the datagram frame. /// /// If the size of the datagram frame exceeds this value, the protocol layer will return an error. /// /// See [RFC](https://www.rfc-editor.org/rfc/rfc9221.html#name-transport-parameter) for more details. max_datagram_frame_size: usize, } impl DatagramWriter { /// Send unreliable data to the peer. /// /// The `data` will not be sent immediately, and the `data` sent is not guaranteed to be delivered. /// /// If the peer dont support want to receive datagram frames, the method will return an error. /// /// The size of the datagram frame is limited by the `max_datagram_frame_size` transport parameter set by the peer. /// See [RFC](https://www.rfc-editor.org/rfc/rfc9221.html#name-transport-parameter) for more details about transport /// parameters. /// /// If the size of the `data` exceeds the limit, the method will return an error. /// /// You can call [`DatagramWriter::max_datagram_frame_size`] to know the maximum size of the datagram frame you can /// send, read its documentation for more details. /// /// If the connection is closing or already closed, the method will also return an error. pub fn send_bytes(&self, data: Bytes) -> io::Result<()> { match self.writer.lock().unwrap().deref_mut() { Ok(writer) => { // Only consider the smallest encoding method: 1 byte if (1 + data.len()) > self.max_datagram_frame_size { tracing::error!(" Cause by: DatagramWriter::send_bytes"); return Err(io::Error::new( io::ErrorKind::InvalidInput, format!( "data size {} exceeds the limit {}", data.len(), self.max_datagram_frame_size ), )); } writer.tx_wakers.wake_all_by(Signals::TRANSPORT); writer.datagrams.push_back(data.clone()); Ok(()) } Err(e) => Err(io::Error::from(e.clone())), } } /// Send unreliable data to the peer. /// /// The `data` will not be sent immediately, and the `data` sent is not guaranteed to be delivered. /// /// The size of the datagram frame is limited by the `max_datagram_frame_size` transport parameter set by the peer. /// See [RFC](https://www.rfc-editor.org/rfc/rfc9221.html#name-transport-parameter) for more details about transport /// parameters. /// /// If the size of the `data` exceeds the limit, the method will return an error. /// /// You can call [`DatagramWriter::max_datagram_frame_size`] to know the maximum size of the datagram frame you can /// send, read its documentation for more details. /// /// If the connection is closing or already closed, the method will also return an error. pub fn send(&self, data: &[u8]) -> io::Result<()> { self.send_bytes(data.to_vec().into()) } /// Returns the maximum size of the datagram frame that can be sent to the peer. /// /// If the connection is closing or already closed, the method will return an error. /// /// The value is a transport parameter set by the peer, /// and you cant send a datagram frame whose size exceeds this value. /// /// Because of the encoding, the size of the data you can send is less than this value, usually 1 byte less. Although /// its possiable to send a datagram frame with the size of `max_datagram_frame_size` - 1, its hardly to happen. /// /// We recommend you to send unreliable data that the size is less or equal to `max_encoding_size` - `1` - `the size /// of the size of the data's length in varint form`. [varint] in definded in the QUIC RFC. /// /// Size 0 means the peer does not want to receive datagram frames, but it dont means the peer will not send datagram /// frames to you. /// /// [varint]: https://www.rfc-editor.org/rfc/rfc9000.html#integer-encoding pub fn max_datagram_frame_size(&self) -> io::Result { match self.writer.lock().unwrap().deref_mut() { Ok(..) => Ok(self.max_datagram_frame_size), Err(e) => Err(io::Error::from(e.clone())), } } } #[cfg(test)] mod tests { use qbase::{ error::{ErrorKind, QuicError}, frame::{ FrameType, PaddingFrame, io::{WriteDataFrame, WriteFrame}, }, }; use super::*; #[test] fn test_datagram_writer_with_length() { let outgoing = DatagramOutgoing::new(Default::default()); let writer = outgoing.new_writer(1024).unwrap(); let data = Bytes::from_static(b"hello world"); writer.send_bytes(data.clone()).unwrap(); let mut buffer = [0; 1024]; let expected_frame = DatagramFrame::new(true, VarInt::try_from(data.len()).unwrap()); assert_eq!( outgoing.try_read_datagram(&mut buffer), Some((expected_frame, 1 + 1 + data.len())) ); let mut expected_buffer = [0; 1024]; { let mut expected_buffer = &mut expected_buffer[..]; expected_buffer.put_data_frame(&expected_frame, &data); } assert_eq!(buffer, expected_buffer); } #[test] fn test_datagram_writer_without_length() { let outgoing = DatagramOutgoing::new(Default::default()); let writer = outgoing.new_writer(1024).unwrap(); let data = Bytes::from_static(b"hello world"); writer.send_bytes(data.clone()).unwrap(); let mut buffer = [0; 1024]; assert_eq!( outgoing.try_read_datagram(&mut buffer[0..12]), Some((DatagramFrame::new(false, VarInt::from_u32(11)), 12)) ); let mut expected_buffer = [0; 1024]; { let mut expected_buffer = &mut expected_buffer[..]; expected_buffer.put_data_frame(&DatagramFrame::new(false, VarInt::from_u32(12)), &data); } assert_eq!(buffer, expected_buffer); } #[test] fn test_datagram_writer_unwritten() { let outgoing = DatagramOutgoing::new(Default::default()); let writer = outgoing.new_writer(1024).unwrap(); let data = Bytes::from_static(b"hello world"); writer.send_bytes(data.clone()).unwrap(); let mut buffer = [0; 1024]; assert!(outgoing.try_read_datagram(&mut buffer[0..1]).is_none()); let expected_buffer = [0; 1024]; assert_eq!(buffer, expected_buffer); } #[test] fn test_datagram_writer_padding_first() { let outgoing = DatagramOutgoing::new(Default::default()); let writer = outgoing.new_writer(1024).unwrap(); // Will be encoded to 2 bytes let data = Bytes::from_static(&[b'a'; 2usize.pow(8 - 2)]); let data_len = VarInt::from_u32(data.len() as u32); writer.send_bytes(data.clone()).unwrap(); let mut buffer = [0; 1024]; assert_eq!( outgoing.try_read_datagram(&mut buffer[..data.len() + 2]), Some((DatagramFrame::new(false, data_len), data.len() + 2)) ); let mut expected_buffer = [0; 1024]; { let mut expected_buffer = &mut expected_buffer[..]; expected_buffer.put_frame(&PaddingFrame); expected_buffer.put_data_frame(&DatagramFrame::new(false, data_len), &data); } assert_eq!(buffer, expected_buffer); } #[test] fn test_datagram_writer_exceeds_limit() { let outgoing = DatagramOutgoing::new(Default::default()); assert!(outgoing.new_writer(0).is_err()); } #[test] fn test_datagram_writer_on_conn_error() { let outgoing = DatagramOutgoing::new(Default::default()); let writer = outgoing.new_writer(1024).unwrap(); outgoing.on_conn_error( &QuicError::new( ErrorKind::ProtocolViolation, FrameType::Datagram(0).into(), "test", ) .into(), ); let writer_guard = writer.writer.lock().unwrap(); assert!(writer_guard.as_ref().is_err()); } } ================================================ FILE: qevent/Cargo.toml ================================================ [package] name = "qevent" version = "0.5.0" edition.workspace = true description = "qlog implementation" readme.workspace = true repository.workspace = true license.workspace = true keywords.workspace = true categories.workspace = true rust-version.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] bytes = { workspace = true } enum_dispatch = { workspace = true } derive_builder = { workspace = true } derive_more = { workspace = true, features = ["from", "into", "display"] } serde = { workspace = true, features = ["derive"] } pin-project-lite = { workspace = true } qbase = { workspace = true } serde_json = { workspace = true } serde_with = { workspace = true, features = ["hex"] } tokio = { workspace = true, features = [ "fs", "rt", "sync", "io-std", "io-util", ] } tracing = { workspace = true } [dev-dependencies] tokio = { workspace = true, features = ["macros", "io-std"] } [features] telemetry = [] raw_data = [] ================================================ FILE: qevent/src/legacy/exporter.rs ================================================ use std::io; use tokio::{ io::{AsyncWrite, AsyncWriteExt}, sync::mpsc, }; use super::QlogFileSeq; use crate::{Event, telemetry::ExportEvent}; pub struct IoExpoter(mpsc::UnboundedSender); impl IoExpoter { pub fn new(qlog_file_seq: QlogFileSeq, mut output: O) -> Self where O: AsyncWrite + Unpin + Send + 'static, { let (tx, mut rx) = mpsc::unbounded_channel(); tokio::spawn(async move { let task = async { const RS: u8 = 0x1E; output.write_u8(RS).await?; let qlog_file_seq = serde_json::to_string(&qlog_file_seq).unwrap(); output.write_all(qlog_file_seq.as_bytes()).await?; output.write_u8(b'\n').await?; while let Some(event) = rx.recv().await { let event = match super::Event::try_from(event) { Ok(event) => serde_json::to_string(&event).unwrap(), Err(_unsuppert) => continue, }; output.write_u8(RS).await?; output.write_all(event.as_bytes()).await?; output.write_u8(b'\n').await?; } io::Result::Ok(()) }; if let Err(error) = task.await { tracing::error!( target: "qlog", ?error, ?qlog_file_seq, "Failed to write qlog, subsequent qlogs in this exporter will be ignored." ); } }); Self(tx) } } impl ExportEvent for IoExpoter { fn emit(&self, event: Event) { _ = self.0.send(event); } } #[cfg(test)] mod tests { use std::sync::Arc; use super::*; use crate::{ legacy::TraceSeq, quic::connectivity::ServerListening, telemetry::{Instrument, Span}, }; #[tokio::test] async fn io_exporter() { let exporter = IoExpoter::new( crate::build!(QlogFileSeq { title: "io exporter example", trace: TraceSeq {} }), tokio::io::stdout(), ); let meaningless_field = 112233u64; crate::span!(Arc::new(exporter), meaningless_field).in_scope(|| { crate::event!(ServerListening { ip_v4: "127.0.0.1".to_owned(), port_v4: 443u16 }); tokio::spawn( async move { assert_eq!(Span::current().load::("path_id"), "new path"); assert_eq!(Span::current().load::("meaningless_field"), 112233u64); // do something } .instrument(crate::span!(@current, path_id = String::from("new path"))), ); }); tokio::task::yield_now().await; } } ================================================ FILE: qevent/src/legacy/quic.rs ================================================ use std::collections::HashMap; use derive_builder::Builder; use derive_more::{From, Into}; use serde::{Deserialize, Serialize}; use serde_json::Value; use crate::{HexString, RawInfo}; #[serde_with::skip_serializing_none] #[derive(Default, Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct ConnectivityServerListening { ip_v4: Option, ip_v6: Option, port_v4: Option, port_v6: Option, /// the server will always answer client initials with a retry /// (no 1-RTT connection setups by choice) retry_required: Option, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct ConnectivityConnectionStarted { #[builder(default)] ip_version: Option, src_ip: IPAddress, dst_ip: IPAddress, /// transport layer protocol #[builder(default = "ConnectivityConnectionStarted::default_protocol()")] #[serde(default = "ConnectivityConnectionStarted::default_protocol")] protocol: String, #[builder(default)] src_port: Option, #[builder(default)] dst_port: Option, #[builder(default)] src_cid: Option, #[builder(default)] dst_cid: Option, } impl ConnectivityConnectionStarted { pub fn default_protocol() -> String { "QUIC".to_string() } } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct ConnectivityConnectionClosed { /// which side closed the connection #[builder(default)] owner: Option, #[builder(default)] connection_code: Option, #[builder(default)] application_code: Option, #[builder(default)] internal_code: Option, #[builder(default)] reason: Option, #[builder(default)] trigger: Option, } #[derive(Debug, Clone, Copy, From, Serialize, Deserialize, PartialEq)] #[serde(untagged)] pub enum ConnectionCode { TransportError(TransportError), CryptoError(CryptoError), Value(u32), } #[derive(Debug, Clone, From, Serialize, Deserialize, PartialEq)] #[serde(untagged)] pub enum ApplicationCode { ApplicationError(ApplicationError), Value(u32), } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum ConnectivityConnectionClosedTrigger { Clean, HandshakeTimeout, IdleTimeout, /// this is called the "immediate close" in the QUIC RFC Error, StatelessReset, VersionMismatch, /// for example HTTP/3's GOAWAY frame Application, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct ConnectivityConnectionIdUpdated { owner: Owner, old: Option, new: Option, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct ConnectivitySpinBitUpdated { state: bool, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct ConnectivityConnectionStateUpdated { #[builder(default)] old: Option, new: ConnectionState, } // SimpleConnectionState is a subset of this, so skip SimpleConnectionState #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum ConnectionState { /// initial sent/received Attempted, /// peer address validated by: client sent Handshake packet OR /// client used CONNID chosen by the server. /// transport-draft-32, section-8.1 PeerValidated, HandshakeStarted, /// 1 RTT can be sent, but handshake isn't done yet EarlyWrite, /// TLS handshake complete: Finished received and sent /// tls-draft-32, section-4.1.1 HandshakeComplete, /// HANDSHAKE_DONE sent/received (connection is now "active", 1RTT /// can be sent). tls-draft-32, section-4.1.2 HandshakeConfirmed, Closing, /// connection_close sent/received Draining, /// draining period done, connection state discarded Closed, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct SecurityKeyUpdated { key_type: KeyType, old: Option, new: HexString, /// needed for 1RTT key updates generation: Option, trigger: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum SecurityKeyUpdatedTrigger { /// (e.g., initial, handshake and 0-RTT keys /// are generated by TLS) Tls, RemoteUpdate, LocalUpdate, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct SecurityKeyRetired { key_type: KeyType, key: Option, /// needed for 1RTT key updates generation: Option, trigger: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum SecurityKeyRetiredTrigger { /// (e.g., initial, handshake and 0-RTT keys /// are generated by TLS) Tls, RemoteUpdate, LocalUpdate, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct TransportVersionInformation { #[serde(skip_serializing_if = "Vec::is_empty")] server_versions: Vec, #[serde(skip_serializing_if = "Vec::is_empty")] client_versions: Vec, chosen_version: Option, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct TransportALPNInformation { server_alpns: Option>, client_alpns: Option>, chosen_alpn: Option, } #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct TransportParametersSet { owner: Option, /// true if valid session ticket was received resumption_allowed: Option, /// true if early data extension was enabled on the TLS layer early_data_enabled: Option, /// e.g., "AES_128_GCM_SHA256" tls_cipher: Option, /// depends on the TLS cipher, but it's easier to be explicit. /// in bytes #[serde(default = "TransportParametersSet::default_aead_key_length")] #[builder(default = "TransportParametersSet::default_aead_key_length()")] aead_tag_length: u8, /// transport parameters from the TLS layer: original_destination_connection_id: Option, initial_source_connection_id: Option, retry_source_connection_id: Option, stateless_reset_token: Option, disable_active_migration: Option, max_idle_timeout: Option, max_udp_payload_size: Option, ack_delay_exponent: Option, max_ack_delay: Option, active_connection_id_limit: Option, initial_max_data: Option, initial_max_stream_data_bidi_local: Option, initial_max_stream_data_bidi_remote: Option, initial_max_stream_data_uni: Option, initial_max_streams_bidi: Option, initial_max_streams_uni: Option, preferred_address: Option, } impl TransportParametersSet { pub fn default_aead_key_length() -> u8 { 16 } } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct PreferredAddress { ip_v4: IPAddress, ip_v6: IPAddress, port_v4: u16, port_v6: u16, connection_id: ConnectionID, stateless_reset_token: Token, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct TransportParametersRestored { disable_active_migration: Option, max_idle_timeout: Option, max_udp_payload_size: Option, active_connection_id_limit: Option, initial_max_data: Option, initial_max_stream_data_bidi_local: Option, initial_max_stream_data_bidi_remote: Option, initial_max_stream_data_uni: Option, initial_max_streams_bidi: Option, initial_max_streams_uni: Option, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct TransportPacketSent { header: PacketHeader, /// see appendix for the QuicFrame definitions frames: Option>, #[serde(default)] #[builder(default)] is_coalesced: bool, /// only if header.packet_type === "retry" #[builder(default)] retry_token: Option, /// only if header.packet_type === "stateless_reset" /// is always 128 bits in length. #[builder(default)] stateless_reset_token: Option, /// only if header.packet_type === "version_negotiation" #[builder(default)] #[serde(skip_serializing_if = "Vec::is_empty")] supported_versions: Vec, #[builder(default)] raw: Option, #[builder(default)] datagram_id: Option, #[builder(default)] trigger: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum TransportPacketSentTrigger { RetransmitReordered, RetransmitTimeout, PtoProbe, RetransmitCrypto, CcBandwidthProbe, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct TransportPacketReceived { header: PacketHeader, /// see appendix for the definitions #[builder(default)] frames: Option>, #[serde(default)] #[builder(default)] is_coalesced: bool, /// only if header.packet_type === "retry" #[builder(default)] retry_token: Option, /// only if header.packet_type === "stateless_reset" #[builder(default)] /// Is always 128 bits in length. stateless_reset_token: Option, /// only if header.packet_type === "version_negotiation" #[builder(default)] #[serde(skip_serializing_if = "Vec::is_empty")] supported_versions: Vec, #[builder(default)] raw: Option, #[builder(default)] datagram_id: Option, #[builder(default)] trigger: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum TransportPacketReceivedTrigger { KeysAvailable, } #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct TransportPacketDropped { /// primarily packet_type should be filled here, /// as other fields might not be parseable header: Option, raw: Option, datagram_id: Option, trigger: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum TransportpacketDroppedTrigger { KeyUnavailable, UnknownConnectionId, HeaderParseError, PayloadDecryptError, ProtocolViolation, DosPrevention, UnsupportedVersion, UnexpectedPacket, UnexpectedSourceConnectionId, UnexpectedVersion, Duplicate, InvalidInitial, } #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct TransportPacketBuffered { /// primarily packet_type and possible packet_number should be /// filled here as other elements might not be available yet header: Option, raw: Option, datagram_id: Option, trigger: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum TransportPacketBufferedTrigger { /// indicates the parser cannot keep up, temporarily buffers /// packet for later processing Backpressure, /// if packet cannot be decrypted because the proper keys were /// not yet available KeysUnavailable, } #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct TransportPacketsAcked { packet_number_space: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] packet_numbers: Vec, } #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct TransportDatagramsSent { /// to support passing multiple at once count: Option, /// RawInfo:length field indicates total length of the datagrams /// including UDP header length #[serde(default, skip_serializing_if = "Vec::is_empty")] raw: Vec, #[serde(default, skip_serializing_if = "Vec::is_empty")] datagram_ids: Vec, } #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct TransportDatagramsReceived { /// to support passing multiple at once count: Option, /// RawInfo:length field indicates total length of the datagrams /// including UDP header length #[serde(default, skip_serializing_if = "Vec::is_empty")] raw: Vec, #[serde(default, skip_serializing_if = "Vec::is_empty")] datagram_ids: Vec, } #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct TransportDatagramDropped { raw: Option, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct TransportStreamStateUpdated { stream_id: u64, /// mainly useful when opening the stream #[builder(default)] stream_type: Option, #[builder(default)] old: Option, new: StreamState, #[builder(default)] stream_side: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum StreamState { Idle, Open, // bidirectional stream states, RFC 9000 Section 3.4. HalfClosedLocal, HalfClosedRemote, Closed, // sending-side stream states, RFC 9000 Section 3.1. Ready, Send, DataSent, ResetSent, ResetReceived, // receive-side stream states, RFC 9000 Section 3.2. Receive, SizeKnown, DataRead, ResetRead, // both-side states DataReceived, // qlog-defined: memory actually freed Destroyed, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum StreamType { Unidirectional, Bidirectional, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum StreamSide { Sending, Receiving, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct TransportFramesProcessed { /// see appendix for the QuicFrame definitions frames: Vec, #[builder(default)] packet_number: Option, } #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct TransportDataMoved { stream_id: Option, offset: Option, /// byte length of the moved data length: Option, from: Option, to: Option, /// raw bytes that were transferred data: Option, } #[derive(Debug, Clone, PartialEq, Eq)] pub enum StreamDataLocation { User, Application, Transport, Network, Other(String), } impl Serialize for StreamDataLocation { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { match self { StreamDataLocation::User => serializer.serialize_str("user"), StreamDataLocation::Application => serializer.serialize_str("application"), StreamDataLocation::Transport => serializer.serialize_str("transport"), StreamDataLocation::Network => serializer.serialize_str("network"), StreamDataLocation::Other(s) => serializer.serialize_str(s), } } } impl<'de> Deserialize<'de> for StreamDataLocation { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { match String::deserialize(deserializer)? { s if s == "user" => Ok(StreamDataLocation::User), s if s == "application" => Ok(StreamDataLocation::Application), s if s == "transport" => Ok(StreamDataLocation::Transport), s if s == "network" => Ok(StreamDataLocation::Network), s => Ok(StreamDataLocation::Other(s)), } } } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct RecoveryParametersSet { /// Loss detection, see recovery draft-23, Appendix A.2 /// in amount of packets #[builder(default)] reordering_threshold: Option, /// as RTT multiplier #[builder(default)] time_threshold: Option, /// in ms timer_granularity: u16, /// in ms #[builder(default)] initial_rtt: Option, /// congestion control, Appendix B.1. /// in bytes. Note: this, could be updated after pmtud #[builder(default)] max_datagram_size: Option, /// in bytes #[builder(default)] initial_congestion_window: Option, /// Note: this, could change when max_datagram_size changes /// in bytes #[builder(default)] minimum_congestion_window: Option, #[builder(default)] loss_reduction_factor: Option, /// as PTO multiplier #[builder(default)] persistent_congestion_threshold: Option, /// Additionally, this event can contain any number of unspecified fields /// to support different recovery approaches. #[builder(default)] #[serde(flatten, skip_serializing_if = "HashMap::is_empty")] custom_fields: HashMap, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct RecoveryMetricsUpdated { /// Loss detection, see recovery draft-23, Appendix A.3 /// all following rtt fields are expressed in ms #[builder(default)] min_rtt: Option, #[builder(default)] smoothed_rtt: Option, #[builder(default)] latest_rtt: Option, #[builder(default)] rtt_variance: Option, #[builder(default)] pto_count: Option, /// Congestion control, Appendix B.2. /// in bytes #[builder(default)] congestion_window: Option, #[builder(default)] bytes_in_flight: Option, /// in bytes #[builder(default)] ssthresh: Option, /// qlog defined /// sum of all packet number spaces #[builder(default)] packets_in_flight: Option, /// in bits per second #[builder(default)] pacing_rate: Option, /// Additionally, this event can contain any number of unspecified fields /// to support different recovery approaches. #[builder(default)] #[serde(flatten)] #[serde(skip_serializing_if = "HashMap::is_empty")] custom_fields: HashMap, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct RecoveryCongestionStateUpdated { #[builder(default)] old: Option, new: String, #[builder(default)] trigger: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum RecoveryCongestionStateUpdatedTrigger { PersistentCongestion, Ecn, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct RecoveryLossTimerUpdated { /// called "mode" in draft-23 A.9. #[builder(default)] timer_type: Option, #[builder(default)] packet_number_space: Option, event_type: LossTimerEventType, /// if event_type === "set": delta, time is in ms from /// this event's timestamp until when the timer will trigger #[builder(default)] delta: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum LossTimerType { Ack, Pto, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum LossTimerEventType { Set, Expired, Cancelled, } #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct RecoveryPacketLost { /// should include at least the packet_type and packet_number header: Option, /// not all implementations will keep track of full /// packets, so these are optional /// see appendix for the QuicFrame definitions frames: Option>, trigger: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum RecoveryPacketLostTrigger { ReorderingThreshold, TimeThreshold, /// draft-23 section 5.3.1, MAY PtoExpired, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct RecoveryMarkedForRetransmit { /// see appendix for the QuicFrame definitions frames: Vec, } // A.1: skip // A.2 #[derive(Debug, Clone, From, Into, Serialize, Deserialize, PartialEq, Eq)] #[serde(transparent)] pub struct QuicVersion(HexString); #[derive(Debug, Clone, From, Into, Serialize, Deserialize, PartialEq, Eq)] #[serde(transparent)] pub struct ConnectionID(HexString); #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum Owner { Local, Remote, } // A.5 #[derive(Debug, Clone, From, Into, Serialize, Deserialize, PartialEq, Eq)] #[serde(transparent)] pub struct IPAddress(String); #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum IPVersion { V4, V6, } // A.6 #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum PacketType { Initial, Retry, Handshake, #[serde(rename = "0RTT")] ZeroRTT, #[serde(rename = "1RTT")] OneRTT, StatelessReset, VersionNegotiation, Unknown, } // A.7 #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum PacketNumberSpace { Initial, Handshake, ApplicationData, } // A.8 #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct PacketHeader { packet_type: PacketType, // In rfc https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-quic-events-02#name-packetheader, this field mut be present. // But in fact, for packet type Retry and VN and for packet dropped before pn decoded, this field is not exist. // In the updated RFC this field is optional, so here we simply mark it as optional as well. #[builder(default)] packet_number: Option, /// the bit flags of the packet headers (spin bit, key update bit, /// etc. up to and including the packet number length bits /// if present #[builder(default)] flags: Option, /// only if packet_type === "initial" #[builder(default)] token: Option, /// only if packet_type === "initial" || "handshake" || "0RTT" /// Signifies length of the packet_number plus the payload #[builder(default)] length: Option, /// only if present in the header /// if correctly using transport:connection_id_updated events, /// dcid can be skipped for 1RTT packets #[builder(default)] version: Option, #[builder(default)] scil: Option, #[builder(default)] dcil: Option, #[builder(default)] scid: Option, #[builder(default)] dcid: Option, } // A.9 #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct Token { r#type: Option, /// byte length of the token length: Option, /// raw byte value of the token data: Option, /// decoded fields included in the token /// (typically: peer,'s IP address, creation time) #[builder(default)] #[serde(default, skip_serializing_if = "HashMap::is_empty")] details: HashMap, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum TokenType { Retry, Resumption, StatelessReset, } // A.10 #[allow(clippy::enum_variant_names)] #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum KeyType { ServerInitialSecret, ClientInitialSecret, ServerHandshakeSecret, ClientHandshakeSecret, #[serde(rename = "server_0rtt_secret")] Server0RTTSecret, #[serde(rename = "client_0rtt_secret")] Client0RTTSecret, #[serde(rename = "server_1rtt_secret")] Server1RTTSecret, #[serde(rename = "client_1rtt_secret")] Client1RTTSecret, } #[derive(Debug, Clone, Serialize, From, Deserialize, PartialEq, Eq)] #[serde(untagged)] pub enum ConnectionCloseTriggerFrameType { Id(u64), Text(String), } #[derive(Debug, Clone, From, Serialize, Deserialize, PartialEq, Eq)] #[serde(untagged)] pub enum ConnectionCloseErrorCode { TransportError(TransportError), ApplicationError(ApplicationError), Value(u64), } // A.11# #[serde_with::skip_serializing_none] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(tag = "frame_type")] #[serde(rename_all = "snake_case")] pub enum QuicFrame { Padding { length: Option, payload_length: u32, }, Ping { length: Option, payload_length: Option, }, Ack { ack_delay: Option, acked_ranges: Vec<[u64; 2]>, ect1: Option, ect0: Option, ce: Option, length: Option, payload_length: Option, }, ResetStream { stream_id: u64, error_code: ApplicationCode, final_size: u64, length: Option, payload_length: Option, }, StopSending { stream_id: u64, error_code: ApplicationCode, length: Option, payload_length: Option, }, Crypto { offset: u64, length: u64, payload_length: Option, }, NewToken { token: Token, }, Stream { stream_id: u64, offset: u64, length: u64, #[serde(default)] fin: bool, raw: Option, }, MaxData { maximum: u64, }, MaxStreamData { stream_id: u64, maximum: u64, }, MaxStreams { stream_type: StreamType, maximum: u64, }, DataBlocked { limit: u64, }, StreamDataBlocked { stream_id: u64, limit: u64, }, StreamsBlocked { stream_type: StreamType, limit: u64, }, NewConnectionId { sequence_number: u32, retire_prior_to: u32, connection_id_length: Option, connection_id: ConnectionID, stateless_reset_token: Option, }, RetireConnectionId { sequence_number: u32, }, PathChallenge { data: Option, }, PathResponse { data: Option, }, ConnectionClose { error_space: Option, error_code: Option, raw_error_code: Option, reason: Option, trigger_frame_type: Option, }, HandshakeDone {}, Unknown { raw_frame_type: u64, raw_length: Option, raw: Option, }, // not in v1 Datagram { length: Option, raw: Option, }, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)] pub enum ConnectionCloseErrorSpace { Transport, Application, } // A.11.22 #[derive(Debug, Clone, Copy, From, Serialize, Deserialize, PartialEq, Eq)] pub enum TransportError { NoError, InternalError, ConnectionRefused, FlowControlError, StreamLimitError, StreamStateError, FinalSizeError, FrameEncodingError, TransportParameterError, ConnectionIdLimitError, ProtocolViolation, InvalidToken, ApplicationError, CryptoBufferExceeded, // not in v1 KeyUpdateError, AeadLimitReached, NoViablePath, } // A.11.23 #[derive(Debug, Clone, From, Serialize, Deserialize, PartialEq, Eq)] pub struct ApplicationError(String); // A.11.24 #[derive(Debug, Clone, Copy, From, PartialEq)] pub struct CryptoError(u8); impl Serialize for CryptoError { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { serializer.serialize_str(&format!("crypto_error_0x1{:02x}", self.0)) } } impl<'de> Deserialize<'de> for CryptoError { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { let string = String::deserialize(deserializer)?; string.strip_prefix("crypto_error_0x1").map_or_else( || Err(serde::de::Error::custom("invalid crypto error")), |s| { u8::from_str_radix(s, 16) .map(CryptoError) .map_err(serde::de::Error::custom) }, ) } } crate::gen_builder_method! { ConnectivityServerListeningBuilder => ConnectivityServerListening; ConnectivityConnectionStartedBuilder => ConnectivityConnectionStarted; ConnectivityConnectionClosedBuilder => ConnectivityConnectionClosed; ConnectivityConnectionIdUpdatedBuilder => ConnectivityConnectionIdUpdated; ConnectivitySpinBitUpdatedBuilder => ConnectivitySpinBitUpdated; ConnectivityConnectionStateUpdatedBuilder => ConnectivityConnectionStateUpdated; SecurityKeyUpdatedBuilder => SecurityKeyUpdated; SecurityKeyRetiredBuilder => SecurityKeyRetired; TransportVersionInformationBuilder => TransportVersionInformation; TransportALPNInformationBuilder => TransportALPNInformation; TransportParametersSetBuilder => TransportParametersSet; PreferredAddressBuilder => PreferredAddress; TransportParametersRestoredBuilder => TransportParametersRestored; TransportPacketSentBuilder => TransportPacketSent; TransportPacketReceivedBuilder => TransportPacketReceived; TransportPacketDroppedBuilder => TransportPacketDropped; TransportPacketBufferedBuilder => TransportPacketBuffered; TransportPacketsAckedBuilder => TransportPacketsAcked; TransportDatagramsSentBuilder => TransportDatagramsSent; TransportDatagramsReceivedBuilder => TransportDatagramsReceived; TransportDatagramDroppedBuilder => TransportDatagramDropped; TransportStreamStateUpdatedBuilder => TransportStreamStateUpdated; TransportFramesProcessedBuilder => TransportFramesProcessed; TransportDataMovedBuilder => TransportDataMoved; RecoveryParametersSetBuilder => RecoveryParametersSet; RecoveryMetricsUpdatedBuilder => RecoveryMetricsUpdated; RecoveryCongestionStateUpdatedBuilder => RecoveryCongestionStateUpdated; RecoveryLossTimerUpdatedBuilder => RecoveryLossTimerUpdated; RecoveryPacketLostBuilder => RecoveryPacketLost; RecoveryMarkedForRetransmitBuilder => RecoveryMarkedForRetransmit; PacketHeaderBuilder => PacketHeader; TokenBuilder => Token; } ================================================ FILE: qevent/src/legacy.rs ================================================ pub mod exporter; pub mod quic; use std::collections::HashMap; use derive_builder::Builder; use derive_more::{From, Into}; use serde::{Deserialize, Serialize}; use serde_json::Value; use crate::{GroupID, VantagePoint}; pub const QLOG_VERSION: &str = "0.3"; #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct QlogFile { qlog_version: String, #[builder(default = "QlogFileSeq::default_format()")] #[serde(default = "QlogFileSeq::default_format")] qlog_format: String, title: Option, description: Option, #[builder(default)] #[serde(default, skip_serializing_if = "HashMap::is_empty")] summary: HashMap, #[builder(default)] #[serde(default, skip_serializing_if = "Vec::is_empty")] traces: Vec, } impl QlogFile { pub fn default_format() -> String { "JSON".to_string() } } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct QlogFileSeq { #[builder(default = "QlogFileSeq::default_qlog_version()")] #[serde(default = "QlogFileSeq::default_qlog_version")] qlog_version: String, #[builder(default = "QlogFileSeq::default_format()")] #[serde(default = "QlogFileSeq::default_format")] qlog_format: String, #[builder(default)] title: Option, #[builder(default)] description: Option, #[builder(default)] #[serde(default, skip_serializing_if = "HashMap::is_empty")] summary: HashMap, trace: TraceSeq, } impl QlogFileSeq { pub fn default_qlog_version() -> String { QLOG_VERSION.to_string() } pub fn default_format() -> String { "JSON-SEQ".to_string() } } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, From, Into, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct Summary { #[builder(default)] #[serde(default, skip_serializing_if = "HashMap::is_empty")] custom_fields: HashMap, } #[allow(clippy::large_enum_variant)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(untagged)] pub enum Traces { TraceError(TraceError), Trace(Trace), } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct TraceError { error_description: String, /// the original URI at which we attempted to find the file uri: Option, vantage_point: Option, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct Trace { title: Option, description: Option, configuration: Option, common_fields: Option, vantage_point: Option, events: Vec, } #[serde_with::skip_serializing_none] #[derive(Default, Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct TraceSeq { title: Option, description: Option, configuration: Option, common_fields: Option, vantage_point: Option, } #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct Configuration { /// time_offset is in milliseconds time_offset: f64, original_uris: Vec, #[builder(default)] #[serde(flatten, default, skip_serializing_if = "HashMap::is_empty")] custom_fields: HashMap, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct Event { time: f64, #[serde(flatten)] data: EventData, #[builder(default)] time_format: Option, #[builder(default)] protocol_type: Option, #[builder(default)] group_id: Option, /// events can contain any amount of custom fields #[builder(default)] #[serde(flatten, default, skip_serializing_if = "HashMap::is_empty")] custom_fields: HashMap, } #[derive(Debug, Clone, Copy, From, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum TimeFormat { Relative, Delta, Absolute, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(tag = "name", content = "data")] #[serde(rename_all = "snake_case")] pub enum EventData { // Connectivity #[serde(rename = "connectivity:server_listening")] ServerListening(quic::ConnectivityServerListening), #[serde(rename = "connectivity:connection_started")] ConnectionStarted(quic::ConnectivityConnectionStarted), #[serde(rename = "connectivity:connection_closed")] ConnectionClosed(quic::ConnectivityConnectionClosed), #[serde(rename = "connectivity:connection_id_updated")] ConnectionIdUpdated(quic::ConnectivityConnectionIdUpdated), #[serde(rename = "connectivity:spin_bit_updated")] SpinBitUpdated(quic::ConnectivitySpinBitUpdated), #[serde(rename = "connectivity:connection_state_updated")] ConnectionStateUpdated(quic::ConnectivityConnectionStateUpdated), // Security #[serde(rename = "security:key_updated")] KeyUpdated(quic::SecurityKeyUpdated), #[serde(rename = "security:key_retired")] KeyDiscarded(quic::SecurityKeyRetired), // Transport #[serde(rename = "transport:version_information")] VersionInformation(quic::TransportVersionInformation), #[serde(rename = "transport:alpn_information")] AlpnInformation(quic::TransportALPNInformation), #[serde(rename = "transport:parameters_set")] TransportParametersSet(quic::TransportParametersSet), #[serde(rename = "transport:parameters_restored")] TransportParametersRestored(quic::TransportParametersRestored), #[serde(rename = "transport:datagrams_received")] DatagramsReceived(quic::TransportDatagramsReceived), #[serde(rename = "transport:datagrams_sent")] DatagramsSent(quic::TransportDatagramsSent), #[serde(rename = "transport:datagram_dropped")] DatagramDropped(quic::TransportDatagramDropped), #[serde(rename = "transport:packet_received")] PacketReceived(quic::TransportPacketReceived), #[serde(rename = "transport:packet_sent")] PacketSent(quic::TransportPacketSent), #[serde(rename = "transport:packet_dropped")] PacketDropped(quic::TransportPacketDropped), #[serde(rename = "transport:packet_buffered")] PacketBuffered(quic::TransportPacketBuffered), #[serde(rename = "transport:packets_acked")] PacketsAcked(quic::TransportPacketsAcked), #[serde(rename = "transport:stream_state_updated")] StreamStateUpdated(quic::TransportStreamStateUpdated), #[serde(rename = "transport:frames_processed")] FramesProcessed(quic::TransportFramesProcessed), #[serde(rename = "transport:data_moved")] DataMoved(quic::TransportDataMoved), // Recovery #[serde(rename = "recovery:parameters_set")] RecoveryParametersSet(quic::RecoveryParametersSet), #[serde(rename = "recovery:metrics_updated")] MetricsUpdated(quic::RecoveryMetricsUpdated), #[serde(rename = "recovery:congestion_state_updated")] CongestionStateUpdated(quic::RecoveryCongestionStateUpdated), #[serde(rename = "recovery:loss_timer_updated")] LossTimerUpdated(quic::RecoveryLossTimerUpdated), #[serde(rename = "recovery:packet_lost")] PacketLost(quic::RecoveryPacketLost), #[serde(rename = "recovery:marked_for_retransmit")] MarkedForRetransmit(quic::RecoveryMarkedForRetransmit), #[serde(rename = "generic:error")] GenericError(GenericError), #[serde(rename = "generic:warning")] GenericWarning(GenericWarning), #[serde(rename = "generic:info")] GenericInfo(GenericInfo), #[serde(rename = "generic:debug")] GenericDebug(GenericDebug), #[serde(rename = "generic:verbose")] GenericVerbose(GenericVerbose), #[serde(rename = "simulation:scenario")] SimulationScenario(SimulationScenario), #[serde(rename = "simulation:marker")] SimulationMarker(SimulationMarker), } #[derive(Default, Debug, Clone, From, Into, Serialize, Deserialize, PartialEq)] #[serde(transparent)] pub struct ProtocolType(Vec); #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct CommonFields { time_format: Option, reference_time: Option, protocol_type: Option, group_id: Option, custom_fields: HashMap, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct GenericError { code: Option, message: Option, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct GenericWarning { code: Option, message: Option, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct GenericInfo { message: String, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct GenericDebug { message: String, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct GenericVerbose { message: String, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct SimulationScenario { name: Option, #[builder(default)] #[serde(default, skip_serializing_if = "HashMap::is_empty")] details: HashMap, } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct SimulationMarker { r#type: Option, message: Option, } crate::gen_builder_method! { QlogFileBuilder => QlogFile; QlogFileSeqBuilder => QlogFileSeq; SummaryBuilder => Summary; TraceErrorBuilder => TraceError; TraceBuilder => Trace; TraceSeqBuilder => TraceSeq; ConfigurationBuilder => Configuration; EventBuilder => Event; GenericErrorBuilder => GenericError; GenericWarningBuilder => GenericWarning; GenericInfoBuilder => GenericInfo; GenericDebugBuilder => GenericDebug; GenericVerboseBuilder => GenericVerbose; SimulationScenarioBuilder => SimulationScenario; SimulationMarkerBuilder => SimulationMarker; } ================================================ FILE: qevent/src/lib.rs ================================================ pub mod legacy; pub mod loglevel; pub mod quic; pub mod telemetry; #[doc(hidden)] pub mod macro_support; mod macros; pub mod packet; use std::{collections::HashMap, fmt::Display, net::SocketAddr}; use bytes::Bytes; use derive_builder::Builder; use derive_more::{Display, From, Into}; use qbase::{cid::ConnectionId, role::Role, util::ContinuousData}; use quic::ConnectionID; use serde::{Deserialize, Serialize}; #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct LogFile { file_schema: String, serialization_format: String, #[builder(default)] title: Option, #[builder(default)] description: Option, #[builder(default)] event_schemas: Vec, } #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into), build_fn(private, name = "fallible_build"))] pub struct QlogFile { #[serde(flatten)] log_file: LogFile, traces: Vec, } /// A qlog file using the QlogFileSeq schema can be serialized to a /// streamable JSON format called JSON Text Sequences (JSON-SEQ) /// ([RFC7464]). The top-level element in this schema defines only a /// small set of "header" fields and an array of component traces. /// /// [RFC7464]: https://www.rfc-editor.org/rfc/rfc7464 #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into), build_fn(private, name = "fallible_build"))] pub struct QlogFileSeq { #[serde(flatten)] log_file: LogFile, trace_seq: TraceSeq, } impl QlogFileSeq { pub const SCHEMA: &'static str = "urn:ietf:params:qlog:file:sequential"; } #[allow(clippy::large_enum_variant)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(untagged)] pub enum Traces { Trace(Trace), TraceError(TraceError), } /// The exact conceptual definition of a Trace can be fluid. For /// example, a trace could contain all events for a single connection, /// for a single endpoint, for a single measurement interval, for a /// single protocol, etc. In the normal use case however, a trace is a /// log of a single data flow collected at a single location or vantage /// point. For example, for QUIC, a single trace only contains events /// for a single logical QUIC connection for either the client or the /// server. #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct Trace { /// The optional "title" fields provide additional free-text information about the trace. #[builder(default)] title: Option, /// The optional "description" fields provide additional free-text information about the trace. #[builder(default)] description: Option, #[builder(default)] common_fields: Option, #[builder(default)] vantage_point: Option, events: Vec, } /// TraceSeq is used with QlogFileSeq. It is conceptually similar to a /// Trace, with the exception that qlog events are not contained within /// it, but rather appended after it in a QlogFileSeq. #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct TraceSeq { /// The optional "title" fields provide additional free-text information about the trace. title: Option, /// The optional "description" fields provide additional free-text information about the trace. description: Option, common_fields: Option, vantage_point: Option, } #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct CommonFields { path: PathID, time_format: TimeFormat, reference_time: ReferenceTime, protocol_types: ProtocolTypeList, group_id: GroupID, #[builder(default)] #[serde(flatten)] #[serde(skip_serializing_if = "HashMap::is_empty")] // * text => any custom_fields: HashMap, } /// A VantagePoint describes the vantage point from which a trace originates #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct VantagePoint { #[builder(default)] name: Option, r#type: VantagePointType, #[builder(default)] flow: Option, } #[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum VantagePointType { /// endpoint which initiates the connection Client, /// endpoint which accepts the connection Server, /// observer in between client and server Network, #[default] Unknow, } impl From for VantagePointType { fn from(role: Role) -> Self { match role { Role::Client => VantagePointType::Client, Role::Server => VantagePointType::Server, } } } impl Display for VantagePointType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { VantagePointType::Client => write!(f, "client"), VantagePointType::Server => write!(f, "server"), VantagePointType::Network => write!(f, "network"), VantagePointType::Unknow => write!(f, "unknow"), } } } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct TraceError { error_description: String, #[builder(default)] uri: Option, #[builder(default)] vantage_point: Option, } /// Events are logged at a time instant and convey specific details of the logging use case. /// /// Events can contain any amount of custom fields. #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct Event { time: f64, #[serde(flatten)] data: EventData, /// A qlog event can be associated with a single "network path" (usually, but not always, identified by a 4-tuple /// of IP addresses and ports). In many cases, the path will be the same for all events in a given trace, and does /// not need to be logged explicitly with each event. In this case, the "path" field can be omitted (in which case /// the default value of "" is assumed) or reflected in "common_fields" instead #[builder(default)] path: Option, #[builder(default)] time_format: Option, #[builder(default)] protocol_types: Option, #[builder(default)] group_id: Option, #[builder(default)] system_info: Option, /// events can contain any amount of custom fields #[builder(default)] #[serde(flatten)] #[serde(skip_serializing_if = "HashMap::is_empty")] // * text => any custom_fields: HashMap, } #[derive(Debug, Clone, From, Into, Serialize, Deserialize, PartialEq, Eq)] #[serde(transparent)] pub struct PathID(String); #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] #[serde(try_from = "UncheckedReferenceTime")] pub struct ReferenceTime { /// The required "clock_type" field represents the type of clock used for time measurements. The value "system" /// represents a clock that uses system time, commonly measured against a chosen or well-known epoch. However, /// depending on the system, System time can potentially jump forward or back. In contrast, a clock using monotonic /// time is generally guaranteed to never go backwards. The value "monotonic" represents such a clock. clock_type: TimeClockType, /// The required "epoch" field is the start of the ReferenceTime. When using the "system" clock type, the epoch field /// **SHOULD** have a date/time value using the format defined in [RFC3339]. However, the value "unknown" **MAY** be /// used /// /// [RFC3339]: https://www.rfc-editor.org/rfc/rfc3339 #[serde(default)] epoch: TimeEpoch, /// The optional "wall_clock_time" field can be used to provide an approximate date/time value that logging commenced /// at if the epoch value is "unknown". It uses the format defined in [RFC3339]. Note that conversion of timestamps /// to calendar time based on wall clock times cannot be safely relied on. /// /// [RFC3339]: https://www.rfc-editor.org/rfc/rfc3339 #[builder(default)] wall_clock_time: Option, } /// Intermediate data types during deserialization #[derive(Deserialize)] struct UncheckedReferenceTime { clock_type: TimeClockType, #[serde(default)] epoch: TimeEpoch, wall_clock_time: Option, } impl TryFrom for ReferenceTime { type Error = &'static str; fn try_from(value: UncheckedReferenceTime) -> Result { if value.clock_type == TimeClockType::Monotaonic && value.epoch != TimeEpoch::Unknow { return Err( r#"When using the "monotonic" clock type, the epoch field MUST have the value "unknown"."#, ); } Ok(ReferenceTime { clock_type: value.clock_type, epoch: value.epoch, wall_clock_time: value.wall_clock_time, }) } } #[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum TimeClockType { /// The value "system" represents a clock that uses system time, commonly measured against a chosen or well-known /// epoch #[default] System, /// A clock using monotonic time is generally guaranteed to never go backwards. The value "monotonic" represents /// such a clock. /// /// When using the "monotonic" clock type, the epoch field MUST have the value "unknown". Monotaonic, #[serde(untagged)] Custom(String), } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub enum TimeEpoch { Unknow, #[serde(untagged)] RFC3339DateTime(RFC3339DateTime), } impl Default for TimeEpoch { fn default() -> Self { Self::RFC3339DateTime(Default::default()) } } #[derive(Debug, Clone, From, Into, Serialize, Deserialize, PartialEq, Eq)] #[serde(transparent)] pub struct RFC3339DateTime(String); impl Default for RFC3339DateTime { fn default() -> Self { Self("1970-01-01T00:00:00.000Z".to_owned()) } } #[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum TimeFormat { /// A duration relative to the ReferenceTime "epoch" field. This approach uses the largest amount of characters. /// It is good for stateless loggers. This is the default value of the "time_format" field. #[default] RelativeToEpoch, /// A delta-encoded value, based on the previously logged value. The first event in a trace is always relative to /// the ReferenceTime. This approach uses the least amount of characters. It is suitable for stateful loggers. RelativeToPreviousEvent, } #[derive(Debug, Clone, From, Into, Serialize, Deserialize, PartialEq, Eq)] #[serde(transparent)] pub struct ProtocolTypeList(Vec); #[derive(Debug, Clone, From, Into, Serialize, Deserialize, PartialEq, Eq)] #[serde(transparent)] pub struct ProtocolType(String); impl ProtocolType { pub fn quic() -> ProtocolType { ProtocolType("QUIC".to_owned()) } pub fn http3() -> ProtocolType { ProtocolType("HTTP/3".to_owned()) } } #[derive(Debug, Display, Clone, From, Into, Serialize, Deserialize, PartialEq, Eq)] #[serde(transparent)] pub struct GroupID(String); impl From for GroupID { fn from(value: ConnectionId) -> Self { Self(format!("{value:x}")) } } impl From for GroupID { fn from(value: ConnectionID) -> Self { Self(format!("{value:x}")) } } impl From<(SocketAddr, SocketAddr)> for GroupID { fn from(_value: (SocketAddr, SocketAddr)) -> Self { todo!() } } /// The "system_info" field can be used to record system-specific details related to an event. This is useful, for instance, /// where an application splits work across CPUs, processes, or threads and events for a single trace occur on potentially /// different combinations thereof. Each field is optional to support deployment diversity. #[derive(Builder, Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde_with::skip_serializing_none] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct SystemInformation { processor_id: Option, process_id: Option, thread_id: Option, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum EventImportance { Core = 1, Base = 2, Extra = 3, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(tag = "name", content = "data")] #[enum_dispatch::enum_dispatch(BeEventData)] pub enum EventData { #[serde(rename = "quic:server_listening")] ServerListening(quic::connectivity::ServerListening), #[serde(rename = "quic:connection_started")] ConnectionStarted(quic::connectivity::ConnectionStarted), #[serde(rename = "quic:connection_closed")] ConnectionClosed(quic::connectivity::ConnectionClosed), #[serde(rename = "quic:connection_id_updated")] ConnectionIdUpdated(quic::connectivity::ConnectionIdUpdated), #[serde(rename = "quic:spin_bit_updated")] SpinBitUpdated(quic::connectivity::SpinBitUpdated), #[serde(rename = "quic:connection_state_updated")] ConnectionStateUpdated(quic::connectivity::ConnectionStateUpdated), #[serde(rename = "quic:path_assigned")] PathAssigned(quic::connectivity::PathAssigned), #[serde(rename = "quic:mtu_updated")] MtuUpdated(quic::connectivity::MtuUpdated), #[serde(rename = "quic:version_information")] VersionInformation(quic::transport::VersionInformation), #[serde(rename = "quic:alpn_information")] ALPNInformation(quic::transport::ALPNInformation), #[serde(rename = "quic:parameters_set")] ParametersSet(quic::transport::ParametersSet), #[serde(rename = "quic:parameters_restored")] ParametersRestored(quic::transport::ParametersRestored), #[serde(rename = "quic:packet_sent")] PacketSent(quic::transport::PacketSent), #[serde(rename = "quic:packet_received")] PacketReceived(quic::transport::PacketReceived), #[serde(rename = "quic:packet_dropped")] PacketDropped(quic::transport::PacketDropped), #[serde(rename = "quic:packet_buffered")] PacketBuffered(quic::transport::PacketBuffered), #[serde(rename = "quic:packets_acked")] PacketsAcked(quic::transport::PacketsAcked), #[serde(rename = "quic:udp_datagrams_sent")] UdpDatagramSent(quic::transport::UdpDatagramsSent), #[serde(rename = "quic:udp_datagrams_received")] UdpDatagramReceived(quic::transport::UdpDatagramsReceived), #[serde(rename = "quic:udp_datagram_dropped")] UdpDatagramDropped(quic::transport::UdpDatagramDropped), #[serde(rename = "quic:stream_state_updated")] StreamStateUpdated(quic::transport::StreamStateUpdated), #[serde(rename = "quic:frames_processed")] FramesProcessed(quic::transport::FramesProcessed), #[serde(rename = "quic:stream_data_moved")] StreamDataMoved(quic::transport::StreamDataMoved), #[serde(rename = "quic:datagram_data_moved")] DatagramDataMoved(quic::transport::DatagramDataMoved), #[serde(rename = "quic:migration_state_updated")] MigrationStateUpdated(quic::transport::MigrationStateUpdated), #[serde(rename = "quic:key_updated")] KeyUpdated(quic::security::KeyUpdated), #[serde(rename = "quic:key_discarded")] KeyDiscarded(quic::security::KeyDiscarded), #[serde(rename = "quic:recovery_parameters_set")] RecoveryParametersSet(quic::recovery::RecoveryParametersSet), #[serde(rename = "quic:recovery_metrics_updated")] RecoveryMetricsUpdated(quic::recovery::RecoveryMetricsUpdated), #[serde(rename = "quic:congestion_state_updated")] CongestionStateUpdated(quic::recovery::CongestionStateUpdated), #[serde(rename = "quic:loss_timer_updated")] LossTimerUpdated(quic::recovery::LossTimerUpdated), #[serde(rename = "quic:packet_lost")] PacketLost(quic::recovery::PacketLost), #[serde(rename = "quic:marked_for_retransmit")] MarkedForRetransmit(quic::recovery::MarkedForRetransmit), #[serde(rename = "quic:ecn_state_updated")] ECNStateUpdated(quic::recovery::ECNStateUpdated), #[serde(rename = "loglevel:error")] Error(loglevel::Error), #[serde(rename = "loglevel:warning")] Warning(loglevel::Warning), #[serde(rename = "loglevel:info")] Info(loglevel::Info), #[serde(rename = "loglevel:debug")] Debug(loglevel::Debug), #[serde(rename = "loglevel:verbose")] Verbose(loglevel::Verbose), } pub trait BeSpecificEventData { fn scheme() -> &'static str; fn importance() -> EventImportance; } #[enum_dispatch::enum_dispatch] pub trait BeEventData { fn scheme(&self) -> &'static str; fn importance(&self) -> EventImportance; } impl BeEventData for S { #[inline] fn scheme(&self) -> &'static str { S::scheme() } #[inline] fn importance(&self) -> EventImportance { S::importance() } } macro_rules! imp_be_events { ( $($importance:ident $event:ty => $prefix:ident $schme:literal ;)* ) => { $( imp_be_events!{@impl_one $importance $event => $prefix $schme ; } )* }; (@impl_one $importance:ident $event:ty => urn $schme:literal ; ) => { impl BeSpecificEventData for $event { fn scheme() -> &'static str { concat!["urn:ietf:params:qlog:events:",$schme] } fn importance() -> EventImportance { EventImportance::$importance } } }; } imp_be_events! { Extra quic::connectivity::ServerListening => urn "quic:server_listening"; Base quic::connectivity::ConnectionStarted => urn "quic:connection_started"; Base quic::connectivity::ConnectionClosed => urn "quic:connection_closed"; Base quic::connectivity::ConnectionIdUpdated => urn "quic:connection_id_updated"; Base quic::connectivity::SpinBitUpdated => urn "quic:spin_bit_updated"; Base quic::connectivity::ConnectionStateUpdated => urn "quic:connection_state_updated"; Base quic::connectivity::PathAssigned => urn "quic:path_assigned"; Extra quic::connectivity::MtuUpdated => urn "quic:mtu_updated"; Core quic::transport::VersionInformation => urn "quic:version_information"; Core quic::transport::ALPNInformation => urn "quic:alpn_information"; Core quic::transport::ParametersSet => urn "quic:parameters_set"; Base quic::transport::ParametersRestored => urn "quic:parameters_restored"; Core quic::transport::PacketSent => urn "quic:packet_sent"; Core quic::transport::PacketReceived => urn "quic:packet_received"; Base quic::transport::PacketDropped => urn "quic:packet_dropped"; Base quic::transport::PacketBuffered => urn "quic:packet_buffered"; Extra quic::transport::PacketsAcked => urn "quic:packets_acked"; Extra quic::transport::UdpDatagramsSent => urn "quic:udp_datagrams_sent"; Extra quic::transport::UdpDatagramsReceived => urn "quic:udp_datagrams_received"; Extra quic::transport::UdpDatagramDropped => urn "quic:udp_datagram_dropped"; Base quic::transport::StreamStateUpdated => urn "quic:stream_state_updated"; Extra quic::transport::FramesProcessed => urn "quic:frames_processed"; Base quic::transport::StreamDataMoved => urn "quic:stream_data_moved"; Base quic::transport::DatagramDataMoved => urn "quic:datagram_data_moved"; Extra quic::transport::MigrationStateUpdated => urn "quic:migration_state_updated"; Base quic::security::KeyUpdated => urn "quic:key_updated"; Base quic::security::KeyDiscarded => urn "quic:key_discarded"; Base quic::recovery::RecoveryParametersSet => urn "quic:recovery_parameters_set"; Core quic::recovery::RecoveryMetricsUpdated => urn "quic:recovery_metrics_updated"; Base quic::recovery::CongestionStateUpdated => urn "quic:congestion_state_updated"; Extra quic::recovery::LossTimerUpdated => urn "quic:loss_timer_updated"; Core quic::recovery::PacketLost => urn "quic:packet_lost"; Extra quic::recovery::MarkedForRetransmit => urn "quic:marked_for_retransmit"; Extra quic::recovery::ECNStateUpdated => urn "quic:ecn_state_updated"; Core loglevel::Error => urn "loglevel:error"; Base loglevel::Warning => urn "loglevel:warning"; Extra loglevel::Info => urn "loglevel:info"; Extra loglevel::Debug => urn "loglevel:debug"; Extra loglevel::Verbose => urn "loglevel:verbose"; } /// serialize/deserialize as hex string, but store as bytes in memory #[serde_with::serde_as] #[derive(Debug, Clone, From, Into, Serialize, Deserialize, PartialEq, Eq)] #[serde(transparent)] pub struct HexString(#[serde_as(as = "serde_with::hex::Hex")] Bytes); impl Display for HexString { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:x}", self.0) } } #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct RawInfo { /// the full byte length of the entity (e.g., packet or frame), /// including possible headers and trailers length: Option, /// the byte length of the entity's payload, /// excluding possible headers or trailers payload_length: Option, /// the (potentially truncated) contents of the full entity, /// including headers and possibly trailers #[builder(setter(custom))] data: Option, } impl RawInfoBuilder { /// the (potentially truncated) contents of the full entity, /// including headers and possibly trailers pub fn data(&mut self, data: D) -> &mut Self { self.data = telemetry::filter::raw_data().then(|| Some(data.to_bytes().into())); self } } impl From for RawInfo { fn from(data: D) -> Self { build!(RawInfo { length: data.len() as u64, data: data }) } } /// ``` rust, ignore /// crate::gen_builder_method! { /// FooBuilder => Foo; /// BarBuilder => Bar; /// } /// ``` #[doc(hidden)] #[macro_export] // used in this crate only macro_rules! gen_builder_method { ( $($builder:ty => $event:ty;)* ) => { $( $crate::gen_builder_method!{@impl_one $event => $builder ;} )* }; (@impl_one $event:ty => $builder:ty ; ) => { impl $event { pub fn builder() -> $builder { Default::default() } } impl $builder { pub fn build(&mut self) -> $event { self.fallible_build().expect("Failed to build event") } } }; } gen_builder_method! { LogFileBuilder => LogFile; QlogFileBuilder => QlogFile; QlogFileSeqBuilder => QlogFileSeq; TraceBuilder => Trace; TraceSeqBuilder => TraceSeq; TraceErrorBuilder => TraceError; CommonFieldsBuilder => CommonFields; VantagePointBuilder => VantagePoint; EventBuilder => Event; ReferenceTimeBuilder => ReferenceTime; RawInfoBuilder => RawInfo; } mod rollback { use super::*; use crate::{build, legacy}; impl TryFrom for legacy::EventData { type Error = (); #[rustfmt::skip] fn try_from(value: EventData) -> Result { match value { EventData::ServerListening(data) => Ok(legacy::EventData::ServerListening(data.into())), EventData::ConnectionStarted(data) => Ok(legacy::EventData::ConnectionStarted(data.into())), EventData::ConnectionClosed(data) => Ok(legacy::EventData::ConnectionClosed(data.into())), EventData::ConnectionIdUpdated(data) => Ok(legacy::EventData::ConnectionIdUpdated(data.into())), EventData::SpinBitUpdated(data) => Ok(legacy::EventData::SpinBitUpdated(data.into())), EventData::ConnectionStateUpdated(data) => Ok(legacy::EventData::ConnectionStateUpdated(data.into())), EventData::PathAssigned(_data) => Err(()), EventData::MtuUpdated(_data) => Err(()), EventData::VersionInformation(data) => Ok(legacy::EventData::VersionInformation(data.into())), EventData::ALPNInformation(data) => Ok(legacy::EventData::AlpnInformation(data.into())), EventData::ParametersSet(data) => Ok(legacy::EventData::TransportParametersSet(data.into())), EventData::ParametersRestored(data) => Ok(legacy::EventData::TransportParametersRestored(data.into())), EventData::PacketSent(data) => Ok(legacy::EventData::PacketSent(data.into())), EventData::PacketReceived(data) => Ok(legacy::EventData::PacketReceived(data.into())), EventData::PacketDropped(data) => Ok(legacy::EventData::PacketDropped(data.into())), EventData::PacketBuffered(data) => Ok(legacy::EventData::PacketBuffered(data.into())), EventData::PacketsAcked(data) => Ok(legacy::EventData::PacketsAcked(data.into())), EventData::UdpDatagramSent(data) => Ok(legacy::EventData::DatagramsSent(data.into())), EventData::UdpDatagramReceived(data) => Ok(legacy::EventData::DatagramsReceived(data.into())), EventData::UdpDatagramDropped(data) => Ok(legacy::EventData::DatagramDropped(data.into())), EventData::StreamStateUpdated(data) => Ok(legacy::EventData::StreamStateUpdated(data.into())), EventData::FramesProcessed(data) => Ok(legacy::EventData::FramesProcessed(data.into())), EventData::StreamDataMoved(data) => Ok(legacy::EventData::DataMoved(data.into())), EventData::DatagramDataMoved(_data) => Err(()), EventData::MigrationStateUpdated(_data) => Err(()), EventData::KeyUpdated(data) => Ok(legacy::EventData::KeyUpdated(data.into())), EventData::KeyDiscarded(data) => Ok(legacy::EventData::KeyDiscarded(data.into())), EventData::RecoveryParametersSet(data) => Ok(legacy::EventData::RecoveryParametersSet(data.into())), EventData::RecoveryMetricsUpdated(data) => Ok(legacy::EventData::MetricsUpdated(data.into())), EventData::CongestionStateUpdated(data) => Ok(legacy::EventData::CongestionStateUpdated(data.into())), EventData::LossTimerUpdated(data) => Ok(legacy::EventData::LossTimerUpdated(data.into())), EventData::PacketLost(data) => Ok(legacy::EventData::PacketLost(data.into())), EventData::MarkedForRetransmit(data) => Ok(legacy::EventData::MarkedForRetransmit(data.into())), EventData::ECNStateUpdated(_data) => Err(()), EventData::Error(data) => Ok(legacy::EventData::GenericError(data.into())), EventData::Warning(data) => Ok(legacy::EventData::GenericWarning(data.into())), EventData::Info(data) => Ok(legacy::EventData::GenericInfo(data.into())), EventData::Debug(data) => Ok(legacy::EventData::GenericDebug(data.into())), EventData::Verbose(data) => Ok(legacy::EventData::GenericVerbose(data.into())), } } } impl From for legacy::TimeFormat { fn from(value: TimeFormat) -> Self { match value { // note: depending on reference_time //TOOD: check reference_time here TimeFormat::RelativeToEpoch => legacy::TimeFormat::Absolute, TimeFormat::RelativeToPreviousEvent => legacy::TimeFormat::Delta, } } } impl From for legacy::ProtocolType { fn from(value: ProtocolTypeList) -> Self { value .0 .into_iter() .map(|x| x.into()) .collect::>() .into() } } impl TryFrom for legacy::Event { type Error = (); fn try_from(mut event: Event) -> Result { if let Some(system_info) = event.system_info { let value = serde_json::to_value(system_info).unwrap(); event.custom_fields.insert("system_info".to_owned(), value); } if let Some(path) = event.path { let value = serde_json::to_value(path).unwrap(); event.custom_fields.insert("path".to_owned(), value); } Ok(build!(legacy::Event { time: event.time, data: { legacy::EventData::try_from(event.data)? }, ?time_format: event.time_format, ?protocol_type: event.protocol_types, ?group_id: event.group_id, custom_fields: event.custom_fields })) } } } #[cfg(test)] mod tests { use std::sync::Arc; use qbase::cid::ConnectionId; use super::*; use crate::{loglevel::Warning, quic::connectivity::ConnectionStarted, telemetry::ExportEvent}; #[test] fn custom_fields() { let odcid = ConnectionID::from(ConnectionId::from_slice(&[ 0x61, 0xb6, 0x91, 0x78, 0x80, 0xf7, 0x95, 0xee, ])); let common_fields = build!(CommonFields { path: "".to_owned(), time_format: TimeFormat::default(), reference_time: ReferenceTime::default(), protocol_types: ProtocolTypeList::from(vec![ProtocolType::quic()]), group_id: GroupID::from(odcid), }); let expect = r#"{ "path": "", "time_format": "relative_to_epoch", "reference_time": { "clock_type": "system", "epoch": "1970-01-01T00:00:00.000Z" }, "protocol_types": [ "QUIC" ], "group_id": "61b6917880f795ee" }"#; assert_eq!( serde_json::to_string_pretty(&common_fields).unwrap(), expect ); let with_custom_fields = r#"{ "path": "", "time_format": "relative_to_epoch", "reference_time": { "clock_type": "system", "epoch": "1970-01-01T00:00:00.000Z" }, "protocol_types": [ "QUIC" ], "group_id": "61b6917880f795ee", "pathway": "from A to relay", "customB": "some other extensions" }"#; let des = serde_json::from_str::(with_custom_fields).unwrap(); let filed_string = serde_json::to_string_pretty(&des).unwrap(); let des2 = serde_json::from_str::(&filed_string).unwrap(); assert_eq!(des, des2); } #[test] fn event_data() { let data = EventData::from(build!(Warning { message: "deepseek(已深度思考(用时0秒)):服务器繁忙,请稍后再试。", code: 255u64, })); let event = build!(Event { time: 1.0, data: data.clone(), }); let expect = r#"{ "time": 1.0, "name": "loglevel:warning", "data": { "code": 255, "message": "deepseek(已深度思考(用时0秒)):服务器繁忙,请稍后再试。" } }"#; assert_eq!(serde_json::to_string_pretty(&event).unwrap(), expect); assert_eq!(data.importance(), EventImportance::Base); } #[test] fn rollback() { fn group_id() -> GroupID { GroupID::from(ConnectionID::from(ConnectionId::from_slice(&[ 0xfe, 0xdc, 0xba, 0x09, 0x87, 0x65, 0x43, 0x32, ]))) } fn protocol_types() -> Vec { vec!["QUIC".to_owned(), "UNKNOW".to_owned()] } struct TestBroker; impl ExportEvent for TestBroker { fn emit(&self, event: Event) { let legacy = legacy::Event::try_from(event).unwrap(); let event = serde_json::to_value(legacy).unwrap(); let data = serde_json::json!({ "ip_version": "v4", "src_ip": "127.0.0.1", "dst_ip": "192.168.31.1", "protocol": "QUIC", "src_port": 23456, "dst_port": 21 }); // in 10: this callde protocol_types let protocol_type = serde_json::json!(["QUIC", "UNKNOW"]); assert_eq!(event["data"], data); assert_eq!(event["protocol_types"], serde_json::Value::Null); assert_eq!(event["protocol_type"], protocol_type); assert_eq!(event["to_router"], true); } } span!( Arc::new(TestBroker), group_id = group_id(), protocol_types = protocol_types() ) .in_scope(|| { let src = "127.0.0.1:23456".parse().unwrap(); let dst = "192.168.31.1:21".parse().unwrap(); event!(ConnectionStarted { socket: (src, dst) }, to_router = true) }) } } ================================================ FILE: qevent/src/loglevel.rs ================================================ use derive_builder::Builder; use serde::{Deserialize, Serialize}; #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct Error { code: Option, message: Option, } #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct Warning { code: Option, message: Option, } #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct Info { message: String, } #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct Debug { message: String, } #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct Verbose { message: String, } crate::gen_builder_method! { ErrorBuilder => Error; WarningBuilder => Warning; InfoBuilder => Info; DebugBuilder => Debug; VerboseBuilder => Verbose; } mod rollback { use super::*; use crate::{build, legacy}; impl From for legacy::GenericError { fn from(value: Error) -> Self { build!(legacy::GenericError { ?code: value.code, ?message: value.message }) } } impl From for legacy::GenericWarning { fn from(value: Warning) -> Self { build!(legacy::GenericWarning { ?code: value.code, ?message: value.message }) } } impl From for legacy::GenericInfo { fn from(value: Info) -> Self { build!(legacy::GenericInfo { message: value.message }) } } impl From for legacy::GenericDebug { fn from(value: Debug) -> Self { build!(legacy::GenericDebug { message: value.message }) } } impl From for legacy::GenericVerbose { fn from(value: Verbose) -> Self { build!(legacy::GenericVerbose { message: value.message }) } } } ================================================ FILE: qevent/src/macro_support.rs ================================================ pub use serde_json::Value; ================================================ FILE: qevent/src/macros.rs ================================================ /// A macro to crate a qlog event struct from a set of fields. #[macro_export] macro_rules! build { ($struct:ty { $($tt:tt)* }) => {{ let mut __builder = <$struct>::builder(); $crate::build!(@field __builder, $($tt)*); __builder.build() }}; (@field $builder:expr, $field:ident $(, $($remain:tt)* )? ) => { $builder.$field($field); $crate::build!(@field $builder $(, $($remain)* )? ); }; (@field $builder:expr, $field:ident: Map { $($tt:tt)* } $(, $($remain:tt)* )? ) => { $builder.$field($crate::map!{ $($tt)* }); $crate::build!(@field $builder $(, $($remain)* )? ); }; (@field $builder:expr, $field:ident: $struct:ty { $($tt:tt)* } $(, $($remain:tt)* )? ) => { $builder.$field($crate::build!($struct { $($tt)* })); $crate::build!(@field $builder $(, $($remain)* )? ); }; (@field $builder:expr, $field:ident: $value:expr $(, $($remain:tt)* )? ) => { $builder.$field($value); $crate::build!(@field $builder $(, $($remain)* )? ); }; (@field $builder:expr, ? $field:ident $(, $($remain:tt)* )? ) => { if let Some(__value) = $field { $builder.$field(__value); } $crate::build!(@field $builder $(, $($remain)* )? ); }; (@field $builder:expr, ? $field:ident: $value:expr $(, $($remain:tt)* )? ) => { if let Some(__value) = $value { $builder.$field(__value); } $crate::build!(@field $builder $(, $($remain)* )? ); }; (@field $builder:expr $(,)?) => {}; } /// A macro to create a `HashMap` from a set of fields. /// ``` rust, ignore /// qevent::map! { /// field1: value, /// field2, /// field3: Map { /// subfield1: value, /// }, /// event: loglevel::Error { /// message: "An error occurred", /// } /// } /// ``` #[macro_export] macro_rules! map { {$($tt:tt)*}=>{ { let mut map = ::std::collections::HashMap::::new(); $crate::map_internal!(map, $($tt)*); map }}; } #[doc(hidden)] #[macro_export] macro_rules! map_internal { ($map:expr, $field:ident $(, $($remain:tt)* )?) => { $map.insert(stringify!($field).to_owned(), $field.into()); $crate::map_internal!($map $(, $($remain)* )?) }; ($map:expr, $field:ident: Map {$($tt:tt)*} $(, $($remain:tt)* )?) => { $map.insert(stringify!($field).to_owned(), $crate::map!{ $($tt)* }); $crate::map_internal!($map $(, $($remain)* )?) }; ($map:expr, $field:ident: $struct:ty {$($tt:tt)*} $(, $($remain:tt)* )?) => { $map.insert(stringify!($field).to_owned(), $crate::build!($struct { $($tt)* }).into()); $crate::map_internal!($map $(, $($remain)* )?) }; ($map:expr, $field:ident: $value:expr $(, $($remain:tt)* )?) => { $map.insert(stringify!($field).to_owned(), $value.into()); $crate::map_internal!($map $(, $($remain)* )?) }; ($map:expr $(,)?) => {}; } ================================================ FILE: qevent/src/packet.rs ================================================ use bytes::{BufMut, buf::UninitSlice}; use derive_more::Deref; use qbase::{ net::tx::Signals, packet::{ RecordFrame, header::{ EncodeHeader, GetDcid, GetScid, GetType, io::WriteHeader, long::LongHeader, short::OneRttHeader, }, io::{AssemblePacket, PacketInfo, PacketWriter as BasePacketWriter}, keys::DirectionalKeys, number::PacketNumber, signal::KeyPhaseBit, }, util::ContinuousData, }; use crate::{ RawInfo, quic::{ PacketHeader as QEventPacketHeader, PacketHeaderBuilder as QEventPacketHeaderBuilder, QuicFrame as QEventFrame, QuicFramesCollector, transport::PacketSent, }, }; struct PacketLogger { header: QEventPacketHeaderBuilder, frames: QuicFramesCollector, } impl PacketLogger { pub fn record_frame(&mut self, frame: impl Into) { self.frames.extend([frame]); } pub fn log_sent(mut self, packet: &BasePacketWriter) { // TODO: 如果以后涉及到组装VN,Retry,这里的逻辑得改 if !packet.is_short_header() { self.header.length((packet.payload_len()) as u16); } crate::event!(PacketSent { header: self.header.build(), frames: self.frames, raw: RawInfo { length: packet.packet_len() as u64, payload_length: packet.payload_len() as u64, data: packet.buffer(), }, // TODO: trigger }) } } #[derive(Deref)] pub struct PacketWriter<'b> { #[deref] writer: BasePacketWriter<'b>, logger: PacketLogger, } impl<'b> AsRef> for PacketWriter<'b> { #[inline] fn as_ref(&self) -> &BasePacketWriter<'b> { &self.writer } } impl<'b> PacketWriter<'b> { pub fn new_long( header: &LongHeader, buffer: &'b mut [u8], pn: (u64, PacketNumber), keys: DirectionalKeys, ) -> Result where S: EncodeHeader, LongHeader: GetType, for<'a> &'a mut [u8]: WriteHeader>, { Ok(Self { writer: BasePacketWriter::new_long(header, buffer, pn, keys)?, logger: PacketLogger { header: { let mut builder = QEventPacketHeader::builder(); builder .packet_type(header.get_type()) .packet_number(pn.0) .scil(header.scid().len() as u8) .scid(*header.scid()) .dcil(header.dcid().len() as u8) .dcid(*header.dcid()); builder }, frames: QuicFramesCollector::new(), }, }) } pub fn new_short( header: &OneRttHeader, buffer: &'b mut [u8], pn: (u64, PacketNumber), keys: DirectionalKeys, key_phase: KeyPhaseBit, ) -> Result { Ok(Self { writer: BasePacketWriter::new_short(header, buffer, pn, keys, key_phase)?, logger: PacketLogger { header: { let mut builder = QEventPacketHeader::builder(); builder .packet_type(header.get_type()) .packet_number(pn.0) .dcil(header.dcid().len() as u8) .dcid(*header.dcid()); builder }, frames: QuicFramesCollector::new(), }, }) } } unsafe impl<'b> BufMut for PacketWriter<'b> { #[inline] fn remaining_mut(&self) -> usize { self.writer.remaining_mut() } #[inline] unsafe fn advance_mut(&mut self, cnt: usize) { unsafe { self.writer.advance_mut(cnt) } } #[inline] fn chunk_mut(&mut self) -> &mut UninitSlice { self.writer.chunk_mut() } #[inline] fn put_bytes(&mut self, val: u8, cnt: usize) { if cnt > 0 { self.logger.record_frame(QEventFrame::Padding { length: Some(cnt as _), payload_length: cnt as _, }); self.writer.put_bytes(val, cnt); } } } impl<'b, F, D: ContinuousData> RecordFrame for PacketWriter<'b> where for<'f> &'f F: Into, BasePacketWriter<'b>: RecordFrame, { #[inline] fn record_frame(&mut self, frame: &F) { self.logger.record_frame(frame); self.writer.record_frame(frame); } } impl<'b> AssemblePacket for PacketWriter<'b> { #[inline] fn encrypt_and_protect_packet(self) -> (usize, PacketInfo) { self.logger.log_sent(&self.writer); self.writer.encrypt_and_protect_packet() } } ================================================ FILE: qevent/src/quic/connectivity.rs ================================================ use std::net::SocketAddr; use derive_builder::Builder; use derive_more::From; use qbase::{ error::{AppError, Error, ErrorKind, QuicError}, frame::{AppCloseFrame, ConnectionCloseFrame, QuicCloseFrame}, }; use super::{ ApplicationCode, ConnectionID, CryptoError, IPAddress, IpVersion, Owner, PathEndpointInfo, TransportError, }; use crate::{Deserialize, PathID, Serialize}; /// Emitted when the server starts accepting connections. It has Extra /// importance level; see Section 9.2 of [QLOG-MAIN]. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct ServerListening { #[builder(default)] ip_v4: Option, #[builder(default)] ip_v6: Option, #[builder(default)] port_v4: Option, #[builder(default)] port_v6: Option, /// the server will always answer client initials with a retry /// (no 1-RTT connection setups by choice) #[builder(default)] retry_required: Option, } impl ServerListeningBuilder { pub fn address(&mut self, socket_addr: SocketAddr) -> &mut Self { match socket_addr { SocketAddr::V4(addr) => self.ip_v4(addr.ip().to_string()).port_v4(addr.port()), SocketAddr::V6(addr) => self.ip_v6(addr.ip().to_string()).port_v6(addr.port()), } } } /// The connection_started event is used for both attempting (client- /// perspective) and accepting (server-perspective) new connections. Note /// that while there is overlap with the connection_state_updated event, /// this event is separate event in order to capture additional data that /// can be useful to log. It has Base importance level; see Section 9.2 /// of [QLOG-MAIN]. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct ConnectionStarted { ip_version: IpVersion, src_ip: IPAddress, dst_ip: IPAddress, // transport layer protocol #[builder(default = "ConnectionStarted::default_protocol()")] #[serde(default = "ConnectionStarted::default_protocol")] protocol: String, #[builder(default)] src_port: Option, #[builder(default)] dst_port: Option, #[builder(default)] src_cid: Option, #[builder(default)] dst_cid: Option, } impl ConnectionStartedBuilder { /// helper method to set the source and destination socket addresses pub fn socket(&mut self, (src, dst): (SocketAddr, SocketAddr)) -> &mut Self { debug_assert_eq!(src.is_ipv4(), dst.is_ipv4()); self.ip_version(if src.is_ipv4() { IpVersion::V4 } else { IpVersion::V6 }) .src_ip(src.ip().to_string()) .dst_ip(dst.ip().to_string()) .src_port(src.port()) .dst_port(dst.port()) } } impl ConnectionStarted { pub fn default_protocol() -> String { String::from("QUIC") } } /// The connection_closed event is used for logging when a connection was /// closed, typically when an error or timeout occurred. It has Base /// importance level; see Section 9.2 of [QLOG-MAIN]. /// /// Note that this event has overlap with the connection_state_updated /// event, as well as the CONNECTION_CLOSE frame. However, in practice, /// when analyzing large deployments, it can be useful to have a single /// event representing a connection_closed event, which also includes an /// additional reason field to provide more information. Furthermore, it /// is useful to log closures due to timeouts, which are difficult to /// reflect using the other options. /// /// The connection_closed event is intended to be logged either when the /// local endpoint silently discards the connection due to an idle /// timeout, when a CONNECTION_CLOSE frame is sent (the connection enters /// the 'closing' state on the sender side), when a CONNECTION_CLOSE /// frame is received (the connection enters the 'draining' state on the /// receiver side) or when a Stateless Reset packet is received (the /// connection is discarded at the receiver side). Connectivity-related /// updates after this point (e.g., exiting a 'closing' or 'draining' /// state), should be logged using the connection_state_updated event /// instead. /// /// In QUIC there are two main connection-closing error categories: /// connection and application errors. They have well-defined error /// codes and semantics. Next to these however, there can be internal /// errors that occur that may or may not get mapped to the official /// error codes in implementation-specific ways. As such, multiple error /// codes can be set on the same event to reflect this. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct ConnectionClosed { /// which side closed the connection owner: Option, connection_code: Option, application_code: Option, internal_code: Option, reason: Option, trigger: Option, } impl ConnectionClosedBuilder { pub fn ccf(&mut self, ccf: &ConnectionCloseFrame) -> &mut Self { match &ccf { ConnectionCloseFrame::Quic(frame) => self.quic_close_frame(frame), ConnectionCloseFrame::App(frame) => self.app_close_frame(frame), } } fn quic_close_frame(&mut self, frame: &QuicCloseFrame) -> &mut ConnectionClosedBuilder { self.connection_code(frame.error_kind()) .reason(frame.reason().to_owned()) } fn app_close_frame(&mut self, frame: &AppCloseFrame) -> &mut ConnectionClosedBuilder { self.application_code(frame.error_code() as u32) .reason(frame.reason().to_owned()) } pub fn quic_error(&mut self, error: &QuicError) -> &mut Self { self.connection_code(error.kind()) .reason(error.reason().to_owned()) } pub fn app_error(&mut self, error: &AppError) -> &mut Self { self.application_code(error.error_code() as u32) .reason(error.reason().to_owned()) } pub fn error(&mut self, error: &Error) { match error { Error::Quic(quic_error) => self.quic_error(quic_error), Error::App(app_error) => self.app_error(app_error), }; } } #[derive(Debug, Clone, Copy, From, Serialize, Deserialize, PartialEq, Eq)] #[serde(untagged)] pub enum ConnectionCode { TransportError(TransportError), CryptoError(CryptoError), Value(u32), } impl From for super::ConnectionCloseErrorCode { fn from(value: ConnectionCode) -> Self { match value { ConnectionCode::TransportError(err) => err.into(), ConnectionCode::CryptoError(err) => err.into(), ConnectionCode::Value(code) => (code as u64).into(), } } } impl From for ConnectionCode { fn from(kind: ErrorKind) -> Self { match kind { ErrorKind::None => TransportError::NoError.into(), ErrorKind::Internal => TransportError::InternalError.into(), ErrorKind::ConnectionRefused => TransportError::ConnectionRefused.into(), ErrorKind::FlowControl => TransportError::FlowControlError.into(), ErrorKind::StreamLimit => TransportError::StreamLimitError.into(), ErrorKind::StreamState => TransportError::StreamStateError.into(), ErrorKind::FinalSize => TransportError::FinalSizeError.into(), ErrorKind::FrameEncoding => TransportError::FrameEncodingError.into(), ErrorKind::TransportParameter => TransportError::TransportParameterError.into(), ErrorKind::ConnectionIdLimit => TransportError::ConnectionIdLimitError.into(), ErrorKind::ProtocolViolation => TransportError::ProtocolViolation.into(), ErrorKind::InvalidToken => TransportError::InvalidToken.into(), ErrorKind::Application => TransportError::ApplicationError.into(), ErrorKind::CryptoBufferExceeded => TransportError::CryptoBufferExceeded.into(), ErrorKind::KeyUpdate => TransportError::KeyUpdateError.into(), ErrorKind::AeadLimitReached => TransportError::AeadLimitReached.into(), ErrorKind::NoViablePath => TransportError::NoViablePath.into(), ErrorKind::Crypto(code) => CryptoError(code).into(), } } } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum ConnectionCloseTrigger { IdleTimeout, Application, Error, VersionMismatch, /// when received from peer StatelessReset, /// when it is unclear what triggered the CONNECTION_CLOSE Unspecified, } /// The connection_id_updated event is emitted when either party updates /// their current Connection ID. As this typically happens only /// sparingly over the course of a connection, using this event is more /// efficient than logging the observed CID with each and every /// packet_sent or packet_received events. It has Base importance level; /// see Section 9.2 of [QLOG-MAIN]. /// /// The connection_id_updated event is viewed from the perspective of the /// endpoint applying the new ID. As such, when the endpoint receives a /// new connection ID from the peer, the owner field will be "remote". /// When the endpoint updates its own connection ID, the owner field will /// be "local". /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct ConnectionIdUpdated { owner: Owner, #[builder(default)] old: Option, #[builder(default)] new: Option, } /// The spin_bit_updated event conveys information about the QUIC latency /// spin bit; see Section 17.4 of [QUIC-TRANSPORT]. The event is emitted /// when the spin bit changes value, it SHOULD NOT be emitted if the spin /// bit is set without changing its value. It has Base importance level; /// see Section 9.2 of [QLOG-MAIN]. /// /// [QUIC-TRANSPORT]: https://www.rfc-editor.org/rfc/rfc9000 /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into), build_fn(private, name = "fallible_build"))] pub struct SpinBitUpdated { state: bool, } /// The connection_state_updated event is used to track progress through /// QUIC's complex handshake and connection close procedures. It has /// Base importance level; see Section 9.2 of [QLOG-MAIN]. /// /// [QUIC-TRANSPORT] does not contain an exhaustive flow diagram with /// possible connection states nor their transitions (though some are /// explicitly mentioned, like the 'closing' and 'draining' states). As /// such, this document *non-exhaustively* defines those states that are /// most likely to be useful for debugging QUIC connections. /// /// QUIC implementations SHOULD mainly log the simplified /// BaseConnectionStates, adding the more fine-grained /// GranularConnectionStates when more in-depth debugging is required. /// Tools SHOULD be able to deal with both types equally. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 /// [QUIC-TRANSPORT]: https://www.rfc-editor.org/rfc/rfc9000 #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct ConnectionStateUpdated { #[builder(default)] old: Option, new: ConnectionState, } #[derive(Debug, Clone, Copy, From, Serialize, Deserialize, PartialEq, Eq)] #[serde(untagged)] pub enum ConnectionState { Base(BaseConnectionStates), Granular(GranularConnectionStates), } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum BaseConnectionStates { /// Initial packet sent/received Attempted, /// Handshake packet sent/received HandshakeStarted, /// Both sent a TLS Finished message /// and verified the peer's TLS Finished message /// 1-RTT packets can be sent /// RFC 9001 Section 4.1.1 HandshakeComplete, /// CONNECTION_CLOSE sent/received, /// stateless reset received or idle timeout Closed, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum GranularConnectionStates { /// RFC 9000 Section 8.1 /// client sent Handshake packet OR /// client used connection ID chosen by the server OR /// client used valid address validation token PeerValidated, /// 1-RTT data can be sent by the server, /// but handshake is not done yet /// (server has sent TLS Finished; sometimes called 0.5 RTT data) EarlyWrite, /// HANDSHAKE_DONE sent/received. /// RFC 9001 Section 4.1.2 HandshakeConfirmed, /// CONNECTION_CLOSE sent Closing, /// CONNECTION_CLOSE received Draining, /// draining or closing period done, connection state discarded Closed, } /// This event is used to associate a single PathID's value with other /// parameters that describe a unique network path. /// /// As described in [QLOG-MAIN], each qlog event can be linked to a /// single network path by means of the top-level "path" field, whose /// value is a PathID. However, since it can be cumbersome to encode /// additional path metadata (such as IP addresses or Connection IDs) /// directly into the PathID, this event allows such an association to /// happen separately. As such, PathIDs can be short and unique, and can /// even be updated to be associated with new metadata as the /// connection's state evolves. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct PathAssigned { path_id: PathID, /// the information for traffic going towards the remote receiver #[builder(default)] path_remote: Option, /// the information for traffic coming in at the local endpoint #[builder(default)] path_local: Option, } /// The mtu_updated event indicates that the estimated Path MTU was /// updated. This happens as part of the Path MTU discovery process. It /// has Extra importance level; see Section 9.2 of [QLOG-MAIN]. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct MtuUpdated { #[builder(default)] old: Option, new: u32, /// at some point, MTU discovery stops, as a "good enough" /// packet size has been found #[builder(default)] #[serde(default)] done: bool, } crate::gen_builder_method! { ServerListeningBuilder => ServerListening; ConnectionStartedBuilder => ConnectionStarted; ConnectionClosedBuilder => ConnectionClosed; ConnectionIdUpdatedBuilder => ConnectionIdUpdated; SpinBitUpdatedBuilder => SpinBitUpdated; ConnectionStateUpdatedBuilder => ConnectionStateUpdated; PathAssignedBuilder => PathAssigned; MtuUpdatedBuilder => MtuUpdated; } mod rollback { use super::*; use crate::{build, legacy::quic as legacy}; impl From for legacy::ConnectivityServerListening { #[inline] fn from(value: ServerListening) -> Self { build!(legacy::ConnectivityServerListening { ?ip_v4: value.ip_v4, ?ip_v6: value.ip_v6, ?port_v4: value.port_v4, ?port_v6: value.port_v6, ?retry_required: value.retry_required, }) } } impl From for legacy::ConnectivityConnectionStarted { #[inline] fn from(value: ConnectionStarted) -> Self { build!(legacy::ConnectivityConnectionStarted { ip_version: value.ip_version, src_ip: value.src_ip, dst_ip: value.dst_ip, protocol: value.protocol, ?src_port: value.src_port, ?dst_port: value.dst_port, ?src_cid: value.src_cid, ?dst_cid: value.dst_cid, }) } } impl From for legacy::CryptoError { #[inline] fn from(value: CryptoError) -> Self { legacy::CryptoError::from(value.0) } } impl From for legacy::ConnectionCode { #[inline] fn from(value: ConnectionCode) -> Self { match value { ConnectionCode::TransportError(err) => legacy::TransportError::from(err).into(), ConnectionCode::CryptoError(err) => legacy::CryptoError::from(err).into(), ConnectionCode::Value(code) => code.into(), } } } // 这两类型的交集有限 impl TryFrom for legacy::ConnectivityConnectionClosedTrigger { type Error = (); #[inline] fn try_from(value: ConnectionCloseTrigger) -> Result { match value { ConnectionCloseTrigger::IdleTimeout => { Ok(legacy::ConnectivityConnectionClosedTrigger::IdleTimeout) } ConnectionCloseTrigger::Application => { Ok(legacy::ConnectivityConnectionClosedTrigger::Application) } ConnectionCloseTrigger::Error => { Ok(legacy::ConnectivityConnectionClosedTrigger::Error) } ConnectionCloseTrigger::VersionMismatch => { Ok(legacy::ConnectivityConnectionClosedTrigger::VersionMismatch) } ConnectionCloseTrigger::StatelessReset => { Ok(legacy::ConnectivityConnectionClosedTrigger::StatelessReset) } ConnectionCloseTrigger::Unspecified => Err(()), } } } impl From for legacy::ConnectivityConnectionClosed { #[inline] fn from(value: ConnectionClosed) -> Self { build!(legacy::ConnectivityConnectionClosed { ?owner: value.owner, ?connection_code: value.connection_code, ?application_code: value.application_code, ?internal_code: value.internal_code, ?reason: value.reason, ?trigger: value.trigger.and_then(|v| legacy::ConnectivityConnectionClosedTrigger::try_from(v).ok()), }) } } impl From for legacy::ConnectivityConnectionIdUpdated { #[inline] fn from(value: ConnectionIdUpdated) -> Self { build!(legacy::ConnectivityConnectionIdUpdated { owner: value.owner, ?old: value.old, ?new: value.new, }) } } impl From for legacy::ConnectivitySpinBitUpdated { #[inline] fn from(value: SpinBitUpdated) -> Self { build!(legacy::ConnectivitySpinBitUpdated { state: value.state }) } } impl From for legacy::ConnectionState { #[inline] fn from(value: ConnectionState) -> Self { match value { ConnectionState::Base(BaseConnectionStates::Attempted) => { legacy::ConnectionState::Attempted } ConnectionState::Base(BaseConnectionStates::HandshakeStarted) => { legacy::ConnectionState::HandshakeStarted } ConnectionState::Base(BaseConnectionStates::HandshakeComplete) => { legacy::ConnectionState::HandshakeComplete } ConnectionState::Base(BaseConnectionStates::Closed) => { legacy::ConnectionState::Closed } ConnectionState::Granular(GranularConnectionStates::PeerValidated) => { legacy::ConnectionState::PeerValidated } ConnectionState::Granular(GranularConnectionStates::EarlyWrite) => { legacy::ConnectionState::EarlyWrite } ConnectionState::Granular(GranularConnectionStates::HandshakeConfirmed) => { legacy::ConnectionState::HandshakeConfirmed } ConnectionState::Granular(GranularConnectionStates::Closing) => { legacy::ConnectionState::Closing } ConnectionState::Granular(GranularConnectionStates::Draining) => { legacy::ConnectionState::Draining } ConnectionState::Granular(GranularConnectionStates::Closed) => { legacy::ConnectionState::Closed } } } } impl From for legacy::ConnectivityConnectionStateUpdated { #[inline] fn from(value: ConnectionStateUpdated) -> Self { build!(legacy::ConnectivityConnectionStateUpdated { ?old: value.old, new: value.new, }) } } // event not exist in legacy version // impl From for // event not exist in legacy version // impl From for } ================================================ FILE: qevent/src/quic/recovery.rs ================================================ use std::collections::HashMap; use derive_builder::Builder; use serde::{Deserialize, Serialize}; use super::{PacketHeader, PacketNumberSpace, QuicFrame}; /// The recovery_parameters_set event groups initial parameters from both /// loss detection and congestion control into a single event. It has /// Base importance level; see Section 9.2 of [QLOG-MAIN]. /// /// All these settings are typically set once and never change. /// Implementation that do, for some reason, change these parameters /// during execution, MAY emit the recovery_parameters_set event more /// than once /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct RecoveryParametersSet { /// Loss detection, see RFC 9002 Appendix A.2 /// in amount of packets #[builder(default)] reordering_threshold: Option, /// as RTT multiplier #[builder(default)] time_threshold: Option, /// in ms timer_granularity: u16, /// in ms #[builder(default)] initial_rtt: Option, /// congestion control, see RFC 9002 Appendix B.2 /// in bytes. Note that this could be updated after pmtud #[builder(default)] max_datagram_size: Option, /// in bytes #[builder(default)] initial_congestion_window: Option, /// Note that this could change when max_datagram_size changes /// in bytes #[builder(default)] minimum_congestion_window: Option, loss_reduction_factor: Option, /// as PTO multiplier #[builder(default)] persistent_congestion_threshold: Option, /// Additionally, this event can contain any number of unspecified fields /// to support different recovery approaches. #[builder(default)] #[serde(flatten)] #[serde(skip_serializing_if = "HashMap::is_empty")] custom_fields: HashMap, } /// The recovery_metrics_updated event is emitted when one or more of the /// observable recovery metrics changes value. It has Core importance /// level; see Section 9.2 of [QLOG-MAIN]. /// /// This event SHOULD group all possible metric updates that happen at or /// around the same time in a single event (e.g., if min_rtt and /// smoothed_rtt change at the same time, they should be bundled in a /// single recovery_metrics_updated entry, rather than split out into /// two). Consequently, a recovery_metrics_updated event is only /// guaranteed to contain at least one of the listed metrics. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct RecoveryMetricsUpdated { /// Loss detection, see RFC 9002 Appendix A.3 /// all following rtt fields are expressed in ms smoothed_rtt: Option, min_rtt: Option, latest_rtt: Option, rtt_variance: Option, pto_count: Option, /// Congestion control, see RFC 9002 Appendix B.2. /// in bytes congestion_window: Option, bytes_in_flight: Option, /// in bytes ssthresh: Option, /// qlog defined /// sum of all packet number spaces packets_in_flight: Option, /// in bits per second pacing_rate: Option, /// Additionally, the recovery_metrics_updated event can contain any /// number of unspecified fields to support different recovery /// approaches. #[serde(flatten)] #[serde(skip_serializing_if = "HashMap::is_empty")] custom_fields: HashMap, } /// The congestion_state_updated event indicates when the congestion /// controller enters a significant new state and changes its behaviour. /// It has Base importance level; see Section 9.2 of [QLOG-MAIN]. /// /// The values of the event's fields are intentionally unspecified here /// in order to support different Congestion Control algorithms, as these /// typically have different states and even different implementations of /// these states across stacks. For example, for the algorithm defined /// in the QUIC Recovery RFC ("enhanced" New Reno), the following states /// are used: Slow Start, Congestion Avoidance, Application Limited and /// Recovery. Similarly, states can be triggered by a variety of events, /// including detection of Persistent Congestion or receipt of ECN /// markings. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct CongestionStateUpdated { #[builder(default)] old: Option, new: String, #[builder(default)] trigger: Option, } /// The loss_timer_updated event is emitted when a recovery loss timer /// changes state. It has Extra importance level; see Section 9.2 of /// [QLOG-MAIN]. /// /// The three main event types are: /// /// * set: the timer is set with a delta timeout for when it will /// trigger next /// /// * expired: when the timer effectively expires after the delta /// timeout /// /// * cancelled: when a timer is cancelled (e.g., all outstanding /// packets are acknowledged, start idle period) /// /// In order to indicate an active timer's timeout update, a new set /// event is used. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct LossTimerUpdated { /// called "mode" in RFC 9002 A.9. #[builder(default)] timer_type: Option, #[builder(default)] packet_number_space: Option, event_type: EventType, /// if event_type === "set": delta time is in ms from /// this event's timestamp until when the timer will trigger #[builder(default)] delta: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum TimerType { Ack, Pto, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum EventType { Set, Expired, Cancelled, } /// The packet_lost event is emitted when a packet is deemed lost by loss /// detection. It has Core importance level; see Section 9.2 of /// [QLOG-MAIN]. /// /// It is RECOMMENDED to populate the optional trigger field in order to /// help disambiguate among the various possible causes of a loss /// declaration. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct PacketLost { /// should include at least the packet_type and packet_number header: Option, /// not all implementations will keep track of full /// packets, so these are optional frames: Option>, is_mtu_probe_packet: bool, trigger: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum PacketLostTrigger { ReorderingThreshold, TimeThreshold, /// RFC 9002 Section 6.2.4 paragraph 6, MAY PtoExpired, } /// The marked_for_retransmit event indicates which data was marked for /// retransmission upon detection of packet loss (see packet_lost). It /// has Extra importance level; see Section 9.2 of [QLOG-MAIN]. /// /// Similar to the reasoning for the frames_processed event, in order to /// keep the amount of different events low, this signal is grouped into /// in a single event based on existing QUIC frame definitions for all /// types of retransmittable data. /// /// Implementations retransmitting full packets or frames directly can /// just log the constituent frames of the lost packet here (or do away /// with this event and use the contents of the packet_lost event /// instead). Conversely, implementations that have more complex logic /// (e.g., marking ranges in a stream's data buffer as in-flight), or /// that do not track sent frames in full (e.g., only stream offset + /// length), can translate their internal behaviour into the appropriate /// frame instance here even if that frame was never or will never be put /// on the wire. /// /// Much of this data can be inferred if implementations log packet_sent /// events (e.g., looking at overlapping stream data offsets and length, /// one can determine when data was retransmitted). /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into), build_fn(private, name = "fallible_build"))] pub struct MarkedForRetransmit { frames: Vec, } /// The ecn_state_updated event indicates a progression in the ECN state /// machine as described in section A.4 of [QUIC-TRANSPORT]. It has /// Extra importance level; see Section 9.2 of [QLOG-MAIN]. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct ECNStateUpdated { #[builder(default)] old: Option, new: ECNState, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum ECNState { /// ECN testing in progress Testing, /// ECN state unknown, waiting for acknowledgements /// for testing packets Unknown, /// ECN testing failed Failed, /// testing was successful, the endpoint now /// sends packets with ECT(0) marking Capable, } crate::gen_builder_method! { RecoveryParametersSetBuilder => RecoveryParametersSet; RecoveryMetricsUpdatedBuilder => RecoveryMetricsUpdated; CongestionStateUpdatedBuilder => CongestionStateUpdated; LossTimerUpdatedBuilder => LossTimerUpdated; PacketLostBuilder => PacketLost; MarkedForRetransmitBuilder => MarkedForRetransmit; ECNStateUpdatedBuilder => ECNStateUpdated; } mod rollback { use super::*; use crate::{build, legacy::quic as legacy}; impl From for legacy::RecoveryParametersSet { fn from(value: RecoveryParametersSet) -> Self { build!(legacy::RecoveryParametersSet { ?reordering_threshold: value.reordering_threshold, ?time_threshold: value.time_threshold, timer_granularity: value.timer_granularity, ?initial_rtt: value.initial_rtt, ?max_datagram_size: value.max_datagram_size, ?initial_congestion_window: value.initial_congestion_window, ?minimum_congestion_window: value.minimum_congestion_window.map(|v| v as u32), ?loss_reduction_factor: value.loss_reduction_factor, ?persistent_congestion_threshold: value.persistent_congestion_threshold, custom_fields: value.custom_fields, }) } } impl From for legacy::RecoveryMetricsUpdated { fn from(value: RecoveryMetricsUpdated) -> Self { build!(legacy::RecoveryMetricsUpdated { ?smoothed_rtt: value.smoothed_rtt, ?min_rtt: value.min_rtt, ?latest_rtt: value.latest_rtt, ?rtt_variance: value.rtt_variance, ?pto_count: value.pto_count, ?congestion_window: value.congestion_window, ?bytes_in_flight: value.bytes_in_flight, ?ssthresh: value.ssthresh, ?packets_in_flight: value.packets_in_flight, ?pacing_rate: value.pacing_rate, custom_fields: value.custom_fields, }) } } impl From for legacy::RecoveryCongestionStateUpdated { fn from(value: CongestionStateUpdated) -> Self { build!(legacy::RecoveryCongestionStateUpdated { ?old: value.old, new: value.new, ?trigger: match value.trigger { Some(s) if s == "persistent_congestion" => Some(legacy::RecoveryCongestionStateUpdatedTrigger::PersistentCongestion), Some(s) if s == "ecn" => Some(legacy::RecoveryCongestionStateUpdatedTrigger::Ecn), _ => None, }, }) } } impl From for legacy::LossTimerType { #[inline] fn from(value: TimerType) -> Self { match value { TimerType::Ack => legacy::LossTimerType::Ack, TimerType::Pto => legacy::LossTimerType::Pto, } } } impl From for legacy::LossTimerEventType { #[inline] fn from(value: EventType) -> Self { match value { EventType::Set => legacy::LossTimerEventType::Set, EventType::Expired => legacy::LossTimerEventType::Expired, EventType::Cancelled => legacy::LossTimerEventType::Cancelled, } } } impl From for legacy::RecoveryLossTimerUpdated { fn from(value: LossTimerUpdated) -> Self { build!(legacy::RecoveryLossTimerUpdated { ?timer_type: value.timer_type, ?packet_number_space: value.packet_number_space, event_type: value.event_type, ?delta: value.delta, }) } } impl From for legacy::RecoveryPacketLostTrigger { #[inline] fn from(value: PacketLostTrigger) -> Self { match value { PacketLostTrigger::ReorderingThreshold => { legacy::RecoveryPacketLostTrigger::ReorderingThreshold } PacketLostTrigger::TimeThreshold => { legacy::RecoveryPacketLostTrigger::TimeThreshold } PacketLostTrigger::PtoExpired => legacy::RecoveryPacketLostTrigger::PtoExpired, } } } impl From for legacy::RecoveryPacketLost { fn from(value: PacketLost) -> Self { build!(legacy::RecoveryPacketLost { ?header: value.header, ?frames: value.frames.map(|v| v.into_iter().map(Into::into).collect::>()), ?trigger: value.trigger, }) } } impl From for legacy::RecoveryMarkedForRetransmit { fn from(value: MarkedForRetransmit) -> Self { build!(legacy::RecoveryMarkedForRetransmit { frames: value.frames.into_iter().map(Into::into).collect::>(), }) } } } ================================================ FILE: qevent/src/quic/security.rs ================================================ use derive_builder::Builder; use serde::{Deserialize, Serialize}; use super::KeyType; use crate::HexString; /// The key_updated event has Base importance level; see Section 9.2 of /// [QLOG-MAIN] /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct KeyUpdated { key_type: KeyType, #[builder(default)] old: Option, #[builder(default)] new: Option, /// needed for 1RTT key updates #[builder(default)] key_phase: Option, #[builder(default)] trigger: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum KeyUpdatedTrigger { /// (e.g., initial, handshake and 0-RTT keys /// are generated by TLS) Tls, RemoteUpdate, LocalUpdate, } /// The key_discarded event has Base importance level; see Section 9.2 of /// [QLOG-MAIN]. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct KeyDiscarded { key_type: KeyType, #[builder(default)] key: Option, /// needed for 1RTT key updates key_phase: Option, #[builder(default)] trigger: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum KeyDiscardedTrigger { /// (e.g., initial, handshake and 0-RTT keys /// are generated by TLS) Tls, RemoteUpdate, LocalUpdate, } crate::gen_builder_method! { KeyUpdatedBuilder => KeyUpdated; KeyDiscardedBuilder => KeyDiscarded; } mod rollback { use super::*; use crate::{build, legacy::quic as legacy}; impl From for legacy::KeyType { fn from(value: KeyType) -> Self { match value { KeyType::ServerInitialSecret => legacy::KeyType::ServerInitialSecret, KeyType::ClientInitialSecret => legacy::KeyType::ClientInitialSecret, KeyType::ServerHandshakeSecret => legacy::KeyType::ServerHandshakeSecret, KeyType::ClientHandshakeSecret => legacy::KeyType::ClientHandshakeSecret, KeyType::Server0RttSecret => legacy::KeyType::Server0RTTSecret, KeyType::Client0RttSecret => legacy::KeyType::Client0RTTSecret, KeyType::Server1RttSecret => legacy::KeyType::Server1RTTSecret, KeyType::Client1RttSecret => legacy::KeyType::Client1RTTSecret, } } } impl From for legacy::SecurityKeyUpdatedTrigger { #[inline] fn from(value: KeyUpdatedTrigger) -> Self { match value { KeyUpdatedTrigger::Tls => legacy::SecurityKeyUpdatedTrigger::Tls, KeyUpdatedTrigger::RemoteUpdate => legacy::SecurityKeyUpdatedTrigger::RemoteUpdate, KeyUpdatedTrigger::LocalUpdate => legacy::SecurityKeyUpdatedTrigger::LocalUpdate, } } } impl From for legacy::SecurityKeyUpdated { #[inline] fn from(value: KeyUpdated) -> Self { build!(legacy::SecurityKeyUpdated { key_type: value.key_type, ?old: value.old, // for legacy new is not optional ?new: value.new, // is this key_phase? ?generation: value.key_phase.map(|p| p as u32), ?trigger: value.trigger, }) } } impl From for legacy::SecurityKeyRetiredTrigger { #[inline] fn from(value: KeyDiscardedTrigger) -> Self { match value { KeyDiscardedTrigger::Tls => legacy::SecurityKeyRetiredTrigger::Tls, KeyDiscardedTrigger::RemoteUpdate => { legacy::SecurityKeyRetiredTrigger::RemoteUpdate } KeyDiscardedTrigger::LocalUpdate => legacy::SecurityKeyRetiredTrigger::LocalUpdate, } } } impl From for legacy::SecurityKeyRetired { #[inline] fn from(value: KeyDiscarded) -> Self { build!(legacy::SecurityKeyRetired { key_type: value.key_type, ?key: value.key, // is this key_phase? ?generation: value.key_phase .map(|p| p as u32), ?trigger: value.trigger, }) } } } ================================================ FILE: qevent/src/quic/transport.rs ================================================ use std::{collections::HashMap, time::Duration}; use derive_builder::Builder; use derive_more::From; use qbase::param::{ClientParameters, ParameterId, ServerParameters}; use serde::{Deserialize, Serialize}; use super::{ ConnectionID, ECN, IPAddress, Owner, PacketHeader, PacketNumberSpace, PathEndpointInfo, QuicFrame, QuicVersion, StatelessResetToken, StreamType, }; use crate::{HexString, PathID, RawInfo}; /// The version_information event supports QUIC version negotiation; see /// Section 6 of [QUIC-TRANSPORT]. It has Core importance level; see /// Section 9.2 of [QLOG-MAIN]. /// /// QUIC endpoints each have their own list of QUIC versions they /// support. The client uses the most likely version in their first /// initial. If the server does not support that version, it replies /// with a Version Negotiation packet, which contains its supported /// versions. From this, the client selects a version. The /// version_information event aggregates all this information in a single /// event type. It also allows logging of supported versions at an /// endpoint without actual version negotiation needing to happen. /// /// [QUIC-TRANSPORT]: https://www.rfc-editor.org/rfc/rfc9000 /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct VersionInformation { // Vec for `? filed: [ +ty]``, Option for `* filed: [* ty]` #[serde(skip_serializing_if = "Vec::is_empty")] server_versions: Vec, #[serde(skip_serializing_if = "Vec::is_empty")] client_versions: Vec, chosen_version: Option, } /// The alpn_information event supports Application-Layer Protocol /// Negotiation (ALPN) over the QUIC transport; see [RFC7301] and /// Section 7.4 of [QUIC-TRANSPORT]. It has Core importance level; see /// Section 9.2 of [QLOG-MAIN]. /// /// QUIC endpoints are configured with a list of supported ALPN /// identifiers. Clients send the list in a TLS ClientHello, and servers /// match against their list. On success, a single ALPN identifier is /// chosen and sent back in a TLS ServerHello. If no match is found, the /// connection is closed. /// /// ALPN identifiers are byte sequences, that may be possible to present /// as UTF-8. The ALPNIdentifier` type supports either format. /// Implementations SHOULD log at least one format, but MAY log both or /// none. /// /// [RFC7301]: https://www.rfc-editor.org/rfc/rfc7301 /// [QUIC-TRANSPORT]: https://www.rfc-editor.org/rfc/rfc9000 /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct ALPNInformation { server_alpns: Option>, client_alpns: Option>, chosen_alpn: Option, } /// ALPN identifiers are byte sequences, that may be possible to present /// as UTF-8. The ALPNIdentifier` type supports either format. /// Implementations SHOULD log at least one format, but MAY log both or /// none. #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct ALPNIdentifier { byte_value: Option, string_value: Option, } /// The parameters_set event groups settings from several different /// sources (transport parameters, TLS ciphers, etc.) into a single /// event. This is done to minimize the amount of events and to decouple /// conceptual setting impacts from their underlying mechanism for easier /// high-level reasoning. The event has Core importance level; see /// Section 9.2 of [QLOG-MAIN]. /// /// Most of these settings are typically set once and never change. /// However, they are usually set at different times during the /// connection, so there will regularly be several instances of this /// event with different fields set. /// /// Note that some settings have two variations (one set locally, one /// requested by the remote peer). This is reflected in the owner field. /// As such, this field MUST be correct for all settings included a /// single event instance. If you need to log settings from two sides, /// you MUST emit two separate event instances. /// /// Implementations are not required to recognize, process or support /// every setting/parameter received in all situations. For example, /// QUIC implementations MUST discard transport parameters that they do /// not understand Section 7.4.2 of [QUIC-TRANSPORT]. The /// unknown_parameters field can be used to log the raw values of any /// unknown parameters (e.g., GREASE, private extensions, peer-side /// experimentation). /// /// In the case of connection resumption and 0-RTT, some of the server's /// parameters are stored up-front at the client and used for the initial /// connection startup. They are later updated with the server's reply. /// In these cases, utilize the separate parameters_restored event to /// indicate the initial values, and this event to indicate the updated /// values, as normal. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 /// [QUIC-TRANSPORT]: https://www.rfc-editor.org/rfc/rfc9000 #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct ParametersSet { owner: Option, /// true if valid session ticket was received resumption_allowed: Option, /// true if early data extension was enabled on the TLS layer early_data_enabled: Option, /// e.g., "AES_128_GCM_SHA256" tls_cipher: Option, // RFC9000 original_destination_connection_id: Option, initial_source_connection_id: Option, retry_source_connection_id: Option, stateless_reset_token: Option, disable_active_migration: Option, max_idle_timeout: Option, max_udp_payload_size: Option, ack_delay_exponent: Option, max_ack_delay: Option, active_connection_id_limit: Option, initial_max_data: Option, initial_max_stream_data_bidi_local: Option, initial_max_stream_data_bidi_remote: Option, initial_max_stream_data_uni: Option, initial_max_streams_bidi: Option, initial_max_streams_uni: Option, preferred_address: Option, unknown_parameters: Option>, // RFC9221 max_datagram_frame_size: Option, // RFC9287 /// can only be restored at the client. /// servers MUST NOT restore this parameter! grease_quic_bit: Option, } macro_rules! extract_parameter { ( $( $id:ident as $as:ident $(.map($($tt:tt)*))? from $set:ident to $this:ident.$field:ident ),* $(,)? ) => { $( extract_parameter!(@one $id as $as $(.map($($tt)*))? from $set to $this.$field); )* }; (@one $id:ident as $as:ident .map($($tt:tt)*) from $set:ident to $this:ident.$field:ident) => { $this.$field = $this.$field.take().or_else(|| { Some($set.get::<$as>(ParameterId::$id).map($($tt)*)) }); }; (@one $id:ident as $as:ident from $set:ident to $this:ident.$field:ident) => { $this.$field = $this.$field.take().or_else(|| { Some($set.get::<$as>(ParameterId::$id).map(Into::into)) }); }; } impl ParametersSetBuilder { /// helper method to set all client parameters at once pub fn client_parameters(&mut self, params: &ClientParameters) -> &mut Self { use qbase::cid::ConnectionId; extract_parameter! { InitialSourceConnectionId as ConnectionId from params to self.initial_source_connection_id, DisableActiveMigration as bool from params to self.disable_active_migration, MaxIdleTimeout as Duration.map(|d| d.as_millis() as _) from params to self.max_idle_timeout, MaxUdpPayloadSize as u64.map(|u| u as u32) from params to self.max_udp_payload_size, AckDelayExponent as u64.map(|u| u as u16) from params to self.ack_delay_exponent, MaxAckDelay as Duration.map(|d| d.as_millis() as _) from params to self.max_ack_delay, ActiveConnectionIdLimit as u64.map(|u| u as u32) from params to self.active_connection_id_limit, InitialMaxData as u64 from params to self.initial_max_data, InitialMaxStreamDataBidiLocal as u64 from params to self.initial_max_stream_data_bidi_local, InitialMaxStreamDataBidiRemote as u64 from params to self.initial_max_stream_data_bidi_remote, InitialMaxStreamDataUni as u64 from params to self.initial_max_stream_data_uni, InitialMaxStreamsBidi as u64 from params to self.initial_max_streams_bidi, InitialMaxStreamsUni as u64 from params to self.initial_max_streams_uni, MaxDatagramFrameSize as u64 from params to self.max_datagram_frame_size, GreaseQuicBit as bool from params to self.grease_quic_bit, } self } /// helper method to set all server parameters at once pub fn server_parameters(&mut self, params: &ServerParameters) -> &mut Self { use qbase::{ cid::ConnectionId, param::preferred_address::PreferredAddress, token::ResetToken, }; extract_parameter! { OriginalDestinationConnectionId as ConnectionId from params to self.original_destination_connection_id, InitialSourceConnectionId as ConnectionId from params to self.initial_source_connection_id, RetrySourceConnectionId as ConnectionId from params to self.retry_source_connection_id, StatelessResetToken as ResetToken from params to self.stateless_reset_token, DisableActiveMigration as bool from params to self.disable_active_migration, MaxIdleTimeout as Duration.map(|d| d.as_millis() as _) from params to self.max_idle_timeout, MaxUdpPayloadSize as u64.map(|u| u as u32) from params to self.max_udp_payload_size, AckDelayExponent as u64.map(|u| u as u16) from params to self.ack_delay_exponent, MaxAckDelay as Duration.map(|d| d.as_millis() as _) from params to self.max_ack_delay, ActiveConnectionIdLimit as u64.map(|u| u as u32) from params to self.active_connection_id_limit, InitialMaxData as u64 from params to self.initial_max_data, InitialMaxStreamDataBidiLocal as u64 from params to self.initial_max_stream_data_bidi_local, InitialMaxStreamDataBidiRemote as u64 from params to self.initial_max_stream_data_bidi_remote, InitialMaxStreamDataUni as u64 from params to self.initial_max_stream_data_uni, InitialMaxStreamsBidi as u64 from params to self.initial_max_streams_bidi, InitialMaxStreamsUni as u64 from params to self.initial_max_streams_uni, PreferredAddress as PreferredAddress from params to self.preferred_address, MaxDatagramFrameSize as u64 from params to self.max_datagram_frame_size, GreaseQuicBit as bool from params to self.grease_quic_bit, } self } } #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into), build_fn(private, name = "fallible_build"))] pub struct PreferredAddress { ip_v4: IPAddress, ip_v6: IPAddress, port_v4: u16, port_v6: u16, connection_id: ConnectionID, stateless_reset_token: StatelessResetToken, } impl From for PreferredAddress { fn from(pa: qbase::param::preferred_address::PreferredAddress) -> Self { crate::build!(Self { ip_v4: pa.address_v4().ip().to_string(), ip_v6: pa.address_v6().ip().to_string(), port_v4: pa.address_v4().port(), port_v6: pa.address_v6().port(), connection_id: pa.connection_id(), stateless_reset_token: pa.stateless_reset_token(), }) } } #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct UnknownParameter { id: u64, #[builder(default)] value: Option, } /// When using QUIC 0-RTT, clients are expected to remember and restore /// the server's transport parameters from the previous connection. The /// parameters_restored event is used to indicate which parameters were /// restored and to which values when utilizing 0-RTT. It has Base /// importance level; see Section 9.2 of [QLOG-MAIN]. /// /// Note that not all transport parameters should be restored (many are /// even prohibited from being re-utilized). The ones listed here are /// the ones expected to be useful for correct 0-RTT usage. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct ParametersRestored { // RFC 9000 disable_active_migration: Option, max_idle_timeout: Option, max_udp_payload_size: Option, active_connection_id_limit: Option, initial_max_data: Option, initial_max_stream_data_bidi_local: Option, initial_max_stream_data_bidi_remote: Option, initial_max_stream_data_uni: Option, initial_max_streams_bidi: Option, initial_max_streams_uni: Option, // RFC9221 max_datagram_frame_size: Option, // RFC9287 /// can only be restored at the client. /// servers MUST NOT restore this parameter! grease_quic_bit: Option, } impl ParametersRestoredBuilder { /// helper method to set all client parameters at once pub fn client_parameters(&mut self, params: &ServerParameters) -> &mut Self { extract_parameter! { DisableActiveMigration as bool from params to self.disable_active_migration, MaxIdleTimeout as Duration.map(|d| d.as_millis() as _) from params to self.max_idle_timeout, MaxUdpPayloadSize as u64.map(|u| u as u32) from params to self.max_udp_payload_size, ActiveConnectionIdLimit as u64.map(|u| u as u32) from params to self.active_connection_id_limit, InitialMaxData as u64 from params to self.initial_max_data, InitialMaxStreamDataBidiLocal as u64 from params to self.initial_max_stream_data_bidi_local, InitialMaxStreamDataBidiRemote as u64 from params to self.initial_max_stream_data_bidi_remote, InitialMaxStreamDataUni as u64 from params to self.initial_max_stream_data_uni, InitialMaxStreamsBidi as u64 from params to self.initial_max_streams_bidi, InitialMaxStreamsUni as u64 from params to self.initial_max_streams_uni, MaxDatagramFrameSize as u64 from params to self.max_datagram_frame_size, GreaseQuicBit as bool from params to self.grease_quic_bit, } self } } /// The packet_sent event indicates a QUIC-level packet was sent. It has /// Core importance level; see Section 9.2 of [QLOG-MAIN]. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct PacketSent { header: PacketHeader, #[builder(default)] frames: Option>, /// only if header.packet_type === "stateless_reset" /// is always 128 bits in length. #[builder(default)] stateless_reset_token: Option, /// only if header.packet_type === "version_negotiation" #[builder(default)] #[serde(skip_serializing_if = "Vec::is_empty")] supported_versions: Vec, #[builder(default)] raw: Option, #[builder(default)] datagram_id: Option, #[builder(default)] #[serde(default)] is_mtu_probe_packet: bool, #[builder(default)] trigger: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum PacketSentTrigger { RetransmitReordered, RetransmitTimeout, PtoProbe, RetransmitCrypto, CcBandwidthProbe, } /// The packet_received event indicates a QUIC-level packet was received. /// It has Core importance level; see Section 9.2 of [QLOG-MAIN]. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct PacketReceived { header: PacketHeader, #[builder(default)] frames: Option>, /// only if header.packet_type === "stateless_reset" /// is always 128 bits in length. #[builder(default)] stateless_reset_token: Option, /// only if header.packet_type === "version_negotiation" #[builder(default)] #[serde(skip_serializing_if = "Vec::is_empty")] supported_versions: Vec, #[builder(default)] raw: Option, #[builder(default)] datagram_id: Option, #[builder(default)] trigger: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum PacketReceivedTrigger { /// if packet was buffered because it couldn't be /// decrypted before KeysAvailable, } /// The packet_dropped event indicates a QUIC-level packet was dropped. /// It has Base importance level; see Section 9.2 of [QLOG-MAIN]. /// /// The trigger field indicates a general reason category for dropping /// the packet, while the details field can contain additional /// implementation-specific information. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] #[serde(default)] pub struct PacketDropped { /// Primarily packet_type should be filled here, /// as other fields might not be decrypteable or parseable header: Option, raw: Option, datagram_id: Option, details: HashMap, trigger: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum PacketDroppedTrigger { /// not initialized, out of memory InternalError, /// limits reached, DDoS protection, unwilling to track more paths, duplicate packet Rejected, /// unknown or unsupported version. Unsupported, /// packet parsing or validation error Invalid, /// duplicate packet Duplicate, /// packet does not relate to a known connection or Connection ID ConnectionUnknown, /// decryption failed DecryptionFailure, /// decryption key was unavailable KeyUnavailable, /// situations not clearly covered in the other categories Genera, } impl From for PacketDroppedTrigger { fn from(value: qbase::packet::InvalidPacketNumber) -> Self { match value { qbase::packet::InvalidPacketNumber::TooOld | qbase::packet::InvalidPacketNumber::TooLarge => PacketDroppedTrigger::Genera, qbase::packet::InvalidPacketNumber::Duplicate => PacketDroppedTrigger::Duplicate, } } } impl From for PacketDroppedTrigger { fn from(error: qbase::packet::error::Error) -> Self { match error { qbase::packet::error::Error::UnsupportedVersion(_) => Self::Unsupported, qbase::packet::error::Error::InvalidFixedBit | qbase::packet::error::Error::InvalidReservedBits(_, _) | qbase::packet::error::Error::IncompleteType(_) | qbase::packet::error::Error::IncompleteHeader(_, _) | qbase::packet::error::Error::IncompletePacket(_, _) | qbase::packet::error::Error::UnderSampling(..) => Self::Invalid, qbase::packet::error::Error::RemoveProtectionFailure | qbase::packet::error::Error::DecryptPacketFailure => Self::DecryptionFailure, } } } /// The packet_buffered event is emitted when a packet is buffered /// because it cannot be processed yet. Typically, this is because the /// packet cannot be parsed yet, and thus only the full packet contents /// can be logged when it was parsed in a packet_received event. The /// event has Base importance level; see Section 9.2 of [QLOG-MAIN]. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct PacketBuffered { /// Primarily packet_type should be filled here, /// as other fields might not be decrypteable or parseable header: Option, raw: Option, datagram_id: Option, trigger: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum PacketBufferedTrigger { /// indicates the parser cannot keep up, temporarily buffers /// packet for later processing Backpressure, /// if packet cannot be decrypted because the proper keys were /// not yet available KeysUnavailable, } /// The packets_acked event is emitted when a (group of) sent packet(s) /// is acknowledged by the remote peer _for the first time_. It has Extra /// importance level; see Section 9.2 of [QLOG-MAIN]. /// /// This information could also be deduced from the contents of received /// ACK frames. However, ACK frames require additional processing logic /// to determine when a given packet is acknowledged for the first time, /// as QUIC uses ACK ranges which can include repeated ACKs. /// Additionally, this event can be used by implementations that do not /// log frame contents. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct PacketsAcked { packet_number_space: Option, #[serde(skip_serializing_if = "Vec::is_empty")] packet_nubers: Vec, } /// The datagrams_sent event indicates when one or more UDP-level /// datagrams are passed to the underlying network socket. This is /// useful for determining how QUIC packet buffers are drained to the OS. /// The event has Extra importance level; see Section 9.2 of [QLOG-MAIN]. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct UdpDatagramsSent { /// to support passing multiple at once count: Option, /// The RawInfo fields do not include the UDP headers, /// only the UDP payload #[serde(skip_serializing_if = "Vec::is_empty")] raw: Vec, /// ECN bits in the IP header /// if not set, defaults to the value used on the last /// QUICDatagramsSent event #[serde(skip_serializing_if = "Vec::is_empty")] ecn: Vec, #[serde(skip_serializing_if = "Vec::is_empty")] datagram_ids: Vec, } /// When one or more UDP-level datagrams are received from the socket. /// This is useful for determining how datagrams are passed to the user /// space stack from the OS. The event has Extra importance level; see /// Section 9.2 of [QLOG-MAIN]. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct UdpDatagramsReceived { /// to support passing multiple at once count: Option, /// The RawInfo fields do not include the UDP headers, /// only the UDP payload #[serde(skip_serializing_if = "Vec::is_empty")] raw: Vec, /// ECN bits in the IP header /// if not set, defaults to the value used on the last /// QUICDatagramsSent event #[serde(skip_serializing_if = "Vec::is_empty")] ecn: Vec, #[serde(skip_serializing_if = "Vec::is_empty")] datagram_ids: Vec, } /// When a UDP-level datagram is dropped. This is typically done if it /// does not contain a valid QUIC packet. If it does, but the QUIC /// packet is dropped for other reasons, the packet_dropped event /// (Section 5.7) should be used instead. The event has Extra importance /// level; see Section 9.2 of [QLOG-MAIN]. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct UdpDatagramDropped { /// The RawInfo fields do not include the UDP headers, /// only the UDP payload raw: Option, } // The stream_state_updated event is emitted whenever the internal state // of a QUIC stream is updated; see Section 3 of [QUIC-TRANSPORT]. Most // of this can be inferred from several types of frames going over the // wire, but it's much easier to have explicit signals for these state // changes. The event has Base importance level; see Section 9.2 of // [QLOG-MAIN]. /// /// [QUIC-TRANSPORT]: https://www.rfc-editor.org/rfc/rfc9000 /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct StreamStateUpdated { stream_id: u64, /// mainly useful when opening the stream #[builder(default)] stream_type: Option, #[builder(default)] old: Option, new: StreamState, #[builder(default)] stream_side: Option, } #[derive(Debug, Clone, Copy, From, Serialize, Deserialize, PartialEq, Eq)] #[serde(untagged)] pub enum StreamState { Base(BaseStreamStates), Granular(GranularStreamStates), } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum BaseStreamStates { Idle, Open, Closed, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum GranularStreamStates { // bidirectional stream states, RFC 9000 Section 3.4. HalfClosedLocal, HalfClosedRemote, // sending-side stream states, RFC 9000 Section 3.1. Ready, Send, DataSent, ResetSent, ResetReceived, // receive-side stream states, RFC 9000 Section 3.2. Receive, SizeKnown, DataRead, ResetRead, // both-side states DataReceived, // qlog-defined: memory actually freed Destroyed, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum StreamSide { Sending, Receiving, } /// The frame_processed event is intended to prevent a large /// proliferation of specific purpose events (e.g., packets_acknowledged, /// flow_control_updated, stream_data_received). It has Extra importance /// level; see Section 9.2 of [QLOG-MAIN]. /// /// Implementations have the opportunity to (selectively) log this type /// of signal without having to log packet-level details (e.g., in /// packet_received). Since for almost all cases, the effects of /// applying a frame to the internal state of an implementation can be /// inferred from that frame's contents, these events are aggregated into /// this single frames_processed event. /// /// The frame_processed event can be used to signal internal state change /// not resulting directly from the actual "parsing" of a frame (e.g., /// the frame could have been parsed, data put into a buffer, then later /// processed, then logged with this event). /// /// The packet_received event can convey all constituent frames. It is /// not expected that the frames_processed event will also be used for a /// redundant purpose. Rather, implementations can use this event to /// avoid having to log full packets or to convey extra information about /// when frames are processed (for example, if frame processing is /// deferred for any reason). /// /// Note that for some events, this approach will lose some information /// (e.g., for which encryption level are packets being acknowledged?). /// If this information is important, the packet_received event can be /// used instead. /// /// In some implementations, it can be difficult to log frames directly, /// even when using packet_sent and packet_received events. For these /// cases, the frames_processed event also contains the packet_numbers /// field, which can be used to more explicitly link this event to the /// packet_sent/received events. The field is an array, which supports /// using a single frames_processed event for multiple frames received /// over multiple packets. To map between frames and packets, the /// position and order of entries in the frames and packet_numbers is /// used. If the optional packet_numbers field is used, each frame MUST /// have a corresponding packet number at the same index. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct FramesProcessed { frames: Vec, #[builder(default)] packet_numbers: Option>, } /// The stream_data_moved event is used to indicate when QUIC stream data /// moves between the different layers. This helps make clear the flow /// of data, how long data remains in various buffers, and the overheads /// introduced by individual layers. The event has Base importance /// level; see Section 9.2 of [QLOG-MAIN]. /// /// This event relates to stream data only. There are no packet or frame /// headers and length values in the length or raw fields MUST reflect /// that. /// /// For example, it can be useful to understand when data moves from an /// application protocol (e.g., HTTP) to QUIC stream buffers and vice /// versa. /// /// The stream_data_moved event can provide insight into whether received /// data on a QUIC stream is moved to the application protocol /// immediately (for example per received packet) or in larger batches /// (for example, all QUIC packets are processed first and afterwards the /// application layer reads from the streams with newly available data). /// This can help identify bottlenecks, flow control issues, or /// scheduling problems. /// /// The additional_info field supports optional logging of information /// related to the stream state. For example, an application layer that /// moves data into transport and simultaneously ends the stream, can log /// fin_set. As another example, a transport layer that has received an /// instruction to reset a stream can indicate this to the application /// layer using reset_stream. In both cases, the length-carrying fields /// (length or raw) can be omitted or contain zero values. /// /// This event is only for data in QUIC streams. For data in QUIC /// Datagram Frames, see the datagram_data_moved event defined in /// Section 5.16. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct StreamDataMoved { stream_id: Option, offset: Option, /// byte length of the moved data length: Option, from: Option, to: Option, additional_info: Option, raw: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum StreamDataLocation { Application, Transport, Network, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum DataMovedAdditionalInfo { FinSet, StreamReset, } /// The datagram_data_moved event is used to indicate when QUIC Datagram /// Frame data (see [RFC9221]) moves between the different layers. This /// helps make clear the flow of data, how long data remains in various /// buffers, and the overheads introduced by individual layers. The /// event has Base importance level; see Section 9.2 of [QLOG-MAIN]. /// /// This event relates to datagram data only. There are no packet or /// frame headers and length values in the length or raw fields MUST /// reflect that. /// /// For example, passing from the application protocol (e.g., /// WebTransport) to QUIC Datagram Frame buffers and vice versa. /// /// The datagram_data_moved event can provide insight into whether /// received data in a QUIC Datagram Frame is moved to the application /// protocol immediately (for example per received packet) or in larger /// batches (for example, all QUIC packets are processed first and /// afterwards the application layer reads all Datagrams at once). This /// can help identify bottlenecks, flow control issues, or scheduling /// problems. /// /// This event is only for data in QUIC Datagram Frames. For data in /// QUIC streams, see the stream_data_moved event defined in /// Section 5.15. /// /// [RFC9221]: https://www.rfc-editor.org/rfc/rfc9221.html /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct DatagramDataMoved { /// byte length of the moved data length: Option, from: Option, to: Option, raw: Option, } /// Use to provide additional information when attempting (client-side) /// connection migration. While most details of the QUIC connection /// migration process can be inferred by observing the PATH_CHALLENGE and /// PATH_RESPONSE frames, in combination with the QUICPathAssigned event, /// it can be useful to explicitly log the progression of the migration /// and potentially made decisions in a single location/event. The event /// has Extra importance level; see Section 9.2 of [QLOG-MAIN]. /// /// Generally speaking, connection migration goes through two phases: a /// probing phase (which is not always needed/present), and a migration /// phase (which can be abandoned upon error). /// /// Implementations that log per-path information in a /// QUICMigrationStateUpdated, SHOULD also emit QUICPathAssigned events, /// to serve as a ground-truth source of information. /// /// [QLOG-MAIN]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-09 #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct MigrationStateUpdated { #[builder(default)] old: Option, new: MigrationState, #[builder(default)] path_id: Option, /// the information for traffic going towards the remote receiver #[builder(default)] path_remote: Option, /// the information for traffic coming in at the local endpoint #[builder(default)] path_local: Option, } /// Note that MigrationState does not describe a full state machine /// These entries are not necessarily chronological, /// nor will they always all appear during /// a connection migration attempt. #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum MigrationState { /// probing packets are sent, migration not initiated yet ProbingStarted, /// did not get reply to probing packets, /// discarding path as an option ProbingAbandoned, /// received reply to probing packets, path is migration candidate ProbingSuccessful, /// non-probing packets are sent, attempting migration MigrationStarted, /// something went wrong during the migration, abandoning attempt MigrationAbandoned, /// new path is now fully used, old path is discarded MigrationComplete, } crate::gen_builder_method! { VersionInformationBuilder => VersionInformation; ALPNInformationBuilder => ALPNInformation; ALPNIdentifierBuilder => ALPNIdentifier; ParametersSetBuilder => ParametersSet; PreferredAddressBuilder => PreferredAddress; UnknownParameterBuilder => UnknownParameter; ParametersRestoredBuilder => ParametersRestored; PacketSentBuilder => PacketSent; PacketReceivedBuilder => PacketReceived; PacketDroppedBuilder => PacketDropped; PacketBufferedBuilder => PacketBuffered; PacketsAckedBuilder => PacketsAcked; UdpDatagramsSentBuilder => UdpDatagramsSent; UdpDatagramsReceivedBuilder => UdpDatagramsReceived; UdpDatagramDroppedBuilder => UdpDatagramDropped; StreamStateUpdatedBuilder => StreamStateUpdated; FramesProcessedBuilder => FramesProcessed; StreamDataMovedBuilder => StreamDataMoved; DatagramDataMovedBuilder => DatagramDataMoved; MigrationStateUpdatedBuilder => MigrationStateUpdated; } mod rollback { use bytes::Bytes; use super::*; use crate::{build, legacy::quic as legacy}; impl From for legacy::QuicVersion { #[inline] fn from(value: QuicVersion) -> Self { HexString::from(Bytes::from(value.0.to_be_bytes().to_vec())).into() } } impl From for legacy::TransportVersionInformation { fn from(vi: VersionInformation) -> Self { build!(legacy::TransportVersionInformation { server_versions: vi .server_versions .into_iter() .map(Into::into) .collect::>(), client_versions: vi .client_versions .into_iter() .map(Into::into) .collect::>(), ?chosen_version: vi.chosen_version, }) } } impl From for String { fn from(value: ALPNIdentifier) -> Self { value.string_value.as_ref().map_or( value .byte_value .as_ref() .map(|b| b.to_string()) .unwrap_or_default(), |s| s.to_string(), ) } } impl From for legacy::TransportALPNInformation { fn from(ai: ALPNInformation) -> Self { build!(legacy::TransportALPNInformation { ?client_alpns: ai.client_alpns.map( |v| { v.into_iter() .map(Into::into) .collect::>() }), ?server_alpns: ai.server_alpns.map( |v| { v.into_iter() .map(Into::into) .collect::>() }), // ?chosen_alpn: ai.chosen_alpn.map(String::from), }) } } impl From for legacy::PreferredAddress { fn from(pa: PreferredAddress) -> Self { build!(legacy::PreferredAddress { ip_v4: pa.ip_v4, ip_v6: pa.ip_v6, port_v4: pa.port_v4, port_v6: pa.port_v6, connection_id: pa.connection_id, stateless_reset_token: pa.stateless_reset_token, }) } } impl From for legacy::TransportParametersSet { fn from(ps: ParametersSet) -> Self { build!(legacy::TransportParametersSet { ?owner: ps.owner, ?resumption_allowed: ps.resumption_allowed, ?early_data_enabled: ps.early_data_enabled, ?tls_cipher: ps.tls_cipher, ?original_destination_connection_id: ps.original_destination_connection_id, ?initial_source_connection_id: ps.initial_source_connection_id, ?retry_source_connection_id: ps.retry_source_connection_id, ?stateless_reset_token: ps.stateless_reset_token, ?disable_active_migration: ps.disable_active_migration, ?max_idle_timeout: ps.max_idle_timeout, ?max_udp_payload_size: ps.max_udp_payload_size, ?ack_delay_exponent: ps.ack_delay_exponent, ?max_ack_delay: ps.max_ack_delay, ?active_connection_id_limit: ps.active_connection_id_limit, ?initial_max_data: ps.initial_max_data, ?initial_max_stream_data_bidi_local: ps.initial_max_stream_data_bidi_local, ?initial_max_stream_data_bidi_remote: ps.initial_max_stream_data_bidi_remote, ?initial_max_stream_data_uni: ps.initial_max_stream_data_uni, ?initial_max_streams_bidi: ps.initial_max_streams_bidi, ?initial_max_streams_uni: ps.initial_max_streams_uni, ?preferred_address: ps.preferred_address, // legacy doesnt support these // ?unknown_parameters: , // ?max_datagram_frame_size: ps.max_datagram_frame_size, // ?grease_quic_bit: ps.grease_quic_bit, }) } } impl From for legacy::TransportParametersRestored { fn from(value: ParametersRestored) -> Self { build!(legacy::TransportParametersRestored { ?disable_active_migration: value.disable_active_migration, ?max_idle_timeout: value.max_idle_timeout, ?max_udp_payload_size: value.max_udp_payload_size, ?active_connection_id_limit: value.active_connection_id_limit, ?initial_max_data: value.initial_max_data, ?initial_max_stream_data_bidi_local: value.initial_max_stream_data_bidi_local, ?initial_max_stream_data_bidi_remote: value.initial_max_stream_data_bidi_remote, ?initial_max_stream_data_uni: value.initial_max_stream_data_uni, ?initial_max_streams_bidi: value.initial_max_streams_bidi, ?initial_max_streams_uni: value.initial_max_streams_uni, // legacy doesnt support these // ?max_datagram_frame_size: value.max_datagram_frame_size, // ?grease_quic_bit: value.grease_quic_bit, }) } } impl From for legacy::TransportPacketSentTrigger { fn from(value: PacketSentTrigger) -> Self { match value { PacketSentTrigger::RetransmitReordered => Self::RetransmitReordered, PacketSentTrigger::RetransmitTimeout => Self::RetransmitTimeout, PacketSentTrigger::PtoProbe => Self::PtoProbe, PacketSentTrigger::RetransmitCrypto => Self::RetransmitCrypto, PacketSentTrigger::CcBandwidthProbe => Self::CcBandwidthProbe, } } } impl From for legacy::TransportPacketSent { fn from(value: PacketSent) -> Self { build!(legacy::TransportPacketSent { header: value.header, ?frames: value.frames.map(|v| { v.into_iter() .map(Into::into) .collect::>() }), ?stateless_reset_token: value.stateless_reset_token.map(|tk| Bytes::from(tk.0.to_vec())), supported_versions: value.supported_versions.into_iter() .map(Into::into) .collect::>(), ?raw: value.raw, ?datagram_id: value.datagram_id, ?trigger: value.trigger, }) } } impl From for legacy::TransportPacketReceivedTrigger { #[inline] fn from(value: PacketReceivedTrigger) -> Self { match value { PacketReceivedTrigger::KeysAvailable => Self::KeysAvailable, } } } impl From for legacy::TransportPacketReceived { fn from(value: PacketReceived) -> Self { build!(legacy::TransportPacketReceived { header: value.header, ?frames: value.frames.map(|v| { v.into_iter() .map(Into::into) .collect::>() }), ?stateless_reset_token: value.stateless_reset_token.map(|tk| Bytes::from(tk.0.to_vec())), supported_versions: value.supported_versions.into_iter() .map(Into::into) .collect::>(), ?raw: value.raw, ?datagram_id: value.datagram_id, ?trigger: value.trigger, }) } } impl TryFrom for legacy::TransportpacketDroppedTrigger { type Error = (); #[inline] fn try_from(value: PacketDroppedTrigger) -> Result { match value { // 新设计不如旧的 PacketDroppedTrigger::InternalError | PacketDroppedTrigger::Invalid | PacketDroppedTrigger::Genera // 似乎并没有完全对应, 移除头部保护失败也是这个错误 // PacketDroppedTrigger::DecryptionFailure => Ok(Self::PayloadDecryptError), | PacketDroppedTrigger::DecryptionFailure | PacketDroppedTrigger::Rejected => Err(()), PacketDroppedTrigger::Unsupported => Ok(Self::UnsupportedVersion), PacketDroppedTrigger::Duplicate => Ok(Self::Duplicate), PacketDroppedTrigger::ConnectionUnknown => Ok(Self::UnknownConnectionId), PacketDroppedTrigger::KeyUnavailable => Ok(Self::KeyUnavailable), } } } impl From for legacy::TransportPacketDropped { fn from(value: PacketDropped) -> Self { build!(legacy::TransportPacketDropped { ?header: value.header, ?raw: value.raw, ?datagram_id: value.datagram_id, ?trigger: value.trigger.and_then(|trigger| legacy::TransportpacketDroppedTrigger::try_from(trigger).ok()), }) } } impl From for legacy::TransportPacketBufferedTrigger { #[inline] fn from(value: PacketBufferedTrigger) -> Self { match value { PacketBufferedTrigger::Backpressure => Self::Backpressure, PacketBufferedTrigger::KeysUnavailable => Self::KeysUnavailable, } } } impl From for legacy::TransportPacketBuffered { fn from(value: PacketBuffered) -> Self { build!(legacy::TransportPacketBuffered { ?header: value.header, ?raw: value.raw, ?datagram_id: value.datagram_id, ?trigger: value.trigger, }) } } impl From for legacy::TransportPacketsAcked { fn from(value: PacketsAcked) -> Self { build!(legacy::TransportPacketsAcked { ?packet_number_space: value.packet_number_space, packet_numbers: value.packet_nubers, }) } } impl From for legacy::TransportDatagramsSent { fn from(value: UdpDatagramsSent) -> Self { build!(legacy::TransportDatagramsSent { ?count: value.count, raw: value.raw.into_iter().collect::>(), datagram_ids: value.datagram_ids, }) } } impl From for legacy::TransportDatagramsReceived { fn from(value: UdpDatagramsReceived) -> Self { build!(legacy::TransportDatagramsReceived { ?count: value.count, raw: value.raw.into_iter().collect::>(), datagram_ids: value.datagram_ids, }) } } impl From for legacy::TransportDatagramDropped { fn from(value: UdpDatagramDropped) -> Self { build!(legacy::TransportDatagramDropped { ?raw: value.raw, }) } } impl From for legacy::StreamState { #[inline] fn from(value: StreamState) -> Self { match value { StreamState::Base(BaseStreamStates::Idle) => Self::Idle, StreamState::Base(BaseStreamStates::Open) => Self::Open, StreamState::Base(BaseStreamStates::Closed) => Self::Closed, StreamState::Granular(GranularStreamStates::HalfClosedLocal) => { Self::HalfClosedLocal } StreamState::Granular(GranularStreamStates::HalfClosedRemote) => { Self::HalfClosedRemote } StreamState::Granular(GranularStreamStates::Ready) => Self::Ready, StreamState::Granular(GranularStreamStates::Send) => Self::Send, StreamState::Granular(GranularStreamStates::DataSent) => Self::DataSent, StreamState::Granular(GranularStreamStates::ResetSent) => Self::ResetSent, StreamState::Granular(GranularStreamStates::ResetReceived) => Self::ResetReceived, StreamState::Granular(GranularStreamStates::Receive) => Self::Receive, StreamState::Granular(GranularStreamStates::SizeKnown) => Self::SizeKnown, StreamState::Granular(GranularStreamStates::DataRead) => Self::DataRead, StreamState::Granular(GranularStreamStates::ResetRead) => Self::ResetRead, StreamState::Granular(GranularStreamStates::DataReceived) => Self::DataReceived, StreamState::Granular(GranularStreamStates::Destroyed) => Self::Destroyed, } } } impl From for legacy::StreamSide { #[inline] fn from(value: StreamSide) -> Self { match value { StreamSide::Sending => Self::Sending, StreamSide::Receiving => Self::Receiving, } } } impl From for legacy::TransportStreamStateUpdated { fn from(value: StreamStateUpdated) -> Self { build!(legacy::TransportStreamStateUpdated { stream_id: value.stream_id, ?stream_type: value.stream_type, ?old: value.old, new: value.new, ?stream_side: value.stream_side, }) } } impl From for legacy::TransportFramesProcessed { fn from(value: FramesProcessed) -> Self { assert!( value.packet_numbers.as_ref().is_none() || value.packet_numbers.as_ref().is_some_and(|v| v.len() != 1), "it not possible to do this convert" ); build!(legacy::TransportFramesProcessed { frames: value.frames.into_iter().map(Into::into).collect::>(), ?packet_number: value.packet_numbers.map(|v| v[0]), }) } } impl From for legacy::StreamDataLocation { #[inline] fn from(value: StreamDataLocation) -> Self { match value { StreamDataLocation::Application => Self::Application, StreamDataLocation::Transport => Self::Transport, StreamDataLocation::Network => Self::Network, } } } impl From for legacy::TransportDataMoved { fn from(value: StreamDataMoved) -> Self { build!(legacy::TransportDataMoved { ?stream_id: value.stream_id, ?offset: value.offset, ?length: value.length, ?from: value.from, ?to: value.to, ?data: value.raw.and_then(|raw| raw.data), }) } } } ================================================ FILE: qevent/src/quic.rs ================================================ use std::{ collections::HashMap, fmt::Display, marker::PhantomData, net::SocketAddr, time::Duration, }; use bytes::Bytes; use derive_builder::Builder; use derive_more::{From, Into, LowerHex}; use qbase::{ frame::{ AckFrame, ConnectionCloseFrame, CryptoFrame, DatagramFrame, EncodeSize, Frame, GetFrameType, MaxStreamsFrame, NewTokenFrame, PathChallengeFrame, PathResponseFrame, PingFrame, ReliableFrame, StreamCtlFrame, StreamFrame, StreamsBlockedFrame, }, packet::header::{ GetDcid, GetScid, long::{HandshakeHeader, InitialHeader, ZeroRttHeader}, short::OneRttHeader, }, util::ContinuousData, varint::VarInt, }; use serde::{Deserialize, Serialize}; pub mod connectivity; pub mod recovery; pub mod security; pub mod transport; use crate::{BeSpecificEventData, HexString, RawInfo}; // 8.1 #[derive(Debug, Clone, From, Into, PartialEq, Eq)] pub struct QuicVersion(u32); impl Serialize for QuicVersion { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { #[serde_with::serde_as] #[derive(Serialize)] struct Helper(#[serde_as(as = "serde_with::hex::Hex")] [u8; 4]); Helper(self.0.to_be_bytes()).serialize(serializer) } } impl<'de> Deserialize<'de> for QuicVersion { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { #[serde_with::serde_as] #[derive(Deserialize)] struct Helper(#[serde_as(as = "serde_with::hex::Hex")] [u8; 4]); Helper::deserialize(deserializer).map(|b| Self(u32::from_be_bytes(b.0))) } } // 8.2 // TOOD: 这些结构的序列化/反序列化之后都可以写到qbase中,也不重复写两份结构 #[derive(Default, Debug, LowerHex, From, Into, Clone, Copy, PartialEq, Eq)] pub struct ConnectionID(qbase::cid::ConnectionId); impl Serialize for ConnectionID { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { #[serde_with::serde_as] #[derive(Serialize)] struct Helper<'b>(#[serde_as(as = "serde_with::hex::Hex")] &'b [u8]); Helper(self.0.as_ref()).serialize(serializer) } } impl<'de> Deserialize<'de> for ConnectionID { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { #[serde_with::serde_as] #[derive(Deserialize)] struct Helper(#[serde_as(as = "serde_with::hex::Hex")] Vec); let bytes = Helper::deserialize(deserializer)?.0; if bytes.len() > qbase::cid::MAX_CID_SIZE { return Err(serde::de::Error::custom("ConnectionID too long")); } Ok(Self(qbase::cid::ConnectionId::from_slice(&bytes))) } } // 8.3 #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum Owner { Local, Remote, } // 8.4 /// an IPAddress can either be a "human readable" form /// (e.g., "127.0.0.1" for v4 or /// "2001:0db8:85a3:0000:0000:8a2e:0370:7334" for v6) or /// use a raw byte-form (as the string forms can be ambiguous). /// Additionally, a hash-based or redacted representation /// can be used if needed for privacy or security reasons. #[derive(Debug, Clone, From, Into, Serialize, Deserialize, PartialEq, Eq)] #[serde(transparent)] pub struct IPAddress(String); #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde_with::skip_serializing_none] #[serde(rename_all = "snake_case")] pub enum IpVersion { V4, V6, } // 8.5 /// PathEndpointInfo indicates a single half/direction of a path. A full /// path is comprised of two halves. Firstly: the server sends to the /// remote client IP + port using a specific destination Connection ID. /// Secondly: the client sends to the remote server IP + port using a /// different destination Connection ID. /// /// As such, structures logging path information SHOULD include two /// different PathEndpointInfo instances, one for each half of the path. #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] pub struct PathEndpointInfo { ip_v4: Option, ip_v6: Option, port_v4: Option, port_v6: Option, /// Even though usually only a single ConnectionID /// is associated with a given path at a time, /// there are situations where there can be an overlap /// or a need to keep track of previous ConnectionIDs conenction_ids: Vec, } impl From for PathEndpointInfo { fn from(value: SocketAddr) -> Self { match value { SocketAddr::V4(addr) => crate::build!(PathEndpointInfo { ip_v4: addr.ip().to_string(), port_v4: addr.port(), }), SocketAddr::V6(addr) => crate::build!(PathEndpointInfo { ip_v6: addr.ip().to_string(), port_v6: addr.port(), }), } } } // 8.6 #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum PacketType { Initial, Handshake, #[serde(rename = "0RTT")] ZeroRTT, #[serde(rename = "1RTT")] OneRTT, Retry, VersionNegotiation, StatelessReset, Unknown, } impl From for PacketType { fn from(r#type: qbase::packet::Type) -> Self { match r#type { qbase::packet::r#type::Type::Long(long) => match long { qbase::packet::r#type::long::Type::VersionNegotiation => { PacketType::VersionNegotiation } qbase::packet::r#type::long::Type::V1( qbase::packet::r#type::long::Version::INITIAL, ) => PacketType::Initial, qbase::packet::r#type::long::Type::V1( qbase::packet::r#type::long::Version::HANDSHAKE, ) => PacketType::Handshake, qbase::packet::r#type::long::Type::V1( qbase::packet::r#type::long::Version::ZERO_RTT, ) => PacketType::ZeroRTT, qbase::packet::r#type::long::Type::V1( qbase::packet::r#type::long::Version::RETRY, ) => PacketType::Retry, }, qbase::packet::r#type::Type::Short(_one_rtt) => PacketType::OneRTT, } } } // 8.7 #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum PacketNumberSpace { Initial, Handshake, ApplicationData, } impl From for PacketNumberSpace { fn from(value: qbase::Epoch) -> Self { match value { qbase::Epoch::Initial => Self::Initial, qbase::Epoch::Handshake => Self::Handshake, qbase::Epoch::Data => Self::ApplicationData, } } } // 8.8 #[serde_with::skip_serializing_none] #[derive(Builder, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder(setter(into, strip_option), build_fn(private, name = "fallible_build"))] pub struct PacketHeader { #[builder(default)] #[serde(default)] quic_bit: bool, packet_type: PacketType, /// only if packet_type === "initial" || "handshake" || "0RTT" || "1RTT" #[builder(default)] packet_number: Option, /// the bit flags of the packet headers (spin bit, key update bit, /// etc. up to and including the packet number length bits /// if present #[builder(default)] flags: Option, /// only if packet_type === "initial" || "retry" #[builder(default)] token: Option, /// only if packet_type === "initial" || "handshake" || "0RTT" /// Signifies length of the packet_number plus the payload #[builder(default)] length: Option, /// only if present in the header /// if correctly using transport:connection_id_updated events, /// dcid can be skipped for 1RTT packets #[builder(default)] version: Option, #[builder(default)] scil: Option, #[builder(default)] dcil: Option, #[builder(default)] scid: Option, #[builder(default)] dcid: Option, } impl PacketHeaderBuilder { /// Helper method used to set the fields of the initial header, /// /// Since the header defined by qbase is not complete enough, there are still many fields that need to be set manually. pub fn initial(&mut self, header: &InitialHeader) -> &mut Self { crate::build!(@field self, packet_type: PacketType::Initial, ?token: Token::try_from(header).ok(), scil: header.scid().len() as u8, scid: { *header.scid() }, dcil: header.dcid().len() as u8, dcid: { *header.dcid() } ); self } /// Helper method used to set the fields of the handshake header, /// /// Since the header defined by qbase is not complete enough, there are still many fields that need to be set manually. pub fn handshake(&mut self, header: &HandshakeHeader) -> &mut Self { self.packet_type(PacketType::Handshake) .scil(header.scid().len() as u8) .scid(*header.scid()) .dcil(header.dcid().len() as u8) .dcid(*header.dcid()) } /// Helper method used to set the fields of the 0rtt header, /// /// Since the header defined by qbase is not complete enough, there are still many fields that need to be set manually. pub fn zero_rtt(&mut self, header: &ZeroRttHeader) -> &mut Self { self.packet_type(PacketType::ZeroRTT) .scil(header.scid().len() as u8) .scid(*header.scid()) .dcil(header.dcid().len() as u8) .dcid(*header.dcid()) } /// Helper method used to set the fields of the 1rtt header, /// /// Since the header defined by qbase is not complete enough, there are still many fields that need to be set manually. pub fn one_rtt(&mut self, header: &OneRttHeader) -> &mut Self { self.packet_type(PacketType::OneRTT) .dcil(header.dcid().len() as u8) .dcid(*header.dcid()) } } impl From<&InitialHeader> for PacketHeaderBuilder { fn from(header: &InitialHeader) -> Self { let mut builder = PacketHeader::builder(); builder.initial(header); builder } } impl From<&HandshakeHeader> for PacketHeaderBuilder { fn from(header: &HandshakeHeader) -> Self { let mut builder = PacketHeader::builder(); builder.handshake(header); builder } } impl From<&ZeroRttHeader> for PacketHeaderBuilder { fn from(header: &ZeroRttHeader) -> Self { let mut builder = PacketHeader::builder(); builder.zero_rtt(header); builder } } impl From<&OneRttHeader> for PacketHeaderBuilder { fn from(header: &OneRttHeader) -> Self { let mut builder = PacketHeader::builder(); builder.one_rtt(header); builder } } // 8.9 #[serde_with::skip_serializing_none] #[derive(Builder, Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[builder( default, setter(into, strip_option), build_fn(private, name = "fallible_build") )] #[serde(default)] pub struct Token { pub r#type: Option, /// decoded fields included in the token /// (typically: peer's IP address, creation time) #[serde(skip_serializing_if = "HashMap::is_empty")] details: HashMap, raw: Option, } impl TryFrom<&qbase::packet::header::LongHeader> for Token { type Error = (); fn try_from(header: &qbase::packet::header::LongHeader) -> Result { use qbase::packet::header::RetryHeader; let header: &dyn core::any::Any = header; if let Some(initial) = header.downcast_ref::() { if initial.token().is_empty() { return Err(()); } return Ok(crate::build!(Token { // r#type: TokenType::? raw: initial.token(), })); } if let Some(retry) = header.downcast_ref::() { return Ok(crate::build!(Token { r#type: TokenType::Retry, raw: retry.token(), })); } Err(()) } } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum TokenType { Retry, Resumption, } // 8.10 #[serde_with::serde_as] #[derive(Debug, Clone, Copy, From, Into, Serialize, Deserialize, PartialEq, Eq)] pub struct StatelessResetToken(#[serde_as(as = "serde_with::hex::Hex")] [u8; 16]); impl From for StatelessResetToken { fn from(value: qbase::token::ResetToken) -> Self { Self(*value) } } // 8.11 #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum KeyType { ServerInitialSecret, ClientInitialSecret, ServerHandshakeSecret, ClientHandshakeSecret, #[serde(rename = "server_0rtt_secret")] Server0RttSecret, #[serde(rename = "client_0rtt_secret")] Client0RttSecret, #[serde(rename = "server_1rtt_secret")] Server1RttSecret, #[serde(rename = "client_1rtt_secret")] Client1RttSecret, } // 8.12 #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] pub enum ECN { #[serde(rename = "Not-ECT")] NotEct, #[serde(rename = "ECT(1)")] Ect1, #[serde(rename = "ECT(0)")] Ect0, CE, } // 8.13 #[serde_with::skip_serializing_none] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(tag = "frame_type")] #[serde(rename_all = "snake_case")] pub enum QuicFrame { Padding { /// total frame length, including frame header length: Option, payload_length: u32, }, Ping { /// total frame length, including frame header length: Option, payload_length: Option, }, Ack { /// in ms ack_delay: Option, /// e.g., looks like \[\[1,2],\[4,5], \[7], \[10,22]] serialized /// /// ### AckRange: /// either a single number (e.g., \[1]) or two numbers (e.g., \[1,2]). /// /// For two numbers: /// /// the first number is "from": lowest packet number in interval /// /// the second number is "to": up to and including the highest /// packet number in the interval acked_ranges: Vec<[u64; 2]>, /// ECN (explicit congestion notification) related fields /// (not always present) ect1: Option, ect0: Option, ce: Option, /// total frame length, including frame header length: Option, payload_length: Option, }, ResetStream { stream_id: u64, error_code: ApplicationCode, /// in bytes final_size: u64, /// total frame length, including frame header length: Option, payload_length: Option, }, StopSending { stream_id: u64, error_code: ApplicationCode, /// total frame length, including frame header length: Option, payload_length: Option, }, Crypto { offset: u64, length: u64, payload_length: Option, raw: Option, }, NewToken { token: Token, }, Stream { stream_id: u64, /// These two MUST always be set /// If not present in the Frame type, log their default values offset: u64, length: u64, /// this MAY be set any time, /// but MUST only be set if the value is true /// if absent, the value MUST be assumed to be false #[serde(default)] fin: bool, raw: Option, }, MaxData { maximum: u64, }, MaxStreamData { stream_id: u64, maximum: u64, }, MaxStreams { stream_type: StreamType, maximum: u64, }, DataBlocked { limit: u64, }, StreamDataBlocked { stream_id: u64, limit: u64, }, StreamsBlocked { stream_type: StreamType, limit: u64, }, NewConnectionId { sequence_number: u32, retire_prior_to: u32, /// mainly used if e.g., for privacy reasons the full /// connection_id cannot be logged connection_id_length: Option, connection_id: ConnectionID, stateless_reset_token: Option, }, RetireConnectionId { sequence_number: u32, }, PathChallenge { /// always 64-bit data: Option, }, PathResponse { /// always 64-bit data: Option, }, /// An endpoint that receives unknown error codes can record it in the /// error_code field using the numerical value without variable-length /// integer encoding. /// /// When the connection is closed due a connection-level error, the /// trigger_frame_type field can be used to log the frame that triggered /// the error. For known frame types, the appropriate string value is /// used. For unknown frame types, the numerical value without variable- /// length integer encoding is used. /// /// The CONNECTION_CLOSE reason phrase is a byte sequences. It is likely /// that this sequence is presentable as UTF-8, in which case it can be /// logged in the reason field. The reason_bytes field supports logging /// the raw bytes, which can be useful when the value is not UTF-8 or /// when an endpoint does not want to decode it. Implementations SHOULD /// log at least one format, but MAY log both or none. ConnectionClose { error_space: Option, error_code: Option, reason: Option, reason_bytes: Option, /// when error_space === "transport" trigger_frame_type: Option, }, HandshakeDone {}, /// The frame_type_bytes field is the numerical value without variable- /// length integer encoding. Unknow { frame_type_bytes: u64, raw: Option, }, Datagram { length: Option, raw: Option, }, } impl From<&PingFrame> for QuicFrame { fn from(frame: &PingFrame) -> Self { QuicFrame::Ping { length: Some(frame.encoding_size() as u32), payload_length: Some(0), } } } impl From<(&CryptoFrame, &D)> for QuicFrame { fn from((frame, data): (&CryptoFrame, &D)) -> Self { let payload_length = frame.len(); let length = frame.encoding_size() as u64 + payload_length; QuicFrame::Crypto { offset: frame.offset(), length, payload_length: Some(payload_length as _), raw: Some(crate::build!(RawInfo { length, payload_length, data, })), } } } impl From<&CryptoFrame> for QuicFrame { fn from(frame: &CryptoFrame) -> Self { let payload_length = frame.len(); let length = frame.encoding_size() as u64 + payload_length; QuicFrame::Crypto { offset: frame.offset(), length, payload_length: Some(payload_length as _), raw: Some(crate::build!(RawInfo { length, payload_length, })), } } } impl From<(&StreamFrame, &D)> for QuicFrame { fn from((frame, data): (&StreamFrame, &D)) -> Self { let payload_length = frame.len(); let length = frame.encoding_size() + payload_length; QuicFrame::Stream { stream_id: frame.stream_id().into(), offset: frame.offset(), length: payload_length as u64, fin: frame.is_fin(), raw: Some(crate::build!(RawInfo { length: length as u64, payload_length: payload_length as u64, data: data, })), } } } impl From<&StreamFrame> for QuicFrame { fn from(frame: &StreamFrame) -> Self { let payload_length = frame.len(); let length = frame.encoding_size() + payload_length; QuicFrame::Stream { stream_id: frame.stream_id().into(), offset: frame.offset(), length: payload_length as u64, fin: frame.is_fin(), raw: Some(crate::build!(RawInfo { length: length as u64, payload_length: payload_length as u64, })), } } } impl From<(&DatagramFrame, &D)> for QuicFrame { fn from((frame, data): (&DatagramFrame, &D)) -> Self { let payload_length = frame.len().into_u64(); let length = frame.encoding_size() as u64 + payload_length; QuicFrame::Datagram { length: Some(payload_length as _), raw: Some(crate::build!(RawInfo { length, payload_length, data: data, })), } } } impl From<&DatagramFrame> for QuicFrame { fn from(frame: &DatagramFrame) -> Self { let payload_length = frame.len().into_u64(); let length = frame.encoding_size() as u64 + payload_length; QuicFrame::Datagram { length: Some(payload_length as _), raw: Some(crate::build!(RawInfo { length, payload_length, })), } } } impl From<&PathChallengeFrame> for QuicFrame { fn from(frame: &PathChallengeFrame) -> Self { QuicFrame::PathChallenge { data: Some(Bytes::from_owner(frame.to_vec()).into()), } } } impl From<&PathResponseFrame> for QuicFrame { fn from(frame: &PathResponseFrame) -> Self { QuicFrame::PathResponse { data: Some(Bytes::from_owner(frame.to_vec()).into()), } } } impl From<&AckFrame> for QuicFrame { fn from(frame: &AckFrame) -> Self { Self::Ack { ack_delay: Some(Duration::from_micros(frame.delay()).as_secs_f32() * 1000.0), acked_ranges: frame .ranges() .iter() .fold( ( frame.largest() - frame.first_range(), vec![[frame.largest() - frame.first_range(), frame.largest()]], ), |(previous_smallest, mut acked_ranges), (gap, ack)| { // see https://www.rfc-editor.org/rfc/rfc9000.html#name-ack-ranges let largest = previous_smallest - gap.into_u64() - 2; let smallest = largest - ack.into_u64(); acked_ranges.push([smallest, largest]); (smallest, acked_ranges) }, ) .1, ect1: frame.ecn().map(|ecn| ecn.ect1()), ect0: frame.ecn().map(|ecn| ecn.ect0()), ce: frame.ecn().map(|ecn| ecn.ce()), length: Some(frame.encoding_size() as u32), payload_length: None, } } } impl From<&ReliableFrame> for QuicFrame { fn from(frame: &ReliableFrame) -> Self { match frame { ReliableFrame::NewToken(new_token_frame) => new_token_frame.into(), ReliableFrame::MaxData(max_data_frame) => QuicFrame::MaxData { maximum: max_data_frame.max_data(), }, ReliableFrame::DataBlocked(data_blocked_frame) => QuicFrame::DataBlocked { limit: data_blocked_frame.limit(), }, ReliableFrame::NewConnectionId(new_connection_id_frame) => QuicFrame::NewConnectionId { sequence_number: new_connection_id_frame.sequence() as u32, retire_prior_to: new_connection_id_frame.retire_prior_to() as u32, connection_id_length: Some(new_connection_id_frame.connection_id().len() as u8), connection_id: (*new_connection_id_frame.connection_id()).into(), stateless_reset_token: Some((**new_connection_id_frame.reset_token()).into()), }, ReliableFrame::RetireConnectionId(retire_connection_id_frame) => { QuicFrame::RetireConnectionId { sequence_number: retire_connection_id_frame.sequence() as u32, } } ReliableFrame::HandshakeDone(_handshake_done_frame) => QuicFrame::HandshakeDone {}, ReliableFrame::AddAddress(frame) => QuicFrame::Unknow { frame_type_bytes: VarInt::from(frame.frame_type()).into_u64() as _, raw: None, }, ReliableFrame::RemoveAddress(frame) => QuicFrame::Unknow { frame_type_bytes: VarInt::from(frame.frame_type()).into_u64() as _, raw: None, }, ReliableFrame::PunchMeNow(frame) => QuicFrame::Unknow { frame_type_bytes: VarInt::from(frame.frame_type()).into_u64() as _, raw: None, }, ReliableFrame::PunchDone(frame) => QuicFrame::Unknow { frame_type_bytes: VarInt::from(frame.frame_type()).into_u64() as _, raw: None, }, ReliableFrame::StreamCtl(stream_ctl_frame) => QuicFrame::from(stream_ctl_frame), } } } impl From<&NewTokenFrame> for QuicFrame { fn from(value: &NewTokenFrame) -> Self { QuicFrame::NewToken { token: crate::build!(Token { r#type: TokenType::Retry, raw: RawInfo { length: value.encoding_size() as u64, payload_length: value.token().len() as u64, data: value.token(), }, }), } } } impl From<&StreamCtlFrame> for QuicFrame { fn from(frame: &StreamCtlFrame) -> Self { match frame { StreamCtlFrame::ResetStream(reset_stream_frame) => QuicFrame::ResetStream { stream_id: reset_stream_frame.stream_id().id(), error_code: (reset_stream_frame.app_error_code() as u32).into(), final_size: reset_stream_frame.final_size(), length: None, payload_length: None, }, StreamCtlFrame::StopSending(stop_sending_frame) => QuicFrame::StopSending { stream_id: stop_sending_frame.stream_id().id(), error_code: (stop_sending_frame.app_err_code() as u32).into(), length: None, payload_length: None, }, StreamCtlFrame::MaxStreamData(max_stream_data_frame) => QuicFrame::MaxStreamData { stream_id: max_stream_data_frame.stream_id().id(), maximum: max_stream_data_frame.max_stream_data(), }, StreamCtlFrame::MaxStreams(max_streams_frame) => match max_streams_frame { MaxStreamsFrame::Bi(maximum) => QuicFrame::MaxStreams { stream_type: StreamType::Bidirectional, maximum: maximum.into_u64(), }, MaxStreamsFrame::Uni(maximum) => QuicFrame::MaxStreams { stream_type: StreamType::Unidirectional, maximum: maximum.into_u64(), }, }, StreamCtlFrame::StreamDataBlocked(stream_data_blocked_frame) => { QuicFrame::StreamDataBlocked { stream_id: stream_data_blocked_frame.stream_id().id(), limit: stream_data_blocked_frame.maximum_stream_data(), } } StreamCtlFrame::StreamsBlocked(streams_blocked_frame) => match streams_blocked_frame { StreamsBlockedFrame::Bi(limit) => QuicFrame::StreamsBlocked { stream_type: StreamType::Bidirectional, limit: limit.into_u64(), }, StreamsBlockedFrame::Uni(limit) => QuicFrame::StreamsBlocked { stream_type: StreamType::Unidirectional, limit: limit.into_u64(), }, }, } } } impl From<&ConnectionCloseFrame> for QuicFrame { fn from(frame: &ConnectionCloseFrame) -> Self { Self::ConnectionClose { error_space: Some(match &frame { ConnectionCloseFrame::App(..) => ConnectionCloseErrorSpace::Application, ConnectionCloseFrame::Quic(..) => ConnectionCloseErrorSpace::Transport, }), error_code: match &frame { ConnectionCloseFrame::App(frame) => { Some(ApplicationCode::from(frame.error_code() as u32).into()) } ConnectionCloseFrame::Quic(frame) => { Some(connectivity::ConnectionCode::from(frame.error_kind()).into()) } }, reason: match &frame { ConnectionCloseFrame::App(frame) => Some(frame.reason().to_owned()), ConnectionCloseFrame::Quic(frame) => Some(frame.reason().to_owned()), }, // TODO: 不应该强制要求reason是utf8的 reason_bytes: None, trigger_frame_type: match &frame { ConnectionCloseFrame::Quic(frame) => { Some((VarInt::from(frame.frame_type()).into_u64()).into()) } ConnectionCloseFrame::App(..) => None, }, } } } impl From<&Frame> for QuicFrame { fn from(frame: &Frame) -> Self { match frame { Frame::Padding(..) => QuicFrame::Padding { length: Some(1), payload_length: 1, }, Frame::Ping(..) => QuicFrame::Ping { length: Some(1), payload_length: Some(1), }, Frame::Ack(frame) => frame.into(), Frame::Close(frame) => frame.into(), Frame::NewToken(frame) => frame.into(), Frame::MaxData(frame) => (&ReliableFrame::from(*frame)).into(), Frame::DataBlocked(frame) => (&ReliableFrame::from(*frame)).into(), Frame::NewConnectionId(frame) => (&ReliableFrame::from(*frame)).into(), Frame::RetireConnectionId(frame) => (&ReliableFrame::from(*frame)).into(), Frame::HandshakeDone(frame) => (&ReliableFrame::from(*frame)).into(), Frame::AddAddress(frame) => (&ReliableFrame::from(*frame)).into(), Frame::RemoveAddress(frame) => (&ReliableFrame::from(*frame)).into(), Frame::PunchMeNow(frame) => (&ReliableFrame::from(*frame)).into(), Frame::PathChallenge(frame) => frame.into(), Frame::PathResponse(frame) => frame.into(), Frame::StreamCtl(frame) => frame.into(), Frame::Stream(frame, bytes) if bytes.is_empty() => (frame, bytes).into(), Frame::Crypto(frame, bytes) if bytes.is_empty() => (frame, bytes).into(), Frame::Datagram(frame, bytes) if bytes.is_empty() => (frame, bytes).into(), Frame::Stream(frame, bytes) => (frame, bytes).into(), Frame::Crypto(frame, bytes) => (frame, bytes).into(), Frame::Datagram(frame, bytes) => (frame, bytes).into(), Frame::PunchHello(frame) => QuicFrame::Unknow { frame_type_bytes: VarInt::from(frame.frame_type()).into_u64() as _, raw: None, }, Frame::PunchDone(frame) => QuicFrame::Unknow { frame_type_bytes: VarInt::from(frame.frame_type()).into_u64() as _, raw: None, }, } } } /// A collection of automatically and efficiently converting raw quic frames into qlog quic frames. #[derive(Debug)] pub struct QuicFramesCollector { event: PhantomData, frames: Vec, } impl QuicFramesCollector { pub fn new() -> Self { Self { event: PhantomData, frames: Vec::new(), } } } impl Default for QuicFramesCollector { fn default() -> Self { Self::new() } } impl Extend for QuicFramesCollector where E: BeSpecificEventData, F: Into, { fn extend>(&mut self, iter: T) { if !crate::telemetry::Span::current().filter_event(E::scheme()) { return; } for frame in iter.into_iter().map(Into::into) { if let Some(last) = self.frames.last_mut() { match last { QuicFrame::Padding { length, payload_length, } => { *last = QuicFrame::Padding { length: length.map(|length| length + 1), payload_length: *payload_length + 1, }; continue; } QuicFrame::Ping { length, payload_length, } => { *last = QuicFrame::Ping { length: length.map(|length| length + 1), payload_length: payload_length.map(|length| length + 1), }; continue; } _ => {} } } self.frames.push(frame); } } } impl From> for Vec { fn from(value: QuicFramesCollector) -> Self { value.frames } } #[derive(Debug, Clone, From, Serialize, Deserialize, PartialEq, Eq)] #[serde(untagged)] pub enum ApplicationCode { ApplicationError(ApplicationError), Value(u32), } impl From for ConnectionCloseErrorCode { fn from(value: ApplicationCode) -> Self { match value { ApplicationCode::ApplicationError(error) => { ConnectionCloseErrorCode::ApplicationError(error) } ApplicationCode::Value(value) => ConnectionCloseErrorCode::Value(value as _), } } } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum StreamType { Unidirectional, Bidirectional, } impl From for StreamType { fn from(dir: qbase::sid::Dir) -> Self { match dir { qbase::sid::Dir::Bi => Self::Bidirectional, qbase::sid::Dir::Uni => Self::Unidirectional, } } } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum ConnectionCloseErrorSpace { Transport, Application, } #[derive(Debug, Clone, From, Serialize, Deserialize, PartialEq, Eq)] #[serde(untagged)] pub enum ConnectionCloseErrorCode { TransportError(TransportError), CryptoError(CryptoError), ApplicationError(ApplicationError), Value(u64), } #[derive(Debug, Clone, Serialize, From, Deserialize, PartialEq, Eq)] #[serde(untagged)] pub enum ConnectionCloseTriggerFrameType { Id(u64), Text(String), } // 8.13.23 #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum TransportError { NoError, InternalError, ConnectionRefused, FlowControlError, StreamLimitError, StreamStateError, FinalSizeError, FrameEncodingError, TransportParameterError, ConnectionIdLimitError, ProtocolViolation, InvalidToken, ApplicationError, CryptoBufferExceeded, KeyUpdateError, AeadLimitReached, NoViablePath, } // 8.13.24 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub struct ApplicationError(String); // 8.13.25 #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct CryptoError(u8); impl Display for CryptoError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "crypto_error_0x1{:02x}", self.0) } } impl Serialize for CryptoError { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { serializer.serialize_str(&self.to_string()) } } impl<'de> Deserialize<'de> for CryptoError { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { let string = String::deserialize(deserializer)?; string.strip_prefix("crypto_error_0x1").map_or_else( || Err(serde::de::Error::custom("invalid crypto error")), |s| { u8::from_str_radix(s, 16) .map(CryptoError) .map_err(serde::de::Error::custom) }, ) } } crate::gen_builder_method! { PathEndpointInfoBuilder => PathEndpointInfo; PacketHeaderBuilder => PacketHeader; TokenBuilder => Token; } mod rollback { use super::*; use crate::{build, legacy::quic as legacy}; impl From for legacy::IPAddress { #[inline] fn from(value: IPAddress) -> Self { legacy::IPAddress::from(value.0) } } impl From for legacy::IPVersion { #[inline] fn from(value: IpVersion) -> Self { match value { IpVersion::V4 => legacy::IPVersion::V4, IpVersion::V6 => legacy::IPVersion::V6, } } } impl From for legacy::ConnectionID { #[inline] fn from(value: ConnectionID) -> Self { legacy::ConnectionID::from(HexString::from(Bytes::from(value.0.to_vec()))) } } impl From for legacy::Owner { #[inline] fn from(value: Owner) -> Self { match value { Owner::Local => legacy::Owner::Local, Owner::Remote => legacy::Owner::Remote, } } } impl From for legacy::PacketType { #[inline] fn from(value: PacketType) -> Self { match value { PacketType::Initial => legacy::PacketType::Initial, PacketType::Handshake => legacy::PacketType::Handshake, PacketType::ZeroRTT => legacy::PacketType::ZeroRTT, PacketType::OneRTT => legacy::PacketType::OneRTT, PacketType::Retry => legacy::PacketType::Retry, PacketType::VersionNegotiation => legacy::PacketType::VersionNegotiation, PacketType::StatelessReset => legacy::PacketType::StatelessReset, PacketType::Unknown => legacy::PacketType::Unknown, } } } impl From for legacy::PacketNumberSpace { #[inline] fn from(value: PacketNumberSpace) -> Self { match value { PacketNumberSpace::Initial => legacy::PacketNumberSpace::Initial, PacketNumberSpace::Handshake => legacy::PacketNumberSpace::Handshake, PacketNumberSpace::ApplicationData => legacy::PacketNumberSpace::ApplicationData, } } } impl From for legacy::TokenType { #[inline] fn from(value: TokenType) -> Self { match value { TokenType::Retry => legacy::TokenType::Retry, TokenType::Resumption => legacy::TokenType::Resumption, } } } impl From for legacy::Token { #[inline] fn from(value: Token) -> Self { build!(legacy::Token { ?r#type: value.r#type, details: value.details, ?length: value.raw.as_ref().and_then(|raw| raw.length.map(|length| length as u32)), ?data: value.raw.and_then(|raw| raw.data) }) } } impl From for legacy::Token { #[inline] fn from(value: StatelessResetToken) -> Self { build!(legacy::Token { r#type: TokenType::Resumption, details: HashMap::new(), length: 16u32, data: { Bytes::from_owner(value.0.to_vec()) } }) } } impl From for legacy::PacketHeader { fn from(value: PacketHeader) -> Self { build!(legacy::PacketHeader { packet_type: value.packet_type, ?packet_number: value.packet_number, ?flags: value.flags, ?token: value.token, ?length: value.length, ?version: value.version, ?scil: value.scil, ?dcil: value.dcil, ?scid: value.scid, ?dcid: value.dcid }) } } impl From for legacy::TransportError { #[inline] fn from(value: TransportError) -> Self { match value { TransportError::NoError => legacy::TransportError::NoError, TransportError::InternalError => legacy::TransportError::InternalError, TransportError::ConnectionRefused => legacy::TransportError::ConnectionRefused, TransportError::FlowControlError => legacy::TransportError::FlowControlError, TransportError::StreamLimitError => legacy::TransportError::StreamLimitError, TransportError::StreamStateError => legacy::TransportError::StreamStateError, TransportError::FinalSizeError => legacy::TransportError::FinalSizeError, TransportError::FrameEncodingError => legacy::TransportError::FrameEncodingError, TransportError::TransportParameterError => { legacy::TransportError::TransportParameterError } TransportError::ConnectionIdLimitError => { legacy::TransportError::ConnectionIdLimitError } TransportError::ProtocolViolation => legacy::TransportError::ProtocolViolation, TransportError::InvalidToken => legacy::TransportError::InvalidToken, TransportError::ApplicationError => legacy::TransportError::ApplicationError, TransportError::CryptoBufferExceeded => { legacy::TransportError::CryptoBufferExceeded } TransportError::KeyUpdateError => legacy::TransportError::KeyUpdateError, TransportError::AeadLimitReached => legacy::TransportError::AeadLimitReached, TransportError::NoViablePath => legacy::TransportError::NoViablePath, } } } impl From for legacy::StreamType { #[inline] fn from(value: StreamType) -> Self { match value { StreamType::Unidirectional => legacy::StreamType::Unidirectional, StreamType::Bidirectional => legacy::StreamType::Bidirectional, } } } impl From for legacy::ConnectionCloseErrorSpace { #[inline] fn from(value: ConnectionCloseErrorSpace) -> Self { match value { ConnectionCloseErrorSpace::Transport => { legacy::ConnectionCloseErrorSpace::Transport } ConnectionCloseErrorSpace::Application => { legacy::ConnectionCloseErrorSpace::Application } } } } impl TryFrom for legacy::ConnectionCloseErrorCode { type Error = (); #[inline] fn try_from(value: ConnectionCloseErrorCode) -> Result { match value { ConnectionCloseErrorCode::TransportError(error) => Ok( legacy::ConnectionCloseErrorCode::TransportError(error.into()), ), ConnectionCloseErrorCode::CryptoError(_error) => Err(()), ConnectionCloseErrorCode::ApplicationError(error) => Ok( legacy::ConnectionCloseErrorCode::ApplicationError(error.into()), ), ConnectionCloseErrorCode::Value(value) => { Ok(legacy::ConnectionCloseErrorCode::Value(value)) } } } } impl From for legacy::ConnectionCloseTriggerFrameType { #[inline] fn from(value: ConnectionCloseTriggerFrameType) -> Self { match value { ConnectionCloseTriggerFrameType::Id(id) => { legacy::ConnectionCloseTriggerFrameType::Id(id) } ConnectionCloseTriggerFrameType::Text(text) => { legacy::ConnectionCloseTriggerFrameType::Text(text) } } } } impl From for legacy::QuicFrame { fn from(value: QuicFrame) -> Self { match value { QuicFrame::Padding { length, payload_length, } => legacy::QuicFrame::Padding { length, payload_length, }, QuicFrame::Ping { length, payload_length, } => legacy::QuicFrame::Ping { length, payload_length, }, QuicFrame::Ack { ack_delay, acked_ranges, ect1, ect0, ce, length, payload_length, } => legacy::QuicFrame::Ack { ack_delay, acked_ranges, ect1, ect0, ce, length, payload_length, }, QuicFrame::ResetStream { stream_id, error_code, final_size, length, payload_length, } => legacy::QuicFrame::ResetStream { stream_id, error_code: error_code.into(), final_size, length, payload_length, }, QuicFrame::StopSending { stream_id, error_code, length, payload_length, } => legacy::QuicFrame::StopSending { stream_id, error_code: error_code.into(), length, payload_length, }, QuicFrame::Crypto { offset, length, payload_length, raw: _, } => legacy::QuicFrame::Crypto { offset, length, payload_length, }, QuicFrame::NewToken { token } => legacy::QuicFrame::NewToken { token: token.into(), }, QuicFrame::Stream { stream_id, offset, length, fin, raw, } => legacy::QuicFrame::Stream { stream_id, offset, length, fin, raw, }, QuicFrame::MaxData { maximum } => legacy::QuicFrame::MaxData { maximum }, QuicFrame::MaxStreamData { stream_id, maximum } => { legacy::QuicFrame::MaxStreamData { stream_id, maximum } } QuicFrame::MaxStreams { stream_type, maximum, } => legacy::QuicFrame::MaxStreams { stream_type: stream_type.into(), maximum, }, QuicFrame::DataBlocked { limit } => legacy::QuicFrame::DataBlocked { limit }, QuicFrame::StreamDataBlocked { stream_id, limit } => { legacy::QuicFrame::StreamDataBlocked { stream_id, limit } } QuicFrame::StreamsBlocked { stream_type, limit } => { legacy::QuicFrame::StreamsBlocked { stream_type: stream_type.into(), limit, } } QuicFrame::NewConnectionId { sequence_number, retire_prior_to, connection_id_length, connection_id, stateless_reset_token, } => legacy::QuicFrame::NewConnectionId { sequence_number, retire_prior_to, connection_id_length, connection_id: connection_id.into(), stateless_reset_token: stateless_reset_token.map(Into::into), }, QuicFrame::RetireConnectionId { sequence_number } => { legacy::QuicFrame::RetireConnectionId { sequence_number } } QuicFrame::PathChallenge { data } => legacy::QuicFrame::PathChallenge { data }, QuicFrame::PathResponse { data } => legacy::QuicFrame::PathResponse { data }, QuicFrame::ConnectionClose { error_space, error_code, reason, reason_bytes: _, trigger_frame_type, } => legacy::QuicFrame::ConnectionClose { error_space: error_space.map(Into::into), raw_error_code: match &error_code { Some(ConnectionCloseErrorCode::CryptoError(CryptoError(value))) => { Some(*value as u32) } _ => None, }, error_code: error_code.and_then(|error_code| error_code.try_into().ok()), reason, trigger_frame_type: trigger_frame_type.map(Into::into), }, QuicFrame::HandshakeDone {} => legacy::QuicFrame::HandshakeDone {}, QuicFrame::Unknow { frame_type_bytes, raw, } => legacy::QuicFrame::Unknown { raw_frame_type: frame_type_bytes, raw_length: raw .as_ref() .and_then(|raw| raw.length.map(|length| length as u32)), raw: raw.and_then(|raw| raw.data), }, QuicFrame::Datagram { length, raw } => legacy::QuicFrame::Datagram { length, raw }, } } } impl From for legacy::ApplicationError { #[inline] fn from(value: ApplicationError) -> Self { value.0.into() } } impl From for legacy::ApplicationCode { #[inline] fn from(value: ApplicationCode) -> Self { match value { ApplicationCode::ApplicationError(error) => { legacy::ApplicationCode::ApplicationError(error.into()) } ApplicationCode::Value(value) => legacy::ApplicationCode::Value(value), } } } } #[cfg(test)] mod tests { use super::*; #[test] fn ack() { // 123 56 9 let frame = AckFrame::new( 9u32.into(), 1000u32.into(), 0u32.into(), vec![(1u32.into(), 1u32.into()), (0u32.into(), 2u32.into())], None, ); let encoding_size = frame.encoding_size(); let quic_frame: QuicFrame = (&frame).into(); assert_eq!( quic_frame, QuicFrame::Ack { ack_delay: Some(1.0), acked_ranges: vec![[9, 9], [5, 6], [1, 3]], ect1: None, ect0: None, ce: None, length: Some(encoding_size as u32), payload_length: None, } ); } } ================================================ FILE: qevent/src/telemetry/filter.rs ================================================ #[inline] #[cfg(feature = "telemetry")] pub fn event(scheme: &'static str) -> bool { super::current_span::CURRENT_SPAN.with(|span| span.borrow().filter_event(scheme)) } #[inline] #[cfg(not(feature = "telemetry"))] pub fn event(_scheme: &'static str) -> bool { false } #[inline] #[cfg(all(feature = "telemetry", feature = "raw_data"))] pub fn raw_data() -> bool { super::current_span::CURRENT_SPAN.with(|span| span.borrow().filter_raw_data()) } #[inline] #[cfg(not(all(feature = "telemetry", feature = "raw_data")))] pub fn raw_data() -> bool { false } ================================================ FILE: qevent/src/telemetry/handy.rs ================================================ use std::{ future::Future, path::{Path, PathBuf}, sync::Arc, }; use tokio::{ io::{self, AsyncWrite, AsyncWriteExt}, sync::mpsc, }; use super::{ExportEvent, QLog, Span}; use crate::{Event, GroupID, VantagePoint, VantagePointType, span}; pub struct NoopExporter; impl ExportEvent for NoopExporter { fn emit(&self, event: Event) { _ = event; } fn filter_event(&self, _: &'static str) -> bool { false } fn filter_raw_data(&self) -> bool { false } } impl ExportEvent for mpsc::UnboundedSender { fn emit(&self, event: Event) { _ = self.send(event); } } pub struct NoopLogger; impl QLog for NoopLogger { #[inline] fn new_trace(&self, _: VantagePointType, _: GroupID) -> Span { span!(Arc::new(NoopExporter)) } } impl QLog for Arc { #[inline] fn new_trace(&self, vantage_point: VantagePointType, group_id: GroupID) -> Span { self.as_ref().new_trace(vantage_point, group_id) } } pub trait TelemetryStorage { fn join( &self, file_name: &str, ) -> impl Future + Send + 'static; } impl TelemetryStorage for PathBuf { fn join( &self, file_name: &str, ) -> impl Future + Send + 'static { let file_path = Path::join(self, file_name); async move { tokio::fs::OpenOptions::new() .create(true) .truncate(true) .write(true) .open(&file_path) .await .unwrap_or_else(|e| { panic!( "failed to create sqlog file {}: {e:?}, qlogs to this connection will be ignored.", file_path.display() ) }) } } } impl TelemetryStorage for tokio::io::Stdout { #[allow(clippy::manual_async_fn)] fn join( &self, _: &str, ) -> impl Future + Send + 'static { async move { tokio::io::stdout() } } } impl TelemetryStorage for tokio::io::Stderr { #[allow(clippy::manual_async_fn)] fn join( &self, _: &str, ) -> impl Future + Send + 'static { async move { tokio::io::stderr() } } } pub struct LegacySeqLogger { storage: S, } impl Clone for LegacySeqLogger { fn clone(&self) -> Self { Self { storage: self.storage.clone(), } } } impl LegacySeqLogger { pub fn new(storage: S) -> Self { Self { storage } } } impl QLog for LegacySeqLogger { fn new_trace(&self, vantage_point: VantagePointType, group_id: GroupID) -> Span { use crate::legacy; let file_name = format!("{group_id}_{vantage_point}.sqlog"); let file = self.storage.join(&file_name); let qlog_file_seq = crate::build!(legacy::QlogFileSeq { title: file_name, trace: legacy::TraceSeq { vantage_point: VantagePoint { r#type: vantage_point }, } }); let (tx, mut rx) = mpsc::unbounded_channel::(); tokio::spawn(async move { let mut log_file = io::BufWriter::new(file.await); const RS: u8 = 0x1E; log_file.write_u8(RS).await?; let qlog_file_seq = serde_json::to_string(&qlog_file_seq).unwrap(); log_file.write_all(qlog_file_seq.as_bytes()).await?; log_file.write_u8(b'\n').await?; while let Some(event) = rx.recv().await { let Ok(event) = legacy::Event::try_from(event) else { continue; }; let event = serde_json::to_string(&event).unwrap(); // log_file.write_vectored(); log_file.write_u8(RS).await?; log_file.write_all(event.as_bytes()).await?; log_file.write_u8(b'\n').await?; } log_file.shutdown().await }); crate::span!(Arc::new(tx), group_id = group_id) } } pub struct TracingLogger; impl QLog for TracingLogger { fn new_trace(&self, vantage_point: VantagePointType, group_id: GroupID) -> Span { use crate::legacy; let span = tracing::info_span!(parent: None,"qlog", role = %vantage_point, odcid = %group_id); let qlog_file_seq = crate::build!(legacy::QlogFileSeq { title: format!("{group_id}_{vantage_point}.sqlog"), trace: legacy::TraceSeq { vantage_point: VantagePoint { r#type: vantage_point }, } }); let (tx, mut rx) = mpsc::unbounded_channel::(); tokio::spawn(tracing::Instrument::instrument( async move { tracing::debug!(target: "qlog", "{}", serde_json::to_string(&qlog_file_seq).unwrap()); while let Some(event) = rx.recv().await { let Ok(event) = legacy::Event::try_from(event) else { continue; }; tracing::debug!(target: "qlog", "{}", serde_json::to_string(&event).unwrap()); } }, span, )); crate::span!(Arc::new(tx), group_id = group_id) } } #[cfg(test)] mod tests { use crate::{ quic::connectivity::ServerListening, telemetry::{Instrument, QLog, Span, handy::LegacySeqLogger}, }; #[tokio::test] #[cfg(feature = "telemetry")] async fn legacy_seq_exporter() { let exporter = LegacySeqLogger::new(tokio::io::stdout()); let root_span = exporter.new_trace( crate::VantagePointType::Server, crate::GroupID::from("test_group".to_string()), ); root_span.in_scope(|| { let any_field = 112233u64; crate::span!(@current, any_field).in_scope(|| { crate::event!(ServerListening { ip_v4: "127.0.0.1".to_owned(), port_v4: 443u16 }); tokio::spawn( async move { assert_eq!(Span::current().load::("any_field"), 112233u64); // do something } .instrument(crate::span!(@current, path_id = String::from("new path"))), ); }); }); tokio::task::yield_now().await; } } ================================================ FILE: qevent/src/telemetry/macro_support.rs ================================================ use serde::Serialize; use super::*; use crate::{BeSpecificEventData, EventBuilder}; #[inline] pub fn new_span(exporter: Arc, fields: HashMap<&'static str, Value>) -> Span { Span { exporter, fields: Arc::new(fields), } } pub fn modify_event_builder_costom_fields( builder: &mut EventBuilder, f: impl FnOnce(&mut HashMap), ) { if builder.custom_fields.is_none() { builder.custom_fields = Some(HashMap::new()); } let custom_fields = builder.custom_fields.as_mut().unwrap(); f(custom_fields); } pub fn current_span_exporter() -> Arc { current_span::CURRENT_SPAN.with(|span| span.borrow().exporter.clone()) } pub fn current_span_fields() -> HashMap<&'static str, Value> { current_span::CURRENT_SPAN.with(|span| span.borrow().fields.as_ref().clone()) } pub fn try_load_current_span(name: &'static str) -> Option { current_span::CURRENT_SPAN.with(|span| { let span = span.borrow(); Some(from_value::(span.fields.get(name)?.clone())) }) } pub fn build_and_emit_event( build_data: impl FnOnce() -> D, build_event: impl FnOnce(D) -> Event, ) { if !filter::event(D::scheme()) { return; } let event = build_event(build_data()); current_span::CURRENT_SPAN.with(|span| span.borrow().emit(event)); } pub fn to_value(value: T) -> Value { serde_json::to_value(value).unwrap() } pub fn from_value(value: Value) -> T { serde_json::from_value(value).unwrap() } ================================================ FILE: qevent/src/telemetry/macros.rs ================================================ #[macro_export] #[cfg(feature = "telemetry")] macro_rules! span { () => {{ $crate::telemetry::Span::current() }}; (@current $(, $($tt:tt)* )?) => {{ let __current_exporter = $crate::telemetry::macro_support::current_span_exporter(); $crate::span!(__current_exporter $(, $($tt)* )?) }}; ($broker:expr $(, $($tt:tt)* )?) => {{ #[allow(unused_mut)] let mut __current_fields = $crate::telemetry::macro_support::current_span_fields(); $crate::span!(@field __current_fields $(, $($tt)* )?); $crate::telemetry::macro_support::new_span($broker, __current_fields) }}; (@field $fields:expr, $name:ident $(, $($tt:tt)* )?) => { $crate::span!( @field $fields, $name = $name $(, $($tt)* )? ); }; (@field $fields:expr, $name:ident = $value:expr $(, $($tt:tt)* )?) => { let __value = $crate::telemetry::macro_support::to_value($value); $fields.insert(stringify!($name), __value); $crate::span!( @field $fields $(, $($tt)* )? ); }; (@field $fields:expr $(,)? ) => {}; } #[macro_export] #[cfg(not(feature = "telemetry"))] macro_rules! span { () => {{ $crate::telemetry::Span::current() }}; (@current $(, $($tt:tt)* )?) => {{ let __current_exporter = $crate::telemetry::macro_support::current_span_exporter(); $crate::span!(__current_exporter $(, $($tt)* )?) }}; ($broker:expr $(, $($tt:tt)* )?) => {{ #[allow(unused_mut)] let mut __current_fields = $crate::telemetry::macro_support::current_span_fields(); $crate::span!(@field __current_fields $(, $($tt)* )?); $crate::telemetry::macro_support::new_span($broker, __current_fields) }}; (@field $fields:expr, $name:ident $(, $($tt:tt)* )?) => { $crate::span!( @field $fields, $name = $name $(, $($tt)* )? ); }; (@field $fields:expr, $name:ident = $value:expr $(, $($tt:tt)* )?) => { _ = $value; $crate::span!( @field $fields $(, $($tt)* )? ); }; (@field $fields:expr $(,)? ) => {}; } #[macro_export] macro_rules! event { ($event_type:ty { $($event_field:tt)* } $(, $($tt:tt)* )?) => {{ $crate::event!($crate::build!($event_type { $($event_field)* }) $(, $($tt)* )?); }}; ($event_data:expr $(, $($tt:tt)* )?) => {{ let __build_data = || $event_data; let __build_event = |__event_data| { let mut __event_builder = $crate::Event::builder(); // as_millis_f64 is nightly only let __time = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_secs_f64() * 1000.0; __event_builder.time(__time); __event_builder.data(__event_data); $crate::event!(@load_known __event_builder, path: $crate::PathID); $crate::event!(@load_known __event_builder, protocol_types: $crate::ProtocolTypeList); $crate::event!(@load_known __event_builder, group_id: $crate::GroupID); $crate::event!(@field __event_builder $(, $($tt)* )?); __event_builder.build() }; $crate::telemetry::macro_support::build_and_emit_event(__build_data, __build_event); }}; (@load_known $event_builder:expr, $name:ident: $type:ty) => { if let Some(__value) = $crate::telemetry::macro_support::try_load_current_span::<$type>(stringify!($name)) { $event_builder.$name(__value); } }; (@field $event_builder:expr, $name:ident $(, $($tt:tt)* )?) => { $crate::event!( @field $event_builder, $name = $name $(, $($tt)* )? ); }; (@field $event_builder:expr, $name:ident = Map { $($build:tt)* } $(, $($tt:tt)* )?) => { let __value = $crate::telemetry::macro_support::to_value($crate::map!{ $($build)* }); $crate::telemetry::macro_support::modify_event_builder_costom_fields(&mut $event_builder, |__custom_fields| { __custom_fields.insert(stringify!($name).to_owned(), __value); }); $crate::event!( @field $event_builder $(, $($tt)* )? ); }; (@field $event_builder:expr, $name:ident = $struct:ident { $(build:tt)* } $(, $($tt:tt)* )?) => { let __value = $crate::telemetry::macro_support::to_value($crate::build!($struct { $(build)* })); $crate::telemetry::macro_support::modify_event_builder_costom_fields(&mut $event_builder, |__custom_fields| { __custom_fields.insert(stringify!($name).to_owned(), __value); }); $crate::event!( @field $event_builder $(, $($tt)* )? ); }; (@field $event_builder:expr, $name:ident = $value:expr $(, $($tt:tt)* )?) => { let __value = $crate::telemetry::macro_support::to_value($value); $crate::telemetry::macro_support::modify_event_builder_costom_fields(&mut $event_builder, |__custom_fields| { __custom_fields.insert(stringify!($name).to_owned(), __value); }); $crate::event!( @field $event_builder $(, $($tt)* )? ); }; (@field $event_builder:expr $(,)? ) => {}; } ================================================ FILE: qevent/src/telemetry.rs ================================================ pub(crate) mod filter; pub mod handy; #[doc(hidden)] pub mod macro_support; mod macros; use std::{ collections::HashMap, fmt::Debug, future::Future, pin::Pin, sync::Arc, task::{Context, Poll}, }; use handy::NoopExporter; use serde::de::DeserializeOwned; use serde_json::Value; use crate::{Event, GroupID, VantagePointType}; pub trait QLog { fn new_trace(&self, vantage_point: VantagePointType, group_id: GroupID) -> Span; } pub trait ExportEvent: Send + Sync { fn emit(&self, event: Event); fn filter_event(&self, scheme: &'static str) -> bool { _ = scheme; true } fn filter_raw_data(&self) -> bool { false } } #[derive(Clone)] pub struct Span { exporter: Arc, fields: Arc>, } impl Debug for Span { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Span") .field("exporter", &"..") .field("fields", &self.fields) .finish() } } impl Span { #[inline] pub fn emit(&self, event: Event) { self.exporter.emit(event); } #[inline] pub fn filter_event(&self, scheme: &'static str) -> bool { self.exporter.filter_event(scheme) } #[inline] pub fn filter_raw_data(&self) -> bool { self.exporter.filter_raw_data() } #[inline] pub fn load(&self, name: &'static str) -> T { let Some(value) = self.fields.get(name) else { panic!( "Failed to load field `{name}` from span fields: {:?}", self.fields ); }; match serde_json::from_value(value.clone()) { Ok(value) => value, Err(e) => panic!( "Failed to load field `{name}` from span fields: {:?}, error: {:?}", self.fields, e ), } } #[inline] pub fn try_load(&self, name: &'static str) -> Option { serde_json::from_value(self.fields.get(name)?.clone()).ok() } } impl PartialEq for Span { fn eq(&self, other: &Self) -> bool { Arc::ptr_eq(&self.fields, &other.fields) && Arc::ptr_eq(&self.exporter, &other.exporter) } } impl Default for Span { fn default() -> Self { Self { exporter: Arc::new(NoopExporter), fields: Arc::new(HashMap::new()), } } } pub struct Entered { previous: Option, } mod current_span { use std::cell::RefCell; use super::{Entered, Span}; thread_local! { pub static CURRENT_SPAN: RefCell = RefCell::new(Span::default()); } impl Drop for Entered { fn drop(&mut self) { if let Some(previous) = &self.previous { CURRENT_SPAN.with(|span| { span.replace(previous.clone()); }); } } } impl Span { pub fn enter(&self) -> Entered { let previous = CURRENT_SPAN.with(|current| { if &*current.borrow() == self { None } else { Some(current.replace(self.clone())) } }); Entered { previous } } pub fn in_scope(&self, f: impl FnOnce() -> T) -> T { let _guard = self.enter(); f() } pub fn current() -> Span { CURRENT_SPAN.with(|span| span.borrow().clone()) } } } pin_project_lite::pin_project! { pub struct Instrumented { span: Span, #[pin] inner: F, } } pub trait Instrument { fn instrument(self, span: Span) -> Instrumented; fn instrument_in_current(self) -> Instrumented; } impl Instrument for F { fn instrument(self, span: Span) -> Instrumented { Instrumented { span, inner: self } } fn instrument_in_current(self) -> Instrumented { self.instrument(crate::span!()) } } impl Future for Instrumented { type Output = F::Output; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); this.span.in_scope(|| this.inner.poll(cx)) } } #[cfg(test)] mod tests { use std::sync::Arc; use qbase::cid::ConnectionId; use super::*; use crate::{ GroupID, event, quic::{ConnectionID, connectivity::ServerListening}, span, }; #[test] fn span_fields() { let exporter = Arc::new(NoopExporter); let _span = span!(exporter.clone()); let a = 0i32; let c = 123456789usize; span!(exporter.clone(), a, a, b = 12.3f32, c, d = "Hello world!").in_scope(|| { assert_eq!(Span::current().load::("a"), 0); assert_eq!(Span::current().load::("b"), 12.3); assert_eq!(Span::current().load::("c"), 123456789); assert_eq!(Span::current().load::("d"), "Hello world!"); let e = vec![1, 2, 3]; span!(exporter.clone(), a = 1, b = 2, c = 3, e).in_scope(|| { assert_eq!(Span::current().load::("a"), 1); assert_eq!(Span::current().load::("b"), 2); assert_eq!(Span::current().load::("c"), 3); assert_eq!(Span::current().load::("d"), "Hello world!"); assert_eq!(Span::current().load::>("e"), vec![1, 2, 3]); }); }) } #[test] fn event() { struct TestBroker; impl ExportEvent for TestBroker { fn emit(&self, event: Event) { let str = serde_json::to_string_pretty(&event).unwrap(); let event = serde_json::to_value(event).unwrap(); println!("{str}"); assert_eq!(event["name"], "quic:server_listening"); let event_data_json = serde_json::json!({ "ip_v4": "127.0.0.1", "port_v4": 8080, }); assert_eq!(event["data"], event_data_json); assert_eq!(event["group_id"], String::from(group_id())); assert_eq!(event["use_strict_mode"], true); } } fn group_id() -> GroupID { GroupID::from(ConnectionID::from(ConnectionId::from_slice(&[ 0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef, ]))) } span!(Arc::new(TestBroker), group_id = group_id()).in_scope(|| { event!( crate::build!(ServerListening { ip_v4: "127.0.0.1".to_owned(), port_v4: 8080u16, }), use_strict_mode = true ); }); } } ================================================ FILE: qinterface/Cargo.toml ================================================ [package] name = "qinterface" version = "0.5.0" edition.workspace = true description = "dquic's network interface and IO abstractions" readme.workspace = true repository.workspace = true license.workspace = true keywords.workspace = true categories.workspace = true [dependencies] bytes = { workspace = true } dashmap = { workspace = true } derive_more = { workspace = true, features = ["deref"] } futures = { workspace = true } http = { workspace = true } netdev = { workspace = true } netwatcher = { workspace = true } parking_lot = { workspace = true } pin-project-lite = { workspace = true } qbase = { workspace = true } qevent = { workspace = true } rustls = { workspace = true } serde = { workspace = true, features = ["derive"] } tokio = { workspace = true, features = ["net", "rt", "sync", "time", "macros"] } tokio-util = { workspace = true, features = ["rt"] } thiserror = { workspace = true } tracing = { workspace = true } [target.'cfg(any(unix, windows))'.dependencies] qudp = { workspace = true, optional = true } [dev-dependencies] tokio = { workspace = true, features = [ "test-util", "macros", "rt-multi-thread", ] } [features] qudp = ["dep:qudp"] ================================================ FILE: qinterface/examples/interface-monitor.rs ================================================ use qinterface::device::Devices; #[tokio::main(flavor = "current_thread")] async fn main() { let global = Devices::global(); let mut monitor = global.monitor(); for (name, iface) in monitor.interfaces() { println!("Interface: {name} => {iface:#?}"); } while let Some((_devices, event)) = monitor.update().await { println!("Event: {event:#?}"); } } ================================================ FILE: qinterface/src/bind_uri.rs ================================================ use std::{ borrow::Cow, fmt::Display, io, net::{AddrParseError, IpAddr, SocketAddr}, str::FromStr, }; use derive_more::{Display, Into}; use qbase::{net::Family, util::UniqueIdGenerator}; use thiserror::Error; #[derive(Debug, Display, Clone, Into, PartialEq, Eq, Hash)] pub struct BindUri(http::Uri); #[derive(Debug, Error)] pub enum ParseError { #[error("Invalid uri {0}")] InvalidUri(::Err), #[error("Missing scheme")] NoScheme, #[error("Unsupported bind uri scheme: {0}")] Unsupported(String), #[error("Path must be empty")] Malformed, #[error("Missing ip family for iface scheme BindUri")] NoFamily, #[error("Missing port for iface scheme BindUri")] NoPort, #[error("Invalid IP address family for iface scheme")] UnknownFamily, #[error("Invalid IP address for inet scheme BindUri: {0}")] InvalidIpAddr(AddrParseError), } fn parse_iface_bind_uri(uri: &http::Uri) -> Result<(Family, &str, u16), ParseError> { let authority = uri.authority().expect("BindUri is absolute URI"); let (ip_family, interface) = authority .host() .split_once('.') .ok_or(ParseError::NoFamily)?; let port = authority.port_u16().ok_or(ParseError::NoPort)?; let ip_family: Family = ip_family.parse().or(Err(ParseError::UnknownFamily))?; Ok((ip_family, interface, port)) } fn parse_inet_bind_uri(uri: &http::Uri) -> Result { let authority = uri.authority().expect("BindUri is absolute URI"); let port = authority.port_u16().ok_or(ParseError::NoPort)?; let host = match authority.host().as_bytes() { [b'[', .., b']'] => authority.host().trim_matches(|c| matches!(c, '[' | ']')), _ => authority.host(), }; match IpAddr::from_str(host) { Ok(ip) => Ok(SocketAddr::new(ip, port)), Err(e) => Err(ParseError::InvalidIpAddr(e)), } } impl FromStr for BindUri { type Err = ParseError; fn from_str(s: &str) -> Result { if let Ok(socket_addr) = s.parse::() { return Ok(socket_addr.into()); } s.parse::() .map_err(ParseError::InvalidUri)? .try_into() } } impl TryFrom for BindUri { type Error = ParseError; fn try_from(uri: http::Uri) -> Result { let scheme = uri .scheme() .ok_or(ParseError::NoScheme)? .as_str() .parse() .map_err(ParseError::Unsupported)?; debug_assert!(uri.authority().is_some(), "BindUri should be absolute URI"); if uri.path() != "/" { return Err(ParseError::Malformed); } match scheme { Scheme::Iface => { parse_iface_bind_uri(&uri)?; } Scheme::Inet => { parse_inet_bind_uri(&uri)?; } } Ok(Self(uri)) } } impl From for BindUri { #[inline] fn from(value: String) -> Self { match BindUri::from_str(&value) { Ok(bind_uri) => bind_uri, Err(e) => panic!("bind uri should be valid: {e}"), } } } impl From<&str> for BindUri { #[inline] fn from(value: &str) -> Self { match BindUri::from_str(value) { Ok(bind_uri) => bind_uri, Err(e) => panic!("bind uri should be valid: {e}"), } } } impl From for BindUri { #[inline] fn from(value: SocketAddr) -> Self { match BindUri::from_str(&format!("inet://{value}")) { Ok(bind_uri) => bind_uri, Err(e) => panic!("{e}"), } } } impl> From<&T> for BindUri { #[inline] fn from(value: &T) -> Self { (*value).into() } } impl From<&BindUri> for BindUri { #[inline] fn from(value: &BindUri) -> Self { value.clone() } } impl BindUri { pub const TEMPORARY_PROP: &str = "temporary"; pub const STUN_PROP: &str = "stun"; pub const STUN_SERVER_PROP: &str = "stun_server"; pub const RELAY_PROP: &str = "relay"; pub fn scheme(&self) -> Scheme { self.0 .scheme() .expect("Invalid BindUri: Missing scheme") .as_str() .parse() .expect("Invalid BindUri: Invalid scheme") } #[inline] pub fn as_uri(&self) -> &http::Uri { &self.0 } pub fn family(&self) -> Family { match self.scheme() { Scheme::Iface => { self.as_iface_bind_uri() .expect("Already checked BindUriScheme is iface") .0 } Scheme::Inet => { match self .as_inet_bind_uri() .expect("Already checked BindUriScheme is inet") { SocketAddr::V4(_) => Family::V4, SocketAddr::V6(_) => Family::V6, } } } } pub fn as_iface_bind_uri(&self) -> Option<(Family, &str, u16)> { if self.scheme() != Scheme::Iface { return None; } Some(parse_iface_bind_uri(&self.0).expect("BindUri should be valid")) } pub fn as_inet_bind_uri(&self) -> Option { if self.scheme() != Scheme::Inet { return None; } Some(parse_inet_bind_uri(&self.0).expect("BindUri should be valid")) } pub fn add_prop(&mut self, key: &str, value: &str) { let mut uri_parts = self.0.clone().into_parts(); uri_parts.path_and_query = uri_parts.path_and_query.map(|pq| { let query = match pq.query() { Some(exist_query) => format!("{exist_query}&{key}={value}"), None => format!("{key}={value}"), }; format!("{}?{}", pq.path(), query) .parse() .expect("Path and query should be valid") }); self.0 = http::Uri::from_parts(uri_parts).expect("BindUri should be valid"); } pub const ALLOC_PORT_ID: &'static str = "alloc_port_id"; pub fn alloc_port(&self) -> Self { match self.scheme() { Scheme::Iface => { let (.., port) = self .as_iface_bind_uri() .expect("Already checked BindUriScheme is iface"); assert_eq!(port, 0, "Only port 0 is allocatable"); } Scheme::Inet => { let addr = self .as_inet_bind_uri() .expect("Already checked BindUriScheme is inet"); assert_eq!(addr.port(), 0, "Only port 0 is allocatable"); } } let mut new_uri = self.clone(); static ID_GENERATOR: UniqueIdGenerator = UniqueIdGenerator::new(); let alloc_port_id = usize::from(ID_GENERATOR.generate()).to_string(); new_uri.add_prop(Self::ALLOC_PORT_ID, &alloc_port_id); new_uri } #[inline] pub fn prop(&self, key: &str) -> Option> { // http://127.0.0.1/fx ?key=value self.0 .query()? .split('&') .find_map(|pair| match pair.split_once('=') { Some((k, v)) if k == key => Some(Cow::Borrowed(v)), None if pair == key => Some(Cow::Borrowed("")), _ => None, }) } pub fn is_temporary(&self) -> bool { match self.prop(Self::TEMPORARY_PROP) { Some(bool) if bool == "true" => true, None | Some(..) => false, } } pub fn enable_stun(&mut self) { self.add_prop(Self::STUN_PROP, "true"); } pub fn is_stun_enabled(&self) -> bool { match self.prop(Self::STUN_PROP) { Some(bool) if bool == "true" => true, None | Some(..) => false, } } pub fn with_stun_server(mut self, stun_server: &str) -> Self { self.add_prop(Self::STUN_SERVER_PROP, stun_server); self } pub fn stun_server(&self) -> Option> { self.prop(Self::STUN_SERVER_PROP) } // TODO: change to bool flag pub fn with_relay(mut self, relay: &str) -> Self { self.add_prop(Self::RELAY_PROP, relay); self } pub fn relay(&self) -> Option> { self.prop(Self::RELAY_PROP) } /// Returns a canonical key for reconciliation purposes. /// /// Strips ephemeral query parameters (like `alloc_port_id`) so that two /// `BindUri`s pointing at the same interface/port compare as equal even /// when produced by separate `alloc_port()` calls. pub fn identity_key(&self) -> String { let uri = &self.0; let mut parts = uri.clone().into_parts(); parts.path_and_query = parts.path_and_query.map(|pq| { pq.path() .parse() .expect("path portion should always be valid") }); http::Uri::from_parts(parts) .expect("BindUri without query should be valid") .to_string() } pub fn resolve(&self) -> Result { match self.scheme() { Scheme::Iface => { let (ip_family, interface, port) = self .as_iface_bind_uri() .expect("Already checked BindUriScheme is iface"); let devices = crate::device::Devices::global(); devices.get(interface).ok_or(io::Error::new( io::ErrorKind::NotFound, "device not found".to_string(), ))?; let ip_addr = devices.resolve(interface, ip_family).ok_or(io::Error::new( io::ErrorKind::NotFound, "ip not matched".to_string(), ))?; Ok(SocketAddr::new(ip_addr, port)) } Scheme::Inet => Ok(self .as_inet_bind_uri() .expect("Already checked BindUriScheme is inet")), } } } impl TryFrom<&BindUri> for SocketAddr { type Error = io::Error; fn try_from(bind_uri: &BindUri) -> Result { bind_uri.resolve() } } impl TryFrom for SocketAddr { type Error = io::Error; fn try_from(bind_uri: BindUri) -> Result { SocketAddr::try_from(&bind_uri) } } #[non_exhaustive] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Scheme { Iface, Inet, } impl Scheme { pub const fn to_str(&self) -> &'static str { match self { Scheme::Iface => "iface", Scheme::Inet => "inet", } } } impl From for http::uri::Scheme { fn from(value: Scheme) -> Self { value .to_str() .parse() .expect("BindUriScheme should be valid URI scheme") } } impl FromStr for Scheme { type Err = String; fn from_str(s: &str) -> Result { match s { "iface" => Ok(Scheme::Iface), "inet" => Ok(Scheme::Inet), other => Err(other.to_string()), } } } impl Display for Scheme { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.to_str().fmt(f) } } #[cfg(test)] mod tests { use super::*; #[test] fn invalid_uri() { assert!(matches!( BindUri::from_str("iface://"), Err(ParseError::InvalidUri(_)) )); } #[test] fn missing_scheme() { assert!(matches!( BindUri::from_str("invalid_uri"), Err(ParseError::NoScheme) )); } #[test] fn invalid_scheme() { assert!(matches!( BindUri::from_str("invalid://example.com"), Err(ParseError::Unsupported(_)) )); } #[test] fn has_path() { assert!(matches!( BindUri::from_str("iface://v4.wlan0/1234"), Err(ParseError::Malformed) )); } #[test] fn missing_ip_family() { assert!(matches!( BindUri::from_str("iface://wlan0:8080"), Err(ParseError::NoFamily) )); } #[test] fn missing_port() { assert!(matches!( BindUri::from_str("iface://v4.wlan0"), Err(ParseError::NoPort) )); } #[test] fn invalid_ip_family() { assert!(matches!( BindUri::from_str("iface://invalid.wlan0:8080"), Err(ParseError::UnknownFamily) )); } #[test] fn invalid_ip_addr() { assert!(matches!( BindUri::from_str("inet://example.com:8080"), Err(ParseError::InvalidIpAddr(..)) )); } #[test] fn iface_bind_uri() { let bind_uri = BindUri::from_str("iface://v4.wlan0:8080?temporary=true").unwrap(); assert_eq!(bind_uri.scheme(), Scheme::Iface); let (family, interface, port) = bind_uri.as_iface_bind_uri().unwrap(); assert_eq!(family, Family::V4); assert_eq!(interface, "wlan0"); assert_eq!(port, 8080); assert_eq!( bind_uri.prop(BindUri::TEMPORARY_PROP).as_deref(), Some("true") ); } #[test] fn inet_bind_uri() { let bind_uri = BindUri::from_str("inet://127.0.0.1:7777").unwrap(); assert_eq!(bind_uri.scheme(), Scheme::Inet); let addr = bind_uri.as_inet_bind_uri().unwrap(); assert_eq!( addr, SocketAddr::new(IpAddr::V4("127.0.0.1".parse().unwrap()), 7777) ); assert!(bind_uri.as_uri().query().is_none()); } // tokio runtime requeired for device listing #[tokio::test] async fn interface_not_found() { let bind_uri = BindUri::from_str( "iface://v4.ygiubiougbuyasiudbahsdbadfbkjadbhvkjabvckagdoiuehfjoiajhrpfhrbovhaelvkamdjkfs:8080", ) .unwrap(); assert!(SocketAddr::try_from(bind_uri).is_err_and(|e| e.kind() == io::ErrorKind::NotFound)) } #[test] fn to_socket_addr() { let bind_uri = BindUri::from_str("inet://127.0.0.1:8080").unwrap(); assert_eq!( SocketAddr::try_from(bind_uri).unwrap(), "127.0.0.1:8080".parse().unwrap() ); } #[test] fn alloc_port() { let bind_uri = BindUri::from_str("inet://0.0.0.0:0").unwrap(); assert_ne!(bind_uri.clone().alloc_port(), bind_uri.clone().alloc_port()); } #[test] #[should_panic] fn alloc_port_for_non_zero_port1() { let bind_uri = BindUri::from_str("inet://127.0.0.1:8080").unwrap(); bind_uri.alloc_port(); } #[test] #[should_panic] fn alloc_port_for_non_zero_port2() { let bind_uri = BindUri::from_str("inet://v4.lo:12345").unwrap(); bind_uri.alloc_port(); } #[test] fn temporary() { let bind_uri = BindUri::from_str("iface://v4.wlan0:8080?temporary=true").unwrap(); assert!(bind_uri.is_temporary()); let bind_uri = BindUri::from_str("iface://v4.wlan0:8080?temporary=false").unwrap(); assert!(!bind_uri.is_temporary()); let bind_uri = BindUri::from_str("iface://v4.wlan0:8080").unwrap(); assert!(!bind_uri.is_temporary()); let bind_uri = BindUri::from_str("iface://v4.C5563ED1-2BC9-42C5-8177-59F2F0AF37C8:8080").unwrap(); assert!(!bind_uri.is_temporary()); let mut bind_uri = BindUri::from_str("iface://v4.wlan0:8080").unwrap(); bind_uri.add_prop(BindUri::TEMPORARY_PROP, "true"); assert_eq!( bind_uri.to_string(), "iface://v4.wlan0:8080/?temporary=true" ); assert!(bind_uri.is_temporary()); } #[test] fn stun_enabled() { let mut bind_uri = BindUri::from_str("iface://v4.wlan0:8080").unwrap(); assert!(!bind_uri.is_stun_enabled()); bind_uri.enable_stun(); assert!(bind_uri.is_stun_enabled()); let bind_uri = BindUri::from_str("iface://v4.wlan0:8080?stun=true").unwrap(); assert!(bind_uri.is_stun_enabled()); let bind_uri = BindUri::from_str("iface://v4.wlan0:8080?stun=false").unwrap(); assert!(!bind_uri.is_stun_enabled()); } #[test] fn stun_server() { let bind_uri = BindUri::from_str("iface://v4.wlan0:8080").unwrap(); assert!(bind_uri.stun_server().is_none()); let bind_uri = bind_uri.with_stun_server("stun.example.com:3478"); assert_eq!( bind_uri.stun_server().as_deref(), Some("stun.example.com:3478") ); let bind_uri = BindUri::from_str("iface://v4.wlan0:8080?stun_server=stun.genmeta.net").unwrap(); assert_eq!(bind_uri.stun_server().as_deref(), Some("stun.genmeta.net")); } #[test] fn relay() { let bind_uri = BindUri::from_str("iface://v4.wlan0:8080").unwrap(); assert!(bind_uri.relay().is_none()); let bind_uri = bind_uri.with_relay("turn.example.com:3478"); assert_eq!(bind_uri.relay().as_deref(), Some("turn.example.com:3478")); let bind_uri = BindUri::from_str("iface://v4.wlan0:8080?relay=turn.genmeta.net").unwrap(); assert_eq!(bind_uri.relay().as_deref(), Some("turn.genmeta.net")); } } ================================================ FILE: qinterface/src/component/alive.rs ================================================ use std::{ fmt::Debug, io, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, pin::Pin, sync::{Mutex, MutexGuard}, task::{Context, Poll, ready}, }; use qbase::net::route::{Line, Link, Route}; use thiserror::Error; use tokio::net::UdpSocket; use tokio_util::task::AbortOnDropHandle; use crate::{ Interface, RebindedError, component::Component, device::Devices, io::{IO, IoExt}, }; #[derive(Debug, Error)] pub enum InterfaceFailure { #[error("Invalid QuicIO implementation")] InvalidImplementation, #[error("Interface is broken: {0}")] InterfaceBroken(io::Error), #[error("Real address does not match bind URI")] AddressMismatch, #[error("Failed to bind test socket: {0}")] TestSocketBindFailed(io::Error), #[error("Failed to send test packet: {0}")] SendTestFailed(io::Error), } impl From for InterfaceFailure { fn from(error: io::Error) -> Self { Self::TestSocketBindFailed(error) } } impl InterfaceFailure { pub fn is_recoverable(&self) -> bool { matches!( self, Self::InterfaceBroken(..) | Self::AddressMismatch | Self::SendTestFailed(..) ) } } pub async fn is_alive(iface: &(impl IO + ?Sized)) -> Result<(), InterfaceFailure> { let bound_addr = iface .bound_addr() .map_err(InterfaceFailure::InterfaceBroken)?; let socket_addr = SocketAddr::try_from(&iface.bind_uri())?; // Check if addresses match if !(bound_addr.ip() == socket_addr.ip() && (socket_addr.port() == 0 || bound_addr.port() == socket_addr.port())) { return Err(InterfaceFailure::AddressMismatch); } // Test connectivity with a local socket let localhost = match bound_addr.ip() { IpAddr::V4(ip) if ip.is_unspecified() => Ipv4Addr::LOCALHOST.into(), IpAddr::V4(ip) => ip.into(), IpAddr::V6(ip) if ip.is_unspecified() => Ipv6Addr::LOCALHOST.into(), IpAddr::V6(ip) => ip.into(), }; let socket = UdpSocket::bind(SocketAddr::new(localhost, 0)) .await .map_err(InterfaceFailure::TestSocketBindFailed)?; let dst_addr = socket .local_addr() .map_err(InterfaceFailure::TestSocketBindFailed)?; // Send test packet let link = Link::new(bound_addr, dst_addr); let packets = [io::IoSlice::new(&[0; 1])]; let line = Line::new(link, 64, None, packets[0].len() as u16); let header = Route::new(link.into(), line); iface .sendmmsg(&packets, header) .await .map_err(InterfaceFailure::SendTestFailed)?; Ok(()) } #[derive(Debug)] pub struct RebindOnNetworkChangedComponent { devices: &'static Devices, task: Mutex>>, } impl RebindOnNetworkChangedComponent { pub fn new(iface: &Interface, devices: &'static Devices) -> Self { let component = Self { devices, task: Mutex::new(None), }; component.init(iface); component } fn lock_task(&self) -> MutexGuard<'_, Option>> { self.task .lock() .expect("RebindOnNetworkChanged task mutex poisoned") } fn init(&self, iface: &Interface) { let mut task = self.lock_task(); if !task.as_ref().is_none_or(|t| t.is_finished()) { return; } let bind_uri = iface.bind_uri(); if bind_uri.is_temporary() { return; } let Some((_, device, ..)) = bind_uri.as_iface_bind_uri() else { return; }; let device = device.to_owned(); let weak_iface = iface.bind_interface().downgrade(); let mut event_receiver = self.devices.event_receiver(); *task = Some(AbortOnDropHandle::new(tokio::spawn(async move { let try_rebind = async move || { if let Ok(iface) = weak_iface.upgrade() && let Err(error) = is_alive(&iface.borrow()).await && error.is_recoverable() && !RebindedError::is_source_of(&error) { iface.rebind().await; } }; try_rebind().await; while let Some(event) = event_receiver.recv().await { if event.device() != device { continue; } try_rebind().await; } }))); } } impl Component for RebindOnNetworkChangedComponent { fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> { let mut task_guard = self.lock_task(); if let Some(task) = task_guard.as_mut() { task.abort(); _ = ready!(Pin::new(task).poll(cx)); *task_guard = None; } Poll::Ready(()) } fn reinit(&self, iface: &Interface) { self.init(iface); } } ================================================ FILE: qinterface/src/component/location.rs ================================================ use std::{ any::{Any, TypeId}, collections::{HashMap, hash_map}, fmt::Debug, ops::Deref, sync::{Arc, LazyLock, Mutex, MutexGuard}, task::{Context, Poll}, }; use qbase::util::{UniqueId, UniqueIdGenerator}; use tokio::sync::mpsc; use tokio_util::task::AbortOnDropHandle; use crate::{ BindUri, Interface, WeakInterface, component::Component, io::{IO, RefIO}, }; #[derive(Debug)] pub enum AddressEvent { Upsert(Arc), Remove(TypeId), Closed, } impl Clone for AddressEvent { fn clone(&self) -> Self { match self { Self::Upsert(arg0) => Self::Upsert(arg0.clone()), Self::Remove(arg0) => Self::Remove(*arg0), Self::Closed => Self::Closed, } } } // TODO: 固定类型 impl AddressEvent { pub fn downcast(self) -> Result, Self> { match self { AddressEvent::Upsert(data) => match data.downcast::() { Ok(data) => Ok(AddressEvent::Upsert(data)), Err(data) => Err(AddressEvent::Upsert(data)), }, AddressEvent::Remove(type_id) => match TypeId::of::() == type_id { true => Ok(AddressEvent::Remove(type_id)), false => Err(AddressEvent::Remove(type_id)), }, AddressEvent::Closed => Ok(AddressEvent::Closed), } } } type EventSender = mpsc::UnboundedSender<(BindUri, AddressEvent)>; type EventReceiver = mpsc::UnboundedReceiver<(BindUri, AddressEvent)>; struct EventPublisher { subscriber_id_generator: UniqueIdGenerator, datas: HashMap>>, subscribers: HashMap, } impl EventPublisher { pub fn new() -> Self { Self { subscriber_id_generator: UniqueIdGenerator::new(), datas: HashMap::new(), subscribers: HashMap::new(), } } pub fn publish_event(&mut self, bind_uri: BindUri, event: AddressEvent) { // 1. update state match event.clone() { AddressEvent::Upsert(data) => { let type_id = data.as_ref().type_id(); self.datas .entry(bind_uri.clone()) .or_default() .insert(type_id, data); } AddressEvent::Remove(type_id) => { let entry = self.datas.entry(bind_uri.clone()); if let hash_map::Entry::Occupied(mut entry) = entry { entry.get_mut().remove(&type_id); if entry.get().is_empty() { entry.remove_entry(); } } } AddressEvent::Closed => _ = self.datas.remove(&bind_uri), } // 2. forward event to subscribers self.subscribers .retain(|_, subscriber| subscriber.send((bind_uri.clone(), event.clone())).is_ok()); } pub fn register_subscriber(&mut self, subscriber: EventSender) { let subscriber_id = self.subscriber_id_generator.generate(); for (bind_uri, datas) in &self.datas { for (.., data) in datas { let event = AddressEvent::Upsert(data.clone()); if subscriber.send((bind_uri.clone(), event)).is_err() { // EventReceiver disconnected, so we skip registering this subscriber. return; } } } self.subscribers.insert(subscriber_id, subscriber); } } #[derive(Debug)] pub struct Locations { new_event_tx: EventSender, new_subscriber_tx: mpsc::UnboundedSender, _publisher_task: AbortOnDropHandle<()>, } impl Default for Locations { fn default() -> Self { Self::new() } } impl Locations { pub fn new() -> Self { let (new_event_tx, mut new_event_rx) = mpsc::unbounded_channel::<(BindUri, AddressEvent)>(); let (new_subscriber_tx, mut new_subscriber_rx) = mpsc::unbounded_channel(); let _publisher_task = AbortOnDropHandle::new(tokio::spawn(async move { let mut publisher = EventPublisher::new(); loop { tokio::select! { Some((bind_uri, event)) = new_event_rx.recv() => { publisher.publish_event(bind_uri, event); } Some(new_subscriber) = new_subscriber_rx.recv() => { publisher.register_subscriber(new_subscriber); } else => break } } })); Self { new_event_tx, new_subscriber_tx, _publisher_task, } } pub fn global() -> &'static Arc { static GLOBAL: LazyLock> = LazyLock::new(|| Arc::new(Locations::new())); &GLOBAL } pub fn publish(&self, bind_uri: BindUri, event: AddressEvent) { _ = self.new_event_tx.send((bind_uri, event)); } pub fn upsert(&self, bind_uri: BindUri, data: Arc) { self.publish(bind_uri, AddressEvent::Upsert(data)); } pub fn remove(&self, bind_uri: BindUri) { self.publish(bind_uri, AddressEvent::Remove(TypeId::of::())); } pub fn close(&self, bind_uri: BindUri) { self.publish(bind_uri, AddressEvent::Closed); } pub fn subscribe(&self) -> Observer { let (tx, rx) = mpsc::unbounded_channel(); // Register the new subscriber. _ = self.new_subscriber_tx.send(tx); Observer { receiver: rx } } } pub struct Observer { receiver: EventReceiver, } impl Observer { pub async fn recv(&mut self) -> Option<(BindUri, AddressEvent)> { self.receiver.recv().await } pub fn try_recv(&mut self) -> Result<(BindUri, AddressEvent), mpsc::error::TryRecvError> { self.receiver.try_recv() } } #[derive(Debug, Clone)] pub struct IfaceLocations { locations: Arc, ref_iface: Arc>, } impl IfaceLocations { pub fn new(ref_iface: I, locations: Arc) -> Self { locations.upsert( ref_iface.iface().bind_uri(), Arc::new(ref_iface.iface().bound_addr()), ); Self { locations, ref_iface: Arc::new(Mutex::new(ref_iface)), } } fn lock_ref_iface(&self) -> MutexGuard<'_, I> { self.ref_iface.lock().expect("Mutex poisoned") } /// Scope operation to the newest interface. pub fn r#for(&self, ref_iface: &R, f: impl FnOnce(&Locations, BindUri)) where R: RefIO + 'static, { let current_iface = self.lock_ref_iface(); let current_iface = current_iface.deref(); if !(ref_iface as &dyn Any) .downcast_ref::() .is_some_and(|ref_iface| ref_iface.same_io(current_iface)) { return; } f(&self.locations, current_iface.iface().bind_uri()); } } pub type LocationsComponent = IfaceLocations; impl Component for LocationsComponent { fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> { _ = cx; Poll::Ready(()) } fn reinit(&self, iface: &Interface) { let mut ref_iface = self.lock_ref_iface(); if iface.downgrade().same_io(ref_iface.deref()) { return; } *ref_iface = iface.downgrade(); let bind_uri = iface.bind_uri(); self.locations.close(bind_uri.clone()); self.locations .upsert(bind_uri.clone(), Arc::new(iface.bound_addr())); } } ================================================ FILE: qinterface/src/component/route/handler.rs ================================================ use std::sync::{Mutex, MutexGuard}; use qbase::packet::Packet; use super::Way; pub type PacketSink

= Box; pub struct PacketHandler

(Mutex>>); impl

std::fmt::Debug for PacketHandler

{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PacketHandler").finish() } } impl

Default for PacketHandler

{ fn default() -> Self { Self::drain() } } impl

PacketHandler

{ pub fn new(sink: PacketSink

) -> Self { Self(Mutex::new(Some(sink))) } pub(crate) fn lock(&self) -> MutexGuard<'_, Option>> { self.0.lock().expect("PacketHandler mutex poisoned") } pub fn drain() -> PacketHandler

{ PacketHandler(Mutex::new(None)) } pub fn update(&self, handler: PacketSink

) { *self.lock() = Some(handler); } pub fn is_drain(&self) -> bool { self.lock().is_none() } pub fn take(&self) -> Option> { self.lock().take() } pub fn deliver(&self, packet: P, way: Way) { if let Some(sink) = self.lock().as_mut() { sink(packet, way); } } pub fn deliver_packets(&self, packets: impl IntoIterator) { if let Some(sink) = self.lock().as_mut() { for (packet, way) in packets { sink(packet, way); } } } } ================================================ FILE: qinterface/src/component/route/packet.rs ================================================ use bytes::{Bytes, BytesMut}; use derive_more::Deref; use qbase::{ error::QuicError, packet::{ decrypt::{ decrypt_packet, remove_protection_of_long_packet, remove_protection_of_short_packet, }, header::long::InitialHeader, keys::ArcOneRttPacketKeys, number::{InvalidPacketNumber, PacketNumber}, }, }; use qevent::quic::{ PacketHeader, PacketHeaderBuilder, QuicFrame, transport::{PacketDropped, PacketDroppedTrigger, PacketReceived}, }; use rustls::quic::{HeaderProtectionKey, PacketKey}; #[derive(Debug, Deref)] pub struct CipherPacket { #[deref] header: H, payload: BytesMut, payload_offset: usize, } impl CipherPacket where PacketHeaderBuilder: for<'a> From<&'a H>, { pub fn new(header: H, payload: BytesMut, payload_offset: usize) -> Self { Self { header, payload, payload_offset, } } pub fn header(&self) -> &H { &self.header } fn qlog_header(&self) -> PacketHeader { PacketHeaderBuilder::from(&self.header).build() } pub fn drop_on_key_unavailable(self) { qevent::event!(PacketDropped { header: self.qlog_header(), raw: self.payload.freeze(), trigger: PacketDroppedTrigger::KeyUnavailable }) } fn drop_on_remove_header_protection_failure(self) { qevent::event!( PacketDropped { header: self.qlog_header(), raw: self.payload.freeze(), trigger: PacketDroppedTrigger::DecryptionFailure }, details = Map { reason: "remove header protection failure" } ); } fn drop_on_decryption_failure(self, error: qbase::packet::error::Error, pn: u64) { qevent::event!( PacketDropped { header: { PacketHeaderBuilder::from(&self.header) .packet_number(pn) .build() }, raw: self.payload.freeze(), trigger: PacketDroppedTrigger::DecryptionFailure }, details = Map { reason: "decryption failure", error: error.to_string(), }, ) } fn drop_on_reverse_bit_error(self, error: &qbase::packet::error::Error) { qevent::event!( PacketDropped { header: self.qlog_header(), raw: self.payload.freeze(), trigger: PacketDroppedTrigger::Invalid, }, details = Map { reason: "reverse bit error", error: error.to_string() }, ) } fn drop_on_invalid_pn(self, invalid_pn: InvalidPacketNumber) { qevent::event!( PacketDropped { header: self.qlog_header(), raw: self.payload.freeze(), trigger: PacketDroppedTrigger::Invalid, }, details = Map { reason: "invalid packet number", invalid_pn: invalid_pn.to_string() }, ) } pub fn payload_len(&self) -> usize { self.payload.len() } pub fn decrypt_long_packet( mut self, hpk: &dyn HeaderProtectionKey, pk: &dyn PacketKey, pn_decoder: impl FnOnce(PacketNumber) -> Result, ) -> Option, QuicError>> { let pkt_buf = self.payload.as_mut(); let undecoded_pn = match remove_protection_of_long_packet(hpk, pkt_buf, self.payload_offset) { Ok(Some(undecoded_pn)) => undecoded_pn, Ok(None) => { self.drop_on_remove_header_protection_failure(); return None; } Err(invalid_reverse_bits) => { self.drop_on_reverse_bit_error(&invalid_reverse_bits); return Some(Err(invalid_reverse_bits.into())); } }; let decoded_pn = match pn_decoder(undecoded_pn) { Ok(pn) => pn, Err(invalid_packet_number) => { self.drop_on_invalid_pn(invalid_packet_number); return None; } }; let body_offset = self.payload_offset + undecoded_pn.size(); let body_length = match decrypt_packet(pk, decoded_pn, pkt_buf, body_offset) { Ok(body_length) => body_length, Err(error) => { self.drop_on_decryption_failure(error, decoded_pn); return None; } }; Some(Ok(PlainPacket { header: self.header, plain: self.payload.freeze(), payload_offset: self.payload_offset, undecoded_pn, decoded_pn, body_len: body_length, })) } pub fn decrypt_short_packet( mut self, hpk: &dyn HeaderProtectionKey, pk: &ArcOneRttPacketKeys, pn_decoder: impl FnOnce(PacketNumber) -> Result, ) -> Option, QuicError>> { let pkt_buf = self.payload.as_mut(); let (undecoded_pn, key_phase) = match remove_protection_of_short_packet(hpk, pkt_buf, self.payload_offset) { Ok(Some((undecoded, key_phase))) => (undecoded, key_phase), Ok(None) => { self.drop_on_remove_header_protection_failure(); return None; } Err(invalid_reverse_bits) => { self.drop_on_reverse_bit_error(&invalid_reverse_bits); return Some(Err(invalid_reverse_bits.into())); } }; let decoded_pn = match pn_decoder(undecoded_pn) { Ok(pn) => pn, Err(invalid_pn) => { self.drop_on_invalid_pn(invalid_pn); return None; } }; let pk = pk.lock_guard().get_remote(key_phase, decoded_pn); let body_offset = self.payload_offset + undecoded_pn.size(); let body_length = match decrypt_packet(pk.as_ref(), decoded_pn, pkt_buf, body_offset) { Ok(body_length) => body_length, Err(error) => { self.drop_on_decryption_failure(error, decoded_pn); return None; } }; Some(Ok(PlainPacket { header: self.header, plain: self.payload.freeze(), payload_offset: self.payload_offset, undecoded_pn, decoded_pn, body_len: body_length, })) } } impl CipherPacket { pub fn drop_on_scid_unmatch(self) { qevent::event!( PacketDropped { header: self.qlog_header(), raw: self.payload.freeze(), trigger: PacketDroppedTrigger::Rejected }, details = Map { reason: "different scid with first initial packet" }, ) } } #[derive(Deref)] pub struct PlainPacket { #[deref] header: H, decoded_pn: u64, undecoded_pn: PacketNumber, plain: Bytes, payload_offset: usize, body_len: usize, } impl PlainPacket { pub fn size(&self) -> usize { self.plain.len() } pub fn pn(&self) -> u64 { self.decoded_pn } pub fn payload_len(&self) -> usize { self.undecoded_pn.size() + self.body_len } pub fn body(&self) -> Bytes { let packet_offset = self.payload_offset + self.undecoded_pn.size(); self.plain .slice(packet_offset..packet_offset + self.body_len) } pub fn raw_info(&self) -> qevent::RawInfo { qevent::build!(qevent::RawInfo { length: self.plain.len() as u64, payload_length: self.payload_len() as u64, data: &self.plain, }) } } impl PlainPacket where PacketHeaderBuilder: for<'a> From<&'a H>, { pub fn qlog_header(&self) -> PacketHeader { let mut builder = PacketHeaderBuilder::from(&self.header); qevent::build! {@field builder, packet_number: self.decoded_pn, length: self.payload_len() as u16 }; builder.build() } pub fn drop_on_interface_not_found(self) { qevent::event!( PacketDropped { header: self.qlog_header(), raw: self.raw_info(), trigger: PacketDroppedTrigger::Genera }, details = Map { reason: "interface not found" } ) } pub fn drop_on_conenction_closed(self) { qevent::event!( PacketDropped { header: self.qlog_header(), raw: self.raw_info(), trigger: PacketDroppedTrigger::Genera }, details = Map { reason: "connection closed" } ) } pub fn log_received(&self, frames: impl Into>) { qevent::event!(PacketReceived { header: self.qlog_header(), frames, raw: self.raw_info(), }) } } ================================================ FILE: qinterface/src/component/route/queue.rs ================================================ use qbase::{ packet::{ DataHeader, Packet, header::{long, short}, }, util::BoundQueue, }; use crate::component::route::{CipherPacket, Way}; type PacketQueue

= BoundQueue<(CipherPacket

, Way)>; // 需要一个四元组,pathway + src + dst #[derive(Debug)] pub struct RcvdPacketQueue { initial: PacketQueue, handshake: PacketQueue, zero_rtt: PacketQueue, one_rtt: PacketQueue, // pub retry: } impl Default for RcvdPacketQueue { fn default() -> Self { Self::new() } } impl RcvdPacketQueue { pub fn new() -> Self { Self { initial: BoundQueue::new(8), handshake: BoundQueue::new(8), zero_rtt: BoundQueue::new(8), one_rtt: BoundQueue::new(128), } } pub fn initial(&self) -> &PacketQueue { &self.initial } pub fn handshake(&self) -> &PacketQueue { &self.handshake } pub fn zero_rtt(&self) -> &PacketQueue { &self.zero_rtt } pub fn one_rtt(&self) -> &PacketQueue { &self.one_rtt } pub fn close_all(&self) { self.initial.close(); self.handshake.close(); self.zero_rtt.close(); self.one_rtt.close(); } pub async fn deliver(&self, packet: Packet, way: Way) { match packet { Packet::Data(packet) => match packet.header { DataHeader::Long(long::DataHeader::Initial(header)) => { let packet = CipherPacket::new(header, packet.bytes, packet.offset); _ = self.initial.send((packet, way)).await; } DataHeader::Long(long::DataHeader::Handshake(header)) => { let packet = CipherPacket::new(header, packet.bytes, packet.offset); _ = self.handshake.send((packet, way)).await; } DataHeader::Long(long::DataHeader::ZeroRtt(header)) => { let packet = CipherPacket::new(header, packet.bytes, packet.offset); _ = self.zero_rtt.send((packet, way)).await; } DataHeader::Short(header) => { let packet = CipherPacket::new(header, packet.bytes, packet.offset); _ = self.one_rtt.send((packet, way)).await; } }, Packet::VN(_vn) => {} Packet::Retry(_retry) => {} } } } ================================================ FILE: qinterface/src/component/route.rs ================================================ use std::{ net::SocketAddr, sync::{Arc, OnceLock, Weak}, task::{Context, Poll}, }; use dashmap::DashMap; use qbase::{ cid::{ConnectionId, GenUniqueCid, RetireCid}, error::Error, frame::{ NewConnectionIdFrame, RetireConnectionIdFrame, io::{ReceiveFrame, SendFrame}, }, net::route::{Link, Pathway}, packet::GetDcid, }; use crate::{BindUri, Interface, component::Component}; mod handler; mod packet; mod queue; pub type Way = (BindUri, Pathway, Link); pub use handler::PacketHandler; pub use packet::{CipherPacket, PlainPacket}; pub use qbase::packet::Packet; pub use queue::RcvdPacketQueue; #[derive(Debug)] pub struct QuicRouter { table: DashMap>, on_unrouted: handler::PacketHandler, } impl QuicRouter { pub fn global() -> &'static Arc { static GLOBAL_ROUTER: OnceLock> = OnceLock::new(); GLOBAL_ROUTER.get_or_init(|| { Arc::new(QuicRouter { table: DashMap::new(), on_unrouted: handler::PacketHandler::drain(), }) }) } pub fn new() -> Self { QuicRouter { table: DashMap::new(), on_unrouted: handler::PacketHandler::drain(), } } // for origin_dcid pub fn insert( self: &Arc, signpost: Signpost, queue: Arc, ) -> QuicRouterEntry { self.table.insert(signpost, queue.clone()); QuicRouterEntry { signpost, queue: Arc::downgrade(&queue), router: self.clone(), } } pub fn remove(&self, signpost: &Signpost) { self.table.remove(signpost); } fn find_entry(&self, packet: &Packet, link: &Link) -> Option> { let dcid = match packet { Packet::VN(vn) => vn.dcid(), Packet::Retry(retry) => retry.dcid(), Packet::Data(data_packet) => data_packet.dcid(), }; if !dcid.is_empty() { let signpost = Signpost::from(*dcid); self.table.get(&signpost).map(|queue| queue.clone()) } else { let signpost = Signpost::from(link.dst); self.table.get(&signpost).map(|queue| queue.clone()) } } pub async fn try_deliver(&self, packet: Packet, way: Way) -> Result<(), (Packet, Way)> { match self.find_entry(&packet, &way.2) { Some(rcvd_pkt_q) => { rcvd_pkt_q.deliver(packet, way).await; Ok(()) } None => Err((packet, way)), } } pub async fn deliver(&self, packet: Packet, way: Way) { let rcvd_pkt_q = match self.find_entry(&packet, &way.2) { Some(rcvd_pkt_q) => rcvd_pkt_q, None => { // For packets that cannot be routed, this likely indicates a new connection. // In some cases, multiple threads (e.g., A and B) may be waiting for the lock, // and both would cause the server to create separate new connections. let mut on_unrouted = self.on_unrouted.lock(); let Some(on_unrouted) = on_unrouted.as_mut() else { // Drain mode, just drop the packet return; }; // Therefore, we retry routing here to allow thread B to route its packet // to the connection created by thread A, instead of creating another new connection. match self.find_entry(&packet, &way.2) { Some(rcvd_pkt_q) => rcvd_pkt_q, None => { (on_unrouted)(packet, way); return; } } } }; rcvd_pkt_q.deliver(packet, way).await; } pub fn on_connectless_packets(&self, sink: S) -> bool where S: Fn(Packet, Way) + Send + 'static, { let mut on_unrouted = self.on_unrouted.lock(); if on_unrouted.is_some() { return false; } *on_unrouted = Some(Box::new(sink)); true } pub fn is_connectless_draining(&self) -> bool { self.on_unrouted.is_drain() } pub fn drain_connectless(&self) { self.on_unrouted.take(); } pub fn registry_on_issuing_scid( self: &Arc, rcvd_pkts_q: Arc, issued_cids: T, ) -> QuicRouterRegistry { QuicRouterRegistry { router: self.clone(), rcvd_pkts_q, issued_cids, } } } impl Default for QuicRouter { fn default() -> Self { Self::new() } } #[derive(Debug, PartialEq, Clone, Copy, Eq, Hash)] pub struct Signpost { cid: ConnectionId, peer: Option, } impl From for Signpost { fn from(value: ConnectionId) -> Self { Self { cid: value, peer: None, } } } impl From for Signpost { fn from(value: SocketAddr) -> Self { Self { cid: ConnectionId::default(), peer: Some(value), } } } #[must_use = "When RouterEntry dropped, this will remove the entry from the router table"] pub struct QuicRouterEntry { signpost: Signpost, queue: Weak, router: Arc, } impl QuicRouterEntry { pub fn signpost(&self) -> Signpost { self.signpost } pub fn remove(&self) { self.router .table .remove_if(&self.signpost, |_, exist_queue| { Weak::ptr_eq(&Arc::downgrade(exist_queue), &self.queue) }); } } impl Drop for QuicRouterEntry { fn drop(&mut self) { self.remove(); } } #[derive(Clone)] pub struct QuicRouterRegistry { router: Arc, rcvd_pkts_q: Arc, issued_cids: TX, } impl GenUniqueCid for QuicRouterRegistry where T: Send + Sync + 'static, { fn gen_unique_cid(&self) -> ConnectionId { core::iter::from_fn(|| Some(ConnectionId::random_gen_with_mark(8, 0x80, 0x7F))) .find(|cid| { let signpost = Signpost::from(*cid); let entry = self.router.table.entry(signpost); if matches!(entry, dashmap::Entry::Occupied(..)) { return false; } entry.insert(self.rcvd_pkts_q.clone()); true }) .unwrap() } } impl RetireCid for QuicRouterRegistry where TX: Send + Sync + 'static, { fn retire_cid(&self, cid: ConnectionId) { self.router.remove(&Signpost::from(cid)); } } impl SendFrame for QuicRouterRegistry where TX: SendFrame, { fn send_frame>(&self, iter: I) { self.issued_cids.send_frame(iter); } } impl ReceiveFrame for QuicRouterRegistry where RX: ReceiveFrame, { type Output = (); fn recv_frame(&self, frame: RetireConnectionIdFrame) -> Result { self.issued_cids.recv_frame(frame) } } #[derive(Debug, Clone)] pub struct QuicRouterComponent { router: Arc, } impl QuicRouterComponent { pub fn new(router: Arc) -> Self { Self { router } } pub fn router(&self) -> Arc { self.router.clone() } } impl Component for QuicRouterComponent { fn reinit(&self, _quic_iface: &Interface) {} fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> { _ = cx; Poll::Ready(()) } } ================================================ FILE: qinterface/src/component.rs ================================================ use std::{ any::{Any, TypeId}, collections::{HashMap, hash_map}, fmt::Debug, hash::{BuildHasherDefault, Hasher}, task::{Context, Poll, ready}, }; use crate::Interface; pub mod alive; pub mod location; pub mod route; pub trait Component: Any + Debug + Send + Sync { /// Gracefully shutdown the component when IO is unbound. fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()>; /// Re-initialize the component after the QuicIO has been rebound /// /// Normally, this method first shuts down the component, /// then re-initializes it with the new QuicIO. /// /// Implementation may override this method for optimization. fn reinit(&self, iface: &Interface); } // With TypeIds as keys, there's no need to hash them. They are already hashes // themselves, coming from the compiler. The IdHasher just holds the u64 of // the TypeId, and then returns it, instead of doing any bit fiddling. #[derive(Default)] pub(super) struct IdHasher(u64); impl Hasher for IdHasher { fn write(&mut self, _: &[u8]) { unreachable!("TypeId calls write_u64"); } #[inline] fn write_u64(&mut self, id: u64) { self.0 = id; } #[inline] fn finish(&self) -> u64 { self.0 } } #[derive(Default)] pub struct Components { pub(super) map: HashMap, BuildHasherDefault>, } impl Components { pub fn new() -> Self { Self::default() } pub fn get(&self) -> Option<&C> { self.map .get(&TypeId::of::()) .and_then(|c| (c.as_ref() as &dyn Any).downcast_ref()) } pub fn exist(&self) -> bool { self.map.contains_key(&TypeId::of::()) } pub fn with(&self, f: impl FnOnce(&C) -> T) -> Option { self.get::().map(f) } pub fn init_with(&mut self, init: impl FnOnce() -> C) -> &mut C { let ref_mut = self .map .entry(TypeId::of::()) .or_insert_with(|| Box::new(init())); (ref_mut.as_mut() as &mut dyn Any).downcast_mut().unwrap() } pub fn try_init_with( &mut self, init: impl FnOnce() -> Result, ) -> Result<&mut C, E> { let entry = self.map.entry(TypeId::of::()); let ref_mut = match entry { hash_map::Entry::Occupied(entry) => entry.into_mut(), hash_map::Entry::Vacant(entry) => entry.insert(Box::new(init()?)), }; Ok((ref_mut.as_mut() as &mut dyn Any).downcast_mut().unwrap()) } pub fn poll_remove(&mut self, cx: &mut Context<'_>) -> Poll<()> where C: Component, { let hash_map::Entry::Occupied(entry) = self.map.entry(TypeId::of::()) else { return Poll::Ready(()); }; ready!(entry.get().poll_shutdown(cx)); entry.remove(); Poll::Ready(()) } } ================================================ FILE: qinterface/src/device.rs ================================================ use std::{ collections::HashMap, fmt::Debug, net::IpAddr, sync::{Arc, Mutex, OnceLock, RwLock}, time::Duration, }; use derive_more::{Deref, DerefMut}; pub use netdev::Interface; pub use netwatcher::Error as WatcherError; use netwatcher::WatchHandle; use qbase::{ net::Family, util::{UniqueId, UniqueIdGenerator}, }; use tokio::{ sync::mpsc::{UnboundedReceiver, UnboundedSender}, time::MissedTickBehavior, }; use tokio_util::task::AbortOnDropHandle; #[allow(clippy::large_enum_variant)] #[derive(Debug, Clone, PartialEq, Eq)] pub enum InterfaceEvent { Added { device: String, new_interface: Interface, }, Removed { device: String, old_interface: Interface, }, Changed { device: String, old_interface: Interface, new_interface: Interface, }, } impl InterfaceEvent { pub fn device(&self) -> &str { match self { InterfaceEvent::Added { device, .. } => device, InterfaceEvent::Removed { device, .. } => device, InterfaceEvent::Changed { device, .. } => device, } } pub fn old_interface(&self) -> Option<&Interface> { match self { InterfaceEvent::Removed { old_interface, .. } | InterfaceEvent::Changed { old_interface, .. } => Some(old_interface), _ => None, } } pub fn new_interface(&self) -> Option<&Interface> { match self { InterfaceEvent::Added { new_interface, .. } | InterfaceEvent::Changed { new_interface, .. } => Some(new_interface), _ => None, } } } impl InterfaceEvent { pub fn from_update<'i>( old_interfaces: &'i HashMap, new_interfaces: &'i HashMap, ) -> impl Iterator + 'i { new_interfaces .iter() .filter_map(|(name, new_interface)| match old_interfaces.get(name) { Some(old_interface) if new_interface != old_interface => { Some(InterfaceEvent::Changed { device: name.to_owned(), old_interface: old_interface.clone(), new_interface: new_interface.clone(), }) } None => Some(InterfaceEvent::Added { device: name.to_owned(), new_interface: new_interface.clone(), }), _ => None, }) .chain( old_interfaces .iter() .filter(|(name, ..)| !new_interfaces.contains_key(*name)) .map(|(name, old_interface)| InterfaceEvent::Removed { device: name.to_owned(), old_interface: old_interface.clone(), }), ) } } fn scan_interfaces() -> HashMap { netdev::get_interfaces() .into_iter() .map(|mut iface| { // compatibility with windows interface names iface.name = iface .name .trim_start_matches('{') .trim_end_matches('}') .to_string(); iface }) .map(|iface| (iface.name.clone(), iface)) .collect() } type SubscribersMap = RwLock>>>; type InterfacesMap = RwLock>; #[derive(Debug, Deref, DerefMut)] pub struct InterfaceEventReceiver { id: UniqueId, #[deref] #[deref_mut] receiver: UnboundedReceiver>, subscribers: Arc, } impl Drop for InterfaceEventReceiver { fn drop(&mut self) { self.subscribers.write().unwrap().remove(&self.id); } } pub struct InterfacesMonitor { interfaces: HashMap, receiver: InterfaceEventReceiver, } impl InterfacesMonitor { #[inline] pub async fn update(&mut self) -> Option<(&HashMap, Arc)> { self.receiver.recv().await.map(|event| { match event.as_ref() { InterfaceEvent::Added { device, new_interface, } => { self.interfaces .insert(device.clone(), new_interface.clone()); } InterfaceEvent::Removed { device, .. } => { self.interfaces.remove(device); } InterfaceEvent::Changed { device, new_interface, .. } => { self.interfaces .insert(device.clone(), new_interface.clone()); } } (self.interfaces(), event) }) } #[inline] pub fn try_update(&mut self) -> Option<(&HashMap, Arc)> { self.receiver.try_recv().ok().map(|event| { match event.as_ref() { InterfaceEvent::Added { device, new_interface, } => { self.interfaces .insert(device.clone(), new_interface.clone()); } InterfaceEvent::Removed { device, .. } => { self.interfaces.remove(device); } InterfaceEvent::Changed { device, new_interface, .. } => { self.interfaces .insert(device.clone(), new_interface.clone()); } } (self.interfaces(), event) }) } #[inline] pub fn interfaces(&self) -> &HashMap { &self.interfaces } pub fn into_inner(self) -> (HashMap, InterfaceEventReceiver) { (self.interfaces, self.receiver) } } #[derive(Debug)] struct State { interfaces: InterfacesMap, subscrib_id_generator: UniqueIdGenerator, subscribers: Arc, } impl Default for State { fn default() -> Self { Self { interfaces: RwLock::new(scan_interfaces()), subscrib_id_generator: UniqueIdGenerator::new(), subscribers: Arc::new(RwLock::new(HashMap::new())), } } } impl State { fn check_network_changes(&self) { let mut interfaces = self.interfaces.write().unwrap(); let subscribers = self.subscribers.read().unwrap(); let old_interfaces = interfaces.clone(); let new_interfaces = scan_interfaces(); for event in InterfaceEvent::from_update(&old_interfaces, &new_interfaces) { let arc_event = Arc::new(event); for sender in subscribers.values() { let _ = sender.send(arc_event.clone()); } } *interfaces = new_interfaces.clone(); } fn monitor(&self) -> (HashMap, InterfaceEventReceiver) { let mut subscribers = self.subscribers.write().unwrap(); let interfaces = self.interfaces.read().unwrap().clone(); let current_interfaces = interfaces; let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let id = self.subscrib_id_generator.generate(); subscribers.insert(id, tx); let observer = InterfaceEventReceiver { id, receiver: rx, subscribers: Arc::clone(&self.subscribers), }; (current_interfaces, observer) } fn event_receiver(&self) -> InterfaceEventReceiver { let mut subscribers = self.subscribers.write().unwrap(); let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let id = self.subscrib_id_generator.generate(); subscribers.insert(id, tx); InterfaceEventReceiver { id, receiver: rx, subscribers: Arc::clone(&self.subscribers), } } fn interfaces(&self) -> HashMap { self.interfaces.read().unwrap().clone() } fn get(&self, name: &str) -> Option { self.interfaces.read().unwrap().get(name).cloned() } } pub struct Devices { state: Arc, watcher: Mutex>, _timer: AbortOnDropHandle<()>, } impl Debug for Devices { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Devices") .field("state", &self.state) .field("watcher", &"...") .field("_timer", &self._timer) .finish() } } impl Devices { pub fn global() -> &'static Devices { static DEVICES: OnceLock = OnceLock::new(); DEVICES.get_or_init(Self::new) } pub fn new() -> Self { let state = Arc::new(State::default()); let timer = AbortOnDropHandle::new(tokio::spawn({ let state = state.clone(); async move { let mut interval = tokio::time::interval(Duration::from_secs(5)); interval.set_missed_tick_behavior(MissedTickBehavior::Delay); loop { interval.tick().await; state.check_network_changes(); } } })); let watcher = netwatcher::watch_interfaces({ let state = state.clone(); move |_update| { // TODO: use the update info to avoid full scan state.check_network_changes(); } }); if let Err(initial_watcher_error) = &watcher { tracing::warn!(target: "interface", "failed to start interfaces watcher: {initial_watcher_error}"); } Self { state, _timer: timer, watcher: watcher.into(), } } #[inline] pub fn restart_watcher(&self) -> Result<(), WatcherError> { let new_watcher = netwatcher::watch_interfaces({ let state = self.state.clone(); move |_update| { // TODO: use the update info to avoid full scan state.check_network_changes(); } })?; *self.watcher.lock().unwrap() = Ok(new_watcher); Ok(()) } #[inline] pub fn on_interface_changed(&self) { self.state.check_network_changes(); } #[inline] pub fn monitor(&self) -> InterfacesMonitor { let (interfaces, receiver) = self.state.monitor(); InterfacesMonitor { interfaces, receiver, } } #[inline] pub fn event_receiver(&self) -> InterfaceEventReceiver { self.state.event_receiver() } #[inline] pub fn interfaces(&self) -> HashMap { self.state.interfaces() } pub fn get(&self, name: &str) -> Option { self.state.get(name) } pub fn resolve(&self, device: &str, family: Family) -> Option { let interface = self.get(device)?; match family { Family::V4 => interface .ipv4 .first() .map(|ipnet| ipnet.addr()) .map(IpAddr::V4), Family::V6 => interface .ipv6 .iter() .map(|ipnet| ipnet.addr()) .find(|ip| !matches!(ip.octets(), [0xfe, 0x80, ..])) .map(IpAddr::V6), } } } impl Default for Devices { #[inline] fn default() -> Self { Self::new() } } ================================================ FILE: qinterface/src/iface.rs ================================================ ================================================ FILE: qinterface/src/io/factory.rs ================================================ use std::task::{Context, Poll, ready}; use crate::{BindUri, IO}; pub trait ProductIO: Send + Sync { fn bind(&self, bind_uri: BindUri) -> Box; fn poll_rebind(&self, cx: &mut Context<'_>, quic_io: &mut Box) -> Poll<()> { _ = ready!(quic_io.poll_close(cx)); *quic_io = self.bind(quic_io.bind_uri()); Poll::Ready(()) } } pub trait ProductIoExt: ProductIO { fn rebind(&self, quic_io: &mut Box) -> impl Future { async { core::future::poll_fn(|cx| self.poll_rebind(cx, quic_io)).await } } } impl ProductIO for F where F: Fn(BindUri) -> Q + Send + Sync, Q: IO + 'static, { #[inline] fn bind(&self, bind_uri: BindUri) -> Box { Box::new((self)(bind_uri)) } } ================================================ FILE: qinterface/src/io/handy.rs ================================================ use crate::BindUri; #[cfg(all(feature = "qudp", any(unix, windows)))] pub mod qudp { use std::{ error::{Error, Error as StdError}, fmt::Display, io::{self, IoSliceMut}, net::SocketAddr, sync::Arc, task::{Context, Poll, ready}, }; use bytes::BytesMut; use qbase::{ net::route::{Line, Link, Pathway}, util::Wakers, }; use qudp::BATCH_SIZE; use thiserror::Error; use crate::{BindUri, IO, Route}; pub struct UdpSocketController { bind_uri: BindUri, send_wakers: Arc>, recv_wakers: Arc, io: Result, BindFailed>, } #[derive(Debug, Clone, Copy, Error)] #[error("UdpSocketController closed")] pub struct Closed(()); impl From for io::Error { fn from(error: Closed) -> Self { io::Error::other(error) } } #[derive(Debug, Clone)] pub struct BindFailed(Arc); impl Display for BindFailed { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "failed to bind UdpSocketController") } } impl StdError for BindFailed { fn source(&self) -> Option<&(dyn Error + 'static)> { Some(self.0.as_ref()) } } impl From for io::Error { fn from(error: BindFailed) -> Self { io::Error::other(error) } } impl UdpSocketController { pub fn bind(bind_uri: BindUri) -> Self { let io = SocketAddr::try_from(&bind_uri) .map_err(|e| { io::Error::new( io::ErrorKind::NotFound, format!("Failed to bind {bind_uri}: {e}"), ) }) .and_then(qudp::UdpSocket::bind); UdpSocketController { bind_uri, send_wakers: Arc::new(Wakers::new()), recv_wakers: Arc::new(Wakers::new()), io: io.map(Ok).map_err(|e| BindFailed(Arc::new(e))), } } fn usc(&self) -> io::Result<&qudp::UdpSocket> { self.io .as_ref() .map_err(|e| io::Error::from(e.clone())) .and_then(|result| result.as_ref().map_err(|e| (*e).into())) } } impl IO for UdpSocketController { fn bind_uri(&self) -> BindUri { self.bind_uri.clone() } fn bound_addr(&self) -> io::Result { self.usc()?.local_addr() } fn max_segments(&self) -> io::Result { Ok(BATCH_SIZE) } fn max_segment_size(&self) -> io::Result { Ok(1500) } fn poll_send( &self, cx: &mut Context, pkts: &[io::IoSlice], route: Route, ) -> Poll> { let io = self.usc()?; self.send_wakers.combine_with(cx, |cx| { debug_assert_eq!(route.ecn(), None); io.poll_send(cx, pkts, &route.line) }) } fn poll_recv( &self, cx: &mut Context, pkts: &mut [BytesMut], route: &mut [Route], ) -> Poll> { let io = self.usc()?; self.recv_wakers.combine_with(cx, |cx| { let dst = io.local_addr()?; let len = route.len().min(pkts.len()); let mut rcvd_lines = Vec::with_capacity(len); rcvd_lines.resize_with(route.len(), Line::default); let mut bufs = pkts[..len] .iter_mut() .map(|p| IoSliceMut::new(p.as_mut())) .collect::>(); debug_assert_eq!(rcvd_lines.len(), bufs.len()); let nrcvd = ready!(io.poll_recv(cx, &mut bufs, &mut rcvd_lines))?; for (idx, mut line) in rcvd_lines.into_iter().take(nrcvd).enumerate() { let pathway = Pathway::new(line.link.src.into(), dst.into()); line.link = Link::new(line.src, io.local_addr()?).flip(); route[idx] = Route::new(pathway.flip(), line); } Poll::Ready(Ok(nrcvd)) }) } fn poll_close(&mut self, _cx: &mut Context) -> Poll> { self.usc()?; self.send_wakers.wake_all(); self.recv_wakers.wake_all(); self.io = Ok(Err(Closed(()))); Poll::Ready(Ok(())) } } } pub mod unsupported { use std::{ io, net::SocketAddr, task::{Context, Poll}, }; use bytes::BytesMut; use qbase::net::route::Route; use thiserror::Error; use crate::{BindUri, IO}; #[derive(Debug, Clone)] pub struct Unsupported { bind_uri: BindUri, } #[derive(Debug, Clone, Copy, Error)] #[error( "qudp feature is not enabled or target platform is not supported, you should use your own ProductQuicIO implementation, not the default" )] pub struct UnsupportedError(()); impl From for io::Error { fn from(error: UnsupportedError) -> Self { io::Error::new(io::ErrorKind::Unsupported, error) } } impl Unsupported { pub fn bind(bind_uri: BindUri) -> Self { Unsupported { bind_uri } } } impl IO for Unsupported { fn bind_uri(&self) -> BindUri { self.bind_uri.clone() } fn bound_addr(&self) -> io::Result { Err(UnsupportedError(()).into()) } fn max_segment_size(&self) -> io::Result { Err(UnsupportedError(()).into()) } fn max_segments(&self) -> io::Result { Err(UnsupportedError(()).into()) } fn poll_send( &self, _: &mut Context, _: &[io::IoSlice], _: Route, ) -> Poll> { Poll::Ready(Err(UnsupportedError(()).into())) } fn poll_recv( &self, _: &mut Context, _: &mut [BytesMut], _: &mut [Route], ) -> Poll> { Poll::Ready(Err(UnsupportedError(()).into())) } fn poll_close(&mut self, _: &mut Context) -> Poll> { Poll::Ready(Ok(())) } } } #[cfg(all(feature = "qudp", any(unix, windows)))] pub static DEFAULT_IO_FACTORY: fn(BindUri) -> qudp::UdpSocketController = |bind_uri| qudp::UdpSocketController::bind(bind_uri); #[cfg(not(all(feature = "qudp", any(unix, windows))))] pub static DEFAULT_IO_FACTORY: fn(BindUri) -> unsupported::Unsupported = |bind_uri| unsupported::Unsupported::bind(bind_uri); const _: () = { use super::ProductIO; const fn assert_product_interface_factory(_: &F) {} assert_product_interface_factory(&DEFAULT_IO_FACTORY); }; ================================================ FILE: qinterface/src/io.rs ================================================ use std::{ any::Any, future::Future, io, net::SocketAddr, sync::Arc, task::{Context, Poll}, }; use bytes::BytesMut; use qbase::net::route::Route; pub mod handy; mod factory; pub use factory::*; use crate::bind_uri::BindUri; /// Network I/O trait /// /// Provides a unified interface for different network transport implementations. /// Note that some implementations may not support all bind address types. /// /// `dquic` uses [`ProductIO`] to create (bind) new [`IO`] instances. /// Read its documentation for more information. /// /// Wrapping a new [`IO`] is easy, /// you can refer to the implementations in the [`handy`] module. /// /// [`ProductIO`]: crate::io::ProductIO pub trait IO: Send + Sync + Any { /// Get the bind address that this interface is bound to /// /// This value cannot change after the interface is bound, /// as it is used as the unique identifier for the interface. fn bind_uri(&self) -> BindUri; /// Get the actual address that this interface is bound to. /// /// For example, if this interface is bound to an [`BindUri`], /// this function should return the actual IP address and port /// address of this interface. /// /// Just like [`UdpSocket::local_addr`] may return an error, /// sometimes an interface cannot get its own actual address, /// then the implementation should return an error as well. /// /// [`UdpSocket::local_addr`]: std::net::UdpSocket::local_addr fn bound_addr(&self) -> io::Result; /// Maximum size of a single network segment in bytes fn max_segment_size(&self) -> io::Result; /// Maximum number of segments that can be sent in a single batch fn max_segments(&self) -> io::Result; /// Poll for sending packets /// /// Attempts to send multiple packets in a single operation. /// Return the number of packets sent, fn poll_send( &self, cx: &mut Context, pkts: &[io::IoSlice], route: Route, ) -> Poll>; /// Poll for receiving packets /// /// Attempts to receive multiple packets in a single operation. /// The number of packets received is limited by the smaller of /// `pkts.capacity()` and `hdrs.len()`. fn poll_recv( &self, cx: &mut Context, pkts: &mut [BytesMut], route: &mut [Route], ) -> Poll>; /// Asynchronously destroy the IO. /// /// When it returns [`Poll::Ready`] (whether with `Ok` or `Err`), /// it must indicate that the resource has been completely destroyed, /// and the same [`BindUri`] can be successfully bound again. /// /// Even if this method is not called, /// the implementation should ensure that [`IO`] does not /// leak any resources when it is dropped. fn poll_close(&mut self, cx: &mut Context) -> Poll>; } pub trait IoExt: IO { #[inline] fn sendmmsg( &self, mut bufs: &[io::IoSlice<'_>], route: Route, ) -> impl Future> + Send { async move { while !bufs.is_empty() { let sent = core::future::poll_fn(|cx| self.poll_send(cx, bufs, route)).await?; bufs = &bufs[sent..]; } Ok(()) } } fn recvmmsg<'b>( &self, bufs: &'b mut Vec, route: &'b mut Vec, ) -> impl Future + Send + 'b>> + Send { async move { let rcvd = std::future::poll_fn(|cx| { let max_segments = self.max_segments()?; let max_segment_size = self.max_segment_size()?; bufs.resize_with(max_segments, || BytesMut::zeroed(max_segment_size)); route.resize_with(max_segments, Route::empty); self.poll_recv(cx, bufs, route) }) .await?; Ok(bufs .drain(..rcvd) .zip(route.drain(..rcvd)) .map(|(mut seg, route)| { (seg.split_to(seg.len().min(route.seg_size() as _)), route) })) } } #[inline] fn close(&mut self) -> impl Future> + Send { async { core::future::poll_fn(|cx| self.poll_close(cx)).await } } } impl IoExt for I {} pub trait RefIO: Clone + Send + Sync { type Interface: IO + ?Sized; fn iface(&self) -> &Self::Interface; fn same_io(&self, other: &Self) -> bool; } impl RefIO for Arc { type Interface = I; #[inline] fn iface(&self) -> &Self::Interface { self.as_ref() } fn same_io(&self, other: &Self) -> bool { Arc::ptr_eq(self, other) } } ================================================ FILE: qinterface/src/lib.rs ================================================ pub mod bind_uri; pub mod component; pub mod device; pub mod io; pub mod manager; use std::{ error::Error, fmt::Debug, net::SocketAddr, sync::{Arc, Weak}, task::{Context, Poll}, }; use bytes::BytesMut; use qbase::{net::route::Route, util::UniqueId}; use thiserror::Error; use crate::{ bind_uri::BindUri, io::{IO, RefIO}, manager::InterfaceContext, }; #[derive(Debug, Clone)] pub struct BindInterface { context: Arc, } impl BindInterface { pub(crate) fn new(iface: InterfaceContext) -> Self { Self { context: Arc::new(iface), } } pub fn bind_uri(&self) -> BindUri { self.context.bind_uri() } pub fn close(&self) -> impl Future> + Send { core::future::poll_fn(|cx| self.context.poll_close(cx)) } pub fn rebind(&self) -> impl Future + Send { core::future::poll_fn(|cx| self.poll_rebind(cx)) } #[inline] pub fn borrow(&self) -> Interface { Interface { bind_id: self.context.bind_id(), bind_iface: self.clone(), } } #[inline] pub fn downgrade(&self) -> WeakBindInterface { WeakBindInterface { context: Arc::downgrade(&self.context), } } #[inline] pub fn borrow_weak(&self) -> WeakInterface { self.borrow().downgrade() } } #[derive(Debug, Clone)] pub struct Interface { bind_id: UniqueId, bind_iface: BindInterface, } #[derive(Debug, Error)] #[error("Interface has been rebinded")] pub struct RebindedError; impl RebindedError { pub fn is_source_of(mut error: &(dyn Error + 'static)) -> bool { loop { if error.is::() { return true; } match error.source() { Some(source) => error = source, None => return false, } } } } impl From for std::io::Error { fn from(value: RebindedError) -> Self { std::io::Error::new(std::io::ErrorKind::ConnectionReset, value) } } impl Interface { #[inline] fn with_io(&self, f: impl FnOnce(&dyn IO) -> T) -> std::io::Result { self.bind_iface .context .with_bind_io(self.bind_id, f) .map_err(Into::into) } #[inline] pub fn bind_interface(&self) -> &BindInterface { &self.bind_iface } #[inline] pub fn downgrade(&self) -> WeakInterface { WeakInterface { bind_uri: self.bind_iface.bind_uri(), bind_id: self.bind_id, weak_iface: self.bind_iface.downgrade(), } } pub fn same_io(&self, other: &Interface) -> bool { self.bind_id == other.bind_id && Arc::ptr_eq(&self.bind_iface.context, &other.bind_iface.context) } } impl RefIO for Interface { type Interface = Self; #[inline] fn iface(&self) -> &Self::Interface { self } fn same_io(&self, other: &Self) -> bool { self.same_io(other) } } impl IO for Interface { #[inline] fn bind_uri(&self) -> BindUri { self.bind_iface.bind_uri() } #[inline] fn bound_addr(&self) -> std::io::Result { self.with_io(|io| io.bound_addr())? } #[inline] fn max_segment_size(&self) -> std::io::Result { self.with_io(|io| io.max_segment_size())? } #[inline] fn max_segments(&self) -> std::io::Result { self.with_io(|io| io.max_segments())? } #[inline] fn poll_send( &self, cx: &mut Context, pkts: &[std::io::IoSlice], route: Route, ) -> Poll> { self.with_io(|io| io.poll_send(cx, pkts, route))? } #[inline] fn poll_recv( &self, cx: &mut Context, pkts: &mut [BytesMut], route: &mut [Route], ) -> Poll> { self.with_io(|io| io.poll_recv(cx, pkts, route))? } #[inline] fn poll_close(&mut self, cx: &mut Context) -> Poll> { self.bind_iface.context.poll_close(cx) } } #[derive(Debug, Error)] #[error("Interface has been unbound")] pub struct UnboundError; impl UnboundError { pub fn is_source_of(mut error: &(dyn Error + 'static)) -> bool { loop { if error.is::() { return true; } match error.source() { Some(source) => error = source, None => return false, } } } } impl From for std::io::Error { fn from(value: UnboundError) -> Self { std::io::Error::new(std::io::ErrorKind::ConnectionReset, value) } } #[derive(Debug, Clone)] pub struct WeakBindInterface { context: Weak, } impl WeakBindInterface { pub fn upgrade(&self) -> Result { Ok(BindInterface { context: self.context.upgrade().ok_or(UnboundError)?, }) } pub fn borrow(&self) -> Result { Ok(self.upgrade()?.borrow_weak()) } pub fn same_io(&self, other: &WeakBindInterface) -> bool { Weak::ptr_eq(&self.context, &other.context) } } #[derive(Debug, Clone)] pub struct WeakInterface { bind_uri: BindUri, bind_id: UniqueId, weak_iface: WeakBindInterface, } impl From for WeakInterface { fn from(iface: Interface) -> Self { iface.downgrade() } } impl WeakInterface { pub fn upgrade(&self) -> Result { Ok(Interface { bind_iface: self.weak_iface.upgrade()?, bind_id: self.bind_id, }) } pub fn same_io(&self, other: &WeakInterface) -> bool { self.bind_id == other.bind_id && self.weak_iface.same_io(&other.weak_iface) } } impl RefIO for WeakInterface { type Interface = WeakInterface; fn iface(&self) -> &Self::Interface { self } fn same_io(&self, other: &Self) -> bool { self.same_io(other) } } impl IO for WeakInterface { fn bind_uri(&self) -> BindUri { self.bind_uri.clone() } fn bound_addr(&self) -> std::io::Result { self.upgrade()?.bound_addr() } fn max_segment_size(&self) -> std::io::Result { self.upgrade()?.max_segment_size() } fn max_segments(&self) -> std::io::Result { self.upgrade()?.max_segments() } fn poll_send( &self, cx: &mut Context, pkts: &[std::io::IoSlice], route: Route, ) -> Poll> { self.upgrade()?.poll_send(cx, pkts, route) } fn poll_recv( &self, cx: &mut Context, pkts: &mut [BytesMut], route: &mut [Route], ) -> Poll> { self.upgrade()?.poll_recv(cx, pkts, route) } fn poll_close(&mut self, cx: &mut Context) -> Poll> { self.upgrade()?.poll_close(cx) } } ================================================ FILE: qinterface/src/manager.rs ================================================ use std::{ any::Any, fmt::Debug, future::Future, io, mem, net::SocketAddr, ops::{Deref, DerefMut}, sync::{Arc, OnceLock}, task::{Context, Poll, ready}, }; use bytes::BytesMut; use dashmap::{DashMap, Entry}; use futures::FutureExt; use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use qbase::{ net::route, util::{UniqueId, UniqueIdGenerator}, }; use tokio::sync::SetOnce; use tracing::Instrument as _; use crate::{ BindInterface, Interface, RebindedError, WeakBindInterface, bind_uri::BindUri, component::{Component, Components}, io::{IO, IoExt, ProductIO}, }; /// Global [`IO`] manager that manages the lifecycle of all interfaces. /// /// Calling the [`InterfaceManager::bind`] method with a [`BindUri`] returns a [`BindInterface`], primarily used for listening on addresses. /// As long as [`BindInterface`] instances exist, the corresponding [`IO`] for that [`BindUri`] won't be automatically released. /// /// For actual data transmission, you need [`Interface`], which can be obtained via [`InterfaceManager::borrow`] or [`BindInterface::borrow`]. /// Like [`BindInterface`], it keeps the [`IO`] alive, but with one key difference: once a rebind occurs, /// any previous [`Interface`] for that [`BindUri`] becomes invalid, and attempting to send or receive packets /// will result in [`RebindedError] errors. #[derive(Default, Debug)] pub struct InterfaceManager { interfaces: DashMap, bind_id_generator: UniqueIdGenerator, } #[derive(Debug)] struct InterfaceEntry { weak_iface: WeakBindInterface, dropped: Arc>, } impl InterfaceEntry { fn is_dropped(&self) -> bool { self.dropped.get().is_some() } fn dropped(&self) -> impl Future + use<> { let dropped = self.dropped.clone(); async move { dropped.wait().await; } } } impl InterfaceManager { #[inline] pub fn global() -> &'static Arc { static GLOBAL: OnceLock> = OnceLock::new(); GLOBAL.get_or_init(Arc::default) } #[inline] pub fn new() -> Self { Self::default() } fn new_binding( self: &Arc, entry: Entry, factory: Arc, ) -> BindInterface { let context = InterfaceContext { factory: factory.clone(), binding: RwLock::new(Binding::new( factory.bind(entry.key().clone()), self.bind_id_generator.generate(), )), dropped: Arc::new(SetOnce::new()), ifaces: self.clone(), components: RwLock::new(Components::default()), }; let dropped = context.dropped.clone(); let iface = BindInterface::new(context); let weak_iface = iface.downgrade(); entry.insert(InterfaceEntry { weak_iface, dropped, }); iface } pub async fn bind( self: &Arc, bind_uri: BindUri, factory: Arc, ) -> BindInterface { // TODO: error: rebind with difference factory loop { match self.interfaces.entry(bind_uri.clone()) { // (1) new binding: context closed but not yet removed Entry::Occupied(entry) if entry.get().is_dropped() => { return self.new_binding(Entry::Occupied(entry), factory); } // (2) new binding: no existing context Entry::Vacant(entry) => { return self.new_binding(Entry::Vacant(entry), factory); } // try reuse existing binding Entry::Occupied(entry) => match entry.get().weak_iface.upgrade() { // (3) reuse existing binding Ok(iface) => return iface.clone(), // (4) no existing binding: close context and retry Err(..) => { let dropped_future = entry.get().dropped(); drop(entry); dropped_future.await; } }, } } } #[inline] pub fn borrow(&self, bind_uri: &BindUri) -> Option { self.interfaces .get(bind_uri) .and_then(|entry| Some(entry.weak_iface.upgrade().ok()?.borrow())) } #[inline] pub fn get(&self, bind_uri: &BindUri) -> Option { self.interfaces .get(bind_uri) .and_then(|entry| entry.weak_iface.upgrade().ok()) } #[inline] pub fn unbind(self: &Arc, bind_uri: BindUri) -> impl Future + Send + use<> { let Entry::Occupied(entry) = self.interfaces.entry(bind_uri) else { return std::future::ready(()).right_future(); }; match entry.get().weak_iface.upgrade() { Ok(bind_iface) => { let drop_future = bind_iface.context.as_ref().drop(); spawn_on_drop::SpawnOnDrop::new(Box::pin(drop_future)).left_future() } // Dropping by InterfaceContext::Drop Err(..) => entry.get().dropped().right_future(), } .left_future() } } mod spawn_on_drop { use std::{ future::Future, pin::Pin, task::{Context, Poll, ready}, }; use tracing::Instrument; pub(crate) struct SpawnOnDrop + Unpin + Send + 'static> { pub(crate) future: Option, } impl + Unpin + Send + 'static> SpawnOnDrop { pub(crate) fn new(future: F) -> Self { Self { future: Some(future), } } } impl + Unpin + Send + 'static> Future for SpawnOnDrop { type Output = F::Output; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.as_mut().get_mut().future.as_mut() { Some(future) => { let output = ready!(Pin::new(future).poll(cx)); self.future = None; Poll::Ready(output) } None => panic!("future polled after completion"), } } } impl + Unpin + Send + 'static> Drop for SpawnOnDrop { fn drop(&mut self) { if let Some(future) = self.future.take() { // Best-effort: schedule the must-complete future before it is dropped. tokio::spawn(future.in_current_span()); } } } } struct Binding { io: Box, id: UniqueId, span: tracing::Span, } impl Binding { fn new(io: Box, id: UniqueId) -> Self { let bind_uri = io.bind_uri(); let span = tracing::debug_span!( parent: None, "interface", %bind_uri, bind_id = usize::from(id), ); Self { io, id, span } } } pub struct InterfaceContext { factory: Arc, binding: RwLock, // shared with [InterfaceEntry] dropped: Arc>, ifaces: Arc, components: RwLock, } impl InterfaceContext { fn binding(&self) -> RwLockReadGuard<'_, Binding> { self.binding.read_recursive() } fn binding_mut(&self) -> RwLockWriteGuard<'_, Binding> { self.binding.write() } pub fn bind_id(&self) -> UniqueId { self.binding().id } fn with_io(&self, f: impl FnOnce(&dyn IO) -> T) -> T { let binding = self.binding(); let _guard = binding.span.enter(); f(binding.io.as_ref()) } pub(crate) fn with_bind_io( &self, bind_id: UniqueId, f: impl FnOnce(&dyn IO) -> T, ) -> Result { let binding = self.binding(); if binding.id != bind_id { return Err(RebindedError); } let _guard = binding.span.enter(); Ok(f(binding.io.as_ref())) } fn components(&self) -> RwLockReadGuard<'_, Components> { self.components.read_recursive() } fn components_mut(&self) -> RwLockWriteGuard<'_, Components> { self.components.write() } pub fn poll_close(&self, cx: &mut Context) -> Poll> { let (mut binding, components) = (self.binding_mut(), self.components()); for (.., component) in &components.map { ready!(component.poll_shutdown(cx)); } binding.io.poll_close(cx) } } impl Debug for InterfaceContext { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Interface") .field("bind_uri", &self.binding().io.bind_uri()) .finish() } } impl BindInterface { pub fn poll_rebind(&self, cx: &mut Context<'_>) -> Poll<()> { let context = self.context.as_ref(); // 降级binding锁 // A: rebind, reinit // B: rebind, reinit // 释放binding锁 // A: lock(B), lock(C), rebind, release(B), reinit, release(C) // B: lock(B), lock(C), rebind, reinit // hold read lock to prevent subsequent rebind, avoid compoents seeing inconsistent state let (new_bind_id, components, span) = { let mut binding = context.binding_mut(); let components = context.components(); ready!(context.factory.poll_rebind(cx, &mut binding.io)); binding.id = context.ifaces.bind_id_generator.generate(); binding.span = tracing::debug_span!( parent: None, "interface", bind_uri = %binding.io.bind_uri(), bind_id = usize::from(binding.id), ); (binding.id, components, binding.span.clone()) }; let iface = Interface { bind_id: new_bind_id, bind_iface: self.clone(), }; let _guard = span.enter(); for (.., component) in &components.map { component.reinit(&iface); } Poll::Ready(()) } pub fn insert_component_with(&self, init: impl FnOnce(&Interface) -> C) { self.with_components_mut(|components, iface| { components.init_with(|| init(iface)); }); } pub fn with_components(&self, f: impl FnOnce(&Components, &Interface) -> T) -> T { let context = self.context.as_ref(); let (binding, components) = (context.binding(), context.components()); let _guard = binding.span.enter(); let iface = Interface { bind_id: binding.id, bind_iface: self.clone(), }; f(components.deref(), &iface) } pub fn with_components_mut(&self, f: impl FnOnce(&mut Components, &Interface) -> T) -> T { let context = self.context.as_ref(); let (binding, mut components) = (context.binding(), context.components_mut()); let _guard = binding.span.enter(); let iface = Interface { bind_id: binding.id, bind_iface: self.clone(), }; f(components.deref_mut(), &iface) } } impl Interface { pub fn with_component( &self, f: impl FnOnce(&C) -> T, ) -> Result, RebindedError> { let context = self.bind_iface.context.as_ref(); let (binding, components) = (context.binding(), context.components()); if self.bind_id != binding.id { return Err(RebindedError); } let _guard = binding.span.enter(); Ok(components.with(f)) } pub fn with_components(&self, f: impl FnOnce(&Components) -> T) -> Result { let context = self.bind_iface.context.as_ref(); let (binding, components) = (context.binding(), context.components()); if self.bind_id != binding.id { return Err(RebindedError); } let _guard = binding.span.enter(); Ok(f(components.deref())) } pub fn get_component(&self) -> Result, RebindedError> { self.with_component(C::clone) } } impl IO for InterfaceContext { fn bind_uri(&self) -> BindUri { self.binding().io.bind_uri() } fn bound_addr(&self) -> io::Result { self.with_io(|io| io.bound_addr()) } fn max_segment_size(&self) -> io::Result { self.with_io(|io| io.max_segment_size()) } fn max_segments(&self) -> io::Result { self.with_io(|io| io.max_segments()) } fn poll_send( &self, cx: &mut Context, pkts: &[io::IoSlice], route: route::Route, ) -> Poll> { self.with_io(|io| io.poll_send(cx, pkts, route)) } fn poll_recv( &self, cx: &mut Context, pkts: &mut [BytesMut], route: &mut [route::Route], ) -> Poll> { self.with_io(|io| io.poll_recv(cx, pkts, route)) } fn poll_close(&mut self, cx: &mut Context) -> Poll> { InterfaceContext::poll_close(self, cx) } } mod dropping_io { use thiserror::Error; use super::*; #[derive(Debug, Clone, Error)] #[error("QuicIO is dropping and cannot be used anymore, you should never see this error")] pub(crate) struct DroppingIO { pub(crate) bind_uri: BindUri, } impl DroppingIO { pub(crate) fn to_io_error(&self) -> io::Error { io::Error::new(io::ErrorKind::NotConnected, self.clone()) } } impl From for io::Error { fn from(error: DroppingIO) -> Self { error.to_io_error() } } impl IO for DroppingIO { fn bind_uri(&self) -> BindUri { self.bind_uri.clone() } fn bound_addr(&self) -> io::Result { Err(self.to_io_error()) } fn max_segment_size(&self) -> io::Result { Err(self.to_io_error()) } fn max_segments(&self) -> io::Result { Err(self.to_io_error()) } fn poll_send( &self, _: &mut Context, _: &[io::IoSlice], _: route::Route, ) -> Poll> { Poll::Ready(Err(self.to_io_error())) } fn poll_recv( &self, _: &mut Context, _: &mut [BytesMut], _: &mut [route::Route], ) -> Poll> { Poll::Ready(Err(self.to_io_error())) } fn poll_close(&mut self, _: &mut Context) -> Poll> { Poll::Ready(Ok(())) } } } impl Binding { pub fn is_dropping(&self) -> bool { (self.io.as_ref() as &dyn Any).is::() } pub fn take_io(&mut self) -> Option> { if self.is_dropping() { return None; } let bind_uri = self.io.bind_uri(); let dropping_io = Box::new(dropping_io::DroppingIO { bind_uri }); Some(mem::replace(&mut self.io, dropping_io)) } } impl InterfaceContext { fn drop(&self) -> impl Future + Send + use<> { let dropped = self.dropped.clone(); let Some(mut io) = self.binding_mut().take_io() else { return std::future::ready(()).right_future(); }; let ifaces = self.ifaces.clone(); let bind_uri = io.bind_uri(); let components = mem::take(self.components_mut().deref_mut()); async move { for (_, component) in components.map { _ = core::future::poll_fn(|cx| component.poll_shutdown(cx)).await; } _ = io.close().await; dropped.set(()).expect("duplicated drop, this is a bug"); tokio::task::spawn_blocking(move || { ifaces .interfaces .remove_if(&bind_uri, |_, entry| entry.is_dropped()); }); } .left_future() } } impl Drop for InterfaceContext { fn drop(&mut self) { if !{ self.binding().is_dropping() } { // Best-effort: schedule async cleanup before the context is dropped. tokio::spawn(InterfaceContext::drop(self).in_current_span()); } } } #[cfg(test)] mod tests { use std::{ sync::{ Arc, atomic::{AtomicUsize, Ordering}, }, task::{Context, Poll}, }; use futures::task::noop_waker_ref; use super::*; use crate::{ component::Component, io::{IO, ProductIO}, }; #[derive(Debug)] struct TestComponent { shutdown_calls: Arc, } impl Component for TestComponent { fn poll_shutdown(&self, _cx: &mut Context<'_>) -> Poll<()> { self.shutdown_calls.fetch_add(1, Ordering::SeqCst); Poll::Ready(()) } fn reinit(&self, _iface: &crate::Interface) {} } #[derive(Debug)] struct TestIo { bind_uri: BindUri, close_calls: Arc, } impl IO for TestIo { fn bind_uri(&self) -> BindUri { self.bind_uri.clone() } fn bound_addr(&self) -> io::Result { Err(io::Error::new(io::ErrorKind::Unsupported, "not needed")) } fn max_segment_size(&self) -> io::Result { Ok(1200) } fn max_segments(&self) -> io::Result { Ok(1) } fn poll_send( &self, _cx: &mut Context, _pkts: &[io::IoSlice], _route: route::Route, ) -> Poll> { Poll::Ready(Ok(0)) } fn poll_recv( &self, _cx: &mut Context, _pkts: &mut [BytesMut], _route: &mut [route::Route], ) -> Poll> { Poll::Pending } fn poll_close(&mut self, _cx: &mut Context) -> Poll> { self.close_calls.fetch_add(1, Ordering::SeqCst); Poll::Ready(Ok(())) } } #[derive(Debug)] struct TestFactory { close_calls: Arc, } impl ProductIO for TestFactory { fn bind(&self, bind_uri: BindUri) -> Box { Box::new(TestIo { bind_uri, close_calls: self.close_calls.clone(), }) } } #[test] fn binding_take_io_is_idempotent_and_switches_to_dropping_io() { let close_calls = Arc::new(AtomicUsize::new(0)); let bind_uri: BindUri = "inet://127.0.0.1:0".into(); let mut binding = Binding::new( Box::new(TestIo { bind_uri: bind_uri.clone(), close_calls: close_calls.clone(), }), UniqueIdGenerator::new().generate(), ); let first = binding.take_io(); assert!(first.is_some()); assert!(binding.is_dropping()); let second = binding.take_io(); assert!(second.is_none()); // Ensure the original IO wasn't closed by take_io itself. assert_eq!(close_calls.load(Ordering::SeqCst), 0); } #[test] fn poll_close_shuts_down_components_before_io_close() { let shutdown_calls = Arc::new(AtomicUsize::new(0)); let close_calls = Arc::new(AtomicUsize::new(0)); let bind_uri: BindUri = "inet://127.0.0.1:0".into(); let mut components = Components::new(); components.init_with(|| TestComponent { shutdown_calls: shutdown_calls.clone(), }); let mut cx = Context::from_waker(noop_waker_ref()); let ctx = InterfaceContext { factory: Arc::new(TestFactory { close_calls: close_calls.clone(), }), binding: RwLock::new(Binding::new( Box::new(TestIo { bind_uri, close_calls: close_calls.clone(), }), UniqueIdGenerator::new().generate(), )), dropped: Arc::new(SetOnce::new()), ifaces: Arc::new(InterfaceManager::new()), components: RwLock::new(components), }; let r = ctx.poll_close(&mut cx); assert!(matches!(r, Poll::Ready(Ok(())))); assert_eq!(shutdown_calls.load(Ordering::SeqCst), 1); assert_eq!(close_calls.load(Ordering::SeqCst), 1); // Prevent Drop from spawning without a runtime. let _ = ctx.binding_mut().take_io(); } } ================================================ FILE: qinterface/tests/auto_rebind.rs ================================================ mod common; use std::{sync::Arc, time::Duration}; use common::*; use qinterface::{ component::alive::RebindOnNetworkChangedComponent, device::Devices, manager::InterfaceManager, }; use tokio::time; #[test] fn rebind_on_network_changed_triggers_on_recoverable_failure() { run(async { let Some(bind_uri) = any_iface_bind_uri() else { // No real network interface in this environment; skip. return; }; let manager = InterfaceManager::global().clone(); let factory = Arc::new(FakeFactory::new()); let bind_iface = manager.bind(bind_uri.clone(), factory).await; let before = bind_iface.borrow(); let probe = Arc::new(Probe::default()); bind_iface.insert_component_with(|iface| { RebindOnNetworkChangedComponent::new(iface, Devices::global()) }); bind_iface.insert_component_with(|_iface| ProbeComponent::new(probe.clone())); // The component calls try_rebind() once at init. // If alive-check considers the interface unhealthy (recoverable error), it will rebind. let _ = time::timeout(Duration::from_secs(2), async { loop { let now = bind_iface.borrow(); if !now.same_io(&before) { break; } time::sleep(Duration::from_millis(10)).await; } }) .await; // If it did rebind, the probe should have seen reinit. // If it didn't (alive-check passed), that's also acceptable on some systems. let _reinit_calls = probe.reinit_calls.load(std::sync::atomic::Ordering::SeqCst); }) } ================================================ FILE: qinterface/tests/common/mod.rs ================================================ #![allow(unused)] use std::{ future::Future, io, net::{IpAddr, Ipv4Addr, SocketAddr}, sync::{ Arc, Mutex, atomic::{AtomicBool, AtomicUsize, Ordering}, }, task::{Context, Poll}, time::Duration, }; use bytes::BytesMut; use qbase::net::route::{Line, Link, Pathway, Route}; use qinterface::{Interface, bind_uri::BindUri, component::Component, device::Devices, io::IO}; use tokio::{runtime::Runtime, sync::Notify, time}; pub fn run(future: F) -> F::Output { static RT: std::sync::LazyLock = std::sync::LazyLock::new(|| { tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .unwrap() }); RT.block_on(async move { match time::timeout(Duration::from_secs(30), future).await { Ok(output) => output, Err(_timedout) => panic!("test timed out"), } }) } pub fn test_bind_uri() -> BindUri { // inet scheme is easiest & does not require real interfaces let base: BindUri = "inet://127.0.0.1:0".into(); base.alloc_port() } pub fn any_iface_bind_uri() -> Option { let devices = Devices::global(); let interfaces = devices.interfaces(); // prefer v4 for simplicity for (name, iface) in &interfaces { if !iface.ipv4.is_empty() { return Some(format!("iface://v4.{name}:0").as_str().into()); } } // fallback v6 (non-link-local selection happens in resolve()) for (name, iface) in &interfaces { if !iface.ipv6.is_empty() { return Some(format!("iface://v6.{name}:0").as_str().into()); } } None } #[derive(Debug, Default)] pub struct FakeIoState { pub generation: AtomicUsize, pub close_calls: AtomicUsize, } #[derive(Debug)] pub struct FakeIo { bind_uri: BindUri, bound_addr: SocketAddr, state: Arc, closed: AtomicBool, close_notify: Arc, } impl FakeIo { pub fn new(bind_uri: BindUri, bound_addr: SocketAddr, state: Arc) -> Self { Self { bind_uri, bound_addr, state, closed: AtomicBool::new(false), close_notify: Arc::new(Notify::new()), } } pub fn close_notify(&self) -> Arc { self.close_notify.clone() } } impl IO for FakeIo { fn bind_uri(&self) -> BindUri { self.bind_uri.clone() } fn bound_addr(&self) -> io::Result { Ok(self.bound_addr) } fn max_segment_size(&self) -> io::Result { Ok(1500) } fn max_segments(&self) -> io::Result { Ok(1) } fn poll_send( &self, _cx: &mut Context, pkts: &[io::IoSlice], _route: Route, ) -> Poll> { Poll::Ready(Ok(pkts.len())) } fn poll_recv( &self, _cx: &mut Context, _pkts: &mut [BytesMut], _route: &mut [Route], ) -> Poll> { Poll::Pending } fn poll_close(&mut self, _cx: &mut Context) -> Poll> { if !self.closed.swap(true, Ordering::SeqCst) { self.state.close_calls.fetch_add(1, Ordering::SeqCst); self.close_notify.notify_waiters(); } Poll::Ready(Ok(())) } } #[derive(Debug, Clone)] pub struct FakeFactory { pub state: Arc, pub base_port: u16, } impl FakeFactory { pub fn new() -> Self { Self { state: Arc::new(FakeIoState::default()), base_port: 50000, } } } impl qinterface::io::ProductIO for FakeFactory { fn bind(&self, bind_uri: BindUri) -> Box { let generation = self.state.generation.fetch_add(1, Ordering::SeqCst) + 1; let bound_addr = SocketAddr::new( IpAddr::V4(Ipv4Addr::LOCALHOST), self.base_port.saturating_add(generation as u16), ); Box::new(FakeIo::new(bind_uri, bound_addr, self.state.clone())) } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ProbeEventKind { Reinit, Shutdown, } #[derive(Debug, Clone)] pub struct ProbeEvent { pub kind: ProbeEventKind, pub bind_uri: BindUri, } #[derive(Debug, Default)] pub struct Probe { pub shutdown_calls: AtomicUsize, pub reinit_calls: AtomicUsize, pub events: Mutex>, pub last_bind_uri: Mutex>, } #[derive(Debug, Clone)] pub struct ProbeComponent { pub probe: Arc, } impl ProbeComponent { pub fn new(probe: Arc) -> Self { Self { probe } } } impl Component for ProbeComponent { fn poll_shutdown(&self, _cx: &mut Context<'_>) -> Poll<()> { self.probe.shutdown_calls.fetch_add(1, Ordering::SeqCst); let bind_uri = self .probe .last_bind_uri .lock() .unwrap() .clone() .unwrap_or_else(test_bind_uri); self.probe.events.lock().unwrap().push(ProbeEvent { kind: ProbeEventKind::Shutdown, bind_uri, }); Poll::Ready(()) } fn reinit(&self, iface: &Interface) { self.probe.reinit_calls.fetch_add(1, Ordering::SeqCst); *self.probe.last_bind_uri.lock().unwrap() = Some(iface.bind_uri()); self.probe.events.lock().unwrap().push(ProbeEvent { kind: ProbeEventKind::Reinit, bind_uri: iface.bind_uri(), }); } } pub fn dummy_packet_header() -> Route { let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1); let way = Pathway::new(addr.into(), addr.into()); let link = Link::new(addr, addr); let line = Line::new(link, 64, None, 0); Route::new(way, line) } ================================================ FILE: qinterface/tests/components.rs ================================================ mod common; use std::sync::{ Arc, atomic::{AtomicBool, AtomicUsize, Ordering}, }; use common::*; use qinterface::{Interface, component::Component, manager::InterfaceManager}; #[derive(Debug, Default)] struct RouterState { shutdown_calls: AtomicUsize, reinit_calls: AtomicUsize, } #[derive(Debug, Clone)] struct RouterComponent { state: Arc, } impl Component for RouterComponent { fn poll_shutdown(&self, _cx: &mut std::task::Context<'_>) -> std::task::Poll<()> { self.state.shutdown_calls.fetch_add(1, Ordering::SeqCst); std::task::Poll::Ready(()) } fn reinit(&self, _iface: &Interface) { self.state.reinit_calls.fetch_add(1, Ordering::SeqCst); } } #[derive(Debug, Default)] struct ClientState { saw_router: AtomicBool, missing_router_reinits: AtomicUsize, } #[derive(Debug, Clone)] struct ClientComponent { state: Arc, } impl Component for ClientComponent { fn poll_shutdown(&self, _cx: &mut std::task::Context<'_>) -> std::task::Poll<()> { std::task::Poll::Ready(()) } fn reinit(&self, iface: &Interface) { let has_router = iface .with_components(|cs| cs.get::().is_some()) .expect("reinit should always see a non-stale iface"); if has_router { self.state.saw_router.store(true, Ordering::SeqCst); } else { self.state .missing_router_reinits .fetch_add(1, Ordering::SeqCst); } } } #[test] fn component_dependency_missing_then_added_is_observable_on_rebind() { run(async { let manager = InterfaceManager::global().clone(); let factory = Arc::new(FakeFactory::new()); let bind_uri = test_bind_uri(); let bind_iface = manager.bind(bind_uri, factory).await; let client_state = Arc::new(ClientState::default()); bind_iface.insert_component_with(|_iface| ClientComponent { state: client_state.clone(), }); // First rebind: client exists, router missing bind_iface.rebind().await; assert!(!client_state.saw_router.load(Ordering::SeqCst)); assert!(client_state.missing_router_reinits.load(Ordering::SeqCst) > 0); // Add dependency later, then rebind again: client should observe it. let router_state = Arc::new(RouterState::default()); bind_iface.insert_component_with(|_iface| RouterComponent { state: router_state.clone(), }); bind_iface.rebind().await; assert!(client_state.saw_router.load(Ordering::SeqCst)); assert!(router_state.reinit_calls.load(Ordering::SeqCst) > 0); }) } #[test] fn component_dependency_present_is_visible_inside_reinit() { run(async { let manager = InterfaceManager::global().clone(); let factory = Arc::new(FakeFactory::new()); let bind_uri = test_bind_uri(); let bind_iface = manager.bind(bind_uri, factory).await; let router_state = Arc::new(RouterState::default()); bind_iface.insert_component_with(|_iface| RouterComponent { state: router_state.clone(), }); let client_state = Arc::new(ClientState::default()); bind_iface.insert_component_with(|_iface| ClientComponent { state: client_state.clone(), }); bind_iface.rebind().await; assert!(client_state.saw_router.load(Ordering::SeqCst)); assert!(router_state.reinit_calls.load(Ordering::SeqCst) > 0); }) } ================================================ FILE: qinterface/tests/lifecycle.rs ================================================ mod common; use std::{io::ErrorKind, sync::Arc, time::Duration}; use common::*; use qinterface::{io::IO, manager::InterfaceManager}; use tokio::time; #[test] fn unbind_destroys_and_weak_upgrade_fails() { run(async { let manager = InterfaceManager::global().clone(); let factory = Arc::new(FakeFactory::new()); let state = factory.state.clone(); let bind_uri = test_bind_uri(); let bind_iface: qinterface::BindInterface = manager.bind(bind_uri.clone(), factory).await; let weak_bind = bind_iface.downgrade(); let weak_iface = bind_iface.borrow_weak(); // unbind is async; ensure it completes manager.unbind(bind_uri.clone()).await; // existing strong handle remains upgradeable, but should be unusable let err = bind_iface.borrow().bound_addr().unwrap_err(); assert_eq!(err.kind(), ErrorKind::NotConnected); // ensure IO was actually closed time::timeout(Duration::from_secs(2), async { while state.close_calls.load(std::sync::atomic::Ordering::SeqCst) == 0 { time::sleep(Duration::from_millis(10)).await; } }) .await .expect("unbind did not close IO in time"); drop(bind_iface); time::timeout(Duration::from_secs(2), async { loop { if weak_bind.upgrade().is_err() && weak_iface.upgrade().is_err() { break; } time::sleep(Duration::from_millis(10)).await; } }) .await .expect("weak upgrade should eventually fail after unbind + drop"); }) } #[test] fn auto_drop_when_last_ref_gone_allows_rebind() { run(async { let manager = InterfaceManager::global().clone(); let factory = Arc::new(FakeFactory::new()); let state = factory.state.clone(); let bind_uri = test_bind_uri(); // Bind and create a borrowed Interface (strong ref) let bind_iface: qinterface::BindInterface = manager.bind(bind_uri.clone(), factory.clone()).await; let iface = bind_iface.borrow(); drop(bind_iface); drop(iface); // Binding again must wait for the dropped signal, so this also verifies auto-drop. let _bind_iface2 = time::timeout(Duration::from_secs(2), async { manager.bind(bind_uri.clone(), factory.clone()).await }) .await .expect("rebind after auto-drop timed out"); assert!(state.close_calls.load(std::sync::atomic::Ordering::SeqCst) > 0); }) } ================================================ FILE: qinterface/tests/locations.rs ================================================ mod common; use std::{sync::Arc, time::Duration}; use common::*; use qinterface::{ component::location::{AddressEvent, Locations, LocationsComponent}, manager::InterfaceManager, }; use tokio::time; #[test] fn locations_component_emits_closed_then_upsert_on_rebind() { run(async { let manager = InterfaceManager::global().clone(); let factory = Arc::new(FakeFactory::new()); let bind_uri = test_bind_uri(); let bind_iface = manager.bind(bind_uri.clone(), factory).await; let locations = Arc::new(Locations::new()); let mut observer = locations.subscribe(); bind_iface.insert_component_with(|iface| { LocationsComponent::new(iface.downgrade(), locations.clone()) }); // initial upsert (bound_addr result) should be delivered to the subscriber let (u_bind, ev) = time::timeout(Duration::from_secs(2), observer.recv()) .await .expect("timeout waiting for initial upsert") .expect("observer closed"); assert_eq!(u_bind, bind_uri); assert!(matches!(ev, AddressEvent::Upsert(_))); // trigger rebind bind_iface.rebind().await; // must see Closed then Upsert for same bind_uri let (c_bind, c_ev) = time::timeout(Duration::from_secs(2), observer.recv()) .await .expect("timeout waiting for closed") .expect("observer closed"); assert_eq!(c_bind, bind_uri); assert!(matches!(c_ev, AddressEvent::Closed)); let (u2_bind, u2_ev) = time::timeout(Duration::from_secs(2), observer.recv()) .await .expect("timeout waiting for upsert") .expect("observer closed"); assert_eq!(u2_bind, bind_uri); assert!(matches!(u2_ev, AddressEvent::Upsert(_))); // sanity: stale interface should not be able to touch component let old_iface = bind_iface.borrow(); bind_iface.rebind().await; let err = old_iface.with_components(|_c| ()).unwrap_err(); let _ = err; }) } ================================================ FILE: qinterface/tests/rebind.rs ================================================ mod common; use std::{io::ErrorKind, sync::Arc}; use common::*; use qinterface::{RebindedError, io::IO, manager::InterfaceManager}; #[test] fn manual_rebind_makes_old_interface_stale() { run(async { let manager = InterfaceManager::global().clone(); let factory = Arc::new(FakeFactory::new()); let bind_uri = test_bind_uri(); let bind_iface = manager.bind(bind_uri.clone(), factory).await; let old_iface = bind_iface.borrow(); // install a component so we can validate stale with_component let probe = Arc::new(Probe::default()); bind_iface.insert_component_with(|_iface| ProbeComponent::new(probe.clone())); // rebind -> new bind_id bind_iface.rebind().await; let new_iface = bind_iface.borrow(); assert!(!old_iface.same_io(&new_iface)); // Old iface IO operations should fail with ConnectionReset/RebindedError let err = old_iface.bound_addr().unwrap_err(); assert_eq!(err.kind(), ErrorKind::ConnectionReset); assert!(RebindedError::is_source_of(err.get_ref().unwrap())); // Old iface component access should fail with RebindedError let err = old_iface .with_component::(|_c| ()) .unwrap_err(); let _ = err; // it's exactly RebindedError // New iface works new_iface.bound_addr().expect("new iface should be usable"); assert!(probe.reinit_calls.load(std::sync::atomic::Ordering::SeqCst) > 0); }) } ================================================ FILE: qmacro/Cargo.toml ================================================ [package] name = "qmacro" version = "0.5.0" edition.workspace = true description = "dquic's proc macros" readme.workspace = true repository.workspace = true license.workspace = true keywords.workspace = true categories.workspace = true rust-version.workspace = true [lib] proc-macro = true [dependencies] darling = "0.23" proc-macro2 = "1" syn = "2" quote = "1" ================================================ FILE: qmacro/src/derive.rs ================================================ use darling::{FromMeta, ast::NestedMeta}; use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use quote::{ToTokens, format_ident, quote}; use syn::{Error, Expr, ExprRange, Ident, ItemEnum, Token, Variant, punctuated::Punctuated}; pub fn quic_parameters(item: TokenStream) -> Result { let r#enum = syn::parse::(item)?; let enum_name = &r#enum.ident; let mut try_from_varint_match_arms = quote! {}; let mut into_varint_match_arms = quote! {}; // TODO: validate let mut validate_match_arms = quote! {}; let mut default_value_match_arms = quote! {}; let mut value_type_match_arms = quote! {}; for variant in &r#enum.variants { let discriminant = match variant.discriminant.as_ref() { Some((_eq, discriminant)) => discriminant, None => { return Err(Error::new_spanned( variant, "Each variant must have a discriminant, e.g., `= 0`", )); } }; let ident = &variant.ident; try_from_varint_match_arms.extend(quote! { // u64 => Self #discriminant => #enum_name::#ident, }); into_varint_match_arms.extend(quote! { // Self => u64 #enum_name::#ident => #discriminant, }); let param_args = parse_variant_attrs(variant)?; let validate = (param_args.gen_validate(ident)).map_err(|msg| Error::new_spanned(variant, msg))?; validate_match_arms.extend(quote! { #enum_name::#ident => { #validate } }); let default_value = param_args.gen_default_value(); default_value_match_arms.extend(quote! { #enum_name::#ident => { #default_value } }); let value_type = param_args.gen_value_type(); value_type_match_arms.extend(quote! { #enum_name::#ident => #value_type, }); } Ok(quote! { // TODO: try from impl ::core::convert::TryFrom for #enum_name { type Error = Error; fn try_from(value: VarInt) -> Result { Ok(match value.into_u64() { #try_from_varint_match_arms unknown => return Err(Error::UnknownParameterId(value)) }) } } impl From<#enum_name> for VarInt { fn from(value: #enum_name) -> Self { VarInt::from_u64(match value { #into_varint_match_arms }).expect("All variants should have a valid discriminant") } } impl #enum_name { pub fn validate(&self, value: &ParameterValue) -> Result<(), Error> { match self { #validate_match_arms } Ok(()) } pub fn default_value(&self) -> Option { match self { #default_value_match_arms } } pub fn value_type(&self) -> ParameterValueType { match self { #value_type_match_arms } } } }) } fn parse_variant_attrs(variant: &Variant) -> Result { let param_attr = variant .attrs .iter() .find(|attr| attr.path().is_ident("param")) .ok_or_else(|| { Error::new_spanned( variant, "Each variant must have a `#[param(...)]` attribute", ) })?; let param_metas = param_attr .parse_args_with(Punctuated::::parse_terminated)? .into_iter() .collect::>(); ParamArgs::from_list(¶m_metas).map_err(|de| de.into()) } #[derive(darling::FromMeta)] struct ParamArgs { value_type: ParamType, #[darling(default)] default: Option, #[darling(default)] bound: Option, } impl ParamArgs { fn gen_validate(&self, id: &Ident) -> Result { let Some(bound) = &self.bound else { return Ok(quote! {}); }; let value_type = format_ident!("{}", format!("{:?}", self.value_type)); let mut convert_value = quote! { let ParameterValue::#value_type(v) = value else { return Err(Error::InvalidValueType( Self::#id, value.value_type(), )); }; }; convert_value.extend(match self.value_type { ParamType::VarInt => quote! { v.into_u64() }, ParamType::Duration => quote! { v.as_millis() as u64 }, _ => return Err("Bound is only applicable to VarInt or Duration types"), }); Ok(quote! { let value = { #convert_value }; if !(#bound).contains(&value) { return Err(Error::OutOfBounds ( Self::#id, value, #bound, )); } }) } fn gen_default_value(&self) -> TokenStream2 { match &self.default { Some(default) => quote! { Some((#default).into()) }, None => quote! { None }, } } fn gen_value_type(&self) -> TokenStream2 { let value_type = format_ident!("{}", format!("{:?}", self.value_type)); quote! { ParameterValueType::#value_type } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum ParamType { VarInt, Boolean, Bytes, Duration, ResetToken, ConnectionId, PreferredAddress, } impl FromMeta for ParamType { fn from_string(lit: &str) -> ::darling::Result { match lit { "VarInt" => Ok(ParamType::VarInt), "Boolean" => Ok(ParamType::Boolean), "Bytes" => Ok(ParamType::Bytes), "Duration" => Ok(ParamType::Duration), "ResetToken" => Ok(ParamType::ResetToken), "ConnectionId" => Ok(ParamType::ConnectionId), "PreferredAddress" => Ok(ParamType::PreferredAddress), __other => Err(::darling::Error::unknown_value(__other)), } } fn from_expr(expr: &Expr) -> darling::Result { match *expr { Expr::Lit(ref lit) => Self::from_value(&lit.lit), Expr::Group(ref group) => { // syn may generate this invisible group delimiter when the input to the darling // proc macro (specifically, the attributes) are generated by a // macro_rules! (e.g. propagating a macro_rules!'s expr) // Since we want to basically ignore these invisible group delimiters, // we just propagate the call to the inner expression. Self::from_expr(&group.expr) } Expr::Path(ref path) => return Self::from_string(&path.to_token_stream().to_string()), _ => Err(darling::Error::unexpected_expr_type(expr)), } .map_err(|e| e.with_span(expr)) } } ================================================ FILE: qmacro/src/lib.rs ================================================ use proc_macro::TokenStream; use syn::Error; mod derive; #[proc_macro_derive(ParameterId, attributes(param))] pub fn quic_parameters(item: TokenStream) -> TokenStream { TokenStream::from(derive::quic_parameters(item).unwrap_or_else(Error::into_compile_error)) } ================================================ FILE: qprotocol/Cargo.toml ================================================ [package] name = "qprotocol" version.workspace = true edition.workspace = true description = "STUN, forward and QUIC packet routing protocol implementation for dquic" readme = "README.md" repository.workspace = true license.workspace = true keywords.workspace = true categories.workspace = true [dependencies] async-trait = { workspace = true } bon = { workspace = true } bytes = { workspace = true } dashmap = { workspace = true } derive_more = { workspace = true } enum_dispatch = { workspace = true } futures = { workspace = true } bitflags = { workspace = true } nom = { workspace = true } qbase = { workspace = true } qresolve = { workspace = true } qevent = { workspace = true } qinterface = { workspace = true, features = ["qudp"] } qudp = { workspace = true } rand = { workspace = true } rustls = { workspace = true } smallvec = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["sync", "rt", "time", "macros"] } tokio-util = { workspace = true, features = ["rt"] } tracing = { workspace = true } netdev = { workspace = true } [dev-dependencies] clap = { workspace = true } rustls = { workspace = true, features = ["ring"] } tokio = { features = ["fs", "rt-multi-thread"], workspace = true } tokio-test = "0.4" tracing = { workspace = true } [dev-dependencies.tracing-subscriber] workspace = true features = ["fmt", "ansi", "env-filter", "time", "tracing-log"] [features] # Enable shorter TTL only for tests (especially integration tests in other crates). test-ttl = [] ================================================ FILE: qprotocol/src/dns.rs ================================================ ================================================ FILE: qprotocol/src/forward.rs ================================================ ================================================ FILE: qprotocol/src/io.rs ================================================ ================================================ FILE: qprotocol/src/lib.rs ================================================ pub mod dns; pub mod forward; pub mod io; pub mod quic; pub mod stun; ================================================ FILE: qprotocol/src/quic.rs ================================================ ================================================ FILE: qprotocol/src/stun/msg.rs ================================================ use std::{io, net::SocketAddr}; use bytes::BufMut; use nom::{ Err, IResult, Parser, combinator::map, error::{Error, ErrorKind}, multi::many0, number::streaming::{be_u8, be_u16}, }; use qbase::net::{AddrFamily, Family, WriteSocketAddr, be_socket_addr}; use rand::RngExt; use thiserror::Error; pub const BINDING_REQUEST: u16 = 0x0001; pub const BINDING_RESPONSE: u16 = 0x0101; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct TransactionId([u8; 16]); impl AsRef<[u8]> for TransactionId { fn as_ref(&self) -> &[u8] { &self.0 } } impl TransactionId { pub fn from_slice(slice: &[u8]) -> Self { let mut id = [0u8; 16]; id.copy_from_slice(slice); TransactionId(id) } pub fn random() -> Self { let mut id = [0u8; 16]; rand::rng().fill(&mut id); TransactionId(id) } } #[derive(Debug)] pub enum Packet { Request(Request), Response(Response), } /// STUN数据包中的Attr类型: #[derive(Debug, Clone, PartialEq)] pub enum Attr { // 由服务器返回的外网映射地址 MappedAddress(SocketAddr), // 客户端发起请求携带的指定响应地址 ResponseAddress(SocketAddr), // 由客户端请求转发时,携带变换Ip:Port响应的指示 ChangeRequest(u8), // 由服务器返回的Response消息的源地址,即服务器的地址 SourceAddress(SocketAddr), // 由服务器返回的另一台的STUN服务器地址, // 包括不同端口,供后续参考使用 ChangedAddress(SocketAddr), } #[derive(Debug)] pub enum AttrType { MappedAddress(Family), ResponseAddress(Family), // 由客户端请求转发时,携带变换Ip:Port响应的指示 ChangeRequest(u8), // 由服务器返回的Response消息的源地址,即服务器的地址 SourceAddress(Family), // 由服务器返回的另一台的STUN服务器地址, // 包括不同端口,供后续参考使用 ChangedAddress(Family), } #[derive(Debug, Error)] #[error("Invalid attribute type: {0}")] pub struct InvalidAttrType(u8); impl From for u8 { fn from(value: AttrType) -> Self { match value { AttrType::MappedAddress(Family::V4) => 0, AttrType::MappedAddress(Family::V6) => 1, AttrType::ResponseAddress(Family::V4) => 2, AttrType::ResponseAddress(Family::V6) => 3, AttrType::SourceAddress(Family::V4) => 4, AttrType::SourceAddress(Family::V6) => 5, AttrType::ChangedAddress(Family::V4) => 6, AttrType::ChangedAddress(Family::V6) => 7, AttrType::ChangeRequest(flag_set) => 8 | flag_set, } } } impl TryFrom for AttrType { type Error = InvalidAttrType; fn try_from(value: u8) -> Result { match value { 0 => Ok(AttrType::MappedAddress(Family::V4)), 1 => Ok(AttrType::MappedAddress(Family::V6)), 2 => Ok(AttrType::ResponseAddress(Family::V4)), 3 => Ok(AttrType::ResponseAddress(Family::V6)), 4 => Ok(AttrType::SourceAddress(Family::V4)), 5 => Ok(AttrType::SourceAddress(Family::V6)), 6 => Ok(AttrType::ChangedAddress(Family::V4)), 7 => Ok(AttrType::ChangedAddress(Family::V6)), 8..12 => Ok(AttrType::ChangeRequest(value & 0x3)), _ => Err(InvalidAttrType(value)), } } } trait WriteAttr { fn put_attr(&mut self, attr: &Attr); } impl WriteAttr for T { fn put_attr(&mut self, attr: &Attr) { let typ: u8 = attr.typ().into(); match attr { Attr::MappedAddress(socket_addr) => { self.put_u8(typ); self.put_socket_addr(socket_addr); } Attr::ResponseAddress(socket_addr) => { self.put_u8(typ); self.put_socket_addr(socket_addr); } Attr::ChangeRequest(flag) => { self.put_u8(typ | *flag); } Attr::SourceAddress(socket_addr) => { self.put_u8(typ); self.put_socket_addr(socket_addr); } Attr::ChangedAddress(socket_addr) => { self.put_u8(typ); self.put_socket_addr(socket_addr); } }; } } impl Attr { pub fn typ(&self) -> AttrType { match self { Attr::MappedAddress(socket_addr) => AttrType::MappedAddress(socket_addr.family()), Attr::ResponseAddress(socket_addr) => AttrType::ResponseAddress(socket_addr.family()), Attr::ChangeRequest(flag_set) => AttrType::ChangeRequest(*flag_set), Attr::SourceAddress(socket_addr) => AttrType::SourceAddress(socket_addr.family()), Attr::ChangedAddress(socket_addr) => AttrType::ChangedAddress(socket_addr.family()), } } fn be_attr(input: &[u8]) -> IResult<&[u8], Self> { if input.is_empty() { return Err(Err::Error(Error::new(input, ErrorKind::Eof))); } let (remain, typ) = be_u8(input)?; let typ: AttrType = typ .try_into() .map_err(|_| Err::Error(Error::new(input, ErrorKind::Alt)))?; match typ { AttrType::MappedAddress(family) => { let (remain, addr) = be_socket_addr(remain, family)?; Ok((remain, Attr::MappedAddress(addr))) } AttrType::ResponseAddress(family) => { let (remain, addr) = be_socket_addr(remain, family)?; Ok((remain, Attr::ResponseAddress(addr))) } AttrType::SourceAddress(family) => { let (remain, addr) = be_socket_addr(remain, family)?; Ok((remain, Attr::SourceAddress(addr))) } AttrType::ChangedAddress(family) => { let (remain, addr) = be_socket_addr(remain, family)?; Ok((remain, Attr::ChangedAddress(addr))) } AttrType::ChangeRequest(flags) => Ok((remain, Attr::ChangeRequest(flags))), } } } #[derive(Debug, PartialEq, Clone)] pub struct Request(Vec); /// 目前用到的Request只有3种,一种是空的默认Request;一种是变换IP、Port来响应;一种是只变换端口来响应 /// 可以看出,ChangeRequest属性不可能有超过一个,为满足这种限制,三种Request均直接构造出来,不再有其他 /// 可变操作函数。 impl Default for Request { fn default() -> Self { Self(Vec::with_capacity(0)) } } pub(crate) trait WriteRequest { fn put_request(&mut self, request: &Request); } impl WriteRequest for T { fn put_request(&mut self, request: &Request) { for attr in &request.0 { self.put_attr(attr); } } } pub fn be_request(input: &[u8]) -> IResult<&[u8], Request> { many0(Attr::be_attr).map(Request).parse(input) } pub const CHANGE_PORT: u8 = 0x01; pub const CHANGE_IP: u8 = 0x02; impl Request { pub fn change_ip_and_port() -> Self { let mut request = Request::default(); request.0.push(Attr::ChangeRequest(CHANGE_IP | CHANGE_PORT)); request } pub fn change_port() -> Self { let mut request = Request::default(); request.0.push(Attr::ChangeRequest(CHANGE_PORT)); request } pub fn add_response_address(&mut self, addr: SocketAddr) -> &mut Self { self.0.push(Attr::ResponseAddress(addr)); self } // 仅发送响应地址,移除ChangeRequest属性 pub fn with_response_addr(addr: SocketAddr) -> Self { Request(vec![Attr::ResponseAddress(addr)]) } pub fn change_request(&self) -> Option { for attr in &self.0 { if let Attr::ChangeRequest(flags) = attr { return Some(*flags); } } None } pub fn response_address(&self) -> Option<&SocketAddr> { for attr in &self.0 { if let Attr::ResponseAddress(addr) = attr { return Some(addr); } } None } } #[derive(Debug, Clone, PartialEq)] pub struct Response(pub Vec); pub(crate) trait WriteResponse { fn put_response(&mut self, response: &Response); } impl WriteResponse for T { fn put_response(&mut self, response: &Response) { for attr in &response.0 { self.put_attr(attr); } } } pub fn be_response(input: &[u8]) -> IResult<&[u8], Response> { many0(Attr::be_attr).map(Response).parse(input) } impl Response { pub fn with(attrs: Vec) -> Self { Response(attrs) } pub fn map_addr(&self) -> io::Result { for attr in &self.0 { if let Attr::MappedAddress(addr) = attr { return Ok(*addr); }; } Err(io::Error::other("No mapped address found in response")) } pub fn changed_addr(&self) -> io::Result { for attr in &self.0 { if let Attr::ChangedAddress(addr) = attr { return Ok(*addr); }; } Err(io::Error::other("No changed address found in response")) } pub fn source_addr(&self) -> io::Result { for attr in &self.0 { if let Attr::SourceAddress(addr) = attr { return Ok(*addr); }; } Err(io::Error::other("No source address found in response")) } } pub fn be_packet(input: &[u8]) -> IResult<&[u8], (TransactionId, Packet)> { let (remain, typ) = be_u16(input)?; let (txid, remain) = remain.split_at(16); let (remain, packet) = match typ { BINDING_REQUEST => map(be_request, Packet::Request).parse(remain)?, BINDING_RESPONSE => map(be_response, Packet::Response).parse(remain)?, _ => return Err(Err::Error(Error::new(input, ErrorKind::Alt))), }; Ok((remain, (TransactionId::from_slice(txid), packet))) } pub trait WritePacket { fn put_packet(&mut self, txid: &TransactionId, packet: &Packet); } impl WritePacket for T { fn put_packet(&mut self, txid: &TransactionId, packet: &Packet) { match packet { Packet::Request(request) => { self.put_u16(BINDING_REQUEST); self.put_slice(txid.as_ref()); self.put_request(request); } Packet::Response(response) => { self.put_u16(BINDING_RESPONSE); self.put_slice(txid.as_ref()); self.put_response(response); } } } } #[cfg(test)] mod tests { use super::*; #[test] fn attr_deserialize() { assert_eq!( Attr::be_attr(&[4, 78, 34, 127, 0, 0, 1][..]), Ok(( &[][..], Attr::SourceAddress("127.0.0.1:20002".parse().unwrap()) )) ); assert_eq!( Attr::be_attr(&[6, 78, 34, 127, 0, 0, 1][..]), Ok(( &[][..], Attr::ChangedAddress("127.0.0.1:20002".parse().unwrap()) )) ); assert_eq!( Attr::be_attr(&[0, 48, 57, 127, 0, 0, 1][..]), Ok(( &[][..], Attr::MappedAddress("127.0.0.1:12345".parse().unwrap()) )) ) } #[test] fn request_serialize() { let buf = [ 4, 78, 34, 127, 0, 0, 1, 0, 48, 57, 127, 0, 0, 1, 6, 78, 34, 127, 0, 0, 1, ]; let (remain, response) = be_response(&buf).unwrap(); assert_eq!(remain.len(), 0); assert_eq!( response, Response(vec![ Attr::SourceAddress("127.0.0.1:20002".parse().unwrap()), Attr::MappedAddress("127.0.0.1:12345".parse().unwrap()), Attr::ChangedAddress("127.0.0.1:20002".parse().unwrap()) ]) ); } } ================================================ FILE: qprotocol/src/stun.rs ================================================ pub mod msg; ================================================ FILE: qrecovery/Cargo.toml ================================================ [package] name = "qrecovery" version = "0.5.0" edition.workspace = true description = "The reliable transport part of QUIC, a part of dquic" readme.workspace = true repository.workspace = true license.workspace = true keywords.workspace = true categories.workspace = true rust-version.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] bytes = { workspace = true } derive_more = { workspace = true, features = ["deref"] } enum_dispatch = { workspace = true } futures = { workspace = true } qbase = { workspace = true } qevent = { workspace = true } rand = { workspace = true } rustls = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["io-util", "time"] } tracing = { workspace = true } [dev-dependencies] tokio = { workspace = true, features = ["test-util", "macros"] } ================================================ FILE: qrecovery/src/crypto.rs ================================================ //! The reliable transmission of the crypto stream. mod send { use std::{ io, pin::Pin, sync::{Arc, Mutex}, task::{Context, Poll, Waker}, }; use bytes::{BufMut, Bytes}; use qbase::{ Epoch, frame::CryptoFrame, net::tx::{ArcSendWakers, Signals}, packet::{Package, PacketContent}, varint::{VARINT_MAX, VarInt}, }; use tokio::io::AsyncWrite; use crate::send::SendBuf; #[derive(Debug)] pub(super) struct Sender { sndbuf: SendBuf, writable_waker: Option, flush_waker: Option, tx_wakers: ArcSendWakers, } impl Sender { /// 不再长的像write,因为rust可以多返回值,因此在返回的结果里面将读到的数据返回. /// 调用者一定要自行将其写入到buffer中发送。 /// 一旦这种函数成功使用,try_read_data就可以淘汰了 fn try_load_data

(&mut self, packet: &mut P) -> Result<(), Signals> where P: BufMut + ?Sized, for<'b> (CryptoFrame, &'b [Bytes]): Package

, { let max_size = packet.remaining_mut(); let predicate = |offset: u64| CryptoFrame::estimate_max_capacity(max_size, offset); self.sndbuf .pick_up(predicate, usize::MAX) .map(|(range, _is_fresh, data)| { let frame = CryptoFrame::new( VarInt::from_u64(range.start).unwrap(), VarInt::try_from(range.end - range.start).unwrap(), ); (frame, data.as_slice()).dump(packet).unwrap(); }) } fn on_data_acked(&mut self, crypto_frame: &CryptoFrame) { self.sndbuf.on_data_acked(&crypto_frame.range()); if self.sndbuf.remaining_mut() > 0 && let Some(waker) = self.writable_waker.take() { waker.wake(); } } fn may_loss_data(&mut self, crypto_frame: &CryptoFrame) { self.tx_wakers.wake_all_by(Signals::TRANSPORT); self.sndbuf.may_loss_data(&crypto_frame.range()) } } impl Sender { fn poll_write(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { assert!( self.writable_waker.is_none() || matches!(self.writable_waker, Some(ref waker) if waker.will_wake(cx.waker())) ); assert!( self.flush_waker.is_none() || matches!(self.flush_waker, Some(ref waker) if waker.will_wake(cx.waker())) ); if self.sndbuf.written() + buf.len() as u64 > VARINT_MAX { return Poll::Ready(Err(io::Error::new( io::ErrorKind::WouldBlock, "The largest offset delivered on the crypto stream cannot exceed 2^62-1", ))); } debug_assert!(self.sndbuf.has_remaining_mut()); self.tx_wakers.wake_all_by(Signals::TRANSPORT); self.sndbuf.write(Bytes::copy_from_slice(buf)); Poll::Ready(Ok(buf.len())) } fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { assert!( self.flush_waker.is_none() || matches!(self.flush_waker, Some(ref waker) if waker.will_wake(cx.waker())) ); if self.sndbuf.is_all_rcvd() { Poll::Ready(Ok(())) } else { self.flush_waker = Some(cx.waker().clone()); Poll::Pending } } } pub(super) type ArcSender = Arc>; /// Struct for crypto layer to send crypto data to the peer. /// /// To reduce the memory reallcation, if the internal buffer is filled, the [`write`] call will /// be blocked until the data sent been acknowledged by peer. /// /// [`write`]: tokio::io::AsyncWriteExt::write #[derive(Debug, Clone)] pub struct CryptoStreamWriter(pub(super) ArcSender); /// Struct for transport layer to send crypto data. #[derive(Debug, Clone)] pub struct CryptoStreamOutgoing(pub(super) ArcSender); impl AsyncWrite for CryptoStreamWriter { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { self.0.lock().unwrap().poll_write(cx, buf) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.0.lock().unwrap().poll_flush(cx) } fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { // 永远不会关闭,直到Connection级别的关闭 Poll::Ready(Ok(())) } } impl CryptoStreamOutgoing { /// Try to load the crypto data into the `packet`. pub fn try_load_data_into

(&self, packet: &mut P, force: bool) -> Result<(), Signals> where P: BufMut + ?Sized, for<'b> (CryptoFrame, &'b [Bytes]): Package

, { use std::ops::ControlFlow::*; let mut inner = self.0.lock().unwrap(); if force { inner.sndbuf.resend_flighting(); } let (Continue(result) | Break(result)) = core::iter::from_fn(|| Some(inner.try_load_data(packet))).try_fold( Err(Signals::empty()), |result, once| match (result, once) { (Err(_empty), Ok(())) => Continue(Ok(())), (Err(_empty), Err(signals)) => Break(Err(signals)), (Ok(()), Ok(())) => Continue(Ok(())), (Ok(()), Err(_no_more)) => Break(Ok(())), }, ); result } pub fn package(self, epoch: Epoch) -> CryptoStreamPackage { CryptoStreamPackage { first_load: epoch == Epoch::Initial, outgoing: self, } } /// Called when the crypto frame sent is acknowledged by peer. /// /// Acknowledgment of data may free up a segment in the [`SendBuf`], thus waking up the /// writing task, pub fn on_data_acked(&self, crypto_frame: &CryptoFrame) { self.0.lock().unwrap().on_data_acked(crypto_frame) } /// Called when the crypto frame sent may loss. pub fn may_loss_data(&self, crypto_frame: &CryptoFrame) { self.0.lock().unwrap().may_loss_data(crypto_frame) } } pub struct CryptoStreamPackage { first_load: bool, outgoing: CryptoStreamOutgoing, } impl

Package

for CryptoStreamPackage where P: BufMut + ?Sized, for<'b> (CryptoFrame, &'b [Bytes]): Package

, { fn dump(&mut self, packet: &mut P) -> Result { let force = self.first_load; match self.outgoing.try_load_data_into(packet, force) { Ok(()) => { self.first_load = false; Ok(PacketContent::EffectivePayload) } Err(signals) => Err(signals), } } } pub(super) fn create(tx_wakers: ArcSendWakers) -> ArcSender { Arc::new(Mutex::new(Sender { sndbuf: SendBuf::with_capacity(VARINT_MAX), writable_waker: None, flush_waker: None, tx_wakers, })) } } mod recv { use std::{ io, pin::Pin, sync::{Arc, Mutex}, task::{Context, Poll, Waker}, }; use bytes::{BufMut, Bytes}; use qbase::{ error::Error, frame::{CryptoFrame, io::ReceiveFrame}, varint::VARINT_MAX, }; use tokio::io::{AsyncRead, ReadBuf}; use crate::recv::RecvBuf; #[derive(Debug)] pub(super) struct Recver { rcvbuf: RecvBuf, read_waker: Option, } impl Recver { fn recv(&mut self, offset: u64, data: Bytes) { assert!(offset + data.len() as u64 <= VARINT_MAX); self.rcvbuf.recv(offset, data); if self.rcvbuf.is_readable() && let Some(waker) = self.read_waker.take() { waker.wake() } } fn poll_read( &mut self, cx: &mut Context<'_>, buf: &mut T, ) -> Poll> { assert!( self.read_waker.is_none() || matches!(self.read_waker, Some(ref waker) if waker.will_wake(cx.waker())) ); if self.rcvbuf.is_readable() { self.rcvbuf.try_read(buf); Poll::Ready(Ok(())) } else { self.read_waker = Some(cx.waker().clone()); Poll::Pending } } } pub(super) type ArcRecver = Arc>; /// Struct for crypto layer to read crypto data from the peer. #[derive(Debug, Clone)] pub struct CryptoStreamReader(pub(super) ArcRecver); /// Struct for transport layer to deliver the received crypto to crypto layer. #[derive(Debug, Clone)] pub struct CryptoStreamIncoming(pub(super) ArcRecver); impl AsyncRead for CryptoStreamReader { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { self.0.lock().unwrap().poll_read(cx, buf) } } impl ReceiveFrame<(CryptoFrame, Bytes)> for CryptoStreamIncoming { type Output = (); fn recv_frame(&self, (frame, data): (CryptoFrame, Bytes)) -> Result { self.0.lock().unwrap().recv(frame.offset(), data); Ok(()) } } pub(super) fn create() -> ArcRecver { Arc::new(Mutex::new(Recver { rcvbuf: RecvBuf::default(), read_waker: None, })) } } use qbase::net::tx::ArcSendWakers; pub use recv::{CryptoStreamIncoming, CryptoStreamReader}; pub use send::{CryptoStreamOutgoing, CryptoStreamWriter}; /// Crypto data stream. #[derive(Debug, Clone)] pub struct CryptoStream { sender: send::ArcSender, recver: recv::ArcRecver, } impl CryptoStream { /// Create a new instance of [`CryptoStream`] with the given buffer size. pub fn new(tx_wakers: ArcSendWakers) -> Self { Self { sender: send::create(tx_wakers), recver: recv::create(), } } /// Create a [`CryptoStreamWriter`] which belong to this crypto stream. pub fn writer(&self) -> CryptoStreamWriter { CryptoStreamWriter(self.sender.clone()) } /// Create a [`CryptoStreamReader`] which belong to this crypto stream. pub fn reader(&self) -> CryptoStreamReader { CryptoStreamReader(self.recver.clone()) } /// Create a [`CryptoStreamOutgoing`] which belong to this crypto stream. pub fn outgoing(&self) -> CryptoStreamOutgoing { CryptoStreamOutgoing(self.sender.clone()) } /// Create a [`CryptoStreamIncoming`] which belong to this crypto stream. pub fn incoming(&self) -> CryptoStreamIncoming { CryptoStreamIncoming(self.recver.clone()) } } #[cfg(test)] mod tests { use qbase::{ frame::{CryptoFrame, io::ReceiveFrame}, varint::VarInt, }; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use super::CryptoStream; #[tokio::test] async fn test_read() { let crypto_stream: CryptoStream = CryptoStream::new(Default::default()); crypto_stream .writer() .write_all(b"hello world") .await .unwrap(); crypto_stream .incoming() .recv_frame(( CryptoFrame::new(VarInt::from_u32(0), VarInt::from_u32(11)), bytes::Bytes::copy_from_slice(b"hello world"), )) .unwrap(); let mut buf = [0u8; 11]; crypto_stream.reader().read_exact(&mut buf).await.unwrap(); assert_eq!(&buf[..], b"hello world"); } } ================================================ FILE: qrecovery/src/journal/rcvd.rs ================================================ use std::{ collections::HashSet, sync::{Arc, RwLock}, }; use bytes::BufMut; use qbase::{ frame::AckFrame, net::tx::Signals, packet::{InvalidPacketNumber, Package, PacketContent, PacketNumber, PacketWriter}, util::{IndexDeque, IndexError}, varint::{VARINT_MAX, VarInt}, }; use tokio::time::{Duration, Instant}; /// 收包记录有以下几种状态 /// - Empty:收包记录为空,未收到该包 /// - PacketReceived:(收包时间,最晚ack时间,过期时间), 如果路径没有驱动 ack,由这里驱动 /// - AckSent:(ack_eliciting,收包时间,淘汰时间,确认了这个包的包号集合),如果set里的任意包号被确认了,则转换成 AckConfirmed 状态 /// - AckConfirmed:(ack_eliciting,收包时间,淘汰时间) #[derive(Debug, Clone, PartialEq, Eq, Default)] enum State { #[default] Empty, PacketReceived(Instant, Option, Instant), AckSent(bool, Instant, Instant, HashSet), AckConfirmed(bool, Instant, Instant), } impl State { // 是否要打包到 ack frame 中,如果需要,PacketReceived 状态转换成 AckSent 状态, AckSent 状态记录 pn fn track_packet_in_ack_frame(&mut self, pn: u64) -> bool { match self { State::PacketReceived(recv_time, latest_ack_time, expire_time) => { *self = State::AckSent( latest_ack_time.is_some(), *recv_time, *expire_time, [pn].into(), ); true } State::AckSent(_, _, _, pns) => { pns.insert(pn); true } State::AckConfirmed(_, _, _) => true, State::Empty => false, } } fn could_expire(&self, now: Instant) -> bool { match self { State::Empty => true, State::AckConfirmed(ack_eliciting, _, expire_time) => { !ack_eliciting || *expire_time < now } _ => false, } } } /// 纯碎的一个收包记录,主要用于: /// - 记录包有无收到 /// - 根据某个largest pktno,生成ack frame(ack frame不能超过buf大小) /// - 确定记录不再需要,可以被丢弃,滑走 #[derive(Debug, Default)] struct RcvdJournal { queue: IndexDeque, max_ack_delay: Option, packet_include_ack: HashSet, earliest_not_ack_time: Option<(u64, Instant)>, } impl RcvdJournal { fn with_capacity(capacity: usize, max_ack_delay: Option) -> Self { Self { queue: IndexDeque::with_capacity(capacity), max_ack_delay, packet_include_ack: HashSet::new(), earliest_not_ack_time: None, } } fn decode_pn(&mut self, pkt_number: PacketNumber) -> Result { let expected_pn = self.queue.largest(); let pn = pkt_number.decode(expected_pn); if pn < self.queue.offset() { return Err(InvalidPacketNumber::TooOld); } match self.queue.get(pn) { Some(State::Empty) | None => Ok(pn), _ => Err(InvalidPacketNumber::Duplicate), } } fn on_rcvd_pn(&mut self, pn: u64, is_ack_eliciting: bool, pto: Duration) { let now = tokio::time::Instant::now(); let ack_time = if is_ack_eliciting { Some(now + self.max_ack_delay.unwrap_or_default()) } else { None }; let expire_time = now + pto * 3; if let Some(record) = self.queue.get_mut(pn) { // assert!(matches!(record, State::Empty)); *record = State::PacketReceived(now, ack_time, expire_time); } else if let Err(e @ IndexError::ExceedLimit(..)) = self .queue .insert(pn, State::PacketReceived(now, ack_time, expire_time)) { panic!("packet number never exceed limit: {e}") } if is_ack_eliciting && self.earliest_not_ack_time.is_none() { self.earliest_not_ack_time = Some((pn, now)); } } fn on_rcvd_ack(&mut self, ack_frame: &AckFrame) { let acked_pns: std::collections::HashSet<_> = ack_frame .iter() .flat_map(|range| range.clone()) .filter(|pn| self.packet_include_ack.contains(pn)) .collect(); self.packet_include_ack.retain(|pn| !acked_pns.contains(pn)); for record in self.queue.iter_mut() { if let State::AckSent(ack_eliciting, recv_time, expire_time, pns) = record && pns.iter().any(|pn| acked_pns.contains(pn)) { *record = State::AckConfirmed(*ack_eliciting, *recv_time, *expire_time); } } self.rotate_queue(); } fn rotate_queue(&mut self) { let now = tokio::time::Instant::now(); while self .queue .front() .is_some_and(|(_pn, state)| state.could_expire(now)) { self.queue.pop_front(); } } fn gen_ack_frame_util( &mut self, pn: u64, largest: u64, rcvd_time: Instant, mut capacity: usize, ) -> Result { let mut pkts = self .queue .enumerate_mut() .rev() .skip_while(|(pktno, _)| *pktno > largest); // Minimum length with at least ACK frame type, largest, delay, range count, first_range (at least 1 byte for 0) let largest = VarInt::from_u64(largest).unwrap(); let delay = rcvd_time.elapsed().as_micros() as u64; let delay = VarInt::from_u64(delay).unwrap(); let mut first_range = 0_u32; for (_, s) in pkts.by_ref() { if s.track_packet_in_ack_frame(pn) { first_range += 1; } else { break; } } first_range = first_range.saturating_sub(1); let first_range = VarInt::from(first_range); // Frame type + Largest Acknowledged + First Ack Range + Ack Range Count let min_len = 1 + largest.encoding_size() + delay.encoding_size() + first_range.encoding_size() + 1; if capacity < min_len { return Err(Signals::CONGESTION); } capacity -= min_len; fn range_count_size_increment(range_count: usize) -> usize { match range_count { // 接下来需要2字节编码 len if len == (1 << 6) - 1 => 1, // 2 - 1 // 接下来需要4字节编码 len if len == (1 << 14) - 1 => 2, // 4 - 2 // 接下来需要8字节编码 len if len == (1 << 30) - 1 => 4, // 8 - 4 // 放不下了,不可能走到这里 _ => 0, } } let mut ranges = vec![]; use core::ops::ControlFlow::*; let (Continue((gap, ack, last_is_acked)) | Break((gap, ack, last_is_acked))) = pkts .try_fold( // take_while第一个被判否的元素会被消耗,如果它是gap那这里有gap=1,如果是因为迭代器没有更多元素这里gap=1也不影响 (1, 0, false), |(gap, ack, last_is_acked), (_pktno, state)| { let range_count = ranges.len(); match (last_is_acked, state.track_packet_in_ack_frame(pn)) { // 本range结束了,看看是否放得下本range,开始新的range (true, false) => { // 修正 let gap = VarInt::from_u32(gap - 1); let ack = VarInt::from_u32(ack - 1); let size = range_count_size_increment(range_count) + gap.encoding_size() + ack.encoding_size(); if capacity < size { // last_is_acked为false,不会被填进去 return Break((0, 0, false)); } capacity -= size; ranges.push((gap, ack)); Continue((1, 0, state.track_packet_in_ack_frame(pn))) } // 如果当前是ack,增加ack,保持gap不变 (false | true, true) => { Continue((gap, ack + 1, state.track_packet_in_ack_frame(pn))) } // 当前和之前都是gap,增加gap (false, false) => { Continue((gap + 1, ack, state.track_packet_in_ack_frame(pn))) } } }, ); // 处理最后一个未来完成的range if last_is_acked { let gap = VarInt::from_u32(gap - 1); let ack = VarInt::from_u32(ack - 1); let size = range_count_size_increment(ranges.len()) + gap.encoding_size() + ack.encoding_size(); if capacity > size { // capacity -= size; unnecessary, never read latter ranges.push((gap, ack)); } } self.packet_include_ack.insert(pn); if let Some((pn, _)) = self.earliest_not_ack_time && largest >= pn { self.earliest_not_ack_time = None; } Ok(AckFrame::new(largest, delay, first_range, ranges, None)) } fn need_ack(&self) -> Option<(u64, Instant)> { let now = tokio::time::Instant::now(); let (_, earliest_not_ack_time) = self.earliest_not_ack_time?; let max_ack_delay = self.max_ack_delay.unwrap_or_default(); if earliest_not_ack_time + max_ack_delay >= now { return None; } let (largest, state) = self.queue.back()?; let recv_time = match state { State::PacketReceived(rt, _, _) | State::AckSent(_, rt, _, _) | State::AckConfirmed(_, rt, _) => *rt, _ => return None, }; Some((largest, recv_time)) } } /// Records for received packets, decode the packet number and generate ack frames. // 接收数据包队列,各处共享的,判断包是否收到以及生成ack frame,只需要读锁; // 记录新收到的数据包,或者失活旧数据包并滑走,才需要写锁。 #[derive(Debug, Clone, Default)] pub struct ArcRcvdJournal { inner: Arc>, } impl ArcRcvdJournal { /// Create a new empty records with the given `capacity`. /// /// The number of records can exceed the `capacity` specified at creation time, but the internel /// implementation strvies to avoid reallocation. pub fn with_capacity(capacity: usize, max_ack_delay: Option) -> Self { Self { inner: Arc::new(RwLock::new(RcvdJournal::with_capacity( capacity, max_ack_delay, ))), } } /// Decode the pn from peer's packet to actual packer number. /// /// See [`RFC`](https://www.rfc-editor.org/rfc/rfc9000.html#name-sample-packet-number-decodi) /// for more details about decode packet number. /// /// If the packet is too old or has been received, or the pn is too big, this method will return /// an error. /// /// Note that although the packet number successful decoded, it does not mean that the packet is /// valid, and the frames in it are valid. /// /// The registered packet must be valid, successfully decrypted, and the frames in it must be /// valid. // 当新收到一个数据包,如果这个包很旧,那么大概率意味着是重复包,直接丢弃。 // 如果这个数据包号是最大的,那么它之前的空档都是尚未收到的,得记为未收到。 // 注意,包号合法,不代表的包内容合法,必须等到包被正确解密且其中帧被正确解出后,才能确认收到。 pub fn decode_pn(&self, encoded_pn: PacketNumber) -> Result { self.inner.write().unwrap().decode_pn(encoded_pn) } /// Register the packet has been recieved. /// /// The registered packet must be valid, successfully decrypted, and the frames in it must be /// valid. // 当包号合法,且包被完全解密,且包中的帧都正确之后,记录该包已经收到。 pub fn on_rcvd_pn(&self, pn: u64, is_ack_eliciting: bool, pto: Duration) { self.inner .write() .unwrap() .on_rcvd_pn(pn, is_ack_eliciting, pto); } /// Generate an ack frame which ack the received frames until `largest`. /// /// This method will write an ack frame into the `buf`. The `Ack Delay` field of the frame is /// the argument `recv_time` as microsec, the `Largest Acknowledged` field of the frame is the /// `largest` frame, the ranges in ack frame will not exceed `largest`. pub fn gen_ack_frame_util( &self, pn: u64, largest: u64, rcvd_time: Instant, capacity: usize, ) -> Result { self.inner .write() .unwrap() .gen_ack_frame_util(pn, largest, rcvd_time, capacity) } pub fn on_rcvd_ack(&self, ack_frame: &AckFrame) { self.inner.write().unwrap().on_rcvd_ack(ack_frame); } pub fn need_ack(&self) -> Option<(u64, Instant)> { self.inner.read().unwrap().need_ack() } pub fn revise_max_ack_delay(&self, max_ack_delay: Duration) { self.inner.write().unwrap().max_ack_delay = Some(max_ack_delay); } pub fn ack_package<'r>(&'r self, need_ack: Option<(u64, Instant)>) -> AckPackege<'r> { AckPackege { journal: self, need_ack, } } } pub struct AckPackege<'r> { journal: &'r ArcRcvdJournal, need_ack: Option<(u64, Instant)>, } impl<'r, Target> Package for AckPackege<'r> where Target: AsRef> + ?Sized, AckFrame: Package, { fn dump(&mut self, target: &mut Target) -> Result { self.need_ack .or_else(|| self.journal.need_ack()) .ok_or(Signals::TRANSPORT) .and_then(|(largest_ack, rcvd_time)| { self.journal.gen_ack_frame_util( target.as_ref().packet_number(), largest_ack, rcvd_time, target.as_ref().remaining_mut(), ) })? .dump(target) .unwrap(); Ok(PacketContent::NonAckEliciting) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_rcvd_pkt_records() { let records = ArcRcvdJournal::with_capacity(16, None); assert_eq!(records.decode_pn(PacketNumber::encode(1, 0)), Ok(1)); assert_eq!(records.inner.read().unwrap().queue.len(), 0); let pto = Duration::from_millis(100); records.on_rcvd_pn(1, true, pto); assert_eq!(records.inner.read().unwrap().queue.len(), 2); assert_eq!( records.inner.read().unwrap().queue.get(0).unwrap(), &State::Empty ); assert!(matches!( records.inner.read().unwrap().queue.get(1).unwrap(), State::PacketReceived(_, _, _) )); let ack_frame = records.gen_ack_frame_util(0, 1, Instant::now(), 1200); assert_eq!(&ack_frame.unwrap().largest(), &1); assert!( records .inner .read() .unwrap() .packet_include_ack .contains(&0) ); assert!(matches!( records.inner.read().unwrap().queue.get(1).unwrap(), State::AckSent(true, _, _, _) )); let ack_frame = AckFrame::new(0_u32.into(), 100_u32.into(), 0_u32.into(), vec![], None); records.on_rcvd_ack(&ack_frame); assert_eq!(records.inner.read().unwrap().queue.len(), 1); let binding = records.inner.read().unwrap(); let record = binding.queue.get(1).unwrap(); assert!(matches!(record, State::AckConfirmed(_, _, _))); } #[test] fn gen_ack_frame() { let rcvd_state = State::PacketReceived(Instant::now(), None, Instant::now()); let unrcvd_state = State::Empty; let mut queue = IndexDeque::with_capacity(45); for idx in 1..11 { queue.insert(idx, rcvd_state.clone()).unwrap(); } for idx in 11..12 { queue.insert(idx, unrcvd_state.clone()).unwrap(); } for idx in 12..45 { queue.insert(idx, rcvd_state.clone()).unwrap(); } for idx in 45..50 { queue.insert(idx, unrcvd_state.clone()).unwrap(); } for idx in 50..55 { queue.insert(idx, rcvd_state.clone()).unwrap(); } let mut rcvd_jornal = RcvdJournal { queue, max_ack_delay: None, packet_include_ack: Default::default(), earliest_not_ack_time: None, }; let ack = rcvd_jornal .gen_ack_frame_util(0, 52, Instant::now(), 1000) .unwrap(); assert_eq!( ack.ranges(), &vec![ (VarInt::from_u32(50 - 45 - 1), VarInt::from_u32(45 - 12 - 1)), (VarInt::from_u32(12 - 11 - 1), VarInt::from_u32(11 - 1 - 1)) ] ); assert_eq!(ack.first_range(), 2) } } ================================================ FILE: qrecovery/src/journal/sent.rs ================================================ use std::{ collections::VecDeque, ops::DerefMut, sync::{Arc, Mutex, MutexGuard}, time::Duration, }; use derive_more::{Deref, DerefMut}; use qbase::{ error::{ErrorKind, QuicError}, frame::{AckFrame, GetFrameType}, packet::PacketNumber, util::IndexDeque, varint::VARINT_MAX, }; use tokio::time::Instant; /// 记录发送的数据包的状态,包括 /// - Flighting: 数据包正在传输中 /// - Acked: 数据包已经被确认 /// - Lost: 数据包丢失 #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum SentPktState { Skipped, Flighting { nframes: usize, sent_time: Instant, expire_time: Instant, retran_time: Instant, }, Retransmitted { nframes: usize, sent_time: Instant, expire_time: Instant, }, Acked { nframes: usize, sent_time: Instant, expire_time: Instant, }, } impl SentPktState { #[allow(dead_code)] fn skipped() -> Self { Self::Skipped } fn new(nframes: usize, sent_time: Instant, retran_time: Instant, expire_time: Instant) -> Self { Self::Flighting { nframes, sent_time, retran_time, expire_time, } } fn nframes(&self) -> usize { match self { SentPktState::Skipped => 0, SentPktState::Flighting { nframes, .. } => *nframes, SentPktState::Retransmitted { nframes, .. } => *nframes, SentPktState::Acked { nframes, .. } => *nframes, } } fn be_acked(&mut self) -> usize { match *self { SentPktState::Skipped => 0, SentPktState::Flighting { nframes, sent_time, expire_time, .. } => { *self = SentPktState::Acked { nframes, sent_time, expire_time, }; nframes } SentPktState::Retransmitted { nframes, sent_time, expire_time, .. } => { *self = SentPktState::Acked { nframes, sent_time, expire_time, }; nframes } SentPktState::Acked { .. } => 0, } } fn maybe_lost(&mut self) -> usize { match *self { SentPktState::Flighting { nframes, sent_time, expire_time, .. } => { *self = SentPktState::Retransmitted { nframes, sent_time, expire_time, }; nframes } Self::Retransmitted { nframes, .. } => nframes, _ => 0, } } fn should_retransmit_after(&mut self, now: &Instant) -> bool { match *self { SentPktState::Flighting { sent_time, retran_time, expire_time, .. } if retran_time < *now => { *self = SentPktState::Retransmitted { nframes: self.nframes(), sent_time, expire_time, }; true } _ => false, } } fn should_remain_after(&self, pn: u64, now: &Instant) -> bool { match self { SentPktState::Skipped => false, SentPktState::Flighting { .. } => true, SentPktState::Retransmitted { expire_time, .. } => { if expire_time > now { true } else { tracing::trace!(target: "quic", "retransmitted packet {pn} is expired without ack"); false } } SentPktState::Acked { .. } => false, } } } /// 记录已经发送的帧,尽最大努力省略内存分配。 /// queue记录着所有发送过的帧,records记录着顺序发送的数据包包含几个帧,以及这些数据包的状态。 /// 发送数据包的时候,往其中写入数据包的帧, /// 接收到确认的时候,更新数据包的状态,被确认就什么都不做;丢失的数据包,得重新发送 #[derive(Debug, Default, Deref, DerefMut)] struct SentJournal { #[deref] #[deref_mut] queue: VecDeque, // 记录着每个包的内容,其实是一个数字,该数字对应着queue中的record数量 sent_packets: IndexDeque, largest_acked_pktno: u64, } impl SentJournal { fn on_packet_acked(&mut self, pn: u64) -> impl Iterator + '_ { let mut len = 0; let offset = self .sent_packets .enumerate() .take_while(|(pkt_idx, _)| *pkt_idx < pn) .map(|(_, s)| s.nframes()) .sum::(); if let Some(s) = self.sent_packets.get_mut(pn) { len = s.be_acked(); } self.queue .range_mut(offset..offset + len) .map(|f| f.clone()) } fn may_loss_packet(&mut self, pn: u64) -> impl Iterator + '_ { let mut len = 0; let offset = self .sent_packets .enumerate() // TODO(optimize): 调用者是一种遍历,每次都从头take_while,可以优化 .take_while(|(pkt_idx, _)| *pkt_idx < pn) .map(|(_, s)| s.nframes()) .sum::(); if let Some(s) = self.sent_packets.get_mut(pn) { len = s.maybe_lost(); } self.queue .range_mut(offset..offset + len) .map(|f| f.clone()) } fn fast_retransmit(&mut self) -> impl Iterator + '_ { self.resize(); let now = tokio::time::Instant::now(); self.sent_packets .enumerate_mut() .take_while(|(pn, _)| *pn < self.largest_acked_pktno) .scan(0, move |sum, (_, s)| { let start = *sum; *sum += s.nframes(); Some((s.should_retransmit_after(&now), start..*sum)) }) .filter(|(should_retran, _)| *should_retran) .flat_map(|(_, r)| self.queue.range(r)) .cloned() } } impl SentJournal { fn with_capacity(capacity: usize) -> Self { Self { queue: VecDeque::with_capacity(capacity * 4), sent_packets: IndexDeque::with_capacity(capacity), largest_acked_pktno: 0, } } fn resize(&mut self) { let now = Instant::now(); let (n, f) = self .sent_packets .enumerate() .take_while(|(pn, s)| !s.should_remain_after(*pn, &now)) .fold((0usize, 0usize), |(n, f), (_, s)| (n + 1, f + s.nframes())); self.sent_packets.advance(n); _ = self.queue.drain(..f); } } /// Records for sent packets and frames in them. /// /// [`DataStreams`] need to be aware of frame acknowledgment or possible loss, and so does [`CryptoStream`]. /// This structure records some frames (type T) in each packet sent, and feeds back the frames in /// these packets to [`DataStreams`] and [`CryptoStream`] when the packet is acknowledged or may be /// lost. /// /// The interfaces are on the [`NewPacketGuard`] structure and the [`SentRotateGuard`] structure, read their /// documentation for more. This structure only provide the methods to create them. /// /// If multiple tasks are recording at the same time, the recording will become confusing, so the /// [`NewPacketGuard`] and the [`SentRotateGuard`] are designed to be `Guard`, which means that they hold a /// [`MutexGuard`]. /// /// /// [`DataStreams`]: crate::streams::DataStreams /// [`CryptoStream`]: crate::crypto::CryptoStream #[derive(Debug, Default)] pub struct ArcSentJournal(Arc>>); impl Clone for ArcSentJournal { fn clone(&self) -> Self { Self(self.0.clone()) } } impl ArcSentJournal { /// Create a new empty records with the given `capatity`. /// /// The number of records can exceed the `capacity` specified at creation time, but the internel /// implementation strvies to avoid reallocation. pub fn with_capacity(capacity: usize) -> Self { Self(Arc::new(Mutex::new(SentJournal::with_capacity(capacity)))) } /// Return a [`SentRotateGuard`] to resolve the ack frame from peer. pub fn rotate(&self) -> SentRotateGuard<'_, T> { SentRotateGuard { inner: self.0.lock().unwrap(), } } /// Return a [`NewPacketGuard`] to get the next pn and record frames in the packet. pub fn new_packet(&self) -> NewPacketGuard<'_, T> { let inner = self.0.lock().unwrap(); let origin_len = inner.queue.len(); NewPacketGuard { trivial: false, origin_len, inner, } } } /// Handle the peer's ack frame and feed back the frames in the acknowledged or possibly lost packets to other components. pub struct SentRotateGuard<'a, T> { inner: MutexGuard<'a, SentJournal>, } impl SentRotateGuard<'_, T> { /// Handle the [`Largest Acknowledged`] field of the ack frame from peer. /// /// [`Largest Acknowleged`]: https://www.rfc-editor.org/rfc/rfc9000.html#name-ack-frames pub fn update_largest(&mut self, ack_frame: &AckFrame) -> Result<(), QuicError> { if ack_frame.largest() > self.inner.sent_packets.largest() { return Err(QuicError::new( ErrorKind::ProtocolViolation, ack_frame.frame_type().into(), "ack frame largest pn is larger than the largest pn sent", )); } if ack_frame.largest() > self.inner.largest_acked_pktno { self.inner.largest_acked_pktno = ack_frame.largest(); } Ok(()) } /// Called when the packet sent is acked by peer, return the frames in that packet. pub fn on_packet_acked(&mut self, pn: u64) -> impl Iterator + '_ { self.inner.on_packet_acked(pn) } /// Called when the packet sent may lost, reutrn the frames in that packet. pub fn may_loss_packet(&mut self, pn: u64) -> impl Iterator + '_ { self.inner.may_loss_packet(pn) } pub fn fast_retransmit(&mut self) -> impl Iterator + '_ { self.inner.fast_retransmit() } } impl Drop for SentRotateGuard<'_, T> { fn drop(&mut self) { self.inner.resize(); } } /// Provide the [encoded] packet number to assemble a packet, and record the frames in packet which /// will be send. /// /// One [`NewPacketGuard`] correspond to a packet. /// /// Even if the next packet number is obtained, the packet may not be sent out. If the packet is not /// sent out, the packet number will not be consumed. /// /// Call [`NewPacketGuard::record_trivial`] or [`NewPacketGuard::record_frame`] means that the packet will be /// correspond to this [`NewPacketGuard`] will be sent, and the packet number will be consumed when the /// [`NewPacketGuard`] dropped. /// /// [encoded]: https://www.rfc-editor.org/rfc/rfc9000.html#name-sample-packet-number-encodi #[derive(Debug)] pub struct NewPacketGuard<'a, T> { trivial: bool, origin_len: usize, inner: MutexGuard<'a, SentJournal>, } impl NewPacketGuard<'_, T> { /// Provide a packet number and its [encoded] form to assemble a packet. /// /// Call this method multipes on the same [`NewPacketGuard`] will result the same pn. /// /// [encoded]: https://www.rfc-editor.org/rfc/rfc9000.html#name-sample-packet-number-encodi pub fn pn(&self) -> (u64, PacketNumber) { let pn = self.inner.sent_packets.largest(); let encoded_pn = PacketNumber::encode(pn, self.inner.largest_acked_pktno); (pn, encoded_pn) } /// Records trivial frames that do not need retransmission, such as Padding, Ping, and Ack. /// However, this packet does occupy a packet number. Even if no other reliable frames are sent, /// it still needs to be recorded, with the number of reliable frames in this packet being 0. pub fn record_trivial(&mut self) { self.trivial = true; } /// Records a frame in the packet being sent. /// /// Once this method or [`NewPacketGuard::record_trivial`] called, the packet number will be consumed. /// /// When the packet is acked, or may loss, the frames in packet will been fed back to the /// components which sent them. pub fn record_frame(&mut self, frame: T) { self.inner.deref_mut().push_back(frame); } pub fn build_with_time(mut self, retran_timeout: Duration, expire_timeout: Duration) { let nframes = self.inner.queue.len() - self.origin_len; let sent_time = tokio::time::Instant::now(); if self.trivial && nframes == 0 { self.inner .sent_packets .push_back(SentPktState::Skipped) .expect("packet number never overflow"); } else if nframes > 0 { self.inner .sent_packets .push_back(SentPktState::new( nframes, sent_time, sent_time + retran_timeout, sent_time + expire_timeout, )) .expect("packet number never overflow"); } } pub fn build_trivial(mut self) { assert_eq!(self.inner.queue.len(), self.origin_len); assert!(self.trivial); self.inner .sent_packets .push_back(SentPktState::Skipped) .expect("packet number never overflow"); } } ================================================ FILE: qrecovery/src/journal.rs ================================================ //! The space that reliably transmites frames. use std::time::Duration; mod rcvd; pub use rcvd::*; mod sent; pub use sent::*; /// The bundle of sent packet records and received packet records. /// /// The generic `T` is the generic on [`ArcSentJournal`]. /// /// See [`ArcSentJournal`] and [`ArcRcvdJournal`] for more. #[derive(Debug, Default, Clone)] pub struct Journal { sent: ArcSentJournal, rcvd: ArcRcvdJournal, } impl Journal { /// Create a [`Journal`] containing records with the given `capacity`. pub fn with_capacity(capacity: usize, max_ack_delay: Option) -> Self { Self { sent: ArcSentJournal::with_capacity(capacity), rcvd: ArcRcvdJournal::with_capacity(capacity, max_ack_delay), } } /// Get the [`ArcSentJournal`] of space. pub fn of_sent_packets(&self) -> ArcSentJournal { self.sent.clone() } /// Get the [`ArcRcvdJournal`] of space. pub fn of_rcvd_packets(&self) -> ArcRcvdJournal { self.rcvd.clone() } } impl AsRef> for Journal { fn as_ref(&self) -> &ArcSentJournal { &self.sent } } impl AsRef for Journal { fn as_ref(&self) -> &ArcRcvdJournal { &self.rcvd } } ================================================ FILE: qrecovery/src/lib.rs ================================================ //! Crate to implement reliable transmission. //! //! The structures in this crate dont have the ability to send or receive frames directly, but they //! provide interfaces to generate frames and write them into buffers, handle received frames, and //! handle frame acknowledgment and loss. This is what [`Incoming`], [`Outgoing`], [`DataStreams`], //! [`CryptoStreamIncoming`], [`CryptoStreamOutgoing`] and [`CryptoStream`] do. //! //! The [`reliable`] module of this crate provids the records for sent and received packets, and a //! reliable frame queue to ensure that the frames in it will be sent to the peer and confirmed. //! //! The sent record can provide a packet number for the new packet (although the QUIC packet number //! is incremented, the packet number stored in the packet header is encoded). //! //! The sent records are also responsible for processing the ack frames sent by the other party. //! Through the other party's ack frames, which packets have been confirmed can be known, and then //! the frames in these packets are fed back to [`DataStreams`] and [`CryptoStream`] for processing. //! //! The loss of packets is determined by congestion control, and sending records can feed back the //! frame in may lost data packets to [`DataStreams`] and [`CryptoStream`]. //! //! The received records are used to generate the ack frame, and decode the packet number in the //! packet received. //! //! [`Incoming`]: crate::recv::Incoming //! [`Outgoing`]: crate::send::Outgoing //! [`DataStreams`]: crate::streams::DataStreams //! [`CryptoStreamIncoming`]: crate::crypto::CryptoStreamIncoming //! [`CryptoStreamOutgoing`]: crate::crypto::CryptoStreamOutgoing //! [`CryptoStream`]: crate::crypto::CryptoStream pub mod crypto; pub mod journal; pub mod recv; pub mod reliable; pub mod send; pub mod streams; ================================================ FILE: qrecovery/src/recv/incoming.rs ================================================ use std::ops::DerefMut; use bytes::Bytes; use qbase::{ error::{Error, QuicError}, frame::{MaxStreamDataFrame, ResetStreamFrame, StopSendingFrame, StreamFrame, io::SendFrame}, }; use super::recver::{ArcRecver, Recver}; /// An struct for protocol layer to manage the receiving part of a stream. #[derive(Debug, Clone)] pub struct Incoming(ArcRecver); impl Incoming where TX: SendFrame + SendFrame + Clone + Send + 'static, { /// Receive a stream frame from peer. /// /// The stream frame will be handed over to the receive state machine. /// /// The data in a stream frame is just a fragment of the data on the stream. The data transmitted /// by different stream frames may not continuous. The data will be assembled by [`RecvBuf`] into /// continuous data for the application layer to read through [`Reader`]. /// /// [`RecvBuf`]: crate::recv::RecvBuf /// [`Reader`]: crate::recv::Reader pub fn recv_data( &self, stream_frame: StreamFrame, body: Bytes, ) -> Result<(bool, usize), QuicError> { let mut recver = self.0.recver(); let inner = recver.deref_mut(); let mut is_into_rcvd = false; let mut fresh_data = 0; if let Ok(receiving_state) = inner { match receiving_state { Recver::Recv(r) => { if stream_frame.is_fin() { let mut size_known = r.determin_size(&stream_frame)?; fresh_data = size_known.recv(stream_frame, body)?; if size_known.is_all_rcvd() { is_into_rcvd = true; *receiving_state = Recver::DataRcvd(size_known.upgrade()); } else { *receiving_state = Recver::SizeKnown(size_known); } } else { fresh_data = r.recv(stream_frame, body)?; } } Recver::SizeKnown(r) => { fresh_data = r.recv(stream_frame, body)?; if r.is_all_rcvd() { is_into_rcvd = true; *receiving_state = Recver::DataRcvd(r.upgrade()); } } _ => {} } } Ok((is_into_rcvd, fresh_data)) } /// Receive a stream reset frame from peer. /// /// If all data sent by the peer has not been received, receiving a stream reset frame will cause /// any read calls to return an error, received data will be discarded. pub fn recv_reset(&self, reset_frame: ResetStreamFrame) -> Result { // TODO: ResetStream中还有错误信息,比如http3的错误码,看是否能用到 let mut sync_fresh_data = 0; let mut recver = self.0.recver(); let inner = recver.deref_mut(); if let Ok(receiving_state) = inner { match receiving_state { Recver::Recv(r) => { sync_fresh_data = r.recv_reset(&reset_frame)?; *receiving_state = Recver::ResetRcvd(reset_frame); } Recver::SizeKnown(r) => { r.recv_reset(&reset_frame)?; *receiving_state = Recver::ResetRcvd(reset_frame); } _ => unreachable!(), } } Ok(sync_fresh_data) } } impl Incoming { pub fn new(recver: ArcRecver) -> Self { Self(recver) } /// Called when a connecion error occured /// /// After the connection error occured, trying to read the data from [`Reader`] will result an /// Error. /// /// [`Reader`]: crate::recv::Reader pub fn on_conn_error(&self, err: &Error) { let mut recver = self.0.recver(); let inner = recver.deref_mut(); match inner { Ok(receiving_state) => match receiving_state { Recver::Recv(r) => r.wake_reader(), Recver::SizeKnown(r) => r.wake_reader(), _ => return, }, Err(_) => return, }; *inner = Err(err.clone()); } } ================================================ FILE: qrecovery/src/recv/rcvbuf.rs ================================================ //! An implementation of the receiving buffer for stream data. use std::collections::VecDeque; use bytes::{Buf, BufMut, Bytes}; /// 一段连续的数据片段,每个片段都是Bytes #[derive(Debug, Default)] struct Segment { offset: u64, data: Bytes, } impl Segment { fn new_with_data(offset: u64, data: Bytes) -> Self { Segment { offset, data } } fn end(&self) -> u64 { self.offset + self.data.len() as u64 } } /// Received data of a stream is stored in [`RecvBuf`]. /// /// The receiving buffer is relatively simple, as it receives segmented data /// that may not be continuous. It sequentially stores the received data /// fragments and then reassembles them into a continuous data stream for /// future reading by the application layer. /// /// It implements the [`Buf`] triat and can operate on the **received continuous /// data** through the [`Buf`] trait. [`Buf::has_remaining`] return `flase` not /// only when the buffer is empty, but also when no readable continuous data in /// the buffer. #[derive(Default, Debug)] pub struct RecvBuf { nread: u64, largest_offset: u64, // segments[0].offset >= nread segments: VecDeque, } impl RecvBuf { /// Returns whether the receiving buffer is empty. pub fn is_empty(&self) -> bool { self.segments.is_empty() } /// Returns how many continuous data have been read. /// /// # Example /// /// ``` rust /// # use bytes::{Bytes, BytesMut}; /// # use qrecovery::recv::RecvBuf; /// let mut recvbuf = RecvBuf::default(); /// assert_eq!(recvbuf.nread(), 0); /// /// recvbuf.recv(0, Bytes::from("hello")); /// assert_eq!(recvbuf.nread(), 0); /// // recvbuf: hello /// // offset=0 ^ /// /// let mut dst = BytesMut::new(); /// recvbuf.try_read(&mut dst); /// assert_eq!(recvbuf.nread(), 5); /// // recvbuf: hello /// // offset=5 ^ pub fn nread(&self) -> u64 { self.nread } /// Returns the largest offset received. /// /// For receiver in SizeKnown state, this must smaller than the `final_size` pub fn largest_offset(&self) -> u64 { self.largest_offset } /// Receive a fragment of data, returns the consumption of the flow limit /// /// # Example /// /// The following example demonstrates how [`RecvBuf`] works. /// /// The data "hello, world!" is splitted into four fragments. /// ``` rust /// # use bytes::{Bytes, BytesMut}; /// # use qrecovery::recv::RecvBuf; /// let mut recvbuf = RecvBuf::default(); /// // data: "hello, world!" /// assert_eq!(recvbuf.recv(0, Bytes::from("hell")), 4); /// // recvbuf: "hell" /// // new: "hell" /// assert_eq!(recvbuf.recv(7, Bytes::from("world")), 8); /// // recvbuf: "hell" "world" /// // new: "world" /// assert_eq!(recvbuf.recv(3, Bytes::from("lo, ")), 0); /// // recvbuf: "hello, world" /// // new: "o, " /// assert_eq!(recvbuf.recv(7, Bytes::from("world!")), 1); /// // recvbuf: "hello, world!" /// // new: "!" /// let mut received = BytesMut::new(); /// recvbuf.try_read(&mut received); /// assert_eq!(received.as_ref(), b"hello, world!"); /// ``` pub fn recv(&mut self, offset: u64, mut data: Bytes) -> u64 { let previous_largest = self.largest_offset; // advance data that already read let mut start = offset.max(self.nread); data.advance(data.remaining().min((start - offset) as usize)); loop { if data.is_empty() { break; } // 从前往后放: match self.segments.binary_search_by(|seg| seg.offset.cmp(&start)) { // 恰好和现有的一个数据段在同一位置开始现有的数据段上,如: // | exist_seg | ... | // | new_seg....................| // 裁剪掉new_seg的前面部分,然后继续循环 // | exist_seg | ... | // | new_seg........| // 绝大多数情况下都会先进入这一个分支 Ok(exist_seg_index) => { let length_covered = data.len().min(self.segments[exist_seg_index].data.len()); data.advance(length_covered); start += length_covered as u64; } // 没有恰好和一个现有的数据段重合:查看和上一个&下一段数据是否重合,裁去重合的部分 // | exist_seg1 | | exist_seg2 | // 1. | new_seg| // 2. | new_seg | // 二分查找的结果seg_index可能是上一个seg的index,也可能是下一个seg的index // 1. 如果是上一个seg的index,需要有逻辑:需要检查下一个seg是否存在,如果存在就裁剪自身 // 2. 如果是下一个seg的index(只可能是index=0),也会执行上述逻辑,故index 0 可以做特别处理 Err(0) => { let uncovered = match self.segments.front() { // 如果和下一段数据有重合的话,裁下data中前一部分(不重合的部分) Some(next_seg) if start + data.len() as u64 > next_seg.offset => { // 裁下后,start必定和next_seg.offset相等,下次loop就会进入上一个分支 // next_seg.offset < start + data.len() // next_seg.offset - start < data.len() ,不会越界 data.split_to((next_seg.offset - start) as usize) } // 如果没有重合,或者这是第一段数据,直接取出整个data // 然后下次循环时data.is_empty() == true => break Some(..) | None => core::mem::take(&mut data), }; let segment = Segment::new_with_data(start, uncovered); start += segment.data.len() as u64; self.largest_offset = self.largest_offset.max(segment.end()); self.segments.push_front(segment); } // seg_index != 0 => seg_index > 0 // start > prev_seg.offset Err(seg_index) => { // 首先需要检测是否和上一个seg重合 // 此步骤完成后, offset >= prev_seg.end() data = match self.segments.get(seg_index - 1) { // start > prev_seg.offset && end <= prev_seg.end() // | ---prev_seg-- | // | new_seg | // 有可能这一段完全被上一段囊括,直接break Some(prev_seg) if (start + data.len() as u64) <= prev_seg.end() => break, // start > prev_seg.offset && start < prev_seg.end() // | ---prev_seg-- | // | ---new_seg--- | // 裁剪掉和上一段重合的,剩下的部分也一定不是空的 Some(prev_seg) if start < prev_seg.end() => { // 裁下后,start必定和prev_seg.end()相等 // 下次loop就会进入上一个分支 // start < prev_seg.end() => 0 < prev_seg.end() - start,不会越界 let length_covered = prev_seg.end() - start; start += length_covered; data.split_off(length_covered as usize) } // 如果没有重合,直接取出data // 然后下次循环时data.is_empty() == true => break Some(..) | None => data, }; let uncovered = match self.segments.get(seg_index) { // next_seg.offset >= prev_seg.end() && start >= prev_seg.end() // | ---next_seg--- | // | ---new_seg-- | // uncovered 为 [prev_seg.end(), next_seg.offset)区间的数据 // 如果offset == next_seg.offset,说明unconvert是空的,直接continue Some(next_seg) if start == next_seg.offset => continue, // | --next_seg--- | // | ---new_seg-- | // 如果和下一段数据有重合的话,裁下data中不重合的部分 Some(next_seg) if start + data.len() as u64 > next_seg.offset => { // 裁下后,start必定和next_seg.offset相等,下次loop就会进入上一个分支 // next_seg.offset < start + data.len() // next_seg.offset - start < data.len() ,不会越界 data.split_to((next_seg.offset - start) as usize) } // 如果没有重合,或者这是第一段数据,直接取出data // 然后下次循环时data.is_empty() == true => break Some(..) | None => core::mem::take(&mut data), }; let segment = Segment::new_with_data(start, uncovered); start += segment.data.len() as u64; self.largest_offset = self.largest_offset.max(segment.end()); self.segments.insert(seg_index, segment); } } // 进入新的循环(也可递归) } self.largest_offset - previous_largest } /// Returns the length of continuous unread data. pub fn available(&self) -> u64 { use core::ops::ControlFlow; let (ControlFlow::Continue(continuous_end) | ControlFlow::Break(continuous_end)) = self.segments.iter().try_fold(self.nread, |offset, seg| { if seg.offset == offset { ControlFlow::Continue(offset + seg.data.len() as u64) } else { ControlFlow::Break(offset) } }); continuous_end - self.nread } /// Once the received data becomes continuous, it becomes readable. If necessary (if the application /// layer is blocked on reading), it is necessary to notify the application layer to read. pub fn is_readable(&self) -> bool { !self.segments.is_empty() && self.segments[0].offset == self.nread } /// Try to read continuous data from [`RecvBuf`] into the buffer passed in. /// /// If the following data is not continuous or there is no data, this method returns [`None`] /// /// Otherwise, returns how much data was written to the buffer passed in. /// /// # Example /// /// ``` rust /// # use bytes::{BytesMut, Bytes}; /// # use qrecovery::recv::RecvBuf; /// let mut recvbuf = RecvBuf::default(); /// recvbuf.recv(0, Bytes::from("012")); /// recvbuf.recv(3, Bytes::from("345")); /// recvbuf.recv(7, Bytes::from("789")); /// // recvbuf: 012345 789 /// // readable: ^^^^^^ /// /// let mut dst1 = BytesMut::new(); /// recvbuf.try_read(&mut dst1); /// assert_eq!(dst1.as_ref(), b"012345"); /// /// let mut dst2 = BytesMut::new(); /// recvbuf.recv(6, Bytes::from("6")); /// // recvbuf: 0123456789 /// // readable: ^^^^ /// /// recvbuf.try_read(&mut dst2); /// assert_eq!(dst2.as_ref(), b"6789"); /// pub fn try_read(&mut self, dst: &mut impl BufMut) -> usize { let origin = dst.remaining_mut(); while let Some(seg) = self.segments.front_mut() { if seg.offset != self.nread || !dst.has_remaining_mut() { break; } let read = dst.remaining_mut().min(seg.data.len()); dst.put(seg.data.split_to(read)); self.nread += read as u64; if seg.data.has_remaining() { seg.offset += read as u64; } else { self.segments.pop_front(); } } origin - dst.remaining_mut() } /// Try to get the next continuous data segment. /// /// Compared with [`Self::try_read`], this method is more efficient /// because it reduces some calculations and copies. pub fn try_next(&mut self) -> Option { if self.is_readable() { let data = self.segments.pop_front().unwrap().data; self.nread += data.len() as u64; return Some(data); } None } } #[cfg(test)] mod tests { use super::*; #[test] fn test_no_overlap() { let mut buf = RecvBuf::default(); assert_eq!(buf.recv(0, Bytes::from("hello")), 5); assert_eq!(buf.recv(6, Bytes::from("world")), 6); assert_eq!(buf.segments.len(), 2); assert_eq!(buf.segments[0].offset, 0); assert_eq!(buf.segments[1].offset, 6); assert_eq!(buf.recv(5, Bytes::from(" ")), 0); assert_eq!(buf.segments.len(), 3); assert_eq!(buf.segments[0].offset, 0); assert_eq!(buf.segments[1].offset, 5); assert_eq!(buf.segments[2].offset, 6); } #[test] fn test_left_partially_overlap() { let mut buf = RecvBuf::default(); assert_eq!(buf.recv(0, Bytes::from("01234")), 5); assert_eq!(buf.recv(2, Bytes::from("2345")), 1); //left segment partially overlapped this assert_eq!(buf.recv(6, Bytes::from("6789")), 4); // no overlap assert_eq!(buf.segments.len(), 3); assert_eq!(buf.segments[0].offset, 0); assert_eq!(buf.segments[1].offset, 5); assert_eq!(buf.segments[2].offset, 6); assert_eq!(buf.available(), 10); } #[test] fn test_right_partially_overlap() { let mut buf = RecvBuf::default(); assert_eq!(buf.recv(0, Bytes::from("hello")), 5); assert_eq!(buf.recv(6, Bytes::from("world!")), 7); assert_eq!(buf.recv(5, Bytes::from(" wor")), 0); // overlap right assert_eq!(buf.segments.len(), 3); assert_eq!(buf.segments[0].offset, 0); assert_eq!(buf.segments[1].offset, 5); assert_eq!(buf.segments[2].offset, 6); assert_eq!(buf.available(), 12); } #[test] #[doc(alias = "fully_overlap_left")] fn test_same_offset() { let mut buf = RecvBuf::default(); assert_eq!(buf.recv(0, Bytes::from("01234")), 5); assert_eq!(buf.recv(0, Bytes::from("0123456789")), 5); assert_eq!(buf.segments.len(), 2); assert_eq!(buf.segments[0].offset, 0); assert_eq!(buf.segments[1].offset, 5); assert_eq!(buf.available(), 10); } #[test] fn test_fully_overlap_right() { let mut buf = RecvBuf::default(); assert_eq!(buf.recv(0, Bytes::from("hello")), 5); assert_eq!(buf.recv(6, Bytes::from("world")), 6); assert_eq!(buf.recv(5, Bytes::from(" world!")), 1); // fully overlap right assert_eq!(buf.segments.len(), 4); assert_eq!(buf.segments[0].offset, 0); assert_eq!(buf.segments[1].offset, 5); assert_eq!(buf.segments[2].offset, 6); assert_eq!(buf.segments[3].offset, 11); assert_eq!(buf.available(), 12); } #[test] fn test_left_fully_overlap() { let mut buf = RecvBuf::default(); assert_eq!(buf.recv(0, Bytes::from("114514")), 6); assert_eq!(buf.recv(2, Bytes::from("45")), 0); // left segment fully overlapped this assert_eq!(buf.recv(2, Bytes::from("4514")), 0); // left segment fully overlapped this assert_eq!(buf.segments.len(), 1); assert_eq!(buf.segments[0].offset, 0); assert_eq!(buf.available(), 6); } #[test] fn test_right_fully_overlapp() { let mut buf = RecvBuf::default(); assert_eq!(buf.recv(0, Bytes::from("114514")), 6); assert_eq!(buf.recv(6, Bytes::from("1919810")), 7); assert_eq!(buf.recv(8, Bytes::from("1981")), 0); // right segment fully overlapped this assert_eq!(buf.recv(8, Bytes::from("19810")), 0); // right segment fully overlapped this assert_eq!(buf.segments.len(), 2); assert_eq!(buf.segments[0].offset, 0); assert_eq!(buf.segments[1].offset, 6); assert_eq!(buf.available(), 13); } #[test] fn test_left_right_partially_overlap() { let mut buf = RecvBuf::default(); assert_eq!(buf.recv(0, Bytes::from("012345")), 6); assert_eq!(buf.recv(7, Bytes::from("789")), 4); assert_eq!(buf.recv(6, Bytes::from("6")), 0); // left and right partially overlapped this assert_eq!(buf.segments.len(), 3); assert_eq!(buf.segments[0].offset, 0); assert_eq!(buf.segments[1].offset, 6); assert_eq!(buf.segments[2].offset, 7); assert_eq!(buf.available(), 10); } #[test] fn test_left_right_fully_overlap() { let mut buf = RecvBuf::default(); assert_eq!(buf.recv(0, Bytes::from("01234")), 5); assert_eq!(buf.recv(5, Bytes::from("56789")), 5); assert_eq!(buf.recv(2, Bytes::from("2345678")), 0); // left and right fully overlapped this assert_eq!(buf.segments.len(), 2); assert_eq!(buf.segments[0].offset, 0); assert_eq!(buf.segments[1].offset, 5); assert_eq!(buf.available(), 10); } #[test] fn test_recvbuf_read() { let mut rcvbuf = RecvBuf::default(); assert_eq!(rcvbuf.recv(0, Bytes::from("hello")), 5); assert_eq!(rcvbuf.recv(6, Bytes::from("world")), 6); let mut dst = [0u8; 20]; let mut buf = &mut dst[..]; rcvbuf.try_read(&mut buf); assert_eq!(buf.remaining_mut(), 15); assert_eq!(rcvbuf.recv(5, Bytes::from(" ")), 0); rcvbuf.try_read(&mut buf); assert_eq!(buf.remaining_mut(), 9); assert_eq!(dst[..11], b"hello world"[..]); } } ================================================ FILE: qrecovery/src/recv/reader.rs ================================================ use std::{ io::{self}, ops::DerefMut, pin::Pin, task::{Context, Poll}, }; use bytes::Bytes; use futures::Stream; use qbase::{ frame::{MaxStreamDataFrame, StopSendingFrame, io::SendFrame}, varint::VARINT_MAX, }; use qevent::quic::transport::{GranularStreamStates, StreamSide, StreamStateUpdated}; use tokio::io::{AsyncRead, ReadBuf}; use super::recver::{ArcRecver, Recver}; use crate::streams::error::StreamError; pub trait StopSending { /// Tell peer to stop sending data with the given error code. /// /// If all data has been received (the stream has closed), or the stream has been reset, this method will do /// nothing. /// /// Otherwise, a [`STOP_SENDING frame`] will be sent to the peer, and then the stream will be reset by peer, /// neither new data nor lost data will be sent. /// /// Unlike TCP, stopping a QUIC stream needs an error code, which is used to indicate /// the reason for the stopping. The error code should be a `u64` value, /// defined by the application protocol using QUIC, such as HTTP/3 or gRPC. /// /// [`STOP_SENDING frame`]: https://www.rfc-editor.org/rfc/rfc9000.html#name-stop_sending-frames fn stop(&mut self, error_code: u64); } /// The reader part of a QUIC stream. /// /// A QUIC stream is *reliable*, *ordered*, and *flow-controlled*. /// /// This struct implements the [`AsyncRead`] trait, allowing you to read an ordered byte stream from /// the peer, like [`TcpStream`]. /// /// Try to read from the [`Reader`] into a non-empty buffer, the [`Reader`] will block until some data /// is available, or the stream is closed, or the stream is reset by peer. /// /// # Note /// /// The stream must be closed before [`Reader`] dropped. /// /// The [`read`] returning `Ok(0)` indicates that all data from peer has been read and the stream has /// `closed`, it is okay to drop the [`Reader`] after that. /// /// Alternatively, if the [`read`] result an error, its indicates that the stream has been `reset`, or /// closed duo to other reasons. It's also okay to drop the [`Reader`] after that. /// /// You can call [`stop`] to tell the peer to stop sending data with the given error code, the [`Reader`] /// will be consumed, and the error code will be sent to the peer. /// /// # Example /// /// The [`Reader`] is created by the `open_bi_stream`, `accept_bi_stream`, or `accept_uni_stream` methods /// of `QuicConnection` (in the `quic` crate). /// /// The following example demonstrates how to read and write data on a QUIC stream: /// /// ```rust, ignore /// # use tokio::io::{AsyncWriteExt, AsyncReadExt}; /// # async fn example() -> std::io::Result<()> { /// let (reader, writer) = quic_connection.open_bi_stream().await?; /// /// writer.write_all(b"GET README.md\r\n").await?; /// writer.shutdown().await?; /// /// let mut response = String::new(); /// let n = reader.read_to_string(&mut response).await?; /// println!("Response {} bytes: {}", n, response); /// Ok(()) /// # } /// ``` /// /// [`TcpStream`]: tokio::net::TcpStream /// [`read`]: tokio::io::AsyncReadExt::read /// [`stop`]: Reader::stop /// [`RESET_STREAM frame`]: https://www.rfc-editor.org/rfc/rfc9000.html#name-reset_stream-frames #[derive(Debug)] pub struct Reader { inner: ArcRecver, qlog_span: qevent::telemetry::Span, tracing_span: tracing::Span, } impl Reader { /// Create a new [`Reader`] from the given [`Recver`]. /// /// This method is used by the `accept_bi_stream` and `accept_uni_stream` methods of /// [`QuicConnection`](crate::QuicConnection). pub(crate) fn new(inner: ArcRecver) -> Self { Self { inner, qlog_span: qevent::telemetry::Span::current(), tracing_span: tracing::Span::current(), } } #[inline] pub fn poll_read( &mut self, cx: &mut Context<'_>, buf: &mut impl bytes::BufMut, ) -> Poll> where TX: SendFrame, { let _span = (self.qlog_span.enter(), self.tracing_span.enter()); let mut recver = self.inner.recver(); let receiving_state = recver.as_mut().map_err(|e| e.clone())?; // 能相当清楚地看到应用层读取数据驱动的接收状态演变 match receiving_state { Recver::Recv(r) => r.poll_read(cx, buf).map(Ok), Recver::SizeKnown(r) => r.poll_read(cx, buf).map(Ok), Recver::DataRcvd(r) => { r.poll_read(buf); if r.is_all_read() { r.upgrade(); *receiving_state = Recver::DataRead; } Poll::Ready(Ok(())) } Recver::DataRead => Poll::Ready(Ok(())), Recver::ResetRcvd(r) => { qevent::event!(StreamStateUpdated { stream_id: r.stream_id().id(), stream_type: r.stream_id().dir(), old: GranularStreamStates::ResetReceived, new: GranularStreamStates::ResetRead, stream_side: StreamSide::Receiving }); let reset_stream_error = (&*r).into(); *receiving_state = Recver::ResetRead(reset_stream_error); Poll::Ready(Err(reset_stream_error.into())) } Recver::ResetRead(r) => Poll::Ready(Err((*r).into())), } } #[inline] pub fn poll_next( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> where TX: SendFrame, { let _span = (self.qlog_span.enter(), self.tracing_span.enter()); let mut recver = self.inner.recver(); let receiving_state = recver.as_mut().map_err(|e| e.clone())?; // 能相当清楚地看到应用层读取数据驱动的接收状态演变 match receiving_state { Recver::Recv(r) => r.poll_next(cx).map(Ok).map(Some), Recver::SizeKnown(r) => r.poll_next(cx).map(Ok).map(Some), Recver::DataRcvd(r) => { let Some(data) = r.poll_next() else { return Poll::Ready(None); }; if r.is_all_read() { r.upgrade(); *receiving_state = Recver::DataRead; } Poll::Ready(Some(Ok(data))) } Recver::DataRead => Poll::Ready(None), Recver::ResetRcvd(r) => { qevent::event!(StreamStateUpdated { stream_id: r.stream_id().id(), stream_type: r.stream_id().dir(), old: GranularStreamStates::ResetReceived, new: GranularStreamStates::ResetRead, stream_side: StreamSide::Receiving }); let reset_stream_error = (&*r).into(); *receiving_state = Recver::ResetRead(reset_stream_error); Poll::Ready(Some(Err(reset_stream_error.into()))) } Recver::ResetRead(r) => Poll::Ready(Some(Err((*r).into()))), } } } impl StopSending for Reader where TX: SendFrame, { /// Tell peer to stop sending data with the given error code. /// /// If all data has been received(the stream has closed), or the stream has been reset, this method will do /// nothing. /// /// Otherwise, a [`STOP_SENDING frame`] will be sent to the peer, and then the stream will be reset by peer. /// /// [`STOP_SENDING frame`]: https://www.rfc-editor.org/rfc/rfc9000.html#name-stop_sending-frames fn stop(&mut self, error_code: u64) { let _span = (self.qlog_span.enter(), self.tracing_span.enter()); debug_assert!(error_code <= VARINT_MAX); let mut recver = self.inner.recver(); let inner = recver.deref_mut(); if let Ok(receiving_state) = inner { match receiving_state { Recver::Recv(r) => { r.stop(error_code); } Recver::SizeKnown(r) => { r.stop(error_code); } _ => (), } } } } impl AsyncRead for Reader where TX: SendFrame, { #[inline] fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { Reader::poll_read(self.get_mut(), cx, buf).map_err(io::Error::from) } } impl Stream for Reader where TX: SendFrame, { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Reader::poll_next(self, cx) } } impl Drop for Reader { fn drop(&mut self) { let mut recver = self.inner.recver(); let inner = recver.deref_mut(); if let Ok(receiving_state) = inner { match receiving_state { Recver::Recv(r) if !r.is_stopped() => { #[cfg(debug_assertions)] tracing::warn!( target: "quic", "The receiving {} is not stopped with error before dropped!", r.stream_id(), ); #[cfg(not(debug_assertions))] tracing::debug!( target: "quic", "The receiving {} is not stopped with error before dropped!", r.stream_id(), ); } Recver::SizeKnown(r) if !r.is_stopped() => { #[cfg(debug_assertions)] tracing::warn!( target: "quic", "The receiving {} is not stopped with error before dropped!", r.stream_id() ); #[cfg(not(debug_assertions))] tracing::debug!( target: "quic", "The receiving {} is not stopped with error before dropped!", r.stream_id() ); } _ => (), } } } } ================================================ FILE: qrecovery/src/recv/recver.rs ================================================ use std::{ io, sync::{Arc, Mutex, MutexGuard}, task::{Context, Poll, Waker}, }; use bytes::{BufMut, Bytes}; use qbase::{ error::{Error, ErrorKind, QuicError}, frame::{ GetFrameType, MaxStreamDataFrame, ResetStreamError, ResetStreamFrame, StopSendingFrame, StreamFrame, io::SendFrame, }, sid::StreamId, varint::{VARINT_MAX, VarInt}, }; use qevent::quic::transport::{ GranularStreamStates, StreamDataLocation, StreamDataMoved, StreamSide, StreamStateUpdated, }; use super::rcvbuf; #[derive(Debug)] pub(super) struct Recv { stream_id: StreamId, rcvbuf: rcvbuf::RecvBuf, read_waker: Option, stop_state: Option, broker: TX, largest: u64, max_stream_data: u64, } impl Recv where TX: SendFrame, { pub(super) fn poll_read(&mut self, cx: &mut Context<'_>, buf: &mut impl BufMut) -> Poll<()> { if let Some(_reset) = self.stop_state { // Though STOP_SENDING has been sent, the application layer can still read the data } if !self.rcvbuf.is_readable() { self.read_waker = Some(cx.waker().clone()); return Poll::Pending; } let offset = self.rcvbuf.nread(); let length = self.rcvbuf.try_read(buf) as u64; qevent::event!(StreamDataMoved { stream_id: self.stream_id, offset, length, from: StreamDataLocation::Transport, to: StreamDataLocation::Application, }); let threshold = 1_000_000; if self.rcvbuf.nread() + threshold > self.max_stream_data { let max_stream_data = (self.rcvbuf.nread() + threshold * 2).min(VARINT_MAX); if max_stream_data > self.max_stream_data { self.max_stream_data = max_stream_data; self.broker.send_frame([MaxStreamDataFrame::new( self.stream_id, VarInt::from_u64(max_stream_data).unwrap(), )]); } } Poll::Ready(()) } pub(super) fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll { if !self.rcvbuf.is_readable() { self.read_waker = Some(cx.waker().clone()); return Poll::Pending; } let offset = self.rcvbuf.nread(); let data = self.rcvbuf.try_next().expect("is_readable checked"); let length = data.len() as u64; qevent::event!(StreamDataMoved { stream_id: self.stream_id, offset, length, from: StreamDataLocation::Transport, to: StreamDataLocation::Application, }); let threshold = 1_000_000; if self.rcvbuf.nread() + threshold > self.max_stream_data { let max_stream_data = (self.rcvbuf.nread() + threshold * 2).min(VARINT_MAX); if max_stream_data > self.max_stream_data { self.max_stream_data = max_stream_data; self.broker.send_frame([MaxStreamDataFrame::new( self.stream_id, VarInt::from_u64(max_stream_data).unwrap(), )]); } } Poll::Ready(data) } } impl Recv where TX: SendFrame, { pub(super) fn stop(&mut self, err_code: u64) { if self.stop_state.is_none() { self.stop_state = Some(err_code); self.broker.send_frame([StopSendingFrame::new( self.stream_id, VarInt::from_u64(err_code).expect("app error code must not exceed 2^62!"), )]); } } } impl Recv { pub(super) fn determin_size( &mut self, stream_frame: &StreamFrame, ) -> Result, QuicError> { if let Some(waker) = self.read_waker.take() { waker.wake(); } let final_size = stream_frame.offset() + stream_frame.len() as u64; let received_largest_offset = self.rcvbuf.largest_offset(); if received_largest_offset > final_size { return Err(QuicError::new( ErrorKind::FinalSize, stream_frame.frame_type().into(), format!( "{} received a wrong smaller final size {} than the largest rcvd data offset {}", stream_frame.stream_id(), final_size, received_largest_offset ), )); } qevent::event!(StreamStateUpdated { stream_id: self.stream_id.id(), stream_type: self.stream_id.dir(), old: GranularStreamStates::Receive, new: GranularStreamStates::SizeKnown, stream_side: StreamSide::Receiving }); Ok(SizeKnown { final_size, stream_id: self.stream_id, rcvbuf: std::mem::take(&mut self.rcvbuf), stop_state: self.stop_state.take(), broker: self.broker.clone(), read_waker: self.read_waker.take(), }) } } impl Recv { pub(super) fn new(stream_id: StreamId, buf_size: u64, broker: TX) -> Self { Self { stream_id, rcvbuf: rcvbuf::RecvBuf::default(), read_waker: None, stop_state: None, broker, largest: 0, max_stream_data: buf_size, } } pub(super) fn stream_id(&self) -> StreamId { self.stream_id } pub(super) fn recv( &mut self, stream_frame: StreamFrame, body: Bytes, ) -> Result { let data_start = stream_frame.offset(); let data_end = data_start + body.len() as u64; if data_end > self.max_stream_data { return Err(QuicError::new( ErrorKind::FlowControl, stream_frame.frame_type().into(), format!( "{} send {data_end} bytes which exceeds the stream data limit {}", stream_frame.stream_id(), self.max_stream_data ), )); } let data_length = body.len() as u64; let fresh_data = self.rcvbuf.recv(data_start, body); qevent::event!( StreamDataMoved { stream_id: self.stream_id, offset: data_start, length: data_length, from: StreamDataLocation::Network, to: StreamDataLocation::Transport, }, fresh_data ); if self.largest < data_end { self.largest = data_end; } if self.rcvbuf.is_readable() && let Some(waker) = self.read_waker.take() { waker.wake() } Ok(fresh_data as _) } pub(super) fn recv_reset( &mut self, reset_frame: &ResetStreamFrame, ) -> Result { let final_size = reset_frame.final_size(); if final_size < self.largest { return Err(QuicError::new( ErrorKind::FinalSize, reset_frame.frame_type().into(), format!( "{} reset with a wrong smaller final size {final_size} than the largest rcvd data offset {}", reset_frame.stream_id(), self.largest ), )); } self.wake_reader(); log_reset_event(self.stream_id, GranularStreamStates::Receive); Ok((final_size - self.largest) as _) } pub(super) fn is_stopped(&self) -> bool { self.stop_state.is_some() } pub(super) fn wake_reader(&mut self) { if let Some(waker) = self.read_waker.take() { waker.wake() } } } /// Once the size of the data stream is determined, MAX_STREAM_DATA will no longer /// be updated. Receiving data on this stream is meaningless. At this point, it is /// also meaningless for the application layer to continue receiving data. #[derive(Debug)] pub struct SizeKnown { stream_id: StreamId, rcvbuf: rcvbuf::RecvBuf, read_waker: Option, stop_state: Option, broker: TX, final_size: u64, } impl SizeKnown { pub(super) fn stream_id(&self) -> StreamId { self.stream_id } pub(super) fn recv( &mut self, stream_frame: StreamFrame, data: Bytes, ) -> Result { let data_start = stream_frame.offset(); let data_end = data_start + data.len() as u64; if data_end > self.final_size { return Err(QuicError::new( ErrorKind::FinalSize, stream_frame.frame_type().into(), format!( "{} send {data_end} bytes which exceeds the final_size {}", stream_frame.stream_id(), self.final_size ), )); } if stream_frame.is_fin() && data_end != self.final_size { return Err(QuicError::new( ErrorKind::FinalSize, stream_frame.frame_type().into(), format!( "{} change the final size from {} to {data_end}", stream_frame.stream_id(), self.final_size ), )); } let data_length = data.len() as u64; let fresh_data = self.rcvbuf.recv(data_start, data); qevent::event!( StreamDataMoved { stream_id: self.stream_id, offset: data_start, length: data_length, from: StreamDataLocation::Network, to: StreamDataLocation::Transport, }, fresh_data ); if self.rcvbuf.is_readable() && let Some(waker) = self.read_waker.take() { waker.wake() } Ok(fresh_data as usize) } pub(super) fn is_all_rcvd(&self) -> bool { self.rcvbuf.nread() + self.rcvbuf.available() == self.final_size } #[allow(dead_code)] pub(super) fn read(&mut self, mut buf: &mut [u8]) -> io::Result { if self.rcvbuf.is_readable() { let buflen = buf.remaining_mut(); self.rcvbuf.try_read(&mut buf); Ok(buflen - buf.remaining_mut()) } else { Err(io::ErrorKind::WouldBlock.into()) } } pub(super) fn poll_read(&mut self, cx: &mut Context<'_>, buf: &mut impl BufMut) -> Poll<()> { if let Some(_reset) = self.stop_state { // Though STOP_SENDING has been sent, the application layer can still read the data } if !self.rcvbuf.is_readable() { self.read_waker = Some(cx.waker().clone()); return Poll::Pending; } let offset = self.rcvbuf.nread(); let length = self.rcvbuf.try_read(buf) as u64; qevent::event!(StreamDataMoved { stream_id: self.stream_id, offset, length, from: StreamDataLocation::Transport, to: StreamDataLocation::Application, }); Poll::Ready(()) } pub(super) fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll { if !self.rcvbuf.is_readable() { self.read_waker = Some(cx.waker().clone()); return Poll::Pending; } let offset = self.rcvbuf.nread(); let data = self.rcvbuf.try_next().expect("is_readable checked"); let length = data.len() as u64; qevent::event!(StreamDataMoved { stream_id: self.stream_id, offset, length, from: StreamDataLocation::Transport, to: StreamDataLocation::Application, }); Poll::Ready(data) } pub(super) fn recv_reset(&mut self, reset_frame: &ResetStreamFrame) -> Result<(), QuicError> { let final_size = reset_frame.final_size(); if final_size != self.final_size { return Err(QuicError::new( ErrorKind::FinalSize, reset_frame.frame_type().into(), format!( "{} change the final size from {} to {final_size}", reset_frame.stream_id(), self.final_size ), )); } self.wake_reader(); log_reset_event(self.stream_id, GranularStreamStates::SizeKnown); Ok(()) } pub(super) fn is_stopped(&self) -> bool { self.stop_state.is_some() } pub(super) fn wake_reader(&mut self) { if let Some(waker) = self.read_waker.take() { waker.wake() } } } impl SizeKnown where TX: SendFrame + Clone + Send + 'static, { pub(super) fn upgrade(&mut self) -> DataRcvd { qevent::event!(StreamStateUpdated { stream_id: self.stream_id.id(), stream_type: self.stream_id.dir(), old: GranularStreamStates::SizeKnown, new: GranularStreamStates::DataReceived, stream_side: StreamSide::Receiving }); self.wake_reader(); DataRcvd { stream_id: self.stream_id, rcvbuf: std::mem::take(&mut self.rcvbuf), } } } impl SizeKnown where TX: SendFrame, { /// Abort can be called multiple times at the application level, /// but only the first call is effective. pub(super) fn stop(&mut self, err_code: u64) { if self.stop_state.is_none() { self.stop_state = Some(err_code); self.broker.send_frame([StopSendingFrame::new( self.stream_id, VarInt::from_u64(err_code).expect("app error code must not exceed 2^62!"), )]); } } } /// Once all the data has been received, STOP_SENDING becomes meaningless. /// If the application layer aborts reading, it will directly result in the termination /// of the lifecycle, leading to the release of all states and data. There is also no /// need for any further readable notifications to wake up. Subsequent reads will /// immediately return the available data until the end. #[derive(Debug)] pub struct DataRcvd { stream_id: StreamId, rcvbuf: rcvbuf::RecvBuf, } impl DataRcvd { /// Unlike the previous states, when there is no more data, it no longer returns /// "WouldBlock" but instead returns 0, which typically indicates the end. #[allow(dead_code)] pub(super) fn read(&mut self, mut buf: &mut [u8]) -> io::Result { let buflen = buf.remaining_mut(); self.rcvbuf.try_read(&mut buf); Ok(buflen - buf.remaining_mut()) } /// Unlike the previous states, when there is no more data, it no longer returns /// "Pending" but instead returns "Ready". However, in reality, nothing has been /// read. This kind of result typically indicates the end. pub(super) fn poll_read(&mut self, buf: &mut impl BufMut) { let offset = self.rcvbuf.nread(); let length = self.rcvbuf.try_read(buf) as u64; qevent::event!(StreamDataMoved { stream_id: self.stream_id, offset, length, from: StreamDataLocation::Transport, to: StreamDataLocation::Application, }); } pub(super) fn poll_next(&mut self) -> Option { let offset = self.rcvbuf.nread(); let data = self.rcvbuf.try_next()?; let length = data.len() as u64; qevent::event!(StreamDataMoved { stream_id: self.stream_id, offset, length, from: StreamDataLocation::Transport, to: StreamDataLocation::Application, }); Some(data) } pub(super) fn is_all_read(&self) -> bool { self.rcvbuf.is_empty() } } fn log_reset_event(stream_id: StreamId, old: GranularStreamStates) { qevent::event!(StreamStateUpdated { stream_id: stream_id.id(), stream_type: stream_id.dir(), old, new: GranularStreamStates::ResetReceived, stream_side: StreamSide::Receiving }); } impl DataRcvd { pub(super) fn upgrade(&self) { qevent::event!(StreamStateUpdated { stream_id: self.stream_id.id(), stream_type: self.stream_id.dir(), old: GranularStreamStates::DataReceived, new: GranularStreamStates::DataRead, stream_side: StreamSide::Receiving }); } } /// Receiving stream state machine. In fact, here the state variables such as /// is_closed/is_reset are replaced by a state machine. This not only provides /// clearer semantics and aligns with the QUIC RFC specification but also /// allows the compiler to help us check if the state transitions are correct #[derive(Debug)] pub(super) enum Recver { Recv(Recv), SizeKnown(SizeKnown), DataRcvd(DataRcvd), ResetRcvd(ResetStreamFrame), DataRead, ResetRead(ResetStreamError), } impl Recver { pub(super) fn new(stream_id: StreamId, buf_size: u64, frames_tx: TX) -> Self { Self::Recv(Recv::new(stream_id, buf_size, frames_tx)) } } /// The internal representations of [`Incoming`] and [`Reader`]. /// /// For the application layer, this structure is represented as [`Reader`]. The application can use it /// to read the data from the peer on the stream, or ask the peer stop sending. /// /// For the protocol layer, this structure is represented as [`Incoming`]. The protocol layer use it to /// manages the status of the `Recver` through it, delivers received data to the application layer and /// sends frames to the peer. /// /// [`Incoming`]: super::Incoming /// [`Reader`]: super::Reader #[derive(Debug, Clone)] pub struct ArcRecver(Arc, Error>>>); impl ArcRecver where TX: SendFrame + SendFrame + Clone + Send + 'static, { #[doc(hidden)] pub(crate) fn new(stream_id: StreamId, buf_size: u64, frames_tx: TX) -> Self { ArcRecver(Arc::new(Mutex::new(Ok(Recver::new( stream_id, buf_size, frames_tx, ))))) } } impl ArcRecver { pub(super) fn recver(&'_ self) -> MutexGuard<'_, Result, Error>> { self.0.lock().unwrap() } } ================================================ FILE: qrecovery/src/recv.rs ================================================ //! Types for receiving data on a Stream. mod incoming; mod rcvbuf; mod reader; mod recver; pub use incoming::Incoming; pub use rcvbuf::RecvBuf; pub use reader::{Reader, StopSending}; pub use recver::ArcRecver; ================================================ FILE: qrecovery/src/reliable.rs ================================================ //! The reliable transmission for frames. use std::{ collections::VecDeque, sync::{Arc, Mutex, MutexGuard}, }; use qbase::{ frame::{EncodeSize, FrameFeature, io::SendFrame}, net::tx::{ArcSendWakers, Signals}, packet::{Package, PacketContent}, }; /// A deque for data space to send reliable frames. /// /// Like its name, it is just a queue. [`DataStreams`] or other components that need to send reliable /// frames write frames to this queue by calling [`SendFrame::send_frame`]. The transport layer can /// load the frames from the queue into the packet by calling [`try_load_frames_into`]. /// /// # Example /// ```rust, no_run /// use qbase::frame::{HandshakeDoneFrame, ReliableFrame, io::SendFrame}; /// use qrecovery::reliable::ArcReliableFrameDeque; /// # let data_wakers = Default::default(); /// let mut reliable_frame_deque = ArcReliableFrameDeque::::with_capacity_and_wakers(10, data_wakers); /// reliable_frame_deque.send_frame([HandshakeDoneFrame]); /// ``` /// /// [`DataStreams`]: crate::streams::DataStreams /// [`try_load_frames_into`]: ArcReliableFrameDeque::try_load_frames_into #[derive(Debug, Default)] pub struct ArcReliableFrameDeque { frames: Arc>>, tx_wakers: ArcSendWakers, } impl Clone for ArcReliableFrameDeque { fn clone(&self) -> Self { Self { frames: self.frames.clone(), tx_wakers: self.tx_wakers.clone(), } } } impl ArcReliableFrameDeque { /// Create a new empty deque with at least the specified capacity. pub fn with_capacity_and_wakers(capacity: usize, tx_wakers: ArcSendWakers) -> Self { Self { frames: Arc::new(Mutex::new(VecDeque::with_capacity(capacity))), tx_wakers, } } fn frames_guard(&self) -> MutexGuard<'_, VecDeque> { self.frames.lock().unwrap() } /// Try to load the frame in deque and encode it into the `packet`. pub fn try_load_frames_into(&self, packet: &mut P) -> Result<(), Signals> where for<'a> &'a F: Package

, { let mut deque = self.frames_guard(); if deque.is_empty() { return Err(Signals::TRANSPORT); } while let Some(mut frame) = deque.front() { frame.dump(packet)?; deque.pop_front(); } Ok(()) } } impl Package

for ArcReliableFrameDeque where for<'a> &'a F: Package

, { fn dump(&mut self, packet: &mut P) -> Result { self.try_load_frames_into(packet)?; Ok(PacketContent::EffectivePayload) } } impl SendFrame for ArcReliableFrameDeque where F: EncodeSize + FrameFeature, T: Into, { fn send_frame>(&self, iter: I) { self.frames_guard().extend(iter.into_iter().map(Into::into)); self.tx_wakers.wake_all_by(Signals::TRANSPORT); } } ================================================ FILE: qrecovery/src/send/outgoing.rs ================================================ use std::ops::DerefMut; use bytes::{BufMut, Bytes}; use qbase::{ error::Error as QuicError, frame::{ResetStreamError, StreamFrame}, net::tx::Signals, packet::Package, sid::StreamId, util::ContinuousData, varint::VarInt, }; use qevent::quic::transport::{GranularStreamStates, StreamSide, StreamStateUpdated}; use super::sender::{ArcSender, Sender, SendingSender, StreamData}; /// An struct for protocol layer to manage the sending part of a stream. #[derive(Debug, Clone)] pub struct Outgoing(ArcSender); impl Outgoing { /// Try to load data that the application wants to sent to the packet. /// /// See [`DataStreams::try_load_data_into`] for more about this method. /// /// Return the size of data loaded, and whether the data is fresh. /// /// [`DataStreams::try_load_data_into`]: crate::streams::raw::DataStreams::try_load_data_into // consume the token internally, return the number of fresh data have been written to the buffer. // return None indicates that the stream write no data to the buffer. pub fn try_load_data_into

( &self, packet: &mut P, sid: StreamId, flow_limit: usize, tokens: usize, ) -> Result<(usize, bool), Signals> where P: BufMut + ?Sized, for<'a> (StreamFrame, &'a [Bytes]): Package

, { let origin_len = packet.remaining_mut(); let mut write = |(range, is_fresh, data, is_eos): StreamData| { let mut frame = StreamFrame::new(sid, range.start, (range.end - range.start) as usize); frame.set_eos_flag(is_eos); let strategy = frame.encoding_strategy(origin_len); frame.set_len_bit(strategy.len_bit()); packet.put_bytes(0, strategy.pre_padding()); (frame, data.as_slice()).dump(packet).unwrap(); (ContinuousData::len(data.as_slice()), is_fresh) }; let predicate = |offset| { StreamFrame::estimate_max_capacity(origin_len, sid, offset) .map(|capacity| tokens.min(capacity)) }; let mut sender = self.0.sender(); let sending_state = sender.as_mut().or(Err(Signals::empty()))?; // other(connection closed) match sending_state { Sender::Ready(s) => { let mut s: SendingSender = s.upgrade(); let (result, finished) = s .pick_up(predicate, flow_limit) .map(|payload @ (.., with_eos)| (Ok(write(payload)), with_eos)) .map_err(|s| (Err(s), false)) .unwrap_or_else(|x| x); if finished { *sending_state = Sender::DataSent(s.upgrade()); } else { *sending_state = Sender::Sending(s); } result } Sender::Sending(s) => { let (result, finished) = s .pick_up(predicate, flow_limit) .map(|payload @ (.., with_eos)| (Ok(write(payload)), with_eos)) .map_err(|s| (Err(s), false)) .unwrap_or_else(|x| x); if finished { *sending_state = Sender::DataSent(s.upgrade()); } result } Sender::DataSent(s) => s.pick_up(predicate, flow_limit).map(write), _ => Err(Signals::TRANSPORT), } } } impl Outgoing { /// Create a new instance of [`Outgoing`] pub fn new(sender: ArcSender) -> Self { Self(sender) } /// Update the sending window to `max_data_size` /// /// Callded when the [`MAX_STREAM_DATA frame`] belonging to the stream is received. /// /// [`MAX_STREAM_DATA frame`]: https://www.rfc-editor.org/rfc/rfc9000.html#name-max_stream_data-frames pub fn update_window(&self, max_stream_data: u64) { self.0.update_window(max_stream_data); } /// Called when the data sent to peer is acknowlwged. /// /// * `frame`: the stream frame that has been acknowledged. /// /// Return `true` if the stream is completely acknowledged, all data has been sent and received. /// /// [`SendBuf`]: crate::send::SendBuf pub fn on_data_acked(&self, frame: &StreamFrame) -> bool { let mut sender = self.0.sender(); let inner = sender.deref_mut(); if let Ok(sending_state) = inner { match sending_state { Sender::Ready(_) => { unreachable!("never send data before recv data"); } Sender::Sending(s) => { s.on_data_acked(frame); } Sender::DataSent(s) => { s.on_data_acked(frame); if s.is_all_rcvd() { qevent::event!(StreamStateUpdated { stream_id: frame.stream_id(), stream_type: frame.stream_id().dir(), old: GranularStreamStates::DataSent, new: GranularStreamStates::DataReceived, stream_side: StreamSide::Sending }); *sending_state = Sender::DataRcvd; return true; } } // ignore recv _ => {} } }; false } /// Called when the data sent to peer may lost. /// /// * `frame`: the stream frame that may be lost. pub fn may_loss_data(&self, frame: &StreamFrame) { let mut sender = self.0.sender(); let inner = sender.deref_mut(); if let Ok(sending_state) = inner { match sending_state { Sender::Ready(_) => { unreachable!("never send data before recv data"); } Sender::Sending(s) => { s.may_loss_data(frame); } Sender::DataSent(s) => { s.may_loss_data(frame); } // ignore loss _ => (), } }; } pub fn revise_max_stream_data(&self, zero_rtt_rejected: bool, max_stream_data: u64) { let mut sender = self.0.sender(); let inner = sender.deref_mut(); if let Ok(sending_state) = inner { match sending_state { Sender::Ready(s) => s.revise_max_stream_data(zero_rtt_rejected, max_stream_data), Sender::Sending(s) => s.revise_max_stream_data(zero_rtt_rejected, max_stream_data), Sender::DataSent(s) => s.revise_max_stream_data(zero_rtt_rejected, max_stream_data), _ => (), } }; } /// Called when the [`STOP_SENDING frame`] sent by the peer is received. /// /// If the stream has not been closed, the stream will be reset and then a [`RESET_STREAM frame`] will /// be sent to the peer to reset the peer with the `final_size`. /// In this case, the method will return the `final_size`. /// /// If the stream has closed, `None` will be returned, and the method will do nothing. /// /// [`STOP_SENDING frame`]: https://www.rfc-editor.org/rfc/rfc9000.html#name-stop_sending-frames /// [`STREAM_RESET frame`]: https://www.rfc-editor.org/rfc/rfc9000.html#name-reset_stream-frames pub fn be_stopped(&self, error_code: u64) -> Option { let mut sender = self.0.sender(); let inner = sender.deref_mut(); match inner { Ok(sending_state) => { // THINK: sending_state.stream_id() -> StreamId, sending_state.state() -> GranularStreamStates let (stream_id, old_state, final_size) = match sending_state { Sender::Ready(s) => { (s.stream_id(), GranularStreamStates::Ready, s.be_stopped()) } Sender::Sending(s) => { (s.stream_id(), GranularStreamStates::Send, s.be_stopped()) } Sender::DataSent(s) => ( s.stream_id(), GranularStreamStates::DataSent, s.be_stopped(), ), _ => return None, }; let reset = ResetStreamError::new( // TODO: many places in the codebase perform VarInt -> u64 -> VarInt conversion // which is redundant and may cause bugs, consider refactor call-chain. VarInt::from_u64(error_code).expect("app error code must not exceed 2^62"), VarInt::from_u64(final_size).expect("final size must not exceed 2^62"), ); qevent::event!(StreamStateUpdated { stream_id: stream_id.id(), stream_type: stream_id.dir(), old: old_state, new: GranularStreamStates::ResetReceived, stream_side: StreamSide::Sending }); *sending_state = Sender::ResetSent(reset); Some(final_size) } Err(_) => None, } } /// Called When the [`RESET_STREAM frame`] previously sent to the peer is acknowledged /// /// [`RESET_STREAM frame`]: https://www.rfc-editor.org/rfc/rfc9000.html#name-reset_stream-frames // TODO: stream id not from stream state, consider refactor. (many other places in qrecovery) pub fn on_reset_acked(&self, sid: StreamId) { let mut sender = self.0.sender(); let inner = sender.deref_mut(); if let Ok(sending_state) = inner { match sending_state { Sender::ResetSent(r) => { qevent::event!(StreamStateUpdated { stream_id: sid.id(), stream_type: sid.dir(), old: GranularStreamStates::ResetSent, new: GranularStreamStates::ResetReceived, stream_side: StreamSide::Sending }); *sending_state = Sender::ResetRcvd(*r); } Sender::ResetRcvd(..) => {} _ => unreachable!( "If no RESET_STREAM has been sent, how can there be a received acknowledgment?" ), } } } /// When a connection-level error occurs, all data streams must be notified. /// Their reading and writing should be terminated, accompanied the error of the connection. pub fn on_conn_error(&self, err: &QuicError) { let mut sender = self.0.sender(); let inner = sender.deref_mut(); match inner { Ok(sending_state) => match sending_state { Sender::Ready(s) => s.wake_all(), Sender::Sending(s) => s.wake_all(), Sender::DataSent(s) => s.wake_all(), _ => return, }, Err(_) => return, }; *inner = Err(err.clone()); } } ================================================ FILE: qrecovery/src/send/sender.rs ================================================ use std::{ ops::Range, sync::{Arc, Mutex, MutexGuard}, task::{Context, Poll, Waker}, }; use bytes::Bytes; use qbase::{ error::Error, frame::{ResetStreamError, ResetStreamFrame, StreamFrame, io::SendFrame}, net::tx::{ArcSendWakers, Signals}, sid::StreamId, varint::{VARINT_MAX, VarInt}, }; use qevent::{ RawInfo, quic::transport::{ DataMovedAdditionalInfo, GranularStreamStates, StreamDataLocation, StreamDataMoved, StreamSide, StreamStateUpdated, }, }; use super::sndbuf::SendBuf; use crate::streams::error::StreamError; fn log_reset_event(sid: StreamId, from_state: GranularStreamStates) { qevent::event!(StreamStateUpdated { stream_id: sid.id(), stream_type: sid.dir(), old: from_state, new: GranularStreamStates::ResetSent, stream_side: StreamSide::Sending }); } /// The "Ready" state represents a newly created stream that is able to accept data from the application. /// Stream data might be buffered in this state in preparation for sending. /// An implementation might choose to defer allocating a stream ID to a stream until it sends the first /// STREAM frame and enters this state, which can allow for better stream prioritization. #[derive(Debug)] pub struct ReadySender { stream_id: StreamId, sndbuf: SendBuf, flush_waker: Option, shutdown_waker: Option, broker: TX, tx_wakers: ArcSendWakers, writable_waker: Option, metrics: Option, } impl ReadySender { pub(super) fn new( stream_id: StreamId, buf_size: u64, broker: TX, tx_wakers: ArcSendWakers, metrics: Option, ) -> ReadySender { ReadySender { stream_id, sndbuf: SendBuf::with_capacity(buf_size), flush_waker: None, shutdown_waker: None, broker, tx_wakers, writable_waker: None, metrics, } } pub(super) fn stream_id(&self) -> StreamId { self.stream_id } // /// 非阻塞写,如果没有多余的发送缓冲区,将返回WouldBlock错误。 // /// 但什么时候可写,是没通知的,只能不断去尝试写,直到写入成功。 // /// 仅供展示学习 // #[allow(dead_code)] // fn write(&mut self, buf: &[u8]) -> io::Result { // if self.sndbuf.has_remaining_mut() { // self.tx_wakers.wake_all_by(Signals::WRITTEN); // self.sndbuf.write(Bytes::copy_from_slice(buf)); // Ok(buf.len()) // } else { // Err(io::ErrorKind::WouldBlock.into()) // } // } pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { if self.shutdown_waker.is_some() { return Poll::Ready(Err(StreamError::EosSent)); } if !self.sndbuf.has_remaining_mut() { self.writable_waker = Some(cx.waker().clone()); return Poll::Pending; } Poll::Ready(Ok(())) } pub(crate) fn write(&mut self, data: Bytes) -> Result<(), StreamError> { if self.shutdown_waker.is_some() { return Err(StreamError::EosSent); } let data_len = data.len() as u64; qevent::event!(StreamDataMoved { stream_id: self.stream_id, offset: self.sndbuf.written(), length: data_len, from: StreamDataLocation::Application, to: StreamDataLocation::Transport, raw: data.clone() }); // Update metrics when application writes data if let Some(metrics) = &self.metrics { metrics.new_pending(data_len); } self.tx_wakers.wake_all_by(Signals::WRITTEN); self.sndbuf.write(data); Ok(()) } pub(super) fn update_window(&mut self, max_stream_data: u64) { if max_stream_data > self.sndbuf.max_data() { if self.sndbuf.written() > self.sndbuf.max_data() { self.tx_wakers.wake_all_by(Signals::WRITTEN); } self.sndbuf.extend(max_stream_data); if self.sndbuf.has_remaining_mut() && let Some(waker) = self.writable_waker.take() { waker.wake(); } } } pub(super) fn revise_max_stream_data(&mut self, zero_rtt_rejected: bool, max_stream_data: u64) { if zero_rtt_rejected { self.sndbuf.forget_sent_state(); } self.update_window(max_stream_data); } pub(super) fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<()> { if self.sndbuf.is_all_rcvd() { Poll::Ready(()) } else { self.flush_waker = Some(cx.waker().clone()); Poll::Pending } } pub(super) fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<()> { // 就算当前没有流量窗口,也可以单独发送一个空StreamFrame,携带fin bit self.tx_wakers.wake_all_by(Signals::TRANSPORT); self.shutdown_waker = Some(cx.waker().clone()); Poll::Pending } pub(super) fn wake_all(&mut self) { if let Some(waker) = self.writable_waker.take() { waker.wake(); } if let Some(waker) = self.flush_waker.take() { waker.wake(); } if let Some(waker) = self.shutdown_waker.take() { waker.wake(); } } pub(super) fn be_stopped(&mut self) -> u64 { self.wake_all(); // ReadyState: no data is sent debug_assert_eq!(self.sndbuf.sent(), 0); self.sndbuf.sent() } } /// 状态升级,ReaderSender => SendingSender impl ReadySender { pub(super) fn upgrade(&mut self) -> SendingSender { qevent::event!(StreamStateUpdated { stream_id: self.stream_id, stream_type: self.stream_id.dir(), old: GranularStreamStates::Ready, new: GranularStreamStates::Send, stream_side: StreamSide::Sending }); SendingSender { stream_id: self.stream_id, sndbuf: std::mem::take(&mut self.sndbuf), flush_waker: self.flush_waker.take(), shutdown_waker: self.shutdown_waker.take(), broker: self.broker.clone(), tx_wakers: self.tx_wakers.clone(), writable_waker: self.writable_waker.take(), metrics: self.metrics.clone(), } } } impl ReadySender where TX: SendFrame, { /// 应用层使用,取消发送流 pub(super) fn cancel(&mut self, err_code: u64) -> ResetStreamError { let final_size = self.sndbuf.sent(); let reset_stream_err = ResetStreamError::new( VarInt::from_u64(err_code).expect("app error code must not exceed 2^62"), VarInt::from_u64(final_size).expect("final size must not exceed 2^62"), ); tracing::debug!( target: "quic", "{} is canceled by app layer, with error code {err_code}", self.stream_id ); self.broker .send_frame([reset_stream_err.combine(self.stream_id)]); log_reset_event(self.stream_id, GranularStreamStates::Ready); reset_stream_err } } #[derive(Debug)] pub struct SendingSender { stream_id: StreamId, sndbuf: SendBuf, flush_waker: Option, shutdown_waker: Option, broker: TX, tx_wakers: ArcSendWakers, writable_waker: Option, metrics: Option, } pub type StreamData<'s> = (Range, bool, Vec, bool); impl SendingSender { pub(super) fn stream_id(&self) -> StreamId { self.stream_id } pub(super) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { if self.shutdown_waker.is_some() { return Poll::Ready(Err(StreamError::EosSent)); } if !self.sndbuf.has_remaining_mut() { self.writable_waker = Some(cx.waker().clone()); return Poll::Pending; } Poll::Ready(Ok(())) } pub(super) fn write(&mut self, data: Bytes) -> Result<(), StreamError> { if self.shutdown_waker.is_some() { return Err(StreamError::EosSent); } let data_len = data.len() as u64; qevent::event!(StreamDataMoved { stream_id: self.stream_id, offset: self.sndbuf.written(), length: data_len, from: StreamDataLocation::Application, to: StreamDataLocation::Transport, raw: data.clone() }); // Update metrics when application writes data if let Some(metrics) = &self.metrics { metrics.new_pending(data_len); } self.tx_wakers.wake_all_by(Signals::WRITTEN); self.sndbuf.write(data); Ok(()) } /// 传输层使用 pub(super) fn update_window(&mut self, max_stream_data: u64) { if max_stream_data > self.sndbuf.max_data() { if self.sndbuf.written() > self.sndbuf.max_data() { self.tx_wakers.wake_all_by(Signals::WRITTEN); } self.sndbuf.extend(max_stream_data); if self.sndbuf.has_remaining_mut() && let Some(waker) = self.writable_waker.take() { waker.wake(); } } } pub(super) fn pick_up

( &'_ mut self, predicate: P, flow_limit: usize, ) -> Result, Signals> where P: Fn(u64) -> Option, { let total_size = self.total_size(); let sent = self.sndbuf.sent(); self.sndbuf .pick_up(&predicate, flow_limit) .map(|(range, is_fresh, data)| { (range.clone(), is_fresh, data, Some(range.end) == total_size) }) .or_else(|signals| { if total_size == Some(sent) { predicate(sent).ok_or(signals | Signals::CONGESTION)?; Ok((sent..sent, false, Vec::new(), true)) } else { Err(signals) } }) .map(|(range, is_fresh, data, is_eos)| { qevent::event!(StreamDataMoved { stream_id: self.stream_id, offset: range.start, length: range.end - range.start, from: StreamDataLocation::Transport, to: StreamDataLocation::Network, ?additional_info: is_eos.then_some(DataMovedAdditionalInfo::FinSet), raw: RawInfo { data : data.as_slice() } }); (range, is_fresh, data, is_eos) }) } pub(super) fn on_data_acked(&mut self, frame: &StreamFrame) { self.sndbuf.on_data_acked(&frame.range()); if self.sndbuf.is_all_rcvd() && let Some(waker) = self.flush_waker.take() { waker.wake(); } } pub(super) fn may_loss_data(&mut self, frame: &StreamFrame) { self.tx_wakers.wake_all_by(Signals::TRANSPORT); self.sndbuf.may_loss_data(&frame.range()) } pub(super) fn revise_max_stream_data(&mut self, zero_rtt_rejected: bool, max_stream_data: u64) { if zero_rtt_rejected { self.sndbuf.forget_sent_state(); } self.update_window(max_stream_data); } pub(super) fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<()> { if self.sndbuf.is_all_rcvd() { Poll::Ready(()) } else { self.flush_waker = Some(cx.waker().clone()); Poll::Pending } } pub(super) fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<()> { self.tx_wakers.wake_all_by(Signals::TRANSPORT); self.shutdown_waker = Some(cx.waker().clone()); Poll::Pending } pub(super) fn total_size(&self) -> Option { if self.shutdown_waker.is_some() { Some(self.sndbuf.written()) } else { None } } pub(super) fn wake_all(&mut self) { if let Some(waker) = self.writable_waker.take() { waker.wake(); } if let Some(waker) = self.flush_waker.take() { waker.wake(); } if let Some(waker) = self.shutdown_waker.take() { waker.wake(); } } /// 传输层使用 pub(super) fn be_stopped(&mut self) -> u64 { self.wake_all(); // Actually, these remaining data is not acked and will not be acked self.sndbuf.sent() } } impl SendingSender { pub(super) fn upgrade(&mut self) -> DataSentSender { qevent::event!(StreamStateUpdated { stream_id: self.stream_id, stream_type: self.stream_id.dir(), old: GranularStreamStates::Send, new: GranularStreamStates::DataSent, stream_side: StreamSide::Sending }); DataSentSender { stream_id: self.stream_id, sndbuf: std::mem::take(&mut self.sndbuf), flush_waker: self.flush_waker.take(), shutdown_waker: self.shutdown_waker.take(), broker: self.broker.clone(), tx_wakers: self.tx_wakers.clone(), fin_state: FinState::Sent, } } } impl SendingSender where TX: SendFrame, { pub(super) fn cancel(&mut self, err_code: u64) -> ResetStreamError { let final_size = self.sndbuf.sent(); let reset_stream_err = ResetStreamError::new( VarInt::from_u64(err_code).expect("app error code must not exceed 2^62"), VarInt::from_u64(final_size).expect("final size must not exceed 2^62"), ); tracing::debug!( target: "quic", "{} is canceled by app layer, with error code {err_code}", self.stream_id ); self.broker .send_frame([reset_stream_err.combine(self.stream_id)]); log_reset_event(self.stream_id, GranularStreamStates::Send); reset_stream_err } } #[derive(Debug, PartialEq)] enum FinState { Sent, Lost, Rcvd, } #[derive(Debug)] pub struct DataSentSender { stream_id: StreamId, sndbuf: SendBuf, flush_waker: Option, shutdown_waker: Option, broker: TX, // retran/fin tx_wakers: ArcSendWakers, fin_state: FinState, } impl DataSentSender { pub(super) fn stream_id(&self) -> StreamId { self.stream_id } pub(super) fn pick_up

( &'_ mut self, predicate: P, flow_limit: usize, ) -> Result, Signals> where P: Fn(u64) -> Option, { let total_size = self.sndbuf.written(); self.sndbuf .pick_up(&predicate, flow_limit) .map(|(range, is_fresh, data)| (range.clone(), is_fresh, data, range.end == total_size)) .or_else(|signals| { if self.fin_state == FinState::Lost { self.fin_state = FinState::Sent; Ok((total_size..total_size, false, vec![], true)) } else { Err(signals) } }) .map(|(range, is_fresh, data, is_eos)| { qevent::event!(StreamDataMoved { stream_id: self.stream_id, offset: range.start, length: range.end - range.start, from: StreamDataLocation::Transport, to: StreamDataLocation::Network, ?additional_info: is_eos.then_some(DataMovedAdditionalInfo::FinSet), raw: RawInfo { data : data.as_slice() } },); (range, is_fresh, data, is_eos) }) } pub(super) fn on_data_acked(&mut self, frame: &StreamFrame) { self.sndbuf.on_data_acked(&frame.range()); if frame.is_fin() { self.fin_state = FinState::Rcvd; } if self.is_all_rcvd() { if let Some(waker) = self.flush_waker.take() { waker.wake(); } if let Some(waker) = self.shutdown_waker.take() { waker.wake(); } } } pub(super) fn is_all_rcvd(&self) -> bool { self.sndbuf.is_all_rcvd() && self.fin_state == FinState::Rcvd } pub(super) fn may_loss_data(&mut self, frame: &StreamFrame) { self.tx_wakers.wake_all_by(Signals::TRANSPORT); if frame.is_fin() && self.fin_state != FinState::Rcvd { self.fin_state = FinState::Lost; } self.sndbuf.may_loss_data(&frame.range()) } pub(super) fn revise_max_stream_data(&mut self, zero_rtt_rejected: bool, max_stream_data: u64) { if zero_rtt_rejected { self.sndbuf.forget_sent_state(); } self.sndbuf.extend(max_stream_data); } pub(super) fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<()> { debug_assert!(!self.is_all_rcvd()); self.flush_waker = Some(cx.waker().clone()); Poll::Pending } pub(super) fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<()> { debug_assert!(!self.is_all_rcvd()); self.tx_wakers.wake_all_by(Signals::TRANSPORT); self.shutdown_waker = Some(cx.waker().clone()); Poll::Pending } pub(super) fn wake_all(&mut self) { if let Some(waker) = self.flush_waker.take() { waker.wake(); } if let Some(waker) = self.shutdown_waker.take() { waker.wake(); } } pub(super) fn be_stopped(&mut self) -> u64 { self.wake_all(); // Actually, these remaining data is not acked and will not be acked self.sndbuf.written() } } impl DataSentSender where TX: SendFrame, { pub(super) fn cancel(&mut self, err_code: u64) -> ResetStreamError { let final_size = self.sndbuf.sent(); let reset_stream_err = ResetStreamError::new( VarInt::from_u64(err_code).expect("app error code must not exceed 2^62"), VarInt::from_u64(final_size).expect("final size must not exceed 2^62"), ); tracing::debug!( target: "quic", "{} is canceled by app layer, with error code {err_code}", self.stream_id ); self.broker .send_frame([reset_stream_err.combine(self.stream_id)]); log_reset_event(self.stream_id, GranularStreamStates::DataSent); reset_stream_err } } #[derive(Debug)] pub(super) enum Sender { Ready(ReadySender), Sending(SendingSender), DataSent(DataSentSender), DataRcvd, ResetSent(ResetStreamError), ResetRcvd(ResetStreamError), } impl Sender { pub fn new( stream_id: StreamId, buf_size: u64, broker: TX, tx_wakers: ArcSendWakers, metrics: Option, ) -> Self { Sender::Ready(ReadySender::new( stream_id, buf_size, broker, tx_wakers, metrics, )) } } /// The internal state representations of [`Outgoing`] and [`Writer`]. /// /// For the application layer, this struct is represented as [`Writer`]. The application can use it to /// write data to the stream, or reset the stream. /// /// For the protocol layer, this struct is represented as [`Outgoing`]. The protocol layer uses it to /// manage the status of the `Sender`, sends data(stream frame),reset frames and other frames to peer. /// /// [`Outgoing`]: super::Outgoing /// [`Writer`]: super::Writer #[derive(Debug, Clone)] pub struct ArcSender(Arc, Error>>>); impl ArcSender { #[doc(hidden)] pub(crate) fn new( stream_id: StreamId, buf_size: u64, broker: TX, tx_wakers: ArcSendWakers, metrics: Option, ) -> Self { ArcSender(Arc::new(Mutex::new(Ok(Sender::new( stream_id, buf_size, broker, tx_wakers, metrics, ))))) } } impl ArcSender { // update send window for opened stream. pub(crate) fn update_window(&self, max_stream_data: u64) { assert!(max_stream_data <= VARINT_MAX); match self.sender().as_mut() { Ok(Sender::Ready(s)) => s.update_window(max_stream_data), Ok(Sender::Sending(s)) => s.update_window(max_stream_data), _ => {} } } pub(super) fn sender(&self) -> MutexGuard<'_, Result, Error>> { self.0.lock().unwrap() } } #[cfg(test)] mod tests { use qbase::{role::Role, sid::Dir}; use super::*; #[derive(Debug, Default, Clone)] struct MockBroker(Arc>>); impl SendFrame for MockBroker { fn send_frame>(&self, iter: I) { self.0.lock().unwrap().extend(iter); } } fn create_test_sender() -> ArcSender { let stream_id = StreamId::new(Role::Client, Dir::Bi, 0); let buf_size = 1000; let broker = MockBroker::default(); ArcSender::new(stream_id, buf_size, broker, Default::default(), None) } #[test] fn test_ready_sender_new() { let stream_id = StreamId::new(Role::Client, Dir::Bi, 0); let buf_size = 1000; let broker = MockBroker::default(); let sender = ReadySender::new(stream_id, buf_size, broker, Default::default(), None); assert_eq!(sender.stream_id, stream_id); assert_eq!(sender.sndbuf.max_data(), buf_size); assert!(sender.flush_waker.is_none()); assert!(sender.shutdown_waker.is_none()); assert!(sender.writable_waker.is_none()); } #[test] fn test_ready_sender_write() { let stream_id = StreamId::new(Role::Client, Dir::Bi, 0); let buf_size = 10; let broker = MockBroker::default(); let mut sender = ReadySender::new(stream_id, buf_size, broker, Default::default(), None); let data = Bytes::from_static(b"hello"); let result = sender.write(data); assert!(result.is_ok()); // Test write when buffer is full let large_data = Bytes::from_static(include_bytes!("./sender.rs")); let result = sender.write(large_data); assert!(result.is_ok()); } #[tokio::test] async fn test_ready_sender_poll_write() { let stream_id = StreamId::new(Role::Client, Dir::Bi, 0); let buf_size = 10; let broker = MockBroker::default(); let mut sender = ReadySender::new(stream_id, buf_size, broker, Default::default(), None); let data = Bytes::from_static(b"test"); assert!(matches!(sender.write(data.clone()), Ok(()))); // Test poll_write when buffer is full sender.sndbuf.forget_sent_state(); let mut cx = Context::from_waker(futures::task::noop_waker_ref()); let result = sender.poll_ready(&mut cx); assert!(result.is_pending()); } #[test] fn test_sender_state_transitions() { let stream_id = StreamId::new(Role::Client, Dir::Bi, 0); let buf_size = 1000; let broker = MockBroker::default(); let mut ready = ReadySender::new(stream_id, buf_size, broker, Default::default(), None); // Test transition to SendingSender let mut sending = ready.upgrade(); assert_eq!(sending.stream_id, stream_id); assert_eq!(sending.sndbuf.max_data(), buf_size); // Test transition to DataSentSender let data_sent = sending.upgrade(); assert_eq!(data_sent.stream_id, stream_id); assert!(data_sent.fin_state == FinState::Sent); } #[test] fn test_arc_sender() { let sender = create_test_sender(); // Test buffer size revision sender.update_window(2000); // Test sender lock access let guard = sender.sender(); assert!(guard.is_ok()); } #[test] fn test_data_sent_sender() { let stream_id = StreamId::new(Role::Client, Dir::Bi, 0); let buf_size = 1000; let broker = MockBroker::default(); let mut sender = DataSentSender { stream_id, sndbuf: SendBuf::with_capacity(buf_size), flush_waker: None, shutdown_waker: None, broker, tx_wakers: Default::default(), fin_state: FinState::Sent, }; // Test pick_up with empty buffer let predicate = |_| Some(100); let result = sender.pick_up(predicate, 1000); assert!(result.is_err()); } #[tokio::test] async fn test_data_sent_sender_polling() { let stream_id = StreamId::new(Role::Client, Dir::Bi, 0); let buf_size = 1000; let broker = MockBroker::default(); let mut sender = DataSentSender { stream_id, sndbuf: SendBuf::with_capacity(buf_size), flush_waker: None, shutdown_waker: None, broker, tx_wakers: Default::default(), fin_state: FinState::Sent, }; let mut cx = Context::from_waker(futures::task::noop_waker_ref()); // Test poll_flush when all data received let result = sender.poll_flush(&mut cx); assert!(result.is_pending()); // Test poll_shutdown when all data received let _ = sender.poll_shutdown(&mut cx); assert!(sender.shutdown_waker.is_some()); } } ================================================ FILE: qrecovery/src/send/sndbuf.rs ================================================ use std::{ cmp::Ordering, collections::VecDeque, fmt::{Debug, Display}, ops::Range, }; use bytes::Bytes; use qbase::net::tx::Signals; /// To indicate the state of a data segment, it is colored. #[derive(Default, PartialEq, Eq, Clone, Copy, Debug)] enum Color { #[default] Pending, Flighting, Recved, Lost, } impl Color { fn prefix(&self) -> u64 { match self { Self::Pending => 0, Self::Flighting => 0b01 << 62, Self::Lost => 0b10 << 62, Self::Recved => 0b11 << 62, } } } #[derive(PartialEq, PartialOrd, Eq, Clone, Copy)] struct State(u64); impl State { #[allow(dead_code)] const PREFIX: u64 = 0b11 << 62; const SUFFIX: u64 = u64::MAX >> 2; fn encode(pos: u64, color: Color) -> Self { Self(color.prefix() | pos) } fn offset(&self) -> u64 { self.0 & Self::SUFFIX } fn color(&self) -> Color { match self.0 >> 62 { 0b00 => Color::Pending, 0b01 => Color::Flighting, 0b10 => Color::Lost, 0b11 => Color::Recved, _ => unreachable!("impossible"), } } fn set_color(&mut self, value: Color) { self.0 = (self.0 & Self::SUFFIX) | value.prefix(); } fn decode(&self) -> (u64, Color) { (self.offset(), self.color()) } } impl Display for State { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "[{:?}: {:?}]", self.offset(), self.color()) } } impl Debug for State { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "[{:?}: {:?}]", self.offset(), self.color()) } } /** * Self.0 意思是区间状态信息,它由一段VecDeque表示; * VecDeque中的每个元素是State,其中低62位是offset,代高2位是颜色,代表着到下一个State::offset的区间颜色。 * Self.1代表着结尾位置,不包括Self.1; VecDeque中最后一个元素代表的状态区间,是最后一个State::offset到Self.1的区间。 * 之所以采用这种数据结构,是考虑到CPU缓存行有64字节,可一次处理8段数据,足够很多小流传输了,很高效。 * 即便是大流,其中相同状态的合并起来,各种不同状态的区间也不会很多,相比于链表、跳表、线段树等结构依然很高效。 */ #[derive(Default, Debug)] struct BufMap(VecDeque, u64); impl BufMap { fn size(&self) -> u64 { self.1 } // 追加写数据 fn extend_to(&mut self, pos: u64) -> u64 { debug_assert!(pos < (1 << 62), "pos({pos}) overflow",); debug_assert!(pos >= self.size(), "pos({pos}) less than {}", self.size()); if pos > self.size() { let back = self.0.back(); match back { Some(s) if s.color() == Color::Pending => {} _ => self.0.push_back(State::encode(self.size(), Color::Pending)), }; self.1 = pos; } self.size() } fn sent(&self) -> u64 { match self.0.back() { Some(s) if s.color() == Color::Pending => s.offset(), _ => self.size(), } } // 挑选Lost/Pending的数据发送。越靠前的数据,越高优先级发送; // 丢包重传的数据,相比于Pending数据更靠前,因此具有更高的优先级。 fn pick

( &mut self, predicate: P, flow_limit: usize, send_window_size: u64, ) -> Result<(Range, bool), Signals> where P: Fn(u64) -> Option, { let mut signals = Signals::WRITTEN | Signals::TRANSPORT; // 先找到第一个能发送的区间,并将该区间染成Flight,返回原State self.0 .iter_mut() .enumerate() .find(|(.., state)| { if state.offset() >= send_window_size { // 如果offset已经超过了发送窗口大小,说明该区间不能被发送 signals |= Signals::FLOW_CONTROL; return false; } // 选择Pending的区间(如果流控允许),或者选择Lost的区间 match state.color() { Color::Pending if flow_limit != 0 => return true, Color::Pending => { signals &= !Signals::WRITTEN; signals |= Signals::FLOW_CONTROL } Color::Lost => return true, _ => {} } false }) .and_then(|(idx, state)| { // 如果区间的offset不符合predicate,就不发送这一段 // 其实选择到的第一段数据数据的offset已经是最小的了,如果最小的offset都不能发送,那么后面片段肯定也不能发送 let Some(available) = predicate(state.offset()) else { signals |= Signals::CONGESTION; return None; }; let allowance = if state.color() == Color::Lost { // 重传不受流量控制限制 available } else { available.min(flow_limit) }; Some((idx, allowance, state)) }) .map(|(index, allowance, state)| { let origin_state = *state; // 此处能归还self.0的可变借用 state.set_color(Color::Flighting); (index, origin_state, allowance) }) .map(|(index, origin_state, allowance)| { // 找到了一个合适的区间来发送,但检查区间长度是否足够,过长的话,还要拆区间一分为二 let (start, color) = origin_state.decode(); let mut end = self .0 .get(index + 1) .map(|s| s.offset()) .unwrap_or(self.size()) .min(send_window_size); let mut i = self.same_before(index, Color::Flighting); if start + (allowance as u64) < end { end = start + allowance as u64; if i < index { // 一分为二,如果本来有合并删除的区间,直接旧state回收复用 *self.0.get_mut(i + 1).unwrap() = State::encode(end, color); } else { self.0.insert(i + 1, State::encode(end, color)); } i += 1; } else { // TODO: 这里有个优化,如果紧跟着下一个是Lost或者Pending,可以连起来 self.merge_after(index, Color::Flighting); } // i仍然小于index,说明有需要删除直到index的state,意味着前向合并请求,一次drain即可 if i < index { self.0.drain(i + 1..=index); } (start..end, color == Color::Pending) }) .ok_or(signals) } // 收到了ack确认,确认的数据不需再发送,对于头部连续确认的数据,就可以删掉。 // 寻找到ack区间所在的位置,将这些区间都染成Recved,然后检查前后是否有需要合并的区间,合并之。 // ack区间,不能ack到Pending的数据,因为Pending的数据尚未发送过,当然无法被ack。 fn ack_rcvd(&mut self, range: &Range) { let pos = self.0.binary_search_by(|s| s.offset().cmp(&range.start)); let (mut drain_start, need_insert_at_start, mut drain_end, mut pre_color) = match pos { Ok(idx) => { let s = self.0.get_mut(idx).unwrap(); let pre_color = s.color(); debug_assert!( pre_color != Color::Pending, "Recved Range({:?}) covered Pending part from {}", range, s.offset() ); s.set_color(Color::Recved); ( self.same_before(idx, Color::Recved) + 1, false, idx + 1, pre_color, ) } Err(idx) => { if idx == 0 { (0, false, 0, Color::Recved) } else { let s = self.0.get(idx - 1).unwrap(); let pre_color = s.color(); debug_assert!( pre_color != Color::Pending, "Recved Range({:?}) covered Pending part from {}", range, s.offset() ); (idx, pre_color != Color::Recved, idx, pre_color) } } }; let mut need_insert_at_end = false; loop { let entry = self.0.get(drain_end); match entry { Some(s) => match s.offset().cmp(&range.end) { Ordering::Less => { debug_assert!( s.color() != Color::Pending, "Recved Range({:?}) covered Pending parts from {}", range, s.offset() ); drain_end += 1; pre_color = s.color(); } Ordering::Equal => { // TODO: nightly版本, overflowing_sub 改为unchecked_sub更好 drain_end = self .same_after(drain_end.overflowing_sub(1).0, Color::Recved) .overflowing_add(1) .0; break; } Ordering::Greater => { need_insert_at_end = pre_color != Color::Recved; break; } }, None => { debug_assert!( range.end <= self.size(), "Recved Range({:?}) over {}", range, self.size() ); need_insert_at_end = range.end < self.size() && pre_color != Color::Recved; break; } } } if need_insert_at_start { if drain_start < drain_end { *self.0.get_mut(drain_start).unwrap() = State::encode(range.start, Color::Recved); } else { self.0 .insert(drain_start, State::encode(range.start, Color::Recved)); } drain_start += 1; } if need_insert_at_end { if drain_start < drain_end { *self.0.get_mut(drain_start).unwrap() = State::encode(range.end, pre_color); } else { self.0 .insert(drain_start, State::encode(range.end, pre_color)); } drain_start += 1; } if drain_start < drain_end { self.0.drain(drain_start..drain_end); } } // 寻找第一个不是Recved的位置,意味着之前的数据都已经被确认接收, // 发送缓冲区可以移动到该位置,以让发送缓冲区腾出更多空间 fn shift(&mut self) -> u64 { loop { let entry = self.0.front(); match entry { Some(s) if s.color() == Color::Recved => _ = self.0.pop_front(), Some(s) => return s.offset(), None => return self.size(), } } } // 判定某部分数据丢失,但不一定真的丢失,判定可能有误;丢失的数据需要优先重传。 // 寻找到丢失区间覆盖的范围,其中若遇到Recved的区间,则忽略;只有Flighting/Lost的才可以丢失。 // 然后检查Lost区间前后是否有需要合并的区间,合并之。 // 同样地,Lost区间不能覆盖Pending的数据,因为Pending的数据尚未发送过,无法丢失。 fn may_loss(&mut self, range: &Range) { let pos = self.0.binary_search_by(|s| s.offset().cmp(&range.start)); let (mut drain_start, need_insert_at_start, mut drain_end, mut pre_color) = match pos { Ok(idx) => { let s = self.0.get_mut(idx).unwrap(); debug_assert!( s.color() != Color::Pending, "Lost Range({:?}) covered Pending parts from {}", range, s.offset() ); if s.color() == Color::Recved { // 如果是Recved,那就不需要在前面插入了,直接往后探索 self.may_lost_from(idx + 1, range.end); return; } let pre_color = s.color(); let mut drain_start = idx; if pre_color == Color::Flighting { s.set_color(Color::Lost); // 只有变化了,才会向前寻找同为Lost,寻求合并 // 如果已经是Lost了,那前面的肯定是无法合并的非Lost状态 drain_start = self.same_before(idx, Color::Lost) + 1; } else { // 如果是lost,那这一段状态不需要改变,继续探索下一段需不需要改变 // 如果下一段还是Lost,那下一段可以删掉,往后合并Lost drain_start += 1; } // 肯定不需要在前面插入了,从drain_start开始往后探索,pre_color是当前状态 (drain_start, false, idx + 1, pre_color) } Err(idx) => { if idx == 0 { // 之前的数据都是recved,前面不再需要插入 // 表示从0往后,要尝试变为Lost,就完事儿了 self.may_lost_from(idx, range.end); return; } else { let s = self.0.get(idx - 1).unwrap(); let pre_color = s.color(); debug_assert!( pre_color != Color::Pending, "Lost Range({:?}) covered Pending parts from {}", range, s.offset() ); if pre_color == Color::Recved { // 另有安排,直接调用,lost_from(idx, range.end); self.may_lost_from(idx, range.end); return; } (idx, pre_color == Color::Flighting, idx, pre_color) } } }; let mut need_insert_at_end = false; loop { // 从drain_end位置的entry开始遍历,看其是否存在,存在看其是否仍在Lost的range区间里 let entry = self.0.get(drain_end); match entry { Some(s) => match s.offset().cmp(&range.end) { Ordering::Less => { // 以s.offset开头的区间,仍在Lost的range区间里 debug_assert!( s.color() != Color::Pending, "Lost Range({:?}) covered Pending parts from {}", range, s.offset() ); if s.color() == Color::Recved { // s是recved,那就s的下一段到range.end都是丢失的,相当于独立的may_lost区间处理 // 接下来只需处理drain_end之前的操作即可 self.may_lost_from(drain_end + 1, range.end); break; } else { // s是Lost/Flighting,那就将s染成Lost,继续往后探索 drain_end += 1; pre_color = s.color(); } } Ordering::Equal => { // s之前的是Lost,从上一个检查后续连续lost状态的有多少个 drain_end = self .same_after(drain_end.overflowing_sub(1).0, Color::Lost) .overflowing_add(1) .0; break; } Ordering::Greater => { // s的offset大于range.end,说明s之后的区间都不在Lost的范围内 // s的前一个是Flighting,它要一分为二,前部分为Lost,后部分为Flighting need_insert_at_end = pre_color == Color::Flighting; break; } }, None => { // 找不到,说明到最后一段了 debug_assert!( range.end <= self.size(), "Lost Range({:?}) over {}", range, self.size() ); // 如果上一段的color是Flighting,它要一分为二,到range.end的部分为Lost,后续部分为Flighting need_insert_at_end = range.end < self.size() && pre_color == Color::Flighting; break; } }; } if need_insert_at_start { if drain_start < drain_end { *self.0.get_mut(drain_start).unwrap() = State::encode(range.start, Color::Lost); } else { self.0 .insert(drain_start, State::encode(range.start, Color::Lost)); } drain_start += 1; } if need_insert_at_end { if drain_start < drain_end { *self.0.get_mut(drain_start).unwrap() = State::encode(range.end, pre_color); } else { self.0 .insert(drain_start, State::encode(range.end, pre_color)); } drain_start += 1; } if drain_start < drain_end { self.0.drain(drain_start..drain_end); } } fn resend_flighting(&mut self) { for state in self.0.iter_mut() { if state.color() == Color::Flighting { state.set_color(Color::Lost); } } } } impl BufMap { fn same_before(&self, mut index: usize, color: Color) -> usize { loop { let pre = index.overflowing_sub(1).0; match self.0.get(pre) { Some(s) if s.color() == color => index = pre, _ => break, } } index } fn same_after(&self, mut index: usize, color: Color) -> usize { loop { let next = index.overflowing_add(1).0; match self.0.get(next) { Some(s) if s.color() == color => index = next, _ => break, } } index } fn merge_after(&mut self, index: usize, color: Color) { let same_after = self.same_after(index, color); if index < same_after { self.0.drain(index + 1..=same_after); } } // lost的辅助函数,将idx_start位置的变为Lost,然后向后继续判定丢失 fn may_lost_from(&mut self, mut idx_start: usize, end: u64) { let mut idx = idx_start; let mut pre_color = Color::Recved; let mut need_insert_at_end = false; loop { let entry = self.0.get_mut(idx); match entry { Some(s) => match s.offset().cmp(&end) { Ordering::Less => { debug_assert!( s.color() != Color::Pending, "Lost Range.end({end}) covered Pending parts from {}", s.offset() ); pre_color = s.color(); if s.color() == Color::Recved { // 另有安排,直接调用,lost_from(idx, range.end); self.may_lost_from(idx + 1, end); break; } else { s.set_color(Color::Lost); idx += 1; } } Ordering::Equal => { idx = self .same_after(idx.overflowing_sub(1).0, Color::Lost) .overflowing_add(1) .0; break; } Ordering::Greater => { need_insert_at_end = pre_color == Color::Flighting; break; } }, None => { debug_assert!( end <= self.size(), "Lost Range.end({end}) over {}", self.size() ); need_insert_at_end = end < self.size() && pre_color == Color::Flighting; break; } } } if need_insert_at_end { if idx_start + 1 < idx { *self.0.get_mut(idx_start + 1).unwrap() = State::encode(end, pre_color); } else { self.0.insert(idx_start + 1, State::encode(end, pre_color)); } idx_start += 1; } if idx_start + 1 < idx { self.0.drain(idx_start + 1..idx); } } } /// Data to be reliably sent to the peer will first be cached in [`SendBuf`]. /// /// SendBuf will record the status of data that has been or has not been sent. /// /// The transport layer needs to notify that the data it has sent is confirmed([`on_data_acked`]) or lost /// ([`may_loss_data`]), to uopate the state of [`SendBuf`]. /// /// The transport layer can [`pick_up`] a piece of data that needs to be sent. The data may be new data, /// or old data that has been sent but has not been acknowledged. /// /// The data picked up may not continuous, the [`receive buffer`] will assemble the data into continuous before /// passing them to the application layer. /// /// [`pick_up`]: SendBuf::pick_up /// [`on_data_acked`]: SendBuf::on_data_acked /// [`may_loss_data`]: SendBuf::may_loss_data /// [`receive buffer`]: crate::recv::RecvBuf #[derive(Default, Debug)] pub struct SendBuf { offset: u64, // 写入数据的队列,与接收队列不同的是,每一段数据都是前后连续的 data: VecDeque, // 对BufMap::size的限制 max_data: u64, state: BufMap, } impl SendBuf { /// Create a new [`SendBuf`] with the given size. pub fn with_capacity(capacity: u64) -> Self { Self { offset: 0, data: VecDeque::new(), max_data: capacity, state: BufMap::default(), } } /// Write data to the [`SendBuf`]. /// /// When [`SendBuf`] has buffered [`Self::max_data`] amount of data, /// no more data should be written. pub fn write(&mut self, data: Bytes) { // debug_assert!(self.remaining_mut() > 0, "Sendbuf buffers excess data"); if !data.is_empty() { self.state .extend_to((self.written() + data.len() as u64).min(self.max_data)); self.data.push_back(data); } } /// The maximum amount of data that can be sent in the [`SendBuf`]. /// /// For [`DataStreams`], this is the flow control of the stream. /// /// For [`CryptoStream`], there should be no restrictions. /// /// [`DataStreams`]: crate::streams::DataStreams /// [`CryptoStream`]: crate::crypto::CryptoStream pub fn max_data(&self) -> u64 { self.max_data } /// Forget all state of data that has been sent. /// /// This is usually called when the zero rtt is rejected by server. /// /// All data sent should be resent as fresh data, /// and for the subsequent correction of max_data, max_data is also reset to 0. pub fn forget_sent_state(&mut self) { self.state = BufMap::default(); self.max_data = 0; } /// Extend the [`Self::max_data`] limit. pub fn extend(&mut self, max_data: u64) { debug_assert!(max_data >= self.max_data, "Cannot reduce sndbuf size"); self.max_data = max_data; self.state.extend_to(self.written().min(self.max_data)); } /// Return whether the [`SendBuf`] is empty. pub fn is_empty(&self) -> bool { self.data.is_empty() } /// Return the total length of data that has been cumulatively written to the send buffer in the past. /// /// Note that data the returned size may be larger than [`Self::max_data`]. pub fn written(&self) -> u64 { self.offset + self.data.iter().map(|data| data.len() as u64).sum::() } /// Return the number of bytes that have been sent. pub fn sent(&self) -> u64 { self.state.sent() } /// Return the number of bytes that can be written without exceeding the [`Self::max_data`] limit. /// /// To prevent [`SendBuf`] from buffering excessive data, data should not be written when this method returns 0. pub fn remaining_mut(&self) -> u64 { self.max_data().saturating_sub(self.written()) } /// Return whether there is remaining space to write data without exceeding the [`Self::max_data`] limit. /// /// When this method returns false, data should not be written. pub fn has_remaining_mut(&self) -> bool { self.max_data() > self.written() } // 无需close:不在写入即可,具体到某个状态,才有close // 无需reset:状态转化间,需要reset,而Sender上下文直接释放即可 // 无需clean:Sender上下文直接释放即可, } type Data<'s> = (Range, bool, Vec); impl SendBuf { /// Pick up data that can be sent. /// /// The selected data is subject to `predicate`, which accepts the starting position of the /// data segment, returns whether the segment could be sent and the maximum amount of bytes could /// take. /// /// If the data picked up is new (never sent before), how much data can be sent is also subject /// to `flow_limit`. /// /// ### Returns /// `None` if there is no data picked up. /// /// Otherwise, return a tuple: /// * `Range`: the range of data picked up (start inclusive, end exclusive). /// * `bool`: whether the data is new(not retransmitted). /// * `(&[u8], &[u8])`: the data picked up, duo to the internal buffer is a ring buffer, the data /// picked up is in two parts, the begin of the second slice are the end of the first slice pub fn pick_up

(&mut self, predicate: P, flow_limit: usize) -> Result, Signals> where P: Fn(u64) -> Option, { self.state .pick(predicate, flow_limit, self.max_data()) .map(|(range, is_fresh)| { let iter = self .data .iter() .scan(self.offset, |offset, data| { let current_range = *offset..*offset + data.len() as u64; *offset += data.len() as u64; Some((current_range, data)) }) .filter(move |(slice, ..)| slice.end > range.start && slice.start < range.end) .map(move |(slice, data)| { if slice.start >= range.start && slice.end <= range.end { data.clone() } else { data.slice( (range.start.saturating_sub(slice.start)) as usize ..(range.end.min(slice.end) - slice.start) as usize, ) } }); (range, is_fresh, iter.collect()) }) } /// Called when the `range` of data sent is acknowledged by the peer. /// /// The `range` is the range of data that has been acknowledged. // 通过传输层接收到的对方的ack帧,确认某些包已经被接收到,这些包携带的数据即被确认。 // ack只能确认Flighting/Lost状态的区间;如果确认的是Lost区间,意味着之前的判定丢包是错误的。 pub fn on_data_acked(&mut self, range: &Range) { self.state.ack_rcvd(range); // 对于头部连续确认接收到的,还要前进,以免浪费空间 let min_unrecved_pos = self.state.shift(); if self.offset < min_unrecved_pos { let mut drain_len = (min_unrecved_pos - self.offset) as usize; self.offset = min_unrecved_pos; while !self.data.is_empty() && drain_len > 0 { match drain_len { n if n >= self.data[0].len() => { drain_len -= self.data[0].len(); self.data.pop_front().unwrap(); } n => { self.data[0] = self.data[0].slice(n..); break; } } } } } /// Called when the `range` of data sent may be lost. /// /// The `range` is the range of data that may be lost. // 通过传输层收到的ack帧,判定有些数据包丢失,因为它之后的数据包都被确认了, // 或者距离发送该段数据之后相当长一段时间都没收到它的确认。 pub fn may_loss_data(&mut self, range: &Range) { self.state.may_loss(range); } pub fn resend_flighting(&mut self) { self.state.resend_flighting() } /// Return whether all data currently written has been received(acknowledged) by the peer. pub fn is_all_rcvd(&self) -> bool { self.data.is_empty() } } #[cfg(test)] mod tests { use qbase::net::tx::Signals; use super::{BufMap, Color, State}; #[test] fn test_state() { let state = State::encode(100, Color::Pending); assert_eq!(state.offset(), 100); assert_eq!(state.color(), Color::Pending); let mut state = State::encode(100, Color::Pending); state.set_color(Color::Flighting); assert_eq!(state.color(), Color::Flighting); let state = State::encode(100, Color::Pending); assert_eq!(state.decode(), (100, Color::Pending)); // test Dispaly assert_eq!(format!("{state}"), "[100: Pending]"); assert_eq!(format!("{state:?}"), "[100: Pending]"); } #[test] fn test_bufmap_empty() { let buf_map = BufMap::default(); assert!(buf_map.0.is_empty()); } #[test] fn test_bufmap_extend_to() { let mut buf_map = BufMap::default(); buf_map.extend_to(100); assert_eq!(buf_map.0, vec![State::encode(0, Color::Pending)]); assert_eq!(buf_map.1, 100); buf_map.0.get_mut(0).unwrap().set_color(Color::Flighting); buf_map.extend_to(200); assert_eq!( buf_map.0, vec![ State::encode(0, Color::Flighting), State::encode(100, Color::Pending) ] ); assert_eq!(buf_map.1, 200); } #[test] fn test_bufmap_pick() { let mut buf_map = BufMap::default(); let range = buf_map.pick(|_| Some(20), usize::MAX, u64::MAX); assert_eq!(range, Err(Signals::TRANSPORT | Signals::WRITTEN)); assert!(buf_map.0.is_empty()); buf_map.extend_to(200); let (range, is_fresh) = buf_map.pick(|_| Some(20), usize::MAX, u64::MAX).unwrap(); assert_eq!(range, 0..20); assert!(is_fresh); assert_eq!( buf_map.0, vec![ State::encode(0, Color::Flighting), State::encode(20, Color::Pending) ] ); let (range, is_fresh) = buf_map.pick(|_| Some(20), usize::MAX, u64::MAX).unwrap(); assert_eq!(range, 20..40); assert!(is_fresh); assert_eq!( buf_map.0, vec![ State::encode(0, Color::Flighting), State::encode(40, Color::Pending) ] ); buf_map.0.insert(2, State::encode(50, Color::Lost)); buf_map.0.insert(3, State::encode(120, Color::Pending)); let (range, is_fresh) = buf_map.pick(|_| Some(20), usize::MAX, u64::MAX).unwrap(); assert_eq!(range, 40..50); assert!(is_fresh); assert_eq!( buf_map.0, vec![ State::encode(0, Color::Flighting), State::encode(50, Color::Lost), State::encode(120, Color::Pending) ] ); buf_map.0.get_mut(0).unwrap().set_color(Color::Recved); let (range, is_fresh) = buf_map.pick(|_| Some(20), usize::MAX, u64::MAX).unwrap(); assert_eq!(range, 50..70); assert!(!is_fresh); assert_eq!( buf_map.0, vec![ State::encode(0, Color::Recved), State::encode(50, Color::Flighting), State::encode(70, Color::Lost), State::encode(120, Color::Pending) ] ); let (range, is_fresh) = buf_map.pick(|_| Some(130), usize::MAX, u64::MAX).unwrap(); assert_eq!(range, 70..120); assert!(!is_fresh); assert_eq!( buf_map.0, vec![ State::encode(0, Color::Recved), State::encode(50, Color::Flighting), State::encode(120, Color::Pending) ] ); let (range, is_fresh) = buf_map.pick(|_| Some(130), usize::MAX, u64::MAX).unwrap(); assert_eq!(range, 120..200); assert!(is_fresh); assert_eq!( buf_map.0, vec![ State::encode(0, Color::Recved), State::encode(50, Color::Flighting), ] ); let result = buf_map.pick(|_| Some(130), usize::MAX, u64::MAX); assert!(result.is_err()); assert_eq!( buf_map.0, vec![ State::encode(0, Color::Recved), State::encode(50, Color::Flighting), ] ); } #[test] fn test_bufmap_sent() { let mut buf_map = BufMap::default(); buf_map.extend_to(200); assert_eq!(buf_map.sent(), 0); assert!(buf_map.pick(|_| Some(120), usize::MAX, u64::MAX).is_ok()); assert_eq!(buf_map.sent(), 120); assert!(buf_map.pick(|_| Some(80), usize::MAX, u64::MAX).is_ok()); assert_eq!(buf_map.sent(), 200); } #[test] fn test_bufmap_recved() { let mut buf_map = BufMap::default(); buf_map.extend_to(200); assert!(buf_map.pick(|_| Some(120), usize::MAX, u64::MAX).is_ok()); buf_map.ack_rcvd(&(0..20)); assert_eq!( buf_map.0, vec![ State::encode(0, Color::Recved), State::encode(20, Color::Flighting), State::encode(120, Color::Pending) ] ); buf_map.ack_rcvd(&(30..50)); assert_eq!( buf_map.0, vec![ State::encode(0, Color::Recved), State::encode(20, Color::Flighting), State::encode(30, Color::Recved), State::encode(50, Color::Flighting), State::encode(120, Color::Pending) ] ); buf_map.ack_rcvd(&(25..55)); assert_eq!( buf_map.0, vec![ State::encode(0, Color::Recved), State::encode(20, Color::Flighting), State::encode(25, Color::Recved), State::encode(55, Color::Flighting), State::encode(120, Color::Pending) ] ); buf_map.ack_rcvd(&(20..25)); assert_eq!( buf_map.0, vec![ State::encode(0, Color::Recved), State::encode(55, Color::Flighting), State::encode(120, Color::Pending) ] ); buf_map.0.pop_front(); buf_map.ack_rcvd(&(20..55)); assert_eq!( buf_map.0, vec![ State::encode(55, Color::Flighting), State::encode(120, Color::Pending) ] ); buf_map.ack_rcvd(&(30..70)); assert_eq!( buf_map.0, vec![ State::encode(70, Color::Flighting), State::encode(120, Color::Pending) ] ); buf_map.ack_rcvd(&(100..119)); assert_eq!( buf_map.0, vec![ State::encode(70, Color::Flighting), State::encode(100, Color::Recved), State::encode(119, Color::Flighting), State::encode(120, Color::Pending) ] ); assert!(buf_map.pick(|_| Some(130), usize::MAX, u64::MAX).is_ok()); assert_eq!( buf_map.0, vec![ State::encode(70, Color::Flighting), State::encode(100, Color::Recved), State::encode(119, Color::Flighting), ] ); buf_map.ack_rcvd(&(119..150)); assert_eq!( buf_map.0, vec![ State::encode(70, Color::Flighting), State::encode(100, Color::Recved), State::encode(150, Color::Flighting), ] ); buf_map.ack_rcvd(&(150..200)); assert_eq!( buf_map.0, vec![ State::encode(70, Color::Flighting), State::encode(100, Color::Recved), ] ); } #[test] #[should_panic] fn test_bufmap_invalid_recved() { let mut buf_map = BufMap::default(); buf_map.extend_to(200); assert!(buf_map.pick(|_| Some(120), usize::MAX, u64::MAX).is_ok()); buf_map.ack_rcvd(&(20..40)); buf_map.0.insert(2, State::encode(30, Color::Pending)); assert_eq!( buf_map.0, vec![ State::encode(0, Color::Flighting), State::encode(20, Color::Recved), // Alerting: 30..40 is Pending, never been sent, but they will be Recved State::encode(30, Color::Pending), State::encode(40, Color::Flighting), State::encode(120, Color::Pending) ] ); buf_map.ack_rcvd(&(0..50)); } #[test] #[should_panic] fn test_bufmap_recved_overflow() { let mut buf_map = BufMap::default(); buf_map.extend_to(200); assert!(buf_map.pick(|_| Some(120), usize::MAX, u64::MAX).is_ok()); assert_eq!( buf_map.0, vec![ State::encode(0, Color::Flighting), State::encode(120, Color::Pending), ] ); buf_map.ack_rcvd(&(110..121)); } #[test] #[should_panic] fn test_bufmap_recved_over_end() { let mut buf_map = BufMap::default(); buf_map.extend_to(200); assert!(buf_map.pick(|_| Some(200), usize::MAX, u64::MAX).is_ok()); assert_eq!(buf_map.0, vec![State::encode(0, Color::Flighting)]); buf_map.ack_rcvd(&(0..201)); } #[test] fn test_bufmap_lost() { let mut buf_map = BufMap::default(); buf_map.extend_to(200); assert!(buf_map.pick(|_| Some(120), usize::MAX, u64::MAX).is_ok()); assert_eq!( buf_map.0, vec![ State::encode(0, Color::Flighting), State::encode(120, Color::Pending), ] ); buf_map.may_loss(&(0..20)); assert_eq!( buf_map.0, vec![ State::encode(0, Color::Lost), State::encode(20, Color::Flighting), State::encode(120, Color::Pending), ] ); buf_map.may_loss(&(30..50)); assert_eq!( buf_map.0, vec![ State::encode(0, Color::Lost), State::encode(20, Color::Flighting), State::encode(30, Color::Lost), State::encode(50, Color::Flighting), State::encode(120, Color::Pending), ] ); buf_map.ack_rcvd(&(0..10)); buf_map.ack_rcvd(&(70..100)); buf_map.0.pop_front(); assert_eq!( buf_map.0, vec![ State::encode(10, Color::Lost), State::encode(20, Color::Flighting), State::encode(30, Color::Lost), State::encode(50, Color::Flighting), State::encode(70, Color::Recved), State::encode(100, Color::Flighting), State::encode(120, Color::Pending), ] ); buf_map.may_loss(&(15..25)); assert_eq!( buf_map.0, vec![ State::encode(10, Color::Lost), State::encode(25, Color::Flighting), State::encode(30, Color::Lost), State::encode(50, Color::Flighting), State::encode(70, Color::Recved), State::encode(100, Color::Flighting), State::encode(120, Color::Pending), ] ); buf_map.may_loss(&(10..20)); assert_eq!( buf_map.0, vec![ State::encode(10, Color::Lost), State::encode(25, Color::Flighting), State::encode(30, Color::Lost), State::encode(50, Color::Flighting), State::encode(70, Color::Recved), State::encode(100, Color::Flighting), State::encode(120, Color::Pending), ] ); buf_map.may_loss(&(60..110)); assert_eq!( buf_map.0, vec![ State::encode(10, Color::Lost), State::encode(25, Color::Flighting), State::encode(30, Color::Lost), State::encode(50, Color::Flighting), State::encode(60, Color::Lost), State::encode(70, Color::Recved), State::encode(100, Color::Lost), State::encode(110, Color::Flighting), State::encode(120, Color::Pending), ] ); buf_map.ack_rcvd(&(20..55)); assert_eq!( buf_map.0, vec![ State::encode(10, Color::Lost), State::encode(20, Color::Recved), State::encode(55, Color::Flighting), State::encode(60, Color::Lost), State::encode(70, Color::Recved), State::encode(100, Color::Lost), State::encode(110, Color::Flighting), State::encode(120, Color::Pending), ] ); buf_map.may_loss(&(40..80)); assert_eq!( buf_map.0, vec![ State::encode(10, Color::Lost), State::encode(20, Color::Recved), State::encode(55, Color::Lost), State::encode(70, Color::Recved), State::encode(100, Color::Lost), State::encode(110, Color::Flighting), State::encode(120, Color::Pending), ] ); buf_map.ack_rcvd(&(20..120)); assert_eq!( buf_map.0, vec![ State::encode(10, Color::Lost), State::encode(20, Color::Recved), State::encode(120, Color::Pending), ] ); buf_map.may_loss(&(50..80)); assert_eq!( buf_map.0, vec![ State::encode(10, Color::Lost), State::encode(20, Color::Recved), State::encode(120, Color::Pending), ] ); buf_map.may_loss(&(2..10)); assert_eq!( buf_map.0, vec![ State::encode(10, Color::Lost), State::encode(20, Color::Recved), State::encode(120, Color::Pending), ] ); buf_map.may_loss(&(30..50)); assert_eq!( buf_map.0, vec![ State::encode(10, Color::Lost), State::encode(20, Color::Recved), State::encode(120, Color::Pending), ] ); } #[test] fn test_bufmap_ack_and_lost_all() { let mut buf_map = BufMap::default(); buf_map.extend_to(46); assert!(buf_map.pick(|_| Some(46), usize::MAX, u64::MAX).is_ok()); assert_eq!(buf_map.0, vec![State::encode(0, Color::Flighting)]); buf_map.ack_rcvd(&(0..2)); assert_eq!( buf_map.0, vec![ State::encode(0, Color::Recved), State::encode(2, Color::Flighting) ] ); buf_map.may_loss(&(0..46)); assert_eq!( buf_map.0, vec![ State::encode(0, Color::Recved), State::encode(2, Color::Lost) ] ) } #[test] fn test_bufmap_ack_and_lost_all2() { let mut buf_map = BufMap(vec![State::encode(2, Color::Flighting)].into(), 46); buf_map.may_loss(&(0..46)); assert_eq!(buf_map.0, vec![State::encode(2, Color::Lost)]) } } ================================================ FILE: qrecovery/src/send/writer.rs ================================================ use std::{ ops::DerefMut, pin::Pin, task::{Context, Poll, ready}, }; use bytes::Bytes; use futures::Sink; use qbase::frame::{ResetStreamFrame, io::SendFrame}; use tokio::io::{self, AsyncWrite}; use super::sender::{ArcSender, Sender}; use crate::streams::error::StreamError; pub trait CancelStream { /// Cancels the stream with the given error code. /// /// If all data has been sent and acknowledged by the peer, or the stream has been reset, this /// method will do nothing. /// /// Otherwise, a [`RESET_STREAM frame`] will be sent to the peer, and the stream will be reset, /// neither new data nor lost data will be sent. /// /// Unlike TCP, canceling a QUIC stream needs an error code, which is used to indicate /// the reason for the cancellation. The error code should be a `u64` value, /// defined by the application protocol using QUIC, such as HTTP/3 or gRPC. /// /// [`RESET_STREAM frame`]: https://www.rfc-editor.org/rfc/rfc9000.html#name-reset_stream-frames fn cancel(&mut self, err_code: u64); } /// The writer part of a QUIC stream. /// /// This struct implements the [`AsyncWrite`] trait, allowing you to write data to the stream. /// /// A QUIC stream is *reliable*, *ordered*, and *flow-controlled*. /// /// The amount of data that can be sent is limited by flow control. The [`write`] call will be blocked /// if the amount of data written reaches the flow control limit. /// /// The [`flush`] and [`shutdown`] calls will be blocked until all data written to [`Writer`] has /// been sent and acknowledged by the peer. /// /// # Note /// /// The stream must be cancelled or shutdowned before the [`Writer`] dropped. /// /// Call [`shutdown`] means that there are no more new data will been written to the stream. If all /// of the data written to the stream has been sent and acknowledged by the peer, the stream will be /// `closed`, and the [`shutdown`] call complete with `Ok(())`. /// /// Alternatively, if the operations on the [`Writer`] result an error, its indicates that the stream /// has been cancelled in other reason, such as connection closed, the peer acked local to stop sending. /// /// You can call [`cancel`] to `cancel` the stream with the given error code, The [`Writer`] will be /// consumed, and neither new data nor lost data will be sent anymore. /// /// # Example /// /// The [`Writer`] is created by the `open_bi_stream`, `open_uni_stream`, or `accept_bi_stream` methods of /// `QuicConnection` (in the `dquic` crate). /// /// The following example demonstrates how to read and write data on a QUIC stream. /// /// ```rust, ignore /// # use tokio::io::{AsyncWriteExt, AsyncReadExt}; /// # async fn example() -> std::io::Result<()> { /// let (reader, writer) = quic_connection.open_bi_stream().await?; /// /// writer.write_all(b"GET README.md\r\n").await?; /// writer.shutdown().await?; /// /// let mut response = String::new(); /// let n = reader.read_to_string(&mut response).await?; /// println!("Response {} bytes: {}", n, response); /// Ok(()) /// # } /// ``` /// /// [`write`]: tokio::io::AsyncWriteExt::write /// [`flush`]: tokio::io::AsyncWriteExt::flush /// [`shutdown`]: tokio::io::AsyncWriteExt::shutdown /// [`cancel`]: Writer::cancel /// [`STOP_SENDING frame`]: https://www.rfc-editor.org/rfc/rfc9000.html#name-stop_sending-frames #[derive(Debug)] pub struct Writer { inner: ArcSender, qlog_span: qevent::telemetry::Span, tracing_span: tracing::Span, } impl Writer { pub(crate) fn new(inner: ArcSender) -> Self { Self { inner, qlog_span: qevent::telemetry::Span::current(), tracing_span: tracing::Span::current(), } } } impl CancelStream for Writer where TX: SendFrame, { /// Cancels the stream with the given error code(reset the stream). /// /// If all data has been sent and acknowledged by the peer(the stream has closed), or the stream /// has been reset, this method will do nothing. /// /// Otherwise, a [`RESET_STREAM frame`] will be sent to the peer, and the stream will be reset, /// neither new data nor lost data will be sent. /// /// [`RESET_STREAM frame`]: https://www.rfc-editor.org/rfc/rfc9000.html#name-reset_stream-frames fn cancel(&mut self, err_code: u64) { let _span = (self.qlog_span.enter(), self.tracing_span.enter()); let mut sender = self.inner.sender(); let inner = sender.deref_mut(); if let Ok(sending_state) = inner { match sending_state { Sender::Ready(s) => { *sending_state = Sender::ResetSent(s.cancel(err_code)); } Sender::Sending(s) => { *sending_state = Sender::ResetSent(s.cancel(err_code)); } Sender::DataSent(s) => { *sending_state = Sender::ResetSent(s.cancel(err_code)); } _ => (), } }; } } impl Writer { /// Poll to check whether [`Writer`] can cache more appropriate amount of data. /// /// Even without calling this method in advance, writing data can succeed. /// However, this may cause the QUIC layer to cache excessive data. #[inline] pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { let _span = (self.qlog_span.enter(), self.tracing_span.enter()); let mut sender = self.inner.sender(); let sending_state = sender.as_mut().map_err(|e| e.clone())?; match sending_state { Sender::Ready(s) => s.poll_ready(cx), Sender::Sending(s) => s.poll_ready(cx), Sender::DataSent(_) => Poll::Ready(Err(StreamError::EosSent)), Sender::DataRcvd => Poll::Ready(Err(StreamError::EosSent)), Sender::ResetSent(reset) => Poll::Ready(Err(StreamError::Reset(*reset))), Sender::ResetRcvd(reset) => Poll::Ready(Err(StreamError::Reset(*reset))), } } /// Write data to the stream. /// /// Although data written by this method can also be sent, /// it is recommended to use the `Sink` or `AsyncWrite` API to avoid excessive data caching at the QUIC layer. #[inline] pub fn write(&mut self, buf: Bytes) -> Result<(), StreamError> { let _span = (self.qlog_span.enter(), self.tracing_span.enter()); let mut sender = self.inner.sender(); let sending_state = sender.as_mut().map_err(|e| e.clone())?; match sending_state { Sender::Ready(s) => s.write(buf), Sender::Sending(s) => s.write(buf), Sender::DataSent(_) => Err(StreamError::EosSent), Sender::DataRcvd => Err(StreamError::EosSent), Sender::ResetSent(reset) => Err(StreamError::Reset(*reset)), Sender::ResetRcvd(reset) => Err(StreamError::Reset(*reset)), } } #[inline] pub fn poll_write( &mut self, cx: &mut Context<'_>, data: Bytes, ) -> Poll> { let _span = (self.qlog_span.enter(), self.tracing_span.enter()); let mut sender = self.inner.sender(); let sending_state = sender.as_mut().map_err(|e| e.clone())?; match sending_state { Sender::Ready(s) => { ready!(s.poll_ready(cx)?); Poll::Ready(s.write(data)) } Sender::Sending(s) => { ready!(s.poll_ready(cx)?); Poll::Ready(s.write(data)) } Sender::DataSent(_) => Poll::Ready(Err(StreamError::EosSent)), Sender::DataRcvd => Poll::Ready(Err(StreamError::EosSent)), Sender::ResetSent(reset) => Poll::Ready(Err(StreamError::Reset(*reset))), Sender::ResetRcvd(reset) => Poll::Ready(Err(StreamError::Reset(*reset))), } } #[inline] pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { let _span = (self.qlog_span.enter(), self.tracing_span.enter()); let mut sender = self.inner.sender(); let sending_state = sender.as_mut().map_err(|e| e.clone())?; match sending_state { Sender::Ready(s) => s.poll_flush(cx).map(Ok), Sender::Sending(s) => s.poll_flush(cx).map(Ok), Sender::DataSent(s) => s.poll_flush(cx).map(Ok), Sender::DataRcvd => Poll::Ready(Ok(())), Sender::ResetSent(reset) => Poll::Ready(Err(StreamError::Reset(*reset))), Sender::ResetRcvd(reset) => Poll::Ready(Err(StreamError::Reset(*reset))), } } #[inline] #[doc(alias = "poll_close")] pub fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { let _span = (self.qlog_span.enter(), self.tracing_span.enter()); let mut sender = self.inner.sender(); let sending_state = sender.as_mut().map_err(|e| e.clone())?; match sending_state { Sender::Ready(s) => s.poll_shutdown(cx).map(Ok), Sender::Sending(s) => s.poll_shutdown(cx).map(Ok), Sender::DataSent(s) => s.poll_shutdown(cx).map(Ok), Sender::DataRcvd => Poll::Ready(Ok(())), Sender::ResetSent(reset) => Poll::Ready(Err(StreamError::Reset(*reset))), Sender::ResetRcvd(reset) => Poll::Ready(Err(StreamError::Reset(*reset))), } } } impl AsyncWrite for Writer { #[inline] fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { Writer::poll_write(self.get_mut(), cx, Bytes::copy_from_slice(buf)) .map_ok(|()| buf.len()) .map_err(io::Error::from) } #[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Writer::poll_flush(self.get_mut(), cx).map_err(io::Error::from) } #[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Writer::poll_shutdown(self.get_mut(), cx).map_err(io::Error::from) } } impl Sink for Writer { type Error = StreamError; #[inline] fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Writer::poll_ready(self.get_mut(), cx) } #[inline] fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { Writer::write(self.get_mut(), item) } #[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Writer::poll_flush(self.get_mut(), cx) } #[inline] fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Writer::poll_shutdown(self.get_mut(), cx) } } impl Drop for Writer { fn drop(&mut self) { let mut sender = self.inner.sender(); let inner = sender.deref_mut(); if let Ok(sending_state) = inner { match sending_state { Sender::Ready(s) => { #[cfg(debug_assertions)] tracing::warn!( target: "quic", "The sending {} is not closed before dropped!", s.stream_id(), ); #[cfg(not(debug_assertions))] tracing::debug!( target: "quic", "The sending {} is not closed before dropped!", s.stream_id(), ); } Sender::Sending(s) => { #[cfg(debug_assertions)] tracing::warn!( target: "quic", "The sending {} is not closed before dropped!", s.stream_id(), ); #[cfg(not(debug_assertions))] tracing::debug!( target: "quic", "The sending {} is not closed before dropped!", s.stream_id(), ); } _ => (), } }; } } ================================================ FILE: qrecovery/src/send.rs ================================================ //! Types for sending data on a Stream. mod outgoing; mod sender; mod sndbuf; mod writer; pub use outgoing::Outgoing; pub use sender::ArcSender; pub use sndbuf::SendBuf; pub use writer::{CancelStream, Writer}; ================================================ FILE: qrecovery/src/streams/error.rs ================================================ use std::io; use qbase::{error::Error, frame::ResetStreamError}; use thiserror::Error; #[derive(Error, Debug, Clone, PartialEq, Eq)] pub enum StreamError { #[error(transparent)] Connection(#[from] Error), #[error(transparent)] Reset(#[from] ResetStreamError), #[error("EOS has been sent")] EosSent, } impl From for io::Error { fn from(value: StreamError) -> Self { match value { error @ (StreamError::Connection(..) | StreamError::Reset(..)) => { io::Error::new(io::ErrorKind::BrokenPipe, error) } error @ StreamError::EosSent => io::Error::new(io::ErrorKind::Unsupported, error), } } } ================================================ FILE: qrecovery/src/streams/io.rs ================================================ use std::{ collections::{BTreeMap, HashMap}, ops::{Deref, DerefMut}, sync::{ Arc, Mutex, MutexGuard, atomic::{AtomicU8, Ordering}, }, }; use derive_more::{Deref, DerefMut}; use qbase::{ error::Error as QuicError, sid::{Dir, StreamId}, }; use crate::{recv::Incoming, send::Outgoing}; #[derive(Debug, Clone)] pub(super) struct IOState(Arc); impl IOState { const SENDING: u8 = 0x1; const RECEIVING: u8 = 0x2; pub fn send_only() -> Self { Self(Arc::new(AtomicU8::new(Self::SENDING))) } pub fn receive_only() -> Self { Self(Arc::new(AtomicU8::new(Self::RECEIVING))) } pub fn bidirection() -> Self { Self(Arc::new(AtomicU8::new(Self::SENDING | Self::RECEIVING))) } pub fn is_terminated(&self) -> bool { self.0.load(Ordering::Acquire) == 0 } pub fn shutdown_send(&self) { self.0.fetch_and(!Self::SENDING, Ordering::Release); } pub fn shutdown_receive(&self) { self.0.fetch_and(!Self::RECEIVING, Ordering::Release); } } #[derive(Debug, Clone, Deref, DerefMut)] pub(super) struct Output { #[deref] #[deref_mut] pub(super) outgoings: BTreeMap, IOState)>, pub(super) cursor: Option<(StreamId, usize)>, } impl Output { fn new() -> Self { Self { outgoings: BTreeMap::default(), cursor: None, } } } /// ArcOutput里面包含一个Result类型,一旦发生quic error,就会被替换为Err /// 发生quic error后,其操作将被忽略,不会再抛出QuicError或者panic,因为 /// 有些异步任务可能还未完成,在置为Err后才会完成。 #[derive(Debug, Clone)] pub(super) struct ArcOutput(Arc, QuicError>>>); impl ArcOutput { pub(super) fn new() -> Self { Self(Arc::new(Mutex::new(Ok(Output::new())))) } pub(super) fn streams(&self) -> MutexGuard<'_, Result, QuicError>> { self.0.lock().unwrap() } pub(super) fn guard(&'_ self) -> Result, QuicError> { let streams = self.streams(); match streams.as_ref() { Ok(_) => Ok(ArcOutputGuard(streams)), Err(e) => Err(e.clone()), } } } pub(super) struct ArcOutputGuard<'a, TX>(MutexGuard<'a, Result, QuicError>>); impl Deref for ArcOutputGuard<'_, TX> { type Target = Output; fn deref(&self) -> &Self::Target { match self.0.as_ref() { Ok(output) => output, Err(e) => unreachable!("output is invalid: {e}"), } } } impl DerefMut for ArcOutputGuard<'_, TX> { fn deref_mut(&mut self) -> &mut Self::Target { match self.0.as_mut() { Ok(output) => output, Err(e) => unreachable!("output is invalid: {e}"), } } } impl ArcOutputGuard<'_, TX> { pub(super) fn insert(&mut self, sid: StreamId, outgoing: Outgoing, io_state: IOState) { self.deref_mut().insert(sid, (outgoing, io_state)); } pub(super) fn revise_max_stream_data( &self, zero_rtt_rejected: bool, opened_bidi: u64, opened_uni: u64, bidi_snd_wnd_size: u64, uni_snd_wnd_size: u64, ) { self.deref() .iter() .filter(|(sid, _)| { sid.dir() == Dir::Bi && sid.id() < opened_bidi || sid.dir() == Dir::Uni && sid.id() < opened_uni }) .for_each(|(sid, (outgoing, _))| match sid.dir() { Dir::Bi => outgoing.revise_max_stream_data(zero_rtt_rejected, bidi_snd_wnd_size), Dir::Uni => outgoing.revise_max_stream_data(zero_rtt_rejected, uni_snd_wnd_size), }); } pub(super) fn on_conn_error(&mut self, error: &QuicError) { self.deref() .values() .for_each(|(o, _)| o.on_conn_error(error)); *self.0 = Err(error.clone()); } } /// ArcInput里面包含一个Result类型,一旦发生quic error,就会被替换为Err /// 发生quic error后,其操作将被忽略,不会再抛出QuicError或者panic,因为 /// 有些异步任务可能还未完成,在置为Err后才会完成。 #[allow(clippy::type_complexity)] #[derive(Debug, Clone)] pub(super) struct ArcInput( Arc, IOState)>, QuicError>>>, ); impl Default for ArcInput { fn default() -> Self { Self(Arc::new(Mutex::new(Ok(HashMap::new())))) } } impl ArcInput { #[allow(clippy::type_complexity)] pub(super) fn streams( &self, ) -> MutexGuard<'_, Result, IOState)>, QuicError>> { self.0.lock().unwrap() } pub(super) fn guard(&self) -> Result, QuicError> { let guard = self.0.lock().unwrap(); match guard.as_ref() { Ok(_) => Ok(ArcInputGuard { inner: guard }), Err(e) => Err(e.clone()), } } } #[allow(clippy::type_complexity)] pub(super) struct ArcInputGuard<'a, TX> { inner: MutexGuard<'a, Result, IOState)>, QuicError>>, } impl ArcInputGuard<'_, TX> { pub(super) fn insert(&mut self, sid: StreamId, incoming: Incoming, io_state: IOState) { match self.inner.as_mut() { Ok(set) => set.insert(sid, (incoming, io_state)), Err(e) => unreachable!("input is invalid: {e}"), }; } pub(super) fn on_conn_error(&mut self, error: &QuicError) { match self.inner.as_ref() { Ok(set) => set.values().for_each(|(o, _)| o.on_conn_error(error)), Err(e) => unreachable!("output is invalid: {e}"), }; *self.inner = Err(error.clone()); } } ================================================ FILE: qrecovery/src/streams/listener.rs ================================================ use std::{ collections::VecDeque, future::Future, pin::Pin, sync::{Arc, Mutex, MutexGuard}, task::{Context, Poll, Waker, ready}, }; use qbase::{ error::Error as QuicError, frame::{ResetStreamFrame, io::SendFrame}, param::{ArcParameters, ParameterId}, sid::StreamId, }; use crate::{ recv::{ArcRecver, Reader}, send::{ArcSender, Writer}, }; #[derive(Debug)] struct Listener { // 对方主动创建的流 #[allow(clippy::type_complexity)] bi_streams: VecDeque<(StreamId, (ArcRecver, ArcSender))>, uni_streams: VecDeque<(StreamId, ArcRecver)>, bi_waker: Option, uni_waker: Option, } impl Listener { fn new() -> Self { Self { bi_streams: VecDeque::with_capacity(4), uni_streams: VecDeque::with_capacity(2), bi_waker: None, uni_waker: None, } } fn push_bi_stream(&mut self, sid: StreamId, stream: (ArcRecver, ArcSender)) { self.bi_streams.push_back((sid, stream)); if let Some(waker) = self.bi_waker.take() { waker.wake(); } } fn push_recv_stream(&mut self, sid: StreamId, stream: ArcRecver) { self.uni_streams.push_back((sid, stream)); if let Some(waker) = self.uni_waker.take() { waker.wake(); } } #[allow(clippy::type_complexity)] fn poll_accept_bi_stream( &mut self, cx: &mut Context<'_>, arc_params: &ArcParameters, ) -> Poll, Writer)), QuicError>> { let mut params = arc_params.lock_guard()?; let snd_buf_size = match params.get_remote(ParameterId::InitialMaxStreamDataBidiLocal) { Some(value) => value, None => { ready!(params.poll_ready(cx)); return self.poll_accept_bi_stream(cx, arc_params); } }; if let Some((sid, (recver, sender))) = self.bi_streams.pop_front() { sender.update_window(snd_buf_size); // recver.update_window(rcv_buf_size); Poll::Ready(Ok((sid, (Reader::new(recver), Writer::new(sender))))) } else { self.bi_waker = Some(cx.waker().clone()); Poll::Pending } } fn poll_accept_recv_stream( &mut self, cx: &mut Context<'_>, ) -> Poll), QuicError>> { if let Some((sid, recver)) = self.uni_streams.pop_front() { // recver.update_window(rcv_buf_size); Poll::Ready(Ok((sid, Reader::new(recver)))) } else { self.uni_waker = Some(cx.waker().clone()); Poll::Pending } } } #[derive(Debug, Clone)] pub struct ArcListener(Arc, QuicError>>>); impl ArcListener { pub(crate) fn new() -> Self { Self(Arc::new(Mutex::new(Ok(Listener::new())))) } pub(crate) fn guard(&self) -> Result, QuicError> { let guard = self.0.lock().unwrap(); match guard.as_ref() { Ok(_) => Ok(ListenerGuard { inner: guard }), Err(e) => Err(e.clone()), } } pub fn accept_bi_stream<'a>(&'a self, params: &'a ArcParameters) -> AcceptBiStream<'a, TX> { AcceptBiStream { listener: self, params, } } pub fn accept_uni_stream(&self) -> AcceptUniStream<'_, TX> { AcceptUniStream { listener: self } } #[allow(clippy::type_complexity)] pub fn poll_accept_bi_stream( &self, cx: &mut Context<'_>, arc_params: &ArcParameters, ) -> Poll, Writer)), QuicError>> { match self.0.lock().unwrap().as_mut() { Ok(set) => set.poll_accept_bi_stream(cx, arc_params), Err(e) => Poll::Ready(Err(e.clone())), } } pub fn poll_accept_uni_stream( &self, cx: &mut Context<'_>, ) -> Poll), QuicError>> { match self.0.lock().unwrap().as_mut() { Ok(set) => set.poll_accept_recv_stream(cx), Err(e) => Poll::Ready(Err(e.clone())), } } } pub(crate) struct ListenerGuard<'a, TX> { inner: MutexGuard<'a, Result, QuicError>>, } impl ListenerGuard<'_, TX> where TX: SendFrame + Clone + Send + 'static, { pub(crate) fn push_bi_stream(&mut self, sid: StreamId, stream: (ArcRecver, ArcSender)) { match self.inner.as_mut() { Ok(set) => set.push_bi_stream(sid, stream), Err(e) => unreachable!("listener is invalid: {e}"), } } pub(crate) fn push_uni_stream(&mut self, sid: StreamId, stream: ArcRecver) { match self.inner.as_mut() { Ok(set) => set.push_recv_stream(sid, stream), Err(e) => unreachable!("listener is invalid: {e}"), } } pub(crate) fn on_conn_error(&mut self, e: &QuicError) { match self.inner.as_mut() { Ok(set) => { if let Some(waker) = set.bi_waker.take() { waker.wake(); } if let Some(waker) = set.uni_waker.take() { waker.wake(); } } Err(e) => unreachable!("listener is invalid: {e}"), }; *self.inner = Err(e.clone()); } } /// Future to accept a bidirectional stream. /// /// This future is created by `accept_bi_stream` method of `QuicConnection`. /// /// When the peer created a new bidirectional stream, the future will resolve with a [`Reader`] and /// a [`Writer`] to read and write data on the stream. #[derive(Debug, Clone)] pub struct AcceptBiStream<'a, TX> { listener: &'a ArcListener, params: &'a ArcParameters, } impl Future for AcceptBiStream<'_, TX> where TX: SendFrame + Clone + Send + 'static, { type Output = Result<(StreamId, (Reader, Writer)), QuicError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.listener.poll_accept_bi_stream(cx, self.params) } } /// Future to accept a bidirectional stream. /// /// This future is created by `accept_uni_stream` method of `QuicConnection`. /// /// When the peer created a new bidirectional stream, the future will resolve with a [`Reader`] to /// read data on the stream. #[derive(Debug, Clone)] pub struct AcceptUniStream<'l, TX> { listener: &'l ArcListener, } impl Future for AcceptUniStream<'_, TX> where TX: SendFrame + Clone + Send + 'static, { type Output = Result<(StreamId, Reader), QuicError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.listener.poll_accept_uni_stream(cx) } } ================================================ FILE: qrecovery/src/streams/raw.rs ================================================ use std::{ sync::{ Arc, atomic::{AtomicBool, Ordering::*}, }, task::{Context, Poll, ready}, }; use bytes::{BufMut, Bytes}; use qbase::{ error::{Error, ErrorKind, QuicError}, flow::ArcSendControler, frame::{ DataBlockedFrame, FrameType, GetFrameType, ResetStreamFrame, STREAM_FRAME_MAX_ENCODING_SIZE, StreamCtlFrame, StreamFrame, io::{ReceiveFrame, SendFrame}, }, net::tx::{ArcSendWakers, Signals}, packet::{Package, PacketContent}, param::{ArcParameters, ParameterId, core::Parameters}, role::Role, sid::{ ControlStreamsConcurrency, Dir, StreamId, StreamIds, remote_sid::{AcceptSid, ExceedLimitError}, }, varint::VarInt, }; use super::{ Ext, io::{ArcInput, ArcOutput, IOState}, listener::{AcceptBiStream, AcceptUniStream, ArcListener}, }; use crate::{ recv::{ArcRecver, Incoming, Reader}, send::{ArcSender, Outgoing, Writer}, }; /// Manage all streams in the connection, send and receive frames, handle frame loss, and acknowledge. /// /// The struct dont truly send and receive frames, this struct provides interfaces to generate frames /// will be sent to the peer, receive frames, handle frame loss, and acknowledge. /// /// [`Outgoing`], [`Incoming`] , [`Writer`] and [`Reader`] dont truly send and receive frames, too. /// /// # Send frames /// /// ## Stream frame /// /// When the application wants to send data to the peer, it will call [`write`] method on [`Writer`] /// to write data to the [`SendBuf`]. /// /// Protocol layer will call [`try_load_data_into`] to read data from the streams into stream frames and /// write the frame into the quic packet. /// /// ## Stream control frame /// /// Be different from the stream frame, the stream control frame is much samller in size. /// /// The struct has a generic type `T`, which must implement the [`SendFrame`] trait. The trait has /// a method [`send_frame`], which will be called to send the stream control frame to the peer, see /// [`SendFrame`] for more details. /// /// # Receive frames, handle frame loss and acknowledge /// /// Frames received, frames lost or acknowledgmented will be delivered to the corresponding method. /// | method on [`DataStreams`] | corresponding method | /// | -------------------------------------------------------- | -------------------------------------------------- | /// | [`recv_data`] | [`Incoming::recv_data`] | /// | [`recv_stream_control`] ([`RESET_STREAM frame`]) | [`Incoming::recv_reset`] | /// | [`recv_stream_control`] ([`STOP_SENDING frame`]) | [`Outgoing::be_stopped`] | /// | [`recv_stream_control`] ([`MAX_STREAM_DATA frame`]) | [`Outgoing::update_window`] | /// | [`recv_stream_control`] ([`STREAM_DATA_BLOCKED frame`]) | none(the frame will be ignored) | /// | [`recv_stream_control`] ([`MAX_STREAMS frame`]) | [`ArcLocalStreamIds::recv_max_streams_frame`] | /// | [`recv_stream_control`] ([`STREAMS_BLOCKED frame`]) | [`ArcRemoteStreamIds::recv_streams_blocked_frame`] | /// | [`on_data_acked`] | [`Outgoing::on_data_acked`] | /// | [`may_loss_data`] | [`Outgoing::may_loss_data`] | /// | [`on_reset_acked`] | [`Outgoing::on_reset_acked`] | /// /// # Create and accept streams /// /// Stream frames and stream control frames have the function of creating flows. If a steam frame is /// received but the corresponding stream has not been created, a stream will be created passively. /// /// [`AcceptBiStream`] and [`AcceptUniStream`] are provided to the application layer to `accept` a /// stream (obtain a passively created stream). These future will be resolved when a stream is created /// by peer. /// /// Alternatively, sending a stream frame or a stream control frame will create a stream actively. /// [`OpenBiStream`] and [`OpenUniStream`] are provided to the application layer to `open` a stream. /// These future will be resolved when the connection established. /// /// [`write`]: tokio::io::AsyncWriteExt::write /// [`SendBuf`]: crate::send::SendBuf /// [`send_frame`]: SendFrame::send_frame /// [`try_load_data_into`]: DataStreams::try_load_data_into /// [`recv_data`]: DataStreams::recv_data /// [`recv_stream_control`]: DataStreams::recv_stream_control /// [`on_data_acked`]: DataStreams::on_data_acked /// [`may_loss_data`]: DataStreams::may_loss_data /// [`on_reset_acked`]: DataStreams::on_reset_acked /// [`RESET_STREAM frame`]: https://www.rfc-editor.org/rfc/rfc9000.html#name-reset_stream-frame /// [`STOP_SENDING frame`]: https://www.rfc-editor.org/rfc/rfc9000.html#name-stop_sending-frames /// [`MAX_STREAM_DATA frame`]: https://www.rfc-editor.org/rfc/rfc9000.html#name-max_stream_data-frame /// [`MAX_STREAMS frame`]: https://www.rfc-editor.org/rfc/rfc9000.html#name-max_streams-frame /// [`STREAM_DATA_BLOCKED frame`]: https://www.rfc-editor.org/rfc/rfc9000.html#name-stream_data_blocked-frame /// [`STREAMS_BLOCKED frame`]: https://www.rfc-editor.org/rfc/rfc9000.html#name-streams_blocked-frame /// [`OpenBiStream`]: crate::streams::OpenBiStream /// [`OpenUniStream`]: crate::streams::OpenUniStream /// [`ArcLocalStreamIds::recv_max_streams_frame`]: qbase::sid::ArcLocalStreamIds::recv_max_streams_frame /// [`ArcRemoteStreamIds::recv_streams_blocked_frame`]: qbase::sid::ArcRemoteStreamIds::recv_streams_blocked_frame /// #[derive(Debug)] pub struct DataStreams { // 该queue与space中的transmitter中的frame_queue共享,为了方便向transmitter中写入帧 ctrl_frames: TX, role: Role, stream_ids: StreamIds, Ext>, // 所有流的待写端,要发送数据,就得向这些流索取 output: ArcOutput>, // 所有流的待读端,收到了数据,交付给这些流 input: ArcInput>, // 对方主动创建的流 listener: ArcListener>, tls_fin: AtomicBool, tx_wakers: ArcSendWakers, initial_max_stream_data_bidi_local: u64, initial_max_stream_data_bidi_remote: u64, initial_max_stream_data_uni: u64, metrics: Option, } fn wrapper_error(fty: FrameType) -> impl FnOnce(ExceedLimitError) -> QuicError { move |e| QuicError::new(ErrorKind::StreamLimit, fty.into(), e.to_string()) } impl DataStreams where TX: SendFrame + Clone + Send + 'static, { /// Try to load data from streams into the `packet`, /// with a `flow_limit` which limits the max size of fresh data. /// Returns the size of fresh data. fn try_load_data_into_once( &self, packet: &mut P, flow_ctrl: &ArcSendControler, zero_rtt: bool, ) -> Result<(), Signals> where P: BufMut + ?Sized, for<'a> (StreamFrame, &'a [Bytes]): Package

, FTX: SendFrame, { // todo: use core::range instead in rust 2024 use core::ops::Bound::*; if packet.remaining_mut() < STREAM_FRAME_MAX_ENCODING_SIZE { return Err(Signals::CONGESTION); } let mut guard = self.output.streams(); let output = guard.as_mut().map_err(|_| Signals::empty())?; // connection closed if zero_rtt && self.tls_fin.load(Acquire) { return Err(Signals::TLS_FIN); // should load 1rtt } let Ok(mut credit) = flow_ctrl.credit(packet.remaining_mut()) else { return Err(Signals::empty()); // connection closed }; fn try_load_data_into_once<'s, P, TX: 's + Clone>( streams: impl Iterator, IOState), usize)>, packet: &mut P, flow_limit: usize, ) -> Result<(StreamId, usize, usize), Signals> where P: BufMut + ?Sized, for<'a> (StreamFrame, &'a [Bytes]): Package

, { let mut signals = Signals::TRANSPORT; for (sid, (outgoing, _ios), tokens) in streams { match outgoing.try_load_data_into(packet, sid, flow_limit, tokens) { Ok((data_len, is_fresh)) => { let remain_tokens = tokens - data_len; let fresh_bytes = if is_fresh { data_len } else { 0 }; return Ok((sid, remain_tokens, fresh_bytes)); } Err(s) => signals |= s, } } Err(signals) } // 不一定所有流都允许被发送,比如,0rtt被拒绝max_streams会倒缩,此时大于max_streams的流就不允许被发送 let remote_role = self.stream_ids.remote.role(); let max_streams_bidi = self.stream_ids.local.opened_streams(Dir::Bi); let max_streams_uni = self.stream_ids.local.opened_streams(Dir::Uni); let stream_allowed = |sid: &StreamId| { sid.role() == remote_role || sid.dir() == Dir::Bi && sid.id() < max_streams_bidi || sid.dir() == Dir::Uni && sid.id() < max_streams_uni }; // 该tokens是令牌桶算法的token,为了多条Stream的公平性,给每个流定期地发放tokens,不累积 // 各流轮流按令牌桶算法发放的tokens来整理数据去发送 const DEFAULT_TOKENS: usize = 4096; let (sid, remain_tokens, fresh_bytes) = match &output.cursor { // rev([..=sid]) + rev([sid+1..]) Some((sid, tokens)) if *tokens == 0 => try_load_data_into_once( (output.outgoings.range(..=sid).rev()) .chain(output.outgoings.range((Excluded(sid), Unbounded)).rev()) .map(|(sid, outgoing)| (*sid, outgoing, DEFAULT_TOKENS)) .filter(|(sid, ..)| stream_allowed(sid)), packet, credit.available(), ), // [sid] + rev([..sid]) + rev([sid+1..]) Some((sid, tokens)) => try_load_data_into_once( Option::into_iter( output .outgoings .get(sid) .map(|outgoing| (*sid, outgoing, *tokens)), ) .chain( (output.outgoings.range(..sid).rev()) .chain(output.outgoings.range((Excluded(sid), Unbounded)).rev()) .map(|(sid, outgoing)| (*sid, outgoing, DEFAULT_TOKENS)), ) .filter(|(sid, ..)| stream_allowed(sid)), packet, credit.available(), ), // rev([..]) None => try_load_data_into_once( (output.outgoings.range(..).rev()) .map(|(sid, outgoing)| (*sid, outgoing, DEFAULT_TOKENS)) .filter(|(sid, ..)| stream_allowed(sid)), packet, credit.available(), ), }?; output.cursor = Some((sid, remain_tokens)); credit.post_sent(fresh_bytes); // Update metrics when fresh data is sent if fresh_bytes > 0 && let Some(metrics) = &self.metrics { metrics.on_data_sent(fresh_bytes as u64); } Ok(()) } #[inline] pub fn package( self: &Arc, flow_ctrl: ArcSendControler, zero_rtt: bool, ) -> StreamFramePackages where TX: SendFrame, { StreamFramePackages { data_stream: self.clone(), flow_ctrl, zero_rtt, } } /// Try to load data from streams into the packet. /// /// # Fairness /// /// It's fair between streams. /// /// We have implemented a token bucket algorithm, and this method will read the data of each stream /// sequentially. Starting from the first stream, when a stream exhausts its tokens (default is 4096, /// depending on the priority of the stream), or there is no data to send, the method will move to /// the next stream, and so on. /// /// # Flow control /// /// QUIC employs a limit-based flow control scheme where a receiver advertises the limit of total /// bytes it is prepared to receive on a given stream or for the entire connection. This leads to /// two levels of data flow control in QUIC, stream level and connection level. /// /// Stream-level flow control had limited by the [`write`] calls on [`Writer`], if the application /// wants to write more data than the stream's flow control limit , the [`write`] call will be /// blocked until the sending window is updated. /// /// For connection-level flow control, it's limited by the parameter `flow_limit` of this method. /// The amount of new data(never sent) will be read from the stream is less or equal to `flow_limit`. /// /// # Returns /// /// If no data written to the buffer, the method will return [`None`], or a tuple will be /// returned: /// /// * [`StreamFrame`]: The stream frame to be sent. /// * [`usize`]: The number of bytes written to the buffer. /// * [`usize`]: The number of new data writen to the buffer. /// /// [`write`]: tokio::io::AsyncWriteExt::write pub fn try_load_data_into( &self, packet: &mut P, flow_ctrl: &ArcSendControler, zero_rtt: bool, ) -> Result<(), Signals> where P: BufMut + ?Sized, for<'a> (StreamFrame, &'a [Bytes]): Package

, FTX: SendFrame, { use core::ops::ControlFlow::*; // 取唯一一个最新的错误(如果有) let (Continue(result) | Break(result)) = core::iter::from_fn(|| Some(self.try_load_data_into_once(packet, flow_ctrl, zero_rtt))) .try_fold(Err(Signals::empty()), |result, once| match (result, once) { (_, Ok(())) => Continue(Ok(())), (Ok(()), Err(_no_more)) => Break(Ok(())), (Err(_), Err(signals)) => Break(Err(signals)), }); result } /// Called when the stream frame acked. /// /// Actually calls the [`Outgoing::on_data_acked`] method of the corresponding stream. pub fn on_data_acked(&self, frame: StreamFrame) { if let Ok(set) = self.output.streams().as_mut() { let mut is_all_rcvd = false; if let Some((o, s)) = set.get(&frame.stream_id()) { is_all_rcvd = o.on_data_acked(&frame); // Update metrics when data is acknowledged let acked_len = frame.range().end - frame.range().start; if acked_len > 0 && let Some(metrics) = &self.metrics { metrics.on_data_acked(acked_len); } if is_all_rcvd { s.shutdown_send(); if s.is_terminated() { self.stream_ids.remote.on_end_of_stream(frame.stream_id()); } } } if is_all_rcvd { set.remove(&frame.stream_id()); } } } /// Called when the stream frame may lost. /// /// Actually calls the [`Outgoing::may_loss_data`] method of the corresponding stream. pub fn may_loss_data(&self, stream_frame: &StreamFrame) { if let Some((o, _s)) = self .output .streams() .as_mut() .ok() .and_then(|set| set.get(&stream_frame.stream_id())) { o.may_loss_data(stream_frame); } } /// Called when the stream reset frame acked. /// /// Actually calls the [`Outgoing::on_reset_acked`] method of the corresponding stream. pub fn on_reset_acked(&self, reset_frame: ResetStreamFrame) { if let Ok(set) = self.output.streams().as_mut() && let Some((o, s)) = set.remove(&reset_frame.stream_id()) { o.on_reset_acked(reset_frame.stream_id()); s.shutdown_send(); if s.is_terminated() { self.stream_ids .remote .on_end_of_stream(reset_frame.stream_id()); } } // 如果流是双向的,接收部分的流独立地管理结束。其实是上层应用决定接收的部分是否同时结束 } /// Called when a stream frame which from peer is received by local. /// /// If the correspoding stream is not exist, `accept` the stream. /// /// Actually calls the [`Incoming::recv_data`] method of the corresponding stream. pub fn recv_data( &self, (stream_frame, body): (StreamFrame, bytes::Bytes), ) -> Result { let sid = stream_frame.stream_id(); // 对方必须是发送端,才能发送此帧 if sid.role() != self.role { // 对方的sid,看是否跳跃,把跳跃的流给创建好 self.try_accept_sid(sid) .map_err(wrapper_error(stream_frame.frame_type()))?; } else { // 我方的sid,那必须是双向流才能收到对方的数据,否则就是错误 if sid.dir() == Dir::Uni { return Err(QuicError::new( ErrorKind::StreamState, stream_frame.frame_type().into(), format!("local {sid} cannot receive STREAM_FRAME"), )); } } if let Ok(set) = self.input.streams().as_mut() && let Some((incoming, s)) = set.get(&sid) { let (is_into_rcvd, fresh_data) = incoming.recv_data(stream_frame, body.clone())?; if is_into_rcvd { // 数据被接收完的,忽略后续的ResetStreamFrame s.shutdown_receive(); if s.is_terminated() { self.stream_ids.remote.on_end_of_stream(sid); } set.remove(&sid); } return Ok(fresh_data); } Ok(0) } /// Called when a stream control frame which from peer is received by local. /// /// If the correspoding stream is not exist, `accept` the stream first. /// /// Actually calls the corresponding method of the corresponding stream for the corresponding frame type. pub fn recv_stream_control( &self, stream_ctl_frame: StreamCtlFrame, ) -> Result { let mut sync_fresh_data = 0; match stream_ctl_frame { StreamCtlFrame::ResetStream(reset) => { let sid = reset.stream_id(); // 对方必须是发送端,才能发送此帧 if sid.role() != self.role { self.try_accept_sid(sid) .map_err(wrapper_error(reset.frame_type()))?; } else { // 我方创建的流必须是双向流,对方才能发送ResetStream,否则就是错误 if sid.dir() == Dir::Uni { return Err(QuicError::new( ErrorKind::StreamState, reset.frame_type().into(), format!("local {sid} cannot receive RESET_STREAM frame"), )); } } if let Ok(set) = self.input.streams().as_mut() && let Some((incoming, s)) = set.remove(&sid) { sync_fresh_data = incoming.recv_reset(reset)?; s.shutdown_receive(); if s.is_terminated() { self.stream_ids.remote.on_end_of_stream(reset.stream_id()); } } } StreamCtlFrame::StopSending(stop_sending) => { let sid = stop_sending.stream_id(); // 对方必须是接收端,才能发送此帧 if sid.role() != self.role { // 对方创建的单向流,接收端是我方,不可能收到对方的StopSendingFrame if sid.dir() == Dir::Uni { return Err(QuicError::new( ErrorKind::StreamState, stop_sending.frame_type().into(), format!("remote {sid} must not send STOP_SENDING_FRAME"), )); } self.try_accept_sid(sid) .map_err(wrapper_error(stop_sending.frame_type()))?; } if let Some(final_size) = self .output .streams() .as_mut() .ok() .and_then(|set| set.get(&sid)) .and_then(|(outgoing, _s)| outgoing.be_stopped(stop_sending.app_err_code())) { self.ctrl_frames.send_frame([StreamCtlFrame::ResetStream( stop_sending.reset_stream(VarInt::from_u64(final_size).unwrap()), )]); } } StreamCtlFrame::MaxStreamData(max_stream_data) => { let sid = max_stream_data.stream_id(); // 对方必须是接收端,才能发送此帧 if sid.role() != self.role { // 对方创建的单向流,接收端是我方,不可能收到对方的MaxStreamData if sid.dir() == Dir::Uni { return Err(QuicError::new( ErrorKind::StreamState, max_stream_data.frame_type().into(), format!("remote {sid} must not send MAX_STREAM_DATA_FRAME"), )); } self.try_accept_sid(sid) .map_err(wrapper_error(max_stream_data.frame_type()))?; } if let Some((outgoing, _s)) = self .output .streams() .as_ref() .ok() .and_then(|set| set.get(&sid)) { outgoing.update_window(max_stream_data.max_stream_data()); } } StreamCtlFrame::StreamDataBlocked(stream_data_blocked) => { let sid = stream_data_blocked.stream_id(); // 对方必须是发送端,才能发送此帧 if sid.role() != self.role { self.try_accept_sid(sid) .map_err(wrapper_error(stream_data_blocked.frame_type()))?; } else { // 我方创建的,必须是双向流,对方才是发送端,才能发出StreamDataBlocked;否则就是错误 if sid.dir() == Dir::Uni { return Err(QuicError::new( ErrorKind::StreamState, stream_data_blocked.frame_type().into(), format!("local {sid} cannot receive STREAM_DATA_BLOCKED_FRAME"), )); } } // 仅仅起到通知作用?主动更新窗口的,此帧没多大用,或许要进一步放大缓冲区大小;被动更新窗口的,此帧有用 } StreamCtlFrame::MaxStreams(max_streams) => { // 主要更新我方能创建的单双向流 _ = self.stream_ids.local.recv_frame(max_streams); } StreamCtlFrame::StreamsBlocked(streams_blocked) => { // 在某些流并发策略中,收到此帧,可能会更新MaxStreams _ = self.stream_ids.remote.recv_frame(streams_blocked); } } Ok(sync_fresh_data) } /// Called when a connection error occured. /// /// After the method called, read on [`Reader`] or write on [`Writer`] will return an error, /// the resouces will be released. pub fn on_conn_error(&self, error: &Error) { let mut output = match self.output.guard() { Ok(out) => out, Err(_) => return, }; let mut input = match self.input.guard() { Ok(input) => input, Err(_) => return, }; let mut listener = match self.listener.guard() { Ok(listener) => listener, Err(_) => return, }; output.on_conn_error(error); input.on_conn_error(error); listener.on_conn_error(error); } } pub struct StreamFramePackages { data_stream: Arc>, flow_ctrl: ArcSendControler, zero_rtt: bool, } impl Package

for StreamFramePackages where TX: SendFrame + SendFrame + Clone + Send + 'static, P: BufMut + ?Sized, for<'a> (StreamFrame, &'a [Bytes]): Package

, { #[inline] fn dump(&mut self, packet: &mut P) -> Result { self.data_stream .try_load_data_into_once(packet, &self.flow_ctrl, self.zero_rtt)?; Ok(PacketContent::EffectivePayload) } } impl DataStreams where TX: SendFrame + Clone + Send + 'static, { pub(super) fn new( role: Role, local_params: &Parameters, remote_params: &Parameters, ctrl: Box, ctrl_frames: TX, tx_wakers: ArcSendWakers, metrics: Option, ) -> Self { use ParameterId::*; Self { role, stream_ids: StreamIds::new( role, local_params .get::(InitialMaxStreamsBidi) .expect("unreachable: default value will be got if the value unset"), local_params .get::(InitialMaxStreamsUni) .expect("unreachable: default value will be got if the value unset"), remote_params .get::(InitialMaxStreamsBidi) .expect("unreachable: default value will be got if the value unset"), remote_params .get::(InitialMaxStreamsUni) .expect("unreachable: default value will be got if the value unset"), Ext(ctrl_frames.clone()), ctrl, tx_wakers.clone(), ), output: ArcOutput::new(), input: ArcInput::default(), listener: ArcListener::new(), ctrl_frames, tls_fin: AtomicBool::new(false), tx_wakers, initial_max_stream_data_bidi_local: local_params .get::(ParameterId::InitialMaxStreamDataBidiLocal) .expect("unreachable: default value will be got if the value unset"), initial_max_stream_data_bidi_remote: local_params .get::(ParameterId::InitialMaxStreamDataBidiRemote) .expect("unreachable: default value will be got if the value unset"), initial_max_stream_data_uni: local_params .get::(ParameterId::InitialMaxStreamDataUni) .expect("unreachable: default value will be got if the value unset"), metrics, } } pub fn revise_params(&self, zero_rtt_rejected: bool, remote_params: &Parameters) { if let Ok(output) = self.output.guard() { // enter 1rtt state, old state must be 0rtt self.tls_fin.store(true, Release); let opened_bidi = self.stream_ids.local.opened_streams(Dir::Bi); let opened_uni = self.stream_ids.local.opened_streams(Dir::Uni); let opened_bidi_snd_wnd_size = remote_params .get::(ParameterId::InitialMaxStreamDataBidiRemote) .expect("unreachable: default value will be got if the value unset"); let opened_uni_snd_wnd_size = remote_params .get::(ParameterId::InitialMaxStreamDataUni) .expect("unreachable: default value will be got if the value unset"); output.revise_max_stream_data( zero_rtt_rejected, opened_bidi, opened_uni, opened_bidi_snd_wnd_size, opened_uni_snd_wnd_size, ); let max_streams_bidi = remote_params .get::(ParameterId::InitialMaxStreamsBidi) .expect("unreachable: default value will be got if the value unset"); let max_streams_uni = remote_params .get::(ParameterId::InitialMaxStreamsUni) .expect("unreachable: default value will be got if the value unset"); self.stream_ids.local.revise_max_streams( zero_rtt_rejected, max_streams_bidi, max_streams_uni, ); } } #[allow(clippy::type_complexity)] pub(super) fn poll_open_bi_stream( &self, cx: &mut Context<'_>, arc_params: &ArcParameters, ) -> Poll>, Writer>))>, Error>> { let mut output = self.output.guard()?; let mut input = self.input.guard()?; let mut params = arc_params.lock_guard()?; let snd_buf_size = match params.remembered() { Some(remembered) => remembered .get(ParameterId::InitialMaxStreamDataBidiRemote) .expect("unreachable: default value will be got if the value unset"), None => match params.get_remote(ParameterId::InitialMaxStreamDataBidiRemote) { Some(value) => value, None => { ready!(params.poll_ready(cx)); // tail recursion should be optimized by compiler return self.poll_open_bi_stream(cx, arc_params); } }, }; let Some(sid) = ready!(self.stream_ids.local.poll_alloc_sid(cx, Dir::Bi)) else { return Poll::Ready(Ok(None)); }; let arc_sender = self.create_sender(sid, snd_buf_size); let arc_recver = self.create_recver(sid, self.initial_max_stream_data_bidi_local); let io_state = IOState::bidirection(); output.insert(sid, Outgoing::new(arc_sender.clone()), io_state.clone()); input.insert(sid, Incoming::new(arc_recver.clone()), io_state); Poll::Ready(Ok(Some(( sid, (Reader::new(arc_recver), Writer::new(arc_sender)), )))) } #[allow(clippy::type_complexity)] pub(super) fn poll_open_uni_stream( &self, cx: &mut Context<'_>, arc_params: &ArcParameters, ) -> Poll>)>, Error>> { let mut output = self.output.guard()?; let mut params = arc_params.lock_guard()?; let snd_buf_size = match params.remembered() { Some(remembered) => remembered .get(ParameterId::InitialMaxStreamDataUni) .expect("unreachable: default value will be got if the value unset"), None => match params.get_remote(ParameterId::InitialMaxStreamDataBidiRemote) { Some(value) => value, None => { ready!(params.poll_ready(cx)); // tail recursion should be optimized by compiler return self.poll_open_uni_stream(cx, arc_params); } }, }; let Some(sid) = ready!(self.stream_ids.local.poll_alloc_sid(cx, Dir::Uni)) else { return Poll::Ready(Ok(None)); }; let arc_sender = self.create_sender(sid, snd_buf_size); let io_state = IOState::send_only(); output.insert(sid, Outgoing::new(arc_sender.clone()), io_state); Poll::Ready(Ok(Some((sid, Writer::new(arc_sender))))) } pub(super) fn accept_bi<'a>( &'a self, params: &'a ArcParameters, ) -> AcceptBiStream<'a, Ext> { self.listener.accept_bi_stream(params) } pub(super) fn accept_uni(&self) -> AcceptUniStream<'_, Ext> { self.listener.accept_uni_stream() } fn try_accept_sid(&self, sid: StreamId) -> Result<(), ExceedLimitError> { match sid.dir() { Dir::Bi => self.try_accept_bi_sid(sid), Dir::Uni => self.try_accept_uni_sid(sid), } } fn try_accept_bi_sid(&self, sid: StreamId) -> Result<(), ExceedLimitError> { let Ok(mut output) = self.output.guard() else { return Ok(()); }; let Ok(mut input) = self.input.guard() else { return Ok(()); }; let Ok(mut listener) = self.listener.guard() else { return Ok(()); }; let result = self.stream_ids.remote.try_accept_sid(sid)?; match result { AcceptSid::Old => Ok(()), AcceptSid::New(need_create) => { for sid in need_create { let arc_recver = self.create_recver(sid, self.initial_max_stream_data_bidi_remote); // buf_size will be revised by Listener::poll_accept_bi_stream let arc_sender = self.create_sender(sid, 0); let io_state = IOState::bidirection(); input.insert(sid, Incoming::new(arc_recver.clone()), io_state.clone()); output.insert(sid, Outgoing::new(arc_sender.clone()), io_state); listener.push_bi_stream(sid, (arc_recver, arc_sender)); } Ok(()) } } } fn try_accept_uni_sid(&self, sid: StreamId) -> Result<(), ExceedLimitError> { let mut input = match self.input.guard() { Ok(input) => input, Err(_) => return Ok(()), }; let mut listener = match self.listener.guard() { Ok(listener) => listener, Err(_) => return Ok(()), }; let result = self.stream_ids.remote.try_accept_sid(sid)?; match result { AcceptSid::Old => Ok(()), AcceptSid::New(need_create) => { for sid in need_create { let arc_receiver = self.create_recver(sid, self.initial_max_stream_data_uni); let io_state = IOState::receive_only(); input.insert(sid, Incoming::new(arc_receiver.clone()), io_state); listener.push_uni_stream(sid, arc_receiver); } Ok(()) } } } fn create_sender(&self, sid: StreamId, buf_size: u64) -> ArcSender> { ArcSender::new( sid, buf_size, Ext(self.ctrl_frames.clone()), self.tx_wakers.clone(), self.metrics.clone(), ) } fn create_recver(&self, sid: StreamId, buf_size: u64) -> ArcRecver> { ArcRecver::new(sid, buf_size, Ext(self.ctrl_frames.clone())) } } ================================================ FILE: qrecovery/src/streams.rs ================================================ //! The internal implementation of the QUIC stream. //! //! If you want to know how to create a stream, see the `QuicConnection` in another crate for more. //! //! If you want to know how to use a stream, see the [`Reader`] and [`Writer`] for more details. //! //! The structure in this module does not have the ability to actually send and receive frames, or //! sense the loss or confirmation of frames. These functions are implemented by other modules. This //! module provides the ability to generate frames, process frames, handle the frame lost and acked, //! manage the state of all streams. //! //! [`DataStreams`] provides a large number of APIs for other blocks to call to achieve the above functions. //! It corresponds to all streams on the connection. //! //! [`Incoming`] and[`Outgoing`] correspond to the input and output of a stream. They manage the sending and //! receiving state machines and provide APIs for DataStream to use. //! //! [`Incoming`]: crate::recv::Incoming //! [`Outgoing`]: crate::send::Outgoing use std::{ fmt::Debug, future::Future, pin::Pin, sync::Arc, task::{Context, Poll}, }; use bytes::Bytes; use derive_more::Deref; pub use listener::{AcceptBiStream, AcceptUniStream}; use qbase::{ error::Error, frame::{ StreamCtlFrame, StreamFrame, io::{ReceiveFrame, SendFrame}, }, net::tx::ArcSendWakers, param::{ArcParameters, core::Parameters}, role::Role, sid::{ControlStreamsConcurrency, StreamId}, }; use crate::{recv::Reader, send::Writer}; pub mod error; mod io; mod listener; pub mod raw; #[derive(Debug, Clone)] pub struct Ext(T); impl SendFrame for Ext where TX: SendFrame + Clone + Send + 'static, F: Into, { fn send_frame>(&self, iter: I) { self.0.send_frame(iter.into_iter().map(Into::into)); } } /// Shared data streams, one for each connection. /// /// App layer can use it to create and accept bidirectional or unidirectional streams. /// QUIC layer will read frames and data from the streams and send them to peer, /// and also write the frames and data received from peer to this data streams. /// /// The `TX` is the frame sender, it should be able to send the [`StreamCtlFrame`], including: /// - [`StreamCtlFrame::MaxStreamData`] /// - [`StreamCtlFrame::MaxStreams`] /// - [`StreamCtlFrame::StreamDataBlocked`] /// - [`StreamCtlFrame::StreamsBlocked`] /// - [`StreamCtlFrame::StopSending`] /// - [`StreamCtlFrame::ResetStream`] /// /// See [`raw::DataStreams`] for more details. #[derive(Debug, Clone, Deref)] pub struct DataStreams(Arc>) where TX: SendFrame + Clone + Send + 'static; impl DataStreams where TX: SendFrame + Clone + Send + 'static, { /// Creates a new instance of [`DataStreams`]. /// /// The `ctrl_frames` is the frame sender, read [`raw::DataStreams`] for more details. pub fn new( role: Role, local_params: &Parameters, remote_params: &Parameters, ctrl: Box, ctrl_frames: TX, tx_wakers: ArcSendWakers, metrics: Option, ) -> Self { Self(Arc::new(raw::DataStreams::new( role, local_params, remote_params, ctrl, ctrl_frames, tx_wakers, metrics, ))) } /// Create a bidirectional stream, see the method of the same name on `QuicConnection` for more. #[inline] pub fn open_bi<'a>(&'a self, params: &'a ArcParameters) -> OpenBiStream<'a, TX> { OpenBiStream { streams: self, params, } } /// Create a unidirectional stream, see the method of the same name on `QuicConnection` for more. #[inline] pub fn open_uni<'a>(&'a self, params: &'a ArcParameters) -> OpenUniStream<'a, TX> { OpenUniStream { streams: self, params, } } /// accept a bidirectional stream, see the method of the same name on `QuicConnection` for more. #[inline] pub fn accept_bi<'a>(&'a self, params: &'a ArcParameters) -> AcceptBiStream<'a, Ext> { self.0.accept_bi(params) } /// accept a unidirectional stream, see the method of the same name on `QuicConnection` for more. #[inline] pub fn accept_uni(&self) -> AcceptUniStream<'_, Ext> { self.0.accept_uni() } } impl ReceiveFrame for DataStreams where TX: SendFrame + Clone + Send + 'static, { type Output = usize; fn recv_frame(&self, frame: StreamCtlFrame) -> Result { self.0.recv_stream_control(frame).map_err(Error::Quic) } } impl ReceiveFrame<(StreamFrame, Bytes)> for DataStreams where TX: SendFrame + Clone + Send + 'static, { type Output = usize; fn recv_frame(&self, frame: (StreamFrame, Bytes)) -> Result { self.0.recv_data(frame).map_err(Error::Quic) } } /// Future to open a bidirectional stream. /// /// The creation of the stream is limited by the stream id. Once the stream id is available, the /// future will complete immediately. /// /// If a connection error occurred, the future will return an error. /// /// Although this is a bidirectional stream, the peer will not be aware of this stream until we send /// a frame on this stream. pub struct OpenBiStream<'d, TX> where TX: SendFrame + Clone + Send + 'static, { streams: &'d raw::DataStreams, params: &'d ArcParameters, } impl Future for OpenBiStream<'_, TX> where TX: SendFrame + Clone + Send + 'static, { type Output = Result>, Writer>))>, Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.streams.poll_open_bi_stream(cx, self.params) } } /// Future to open a unidirectional stream. /// /// The creation of the stream is limited by the stream id. Once the stream id is available, the /// future will complete immediately. /// /// If a connection error occurred, the future will return an error. /// /// Note that the peer will not be aware of this stream until we send a frame on this stream. pub struct OpenUniStream<'a, TX> where TX: SendFrame + Clone + Send + 'static, { streams: &'a raw::DataStreams, params: &'a ArcParameters, } impl Future for OpenUniStream<'_, TX> where TX: SendFrame + Clone + Send + 'static, { type Output = Result>)>, Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.streams.poll_open_uni_stream(cx, self.params) } } ================================================ FILE: qresolve/Cargo.toml ================================================ [package] name = "qresolve" version = "0.5.0" edition.workspace = true description = "dquic's dns abstractions" readme.workspace = true repository.workspace = true license.workspace = true keywords.workspace = true categories.workspace = true [dependencies] futures = { workspace = true } tokio = { workspace = true } qinterface = { workspace = true } qbase = { workspace = true } ================================================ FILE: qresolve/src/lib.rs ================================================ use std::{ fmt::{Debug, Display}, io, sync::Arc, }; use futures::{FutureExt, TryFutureExt, future::BoxFuture, stream::BoxStream}; pub use qbase::net::{Family, addr::EndpointAddr}; pub type PublishFuture<'a> = BoxFuture<'a, io::Result<()>>; pub trait Publish: Display + Debug { fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a>; } #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Source { Mdns { nic: Arc, family: Family }, Http { server: Arc }, System, Dht, } impl Display for Source { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Source::Mdns { nic, family } => write!(f, "MDNS Resolver({nic} {family})"), Source::Http { server } => write!(f, "HTTP DNS Resolver({server})"), Source::System => write!(f, "System DNS Resolver"), Source::Dht => write!(f, "DHT"), } } } pub type Record = (Source, EndpointAddr); pub type RecordStream = BoxStream<'static, Record>; pub type ResolveResult = io::Result; pub type ResolveFuture<'r> = BoxFuture<'r, ResolveResult>; /// Resolves names into QUIC peer endpoints. /// /// The result is a stream to allow implementations that yield endpoints over time /// (e.g. multi-source resolvers, H3x Dns, Mdns). pub trait Resolve: Send + Sync + Display + Debug { fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l>; } use futures::{StreamExt, stream}; /// Default resolver backed by `tokio::net::lookup_host`. #[derive(Debug, Default, Clone, Copy)] pub struct SystemResolver; impl Display for SystemResolver { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Display::fmt(&Source::System, f) } } impl Resolve for SystemResolver { fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { let source = Source::System; tokio::net::lookup_host(name.to_owned()) .map_ok(|addrs| { stream::iter(addrs.map(move |addr| { let ep = EndpointAddr::direct(addr); (source.clone(), ep) })) .boxed() }) .boxed() } } ================================================ FILE: qtraversal/Cargo.toml ================================================ [package] name = "qtraversal" version.workspace = true edition.workspace = true description = "NAT traversal utilities for QUIC, a part of dquic" readme = "README.md" repository.workspace = true license.workspace = true keywords.workspace = true categories.workspace = true [dependencies] async-trait = { workspace = true } bon = { workspace = true } bytes = { workspace = true } dashmap = { workspace = true } derive_more = { workspace = true } enum_dispatch = { workspace = true } futures = { workspace = true } bitflags = { workspace = true } nom = { workspace = true } qbase = { workspace = true } qresolve = { workspace = true } qevent = { workspace = true } qinterface = { workspace = true, features = ["qudp"] } qudp = { workspace = true } rand = { workspace = true } rustls = { workspace = true } smallvec = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["sync", "rt", "time", "macros"] } tokio-util = { workspace = true, features = ["rt"] } tracing = { workspace = true } netdev = { workspace = true } [dev-dependencies] clap = { workspace = true } rustls = { workspace = true, features = ["ring"] } tokio = { features = ["fs", "rt-multi-thread"], workspace = true } tokio-test = "0.4" tracing = { workspace = true } [dev-dependencies.tracing-subscriber] workspace = true features = ["fmt", "ansi", "env-filter", "time", "tracing-log"] [features] # Enable shorter TTL only for tests (especially integration tests in other crates). test-ttl = [] [[example]] name = "stun_client" [[example]] name = "stun_server" ================================================ FILE: qtraversal/README.md ================================================ # qtraversal `qtraversal` is a NAT traversal library designed for QUIC. It implements sophisticated hole-punching strategies to establish peer-to-peer connections even behind difficult NATs (Symmetric, Restricted, etc.). ## Features - **STUN Client**: Detects NAT type and external IP/Port. - **Hole Punching**: Implements various strategies including: - Direct Connection (Full Cone) - Reverse Punching - Birthday Attack (for Symmetric NATs) - Port prediction ## STUN Configuration The library uses `nat.genmeta.net:20004` as the default STUN server in examples. You can configure your own STUN server when initializing the client. ## Usage See `examples/` for details on how to use the `Client` and `Puncher`. ================================================ FILE: qtraversal/examples/stun_client.rs ================================================ use std::{io::Result, net::SocketAddr, sync::Arc}; use clap::Parser; use qinterface::{ component::location::Locations, io::{IO, ProductIO, handy::DEFAULT_IO_FACTORY}, }; use qtraversal::{ nat::{client::StunClient, router::StunRouter}, route::ReceiveAndDeliverPacket, }; use tracing::info; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] pub struct Arguments { #[arg(long, default_value = "0.0.0.0:12345")] pub bind: SocketAddr, #[arg(long, default_value = "nat.genmeta.net:20004")] pub stun_svr: String, } #[tokio::main(flavor = "current_thread")] async fn main() -> Result<()> { init_logger().unwrap(); let args = Arguments::parse(); let stun_server = tokio::net::lookup_host(&args.stun_svr) .await? .find(|addr| addr.is_ipv4() == args.bind.is_ipv4()) .ok_or_else(|| std::io::Error::other("failed to resolve stun server"))?; let bind_uri = format!("inet://{}", args.bind).into(); let iface: Arc = Arc::from(DEFAULT_IO_FACTORY.bind(bind_uri)); let stun_router = StunRouter::new(); let stun_client = StunClient::new(iface.clone(), stun_router.clone(), stun_server, None); let _task = ReceiveAndDeliverPacket::task() .stun_router(stun_router) .iface_ref(iface.clone()) .spawn(); let outer_addr = stun_client .outer_addr() .await .expect("failed to get outer addr"); info!("Outer addr: {} Agent addr {}", outer_addr, stun_server); // Ok(()) let nat_type = stun_client.nat_type().await; let mut observer = Locations::global().subscribe(); while let Some(event) = observer.recv().await { info!("Location event: {:?}", event); info!("Nat type: {:?}", nat_type); } Ok(()) // unreachable!("Observer never return None") } fn init_logger() -> std::io::Result<()> { tracing_subscriber::fmt() .with_max_level(tracing::Level::DEBUG) .init(); Ok(()) } ================================================ FILE: qtraversal/examples/stun_server.rs ================================================ use std::{io::Result, net::SocketAddr, sync::Arc}; use clap::Parser; use qinterface::io::{IO, ProductIO, handy::DEFAULT_IO_FACTORY}; use qtraversal::{ nat::{ router::StunRouter, server::{StunServer, StunServerConfig}, }, route::{Forwarder, ReceiveAndDeliverPacket}, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] pub struct Arguments { #[arg(long, default_value = "127.0.0.1:20002")] pub bind_addr1: SocketAddr, #[arg(long, default_value = "127.0.0.1:4433")] pub bind_addr2: SocketAddr, #[arg(long, default_value = "127.0.0.1:20002")] pub change_addr: SocketAddr, #[arg(long, default_value = "127.0.0.1:20002")] pub outer_addr1: SocketAddr, #[arg(long, default_value = "127.0.0.1:20002")] pub outer_addr2: SocketAddr, } #[tokio::main(flavor = "current_thread")] async fn main() -> Result<()> { let args = Arguments::parse(); init_logger(&args)?; let factory: Arc = Arc::new(DEFAULT_IO_FACTORY); let bind_uri1 = format!("inet://{}", args.bind_addr1).into(); let iface1: Arc = Arc::from(factory.bind(bind_uri1)); let stun_router1 = StunRouter::new(); let _iface1_recv_task = ReceiveAndDeliverPacket::task() .stun_router(stun_router1.clone()) .forwarder(Forwarder::Server { outer_addr: args.outer_addr1, }) .iface_ref(iface1.clone()) .spawn(); let bind_uri2 = format!("inet://{}", args.bind_addr2).into(); let iface2: Arc = Arc::from(factory.bind(bind_uri2)); let stun_router2 = StunRouter::new(); let _iface2_recv_task = ReceiveAndDeliverPacket::task() .stun_router(stun_router2.clone()) .forwarder(Forwarder::Server { outer_addr: args.outer_addr2, }) .iface_ref(iface2.clone()) .spawn(); let server1 = StunServer::new( iface1, stun_router1, StunServerConfig::builder() .change_port(args.bind_addr2.port()) .change_address(args.change_addr) .init(), ); let server2 = StunServer::new( iface2, stun_router2, StunServerConfig::builder() .change_port(args.bind_addr1.port()) .change_address(args.change_addr) .init(), ); _ = tokio::try_join!(server1.spawn(), server2.spawn())?; Ok(()) } fn init_logger(args: &Arguments) -> std::io::Result<()> { let log_name = args.bind_addr1.ip().to_string() + "-stun.log"; let file = std::fs::OpenOptions::new() .create(true) .write(true) .truncate(true) .open(log_name)?; let _ = tracing_subscriber::registry() .with( tracing_subscriber::fmt::layer() .with_target(true) .with_ansi(false) .with_writer(file), ) .try_init(); Ok(()) } ================================================ FILE: qtraversal/src/addr.rs ================================================ use std::{ collections::{HashMap, HashSet, hash_map::Entry}, net::SocketAddr, ops::Deref, }; use futures::io; use qbase::{ frame::{AddAddressFrame, RemoveAddressFrame}, net::{NatType, addr::EndpointAddr}, }; use qinterface::bind_uri::BindUri; use qresolve::Source; #[derive(Default)] pub struct AddressBook { local: HashMap, remote: HashMap, local_endpoint: HashSet<(BindUri, EndpointAddr)>, /// Remote endpoints with their DNS [`Source`] so the puncher can enforce /// source-specific constraints (e.g. mDNS endpoints are tied to a NIC). remote_endpoint: HashMap, largest_seq_num: u32, } impl AddressBook { pub(crate) fn add_local_address( &mut self, bind: BindUri, addr: SocketAddr, tire: u32, nat_type: NatType, ) -> io::Result { if self .local .values() .any(|(_local, frame)| *frame.deref() == addr) { tracing::debug!(target: "quic", %addr, "Duplicate local address"); return Err(io::Error::other("Duplicate local address")); } let frame = AddAddressFrame::new(self.largest_seq_num, addr, tire, nat_type); self.local.insert(self.largest_seq_num, (bind, frame)); self.largest_seq_num += 1; Ok(frame) } pub(crate) fn add_local_endpoint( &mut self, bind: BindUri, addr: EndpointAddr, ) -> io::Result<()> { if !self.local_endpoint.insert((bind, addr)) { return Err(io::Error::other("Duplicate local endpoint")); } Ok(()) } pub(crate) fn add_peer_endpoint( &mut self, endpoint: EndpointAddr, source: Source, ) -> io::Result<()> { match self.remote_endpoint.entry(endpoint) { Entry::Occupied(_) => return Err(io::Error::other("Duplicate remote endpoint")), Entry::Vacant(e) => { e.insert(source); } } Ok(()) } pub(crate) fn remote_endpoint(&self) -> &HashMap { &self.remote_endpoint } pub(crate) fn local_endpoint(&self) -> &HashSet<(BindUri, EndpointAddr)> { &self.local_endpoint } pub(crate) fn remove_local_address( &mut self, addr: SocketAddr, ) -> io::Result { let Some(seq_num) = self .local .iter() .find(|(_, (_local, frame))| *frame.deref() == addr) .map(|(key, _)| *key) else { tracing::debug!(target: "quic", %addr, "No matching local address to remove"); return Err(io::Error::other("No matching local address")); }; self.local.remove(&seq_num).map(|(_local, _frame)| seq_num); Ok(RemoveAddressFrame { seq_num: seq_num.into(), }) } pub(crate) fn get_local_address(&self, seq_num: &u32) -> Option<(BindUri, AddAddressFrame)> { self.local.get(seq_num).cloned() } pub(crate) fn add_remote_address(&mut self, remote: AddAddressFrame) -> io::Result<()> { match self.remote.entry(remote.seq_num()) { Entry::Occupied(_) => { tracing::debug!(target: "quic", remote_seq_num = remote.seq_num(), "Duplicate remote address"); return Err(io::Error::other("Duplicate remote address")); } Entry::Vacant(entry) => { entry.insert(remote); } } Ok(()) } pub(crate) fn remove_remote_address(&mut self, seq_num: u32) -> Option { self.remote.remove(&seq_num) } pub(crate) fn pick_local_address( &self, remote: &AddAddressFrame, ) -> io::Result<(BindUri, AddAddressFrame)> { let mut addrs: Vec<_> = self .local .iter() .filter(|(_seq, (_local, frame))| { frame.tire() == remote.tire() && frame.is_ipv4() == remote.is_ipv4() }) .map(|(_, addr)| addr.clone()) .collect(); if addrs.is_empty() { tracing::debug!(target: "quic", ?remote, "No matching local address for remote address"); return Err(io::Error::other("No matching local address")); } const NAT_PRIORITY: [NatType; 5] = [ NatType::FullCone, NatType::RestrictedCone, NatType::RestrictedPort, NatType::Dynamic, NatType::Symmetric, ]; addrs.sort_by_key(|(_addr, frame)| { NAT_PRIORITY .iter() .position(|&x| x == frame.nat_type()) .unwrap_or(usize::MAX) }); let (bind, frame) = addrs .iter() .find(|(_, frame)| *frame != *remote) .ok_or_else(|| io::Error::other("No matching local address"))?; Ok((bind.clone(), *frame)) } } ================================================ FILE: qtraversal/src/future.rs ================================================ use std::{ mem, ops::{Deref, DerefMut}, sync::{Mutex, MutexGuard}, task::{Context, Poll}, }; use qbase::util::WakerVec; #[derive(Debug)] enum FutureState { Demand(WakerVec), Ready(T), } impl Default for FutureState { fn default() -> Self { Self::Demand(Default::default()) } } #[derive(Debug)] pub struct ReadyFuture<'f, T>(MutexGuard<'f, FutureState>); impl Deref for ReadyFuture<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { match self.0.deref() { FutureState::Demand(..) => unreachable!(), FutureState::Ready(item) => item, } } } /// A value which will be resolved in the future. /// /// Be different with the [`futures::Future`], this is a value not a computation. /// /// The [`Future`] can only been assigned once, and the value can be get multiple times.(so the T /// must be [`Clone`]). If the assign is called multiple times, the old value will not be replaced, /// and the new value will be returned as [`Err`]. /// /// The task can attempt to get the value synchronously by calling [`try_get`], or asynchronously by /// calling [`get`]. There are also a [`poll_get`] method for the task to poll the value. Read their /// document for more details about the behavior. /// /// # Examples /// ```rust, ignore /// # async fn some_work() -> &'static str { "Hello World" } /// # async fn test() { /// use std::sync::Arc; /// /// let fut = Arc::new(Future::new()); /// let t1 = tokio::spawn({ /// let fut = fut.clone(); /// async move { /// assert_eq!(fut.get().await, "Hello world"); /// // the value can be get multiple times /// assert_eq!(fut.get().await, "Hello world"); /// assert_eq!(fut.get().await, "Hello world"); /// } /// }); /// /// let t2 = tokio::spawn({ /// let fut = fut.clone(); /// async move { /// // do some work to get the value /// let value = some_work().await; /// fut.assign(value); /// /// // the new value will not replace the old value /// assert_eq!(fut.assign("Hi World"), Err("Hi World")); /// } /// }); /// /// _ = tokio::join!(t1, t2); /// # } /// /// ``` /// /// /// [`get`]: Future::get /// [`try_get`]: Future::try_get /// [`poll_get`]: Future::poll_get #[derive(Debug)] pub struct Future { state: Mutex>, } impl Future { /// Create a new empty [`Future`]. #[inline] #[allow(dead_code)] pub fn new() -> Self { Default::default() } /// Create a new [`Future`] with the given value in it. /// /// Once that the future can only been assigned once, its not a good idea to use this method, /// why dont you use the value directly or share the value with the [`Arc`]? /// /// [`Arc`]: std::sync::Arc #[inline] #[allow(dead_code)] pub fn with(item: T) -> Self { Self { state: Mutex::new(FutureState::Ready(item)), } } fn state(&'_ self) -> MutexGuard<'_, FutureState> { self.state.lock().unwrap() } /// Assign the value to the [`Future`]. /// /// Return the old value as [`Some`] if the future is already assigned. #[inline] pub fn assign(&self, item: T) -> Option { match std::mem::replace(self.state().deref_mut(), FutureState::Ready(item)) { FutureState::Demand(mut wakers) => { mem::take(&mut wakers).wake_all(); None } FutureState::Ready(old) => Some(old), } } /// Poll the value of the [`Future`]. /// /// If the value is ready, the value will be returned as [`Poll::Ready`]. If the value is not /// ready, this method will return [`Poll::Pending`] and the waker will be stored. #[inline] pub fn poll_get(&'_ self, cx: &mut Context<'_>) -> Poll> { let mut state = self.state(); match state.deref_mut() { FutureState::Demand(wakers) => { wakers.register(cx.waker()); Poll::Pending } FutureState::Ready(..) => Poll::Ready(ReadyFuture(state)), } } /// Try to get the value of the [`Future`]. /// /// If the value is ready, the value will be returned as [`Some`]. If the value is not ready, this /// method will return [`None`]. pub fn try_get(&'_ self) -> Option> { let state = self.state(); match state.deref() { FutureState::Demand(..) => None, FutureState::Ready(_) => Some(ReadyFuture(state)), } } /// Get the value of the [`Future`] asynchronously. #[inline] #[allow(unused)] pub async fn get(&'_ self) -> ReadyFuture<'_, T> { std::future::poll_fn(|cx| self.poll_get(cx)).await } pub fn clear(&self) { let mut state = self.state(); *state = match state.deref_mut() { FutureState::Demand(wakers) => FutureState::Demand(mem::take(wakers)), FutureState::Ready(_) => FutureState::Demand(WakerVec::default()), }; } } impl Default for Future { fn default() -> Self { Self { state: Mutex::new(Default::default()), } } } #[cfg(test)] mod tests { use std::{sync::Arc, time::Duration}; use futures::future::join_all; use tokio::{sync::Notify, time::timeout}; use super::*; #[test] fn new() { let future = Future::new(); assert_eq!(future.try_get().as_deref(), None); assert_eq!(future.assign("Hello world"), None); assert_eq!(future.try_get().as_deref(), Some(&"Hello world")); let future = Future::with("Hello World"); assert_eq!(future.try_get().as_deref(), Some(&"Hello World")); assert_eq!(future.assign("Hi"), Some("Hello World")); } #[tokio::test] async fn wait() { let future = Arc::new(Future::<&str>::new()); let write = Arc::new(Notify::new()); let task = tokio::spawn({ let future = future.clone(); let write = write.clone(); async move { core::future::poll_fn(|cx| { assert!(matches!(future.poll_get(cx), Poll::Pending)); write.notify_one(); Poll::Ready(()) }) .await; assert_eq!(*future.get().await, "Hello world"); } }); write.notified().await; assert_eq!(future.assign("Hello world"), None); task.await.unwrap(); } #[tokio::test] async fn change() { let future = Arc::new(Future::<&str>::new()); let write = Arc::new(Notify::new()); let task = tokio::spawn({ let future = future.clone(); let write = write.clone(); async move { core::future::poll_fn(|cx| { assert!(matches!(future.poll_get(cx), Poll::Pending)); write.notify_one(); Poll::Ready(()) }) .await; assert_eq!(*future.get().await, "Hello world"); assert_eq!(*future.get().await, "Hello world"); write.notify_one(); } }); write.notified().await; assert_eq!(future.try_get().as_deref(), None); assert_eq!(future.assign("Hello world"), None); write.notified().await; assert_eq!(future.assign("Changed"), Some("Hello world")); task.await.unwrap(); } #[tokio::test] async fn multiple_wait() { let future = Arc::new(Future::<&str>::new()); let timeout_task = tokio::spawn({ let future = future.clone(); async move { let _ = timeout(Duration::from_millis(100), future.get()).await; let _ = future.assign("Hello world"); } }); let task = tokio::spawn({ let future = future.clone(); async move { assert_eq!(*future.get().await, "Hello world"); } }); join_all([task, timeout_task]).await; } #[tokio::test] async fn clear() { let future = Arc::new(Future::<&str>::new()); future.assign("Hello world"); assert_eq!(*future.get().await, "Hello world"); future.clear(); assert_eq!(future.try_get().as_deref(), None); let task = tokio::spawn({ let future = future.clone(); async move { assert_eq!(*future.get().await, "New Hello world"); } }); future.assign("New Hello world"); task.await.unwrap(); } } ================================================ FILE: qtraversal/src/lib.rs ================================================ use qbase::net::addr::EndpointAddr; pub mod addr; mod future; pub mod nat; pub mod packet; pub mod punch; pub mod route; pub type PathWay = qbase::net::route::Pathway; ================================================ FILE: qtraversal/src/nat/client.rs ================================================ use std::{ collections::HashMap, fmt, io::{self}, net::SocketAddr, ops::{ControlFlow, Deref}, pin::pin, sync::{ Arc, Mutex, MutexGuard, atomic::{AtomicBool, AtomicU8, Ordering::SeqCst}, }, task::{Context, Poll, ready}, time::Duration, }; use futures::{FutureExt, StreamExt, stream::FuturesUnordered}; use qbase::net::{Family, addr::EndpointAddr}; pub use qbase::net::{NatType, NetFeature}; use qinterface::{ Interface, RebindedError, WeakInterface, component::{ Component, location::{IfaceLocations, LocationsComponent}, }, io::{IO, RefIO}, }; use qresolve::Resolve; use thiserror::Error; use tokio::{sync::Notify, task::JoinSet}; use tokio_util::task::AbortOnDropHandle; use tracing::Instrument; use super::{router::StunRouter, tx::Transaction}; use crate::{ future::Future, nat::{iface::StunIO, msg::Request, router::StunRouterComponent}, }; #[derive(Error, Clone)] #[error(transparent)] pub struct ArcIoError(Arc); impl fmt::Debug for ArcIoError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.as_ref().fmt(f) } } impl From for ArcIoError { fn from(source: io::Error) -> Self { Self(source.into()) } } impl From for io::Error { fn from(source: ArcIoError) -> io::Error { io::Error::other(source) } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum ClientState { Active = 0, Inactive = 1, Closing = 2, } #[derive(Debug, Clone)] struct ArcClientState { state: Arc, observers: [Arc; 3], } impl ArcClientState { pub fn new() -> Self { Self { state: Arc::new(AtomicU8::new(ClientState::Active as u8)), observers: <[_; 3]>::default(), } } pub fn try_update(&self, old_state: ClientState, new_state: ClientState) -> bool { match self .state .compare_exchange(old_state as u8, new_state as u8, SeqCst, SeqCst) { Ok(_old) => { self.observers[new_state as usize].notify_waiters(); true } Err(_current) => false, } } pub fn get(&self) -> ClientState { match self.state.load(SeqCst) { 0 => ClientState::Active, 1 => ClientState::Inactive, 2 => ClientState::Closing, _ => unreachable!(), } } pub fn set(&self, new_state: ClientState) -> ClientState { let old_state = self.state.swap(new_state as u8, SeqCst); if old_state != new_state as u8 { self.observers[new_state as usize].notify_waiters(); } match old_state { 0 => ClientState::Active, 1 => ClientState::Inactive, 2 => ClientState::Closing, _ => unreachable!(), } } pub fn wait(&self, expect: ClientState) -> impl futures::Future + use<> { let notify = self.observers[expect as usize].clone(); let state = self.state.clone(); async move { let mut notified = pin!(notify.notified()); loop { notified.as_mut().enable(); if state.load(SeqCst) == expect as u8 { return; } notified.as_mut().await; notified.set(notify.notified()); } } } } #[derive(Debug, Clone)] pub struct StunClient { #[allow(clippy::type_complexity)] outer_addr: Arc>>, nat_type: Arc>>, ref_iface: I, // 可能被复制进keep_alive_task stun_router: StunRouter, stun_agent: SocketAddr, locations: Option>, state: ArcClientState, tasks: Arc>>, } pub type ClientLocationData = Result; impl StunClient { pub fn new( ref_iface: I, stun_router: StunRouter, stun_agent: SocketAddr, locations: Option>, ) -> Self { let client = Self { nat_type: Default::default(), outer_addr: Default::default(), stun_agent, ref_iface, stun_router, locations, state: ArcClientState::new(), tasks: Arc::new(Mutex::new(JoinSet::new())), }; tracing::debug!(target: "stun", %stun_agent, "created new STUN client"); { let mut tasks = client.lock_tasks(); tasks.spawn(client.keep_alive_task()); if !client.ref_iface.iface().bind_uri().is_temporary() { tasks.spawn(client.nat_detect_task()); } } client } fn lock_tasks(&self) -> MutexGuard<'_, JoinSet<()>> { self.tasks.lock().expect("StunClient tasks lock poisoned") } fn keep_alive_task(&self) -> impl futures::Future + use { let outer_addr = self.outer_addr.clone(); let stun_agent = self.stun_agent; let stun_router = self.stun_router.clone(); tracing::debug!(target: "stun", %stun_agent, "starting STUN client keep alive task"); let ref_iface = self.ref_iface.clone(); let bind_uri = ref_iface.iface().bind_uri(); let locations = self.locations.clone(); let client_state = self.state.clone(); let keep_alive_task = async move { let log_detect_result = |detect_result: &io::Result| match &detect_result { Ok(new_outer_addr) => match outer_addr.try_get().as_deref().cloned() { Some(Ok(old_outer)) if old_outer == *new_outer_addr => { tracing::trace!(target: "stun", %new_outer_addr, "Keep alive, outer addr unchanged"); } Some(old_state) => { tracing::debug!(target: "stun", ?old_state, %new_outer_addr, "keep alive, outer addr changed"); } None => { tracing::debug!(target: "stun", %new_outer_addr, "detected outer addr"); } }, Err(error) => { tracing::trace!(target: "stun", ?error, "Detect outer addr failed"); } }; tracing::trace!(target: "stun", "Starting keep alive task"); loop { let detect_result = detect_outer_addr( ref_iface.clone(), stun_router.clone(), stun_agent, 3, Duration::from_millis(300), ) .await; match &detect_result { Ok(_) => client_state.try_update(ClientState::Inactive, ClientState::Active), Err(_) => client_state.try_update(ClientState::Active, ClientState::Inactive), }; log_detect_result(&detect_result); let timeout = match detect_result { Ok(_) => Duration::from_secs(30), Err(_) => Duration::from_secs(1), }; let detect_result = detect_result.map_err(ArcIoError::from); if !bind_uri.is_temporary() && let Some(locations) = locations.as_ref() { locations.r#for(&ref_iface, |locations, bind_uri| { let data = detect_result .clone() .map(|outer| EndpointAddr::with_agent(stun_agent, outer)); locations.upsert::(bind_uri, Arc::new(data)); }); } outer_addr.assign(detect_result); tokio::time::sleep(timeout).await; } }; let bind_uri = self.ref_iface.iface().bind_uri(); keep_alive_task.instrument(tracing::debug_span!( target: "stun", "keep_alive_task", %bind_uri, %stun_agent, )) } pub fn poll_outer_addr(&self, cx: &mut Context) -> Poll> { if self.state.get() == ClientState::Closing { return Poll::Ready(Err(RebindedError.into())); } self.outer_addr .poll_get(cx) .map(|result| result.clone().map_err(io::Error::from)) } pub async fn outer_addr(&self) -> io::Result { core::future::poll_fn(|cx| self.poll_outer_addr(cx)).await } pub fn agent_addr(&self) -> SocketAddr { self.stun_agent } pub fn get_outer_addr(&self) -> Option> { if self.state.get() == ClientState::Closing { return Some(Err(RebindedError.into())); } self.outer_addr .try_get() .map(|result| result.clone().map_err(io::Error::from)) } fn nat_detect_task(&self) -> impl futures::Future + use { let nat_type = self.nat_type.clone(); let ref_iface = self.ref_iface.clone(); let stun_router = self.stun_router.clone(); let stun_agent = self.stun_agent; let bind_uri = ref_iface.iface().bind_uri(); // Note: 原来的逻辑是 nat 探测会新建 iface,但是有的服务器只能开放指定端口,所以还是用监听的端口进行探测 // 又因为Dynamic 总是会新建 iface 进行打洞,所以这里污染了影响不会很大 let task = async move { tracing::debug!(target: "stun", "starting NAT type detection"); let timeout = Duration::from_millis(100); _ = nat_type.assign( detect_nat_type(ref_iface, stun_router, stun_agent, 30, timeout) .await .map_err(ArcIoError::from), ); }; task.instrument(tracing::debug_span!( target: "stun", "nat_type_task", %bind_uri, %stun_agent, )) } pub fn poll_nat_type(&self, cx: &mut Context) -> Poll> { if self.state.get() == ClientState::Closing { return Poll::Ready(Err(RebindedError.into())); } self.nat_type .poll_get(cx) .map(|result| result.clone().map_err(io::Error::from)) } pub async fn nat_type(&self) -> io::Result { core::future::poll_fn(|cx| self.poll_nat_type(cx)).await } pub fn get_nat_type(&self) -> Option> { if self.state.get() == ClientState::Closing { return Some(Err(RebindedError.into())); } self.nat_type .try_get() .map(|result| result.clone().map_err(io::Error::from)) } // fn restart(&mut self) -> io::Result<()> { // self.stun_router.clear(); // *self = RunningClient::new( // self.ref_iface.clone(), // self.stun_router.clone(), // self.stun_agent, // ); // Ok(()) // } pub fn poll_close(&self, cx: &mut Context) -> Poll<()> { if self.state.set(ClientState::Closing) == ClientState::Closing { return Poll::Ready(()); } self.lock_tasks().abort_all(); while ready!(self.lock_tasks().poll_join_next(cx)).is_some() {} self.nat_type.clear(); self.outer_addr.clear(); Poll::Ready(()) } } #[derive(Debug)] pub struct StunClientComponent { client: Mutex>, } impl StunClientComponent { pub fn new(client: StunClient) -> Self { Self { client: Mutex::new(client), } } fn lock_client(&self) -> MutexGuard<'_, StunClient> { self.client.lock().expect("StunClient lock poisoned") } pub fn client(&self) -> StunClient { self.lock_client().clone() } } impl Component for StunClientComponent { fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> { self.lock_client().poll_close(cx) } fn reinit(&self, iface: &Interface) { let mut client = self.lock_client(); if client.ref_iface.same_io(&iface.downgrade()) { return; } let Ok(locations) = iface.with_component(|loc: &LocationsComponent| { loc.reinit(iface); loc.clone() }) else { return; }; let new_client = StunClient::new( iface.downgrade(), client.stun_router.clone(), client.stun_agent, locations, ); *client = new_client; } } type StunClientsMap = HashMap>; #[derive(Debug)] struct StunClientsInner { ref_iface: I, clients: Arc>>, resolver: Arc, server: Arc, task: Option>, } pub const DEFAULT_STUN_SERVER: &str = "nat.genmeta.net:20004"; impl StunClientsInner { pub const MIN_AGENTS: usize = 3; pub fn new( ref_iface: I, router: StunRouter, resolver: Arc, server: Arc, agents: impl IntoIterator, locations: Option>, ) -> Self { let new_stun_client = { let ref_iface = ref_iface.clone(); move |agent_addr: SocketAddr| { let local_addr = ref_iface.iface().local_addr().ok()?; if local_addr.is_ipv4() != agent_addr.is_ipv4() { return None; } let stun_router = router.clone(); Some(StunClient::new( ref_iface.clone(), stun_router, agent_addr, locations.clone(), )) } }; let clients: Arc>> = Arc::new(Mutex::new( agents .into_iter() .filter_map(|agent| { tracing::trace!(target: "stun", %agent, "Initializing STUN client for agent"); new_stun_client(agent).map(|client| (agent, client)) }) .collect(), )); let task = AbortOnDropHandle::new(tokio::spawn({ let clients = clients.clone(); let resolver = resolver.clone(); let server = server.clone(); let ref_iface = ref_iface.clone(); async move { let lock_clients = || clients.lock().expect("StunClients mutex poisoned"); let should_lookup_agents = |clients: &StunClientsMap| match clients .values() .try_fold((0, 0), |(active, inactive), client| { match client.state.get() { ClientState::Active => ControlFlow::Continue((active + 1, inactive)), ClientState::Inactive => ControlFlow::Continue((active, inactive + 1)), ClientState::Closing => ControlFlow::Break(()), } }) { ControlFlow::Continue((active, _inactive)) => active < Self::MIN_AGENTS, ControlFlow::Break(_) => false, }; let wait_too_few_agents = |clients: &StunClientsMap| { let clients_len = clients.len(); debug_assert!(clients_len >= Self::MIN_AGENTS); let mut stream = clients .iter() .map(|(.., client)| client.state.wait(ClientState::Inactive)) .collect::>() .skip(clients_len.saturating_sub(Self::MIN_AGENTS)); async move { _ = stream.next().await } }; loop { while !{ should_lookup_agents(&lock_clients()) } { { wait_too_few_agents(&lock_clients()) }.await; } // 保证两次 lookup 至少间隔 10s,同时限时 10s 防止 resolver 卡住 let deadline = tokio::time::Instant::now() + Duration::from_secs(10); _ = tokio::time::timeout_at(deadline, async { let Ok(stream) = resolver.lookup(server.as_ref()).await else { return }; let is_ipv4 = ref_iface.iface().bind_uri().family() == Family::V4; let mut stream = std::pin::pin!(stream); while let Some((_, addr)) = stream.next().await { let EndpointAddr::Direct { addr } = addr else { continue }; if addr.is_ipv4() != is_ipv4 { continue } let done = { let mut clients = lock_clients(); if clients.contains_key(&addr) { continue } if let Some(client) = new_stun_client(addr) { tracing::debug!(target: "stun", %addr, "discovered new STUN agent"); clients.insert(addr, client); !should_lookup_agents(&clients) } else { false } }; if done { break } } }).await; tokio::time::sleep_until(deadline).await; } } })); Self { ref_iface, clients, resolver, server, task: Some(task), } } fn lock_clients(&self) -> MutexGuard<'_, StunClientsMap> { self.clients .lock() .expect("StunClientsComponentInner lock poisoned") } pub fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<()> { if let Some(task) = self.task.as_mut() { task.abort(); _ = ready!(task.poll_unpin(cx)); self.task.take(); } for (.., client) in self.lock_clients().iter() { ready!(client.poll_close(cx)) } Poll::Ready(()) } } #[derive(Debug, Clone)] pub struct StunClients { clients: Arc>>, } impl StunClients { pub fn new( ref_iface: I, router: StunRouter, resolver: Arc, server: impl Into>, agents: impl IntoIterator, locations: Option>, ) -> Self { Self { clients: Arc::new(Mutex::new(StunClientsInner::new( ref_iface, router, resolver, server.into(), agents, locations, ))), } } fn lock_clients(&self) -> MutexGuard<'_, StunClientsInner> { self.clients .lock() .expect("StunClientsComponent lock poisoned") } pub fn with_clients(&self, f: impl FnOnce(&StunClientsMap) -> T) -> T { f(self.lock_clients().lock_clients().deref()) } pub fn poll_close(&self, cx: &mut Context<'_>) -> Poll<()> { self.lock_clients().poll_close(cx) } } pub type StunClientsComponent = StunClients; impl Component for StunClientsComponent { fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> { self.lock_clients().poll_close(cx) } fn reinit(&self, iface: &Interface) { let mut clients = self.lock_clients(); if clients.ref_iface.same_io(&iface.downgrade()) { return; } _ = iface.with_components(|components| { let Some(router) = components.with(|router: &StunRouterComponent| { router.reinit(iface); router.router() }) else { return; }; let locations = components.with(|locations: &LocationsComponent| { locations.reinit(iface); locations.clone() }); let new_clinets = StunClientsInner::new( iface.downgrade(), router, clients.resolver.clone(), clients.server.clone(), clients.lock_clients().keys().copied(), locations, ); *clients = new_clinets; }); } } fn no_response_error() -> io::Error { io::Error::new(io::ErrorKind::TimedOut, "No response from STUN server") } async fn detect_outer_addr( ref_iface: I, stun_router: StunRouter, stun_agent: SocketAddr, retry_times: u8, timeout: Duration, ) -> io::Result { let request = Request::default(); let response = Transaction::begin(ref_iface, stun_router, retry_times, timeout) .send_request(request, stun_agent) .await? .ok_or_else(no_response_error)?; response.map_addr() } pub static VISUALIZE_NAT_DETECTION: AtomicBool = AtomicBool::new(false); macro_rules! visualize_nat_detection { ($($tt:tt)*) => {{ if VISUALIZE_NAT_DETECTION.load(std::sync::atomic::Ordering::Relaxed) { tracing::info!($($tt)*); } else { tracing::trace!(target: "stun", $($tt)*); } }}; } pub const RESTRICTED_RETRY_TIMES: u8 = 3; async fn detect_nat_type( ref_iface: I, stun_router: StunRouter, stun_agent: SocketAddr, retry_times: u8, timeout: Duration, ) -> io::Result { let local_addr = ref_iface.iface().local_addr()?; visualize_nat_detection!("Starting NAT detection with local address: {local_addr}"); let stun_agent1 = stun_agent; visualize_nat_detection!("Access Test: probing server {stun_agent1}"); let request = Request::default(); let response = Transaction::begin(ref_iface.clone(), stun_router.clone(), retry_times, timeout) .send_request(request, stun_agent1) .await?; let Some(response) = response else { visualize_nat_detection!("Result: No response after {retry_times} attempts"); visualize_nat_detection!( "Conclusion: The network feature is {:?}, NAT Type is {:?}\n", NetFeature::Blocked, NatType::Blocked ); return Ok(NatType::Blocked); }; let mut net_features = NetFeature::empty(); let mapped_addr1 = response.map_addr()?; let stun_agent2 = response.changed_addr()?; visualize_nat_detection!("Result: Received from {stun_agent1}, external addr: {mapped_addr1}"); if mapped_addr1 == local_addr { // Public IP visualize_nat_detection!( "Conclusion: Address {local_addr} has public IP, Proceeding to filtering behavior test.\n" ); visualize_nat_detection!( "Filtering Test: probing server {stun_agent2}. Request server to respond from a changed IP:port", ); net_features |= NetFeature::Public; let request = Request::change_ip_and_port(); let response = Transaction::begin(ref_iface.clone(), stun_router.clone(), retry_times, timeout) .send_request(request, stun_agent2) .await?; if let Some(response) = response { let mapped_addr2 = response.map_addr()?; visualize_nat_detection!( "Result: received from {}, external addr: {mapped_addr2}", response.source_addr()? ); visualize_nat_detection!("Conclusion: Destination IP independent filtering\n"); } else { net_features |= NetFeature::Restricted; visualize_nat_detection!("Result: No response after {retry_times} attempts"); visualize_nat_detection!("Conclusion: Filters packets based on destination IP\n"); } visualize_nat_detection!( "Filtering Test: probing server {stun_agent2}. Request server to respond from a changed port", ); let request = Request::change_port(); let response = Transaction::begin(ref_iface.clone(), stun_router.clone(), retry_times, timeout) .send_request(request, stun_agent2) .await?; if let Some(response) = response { let mapped_addr2 = response.map_addr()?; visualize_nat_detection!( "Result: received from {}, external addr: {mapped_addr2}", response.source_addr()? ); visualize_nat_detection!("Conclusion: Destination port independent filtering\n"); } else { net_features |= NetFeature::PortRestricted; visualize_nat_detection!("Result: No response after {retry_times} attempts"); visualize_nat_detection!("Conclusion: Filters packets based on destination port\n"); } let nat_type = NatType::from(net_features); visualize_nat_detection!( "NAT detection completed. Network features: {:?}, NAT Type: {:?}", net_features, nat_type ); Ok(nat_type) } else { // Private IP visualize_nat_detection!("Conclusion: Address {local_addr} has private IP.\n"); visualize_nat_detection!("Mapping Test1: probing server {stun_agent2}"); let request = Request::default(); let response = Transaction::begin(ref_iface.clone(), stun_router.clone(), retry_times, timeout) .send_request(request, stun_agent2) .await? .ok_or_else(no_response_error)?; let stun_agent3 = response.changed_addr()?; let mapped_addr2 = response.map_addr()?; if mapped_addr1 != mapped_addr2 { net_features |= NetFeature::Symmetric; visualize_nat_detection!( "Result: Received from {stun_agent2}, external addr: {mapped_addr2}" ); visualize_nat_detection!( "Conclusion: The mapped address is different and destination-dependent.\n" ); // 判断规律 visualize_nat_detection!("Mapping Test2: probing server {stun_agent3}"); let request = Request::default(); let response = Transaction::begin(ref_iface.clone(), stun_router.clone(), retry_times, timeout) .send_request(request, stun_agent3) .await?; let Some(response) = response else { visualize_nat_detection!("Result: No response after {retry_times} attempts"); visualize_nat_detection!( "Conclusion: Unable to determine port mapping behavior due to lack of response from third server.\n" ); return Ok(NatType::from(net_features)); }; let mapped_addr3 = response.map_addr()?; let step1 = mapped_addr2.port() as i32 - mapped_addr1.port() as i32; let step2 = mapped_addr3.port() as i32 - mapped_addr2.port() as i32; visualize_nat_detection!( "Result: Received from {stun_agent3}, external addr: {mapped_addr3}" ); if step1 == step2 { visualize_nat_detection!( "Conclusion: The port changes regularly with step {step1}\n" ); } else { visualize_nat_detection!("Conclusion: The Ports change randomly.\n"); } Ok(NatType::from(net_features)) } else { // 不是对称型 // Open test // 发给 server2 换 ip and port 即 server3 回, server3 可能不响应 // server1: ip1:port1 // server2: ip2:port2 // server3: ip3:port1 // server4: ip1:port2 // server5: ip2:port1 // server6: ip3:port2 visualize_nat_detection!( "Filtering Test: probing server {stun_agent2}. Request server to respond from a changed IP and port", ); let request = Request::change_ip_and_port(); // 可能会不响应,超时太久会导致探测很久 let response = Transaction::begin( ref_iface.clone(), stun_router.clone(), RESTRICTED_RETRY_TIMES, timeout, ) .send_request(request, stun_agent2) .await?; if let Some(response) = response { let mapped_addr2 = response.map_addr()?; visualize_nat_detection!( "Result: received from {}, external addr: {mapped_addr2}", response.source_addr()? ); visualize_nat_detection!("Conclusion: Destination IP independent filtering\n"); } else { net_features |= NetFeature::Restricted; visualize_nat_detection!( "Result: No response after {RESTRICTED_RETRY_TIMES} attempts" ); visualize_nat_detection!("Conclusion: Filters packets based on destination IP\n"); } visualize_nat_detection!( "Filtering Test: probing server {stun_agent2}. Request server to respond from a changed port", ); // Restricted test // server2 换 port 即 server5 回,可能不响应 // 可能会不响应,超时太久会导致探测很久 let request = Request::change_port(); let response = Transaction::begin( ref_iface.clone(), stun_router.clone(), RESTRICTED_RETRY_TIMES, timeout, ) .send_request(request, stun_agent2) .await?; if let Some(response) = response { let mapped_addr2 = response.map_addr()?; visualize_nat_detection!( "Result: received from {}, external addr: {mapped_addr2}", response.source_addr()? ); visualize_nat_detection!("Conclusion: Destination port independent filtering\n"); } else { net_features |= NetFeature::PortRestricted; visualize_nat_detection!( "Result: No response after {RESTRICTED_RETRY_TIMES} attempts" ); visualize_nat_detection!("Conclusion: Filters packets based on destination port\n"); } // dynamic test, 请求 server3 visualize_nat_detection!("Dynamic Test: probing server {stun_agent3}",); let request = Request::default(); let response = Transaction::begin(ref_iface.clone(), stun_router.clone(), retry_times, timeout) .send_request(request, stun_agent3) .await?; if let Some(response) = response { // 回包,但是映射地址不一致,为动态型 let mapped_addr3 = response.map_addr()?; visualize_nat_detection!( "Result: received from {}, external addr: {mapped_addr3}", response.source_addr()? ); if mapped_addr1 != mapped_addr3 { net_features |= NetFeature::Dynamic; visualize_nat_detection!( "Conclusion: Mapping inconsistency indicates Address-Dependent Mapping, a Dynamic NAT type\n" ); } else { visualize_nat_detection!( "Conclusion: The mapping address is consistent, not Dynamic\n" ); } } else { // 不回包也视为动态型 net_features |= NetFeature::Dynamic; visualize_nat_detection!("Result: No response after 3 attempts"); visualize_nat_detection!( "Conclusion: Absence of server response may indicates Dynamic NAT behavior\n" ); } let nat_type = NatType::from(net_features); visualize_nat_detection!( "NAT detection completed. Network features: {:?}, NAT Type: {:?}", net_features, nat_type ); Ok(nat_type) } } } ================================================ FILE: qtraversal/src/nat/iface.rs ================================================ use std::{io, net::SocketAddr}; use bytes::{BufMut, BytesMut}; use qbase::net::route::{Line, Link, Route}; use qinterface::io::{IO, IoExt}; use crate::{ nat::msg::{Packet, TransactionId, WritePacket}, packet::{StunHeader, WriteStunHeader}, }; pub trait StunIO: IO { fn local_addr(&self) -> io::Result { self.bound_addr() } fn send_stun_packet( &self, packet: Packet, txid: TransactionId, dst: SocketAddr, ) -> impl Future> + Send { async move { let mut buf = BytesMut::zeroed(128); let (mut stun_hdr, mut stun_body) = buf.split_at_mut(StunHeader::encoding_size()); // put stun header stun_hdr.put_stun_header(&StunHeader::new(0)); // put stun body let origin = stun_body.remaining_mut(); stun_body.put_packet(&txid, &packet); let consumed = origin - stun_body.remaining_mut(); buf.truncate(StunHeader::encoding_size() + consumed); let bufs = &[io::IoSlice::new(&buf)]; // assemble packet header let link = Link::new(self.bound_addr()?, dst); let pathway = link.into(); let line = Line::new(link, 64, None, 0); let hdr = Route::new(pathway, line); self.sendmmsg(bufs, hdr).await } } } impl StunIO for I {} ================================================ FILE: qtraversal/src/nat/msg.rs ================================================ use std::{io, net::SocketAddr}; use bytes::BufMut; use nom::{ Err, IResult, Parser, combinator::map, error::{Error, ErrorKind}, multi::many0, number::streaming::{be_u8, be_u16}, }; use qbase::net::{AddrFamily, Family, WriteSocketAddr, be_socket_addr}; use rand::RngExt; use thiserror::Error; pub const BINDING_REQUEST: u16 = 0x0001; pub const BINDING_RESPONSE: u16 = 0x0101; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct TransactionId([u8; 16]); impl AsRef<[u8]> for TransactionId { fn as_ref(&self) -> &[u8] { &self.0 } } impl TransactionId { pub fn from_slice(slice: &[u8]) -> Self { let mut id = [0u8; 16]; id.copy_from_slice(slice); TransactionId(id) } pub fn random() -> Self { let mut id = [0u8; 16]; rand::rng().fill(&mut id); TransactionId(id) } } #[derive(Debug)] pub enum Packet { Request(Request), Response(Response), } /// STUN数据包中的Attr类型: #[derive(Debug, Clone, PartialEq)] pub enum Attr { // 由服务器返回的外网映射地址 MappedAddress(SocketAddr), // 客户端发起请求携带的指定响应地址 ResponseAddress(SocketAddr), // 由客户端请求转发时,携带变换Ip:Port响应的指示 ChangeRequest(u8), // 由服务器返回的Response消息的源地址,即服务器的地址 SourceAddress(SocketAddr), // 由服务器返回的另一台的STUN服务器地址, // 包括不同端口,供后续参考使用 ChangedAddress(SocketAddr), } #[derive(Debug)] pub enum AttrType { MappedAddress(Family), ResponseAddress(Family), // 由客户端请求转发时,携带变换Ip:Port响应的指示 ChangeRequest(u8), // 由服务器返回的Response消息的源地址,即服务器的地址 SourceAddress(Family), // 由服务器返回的另一台的STUN服务器地址, // 包括不同端口,供后续参考使用 ChangedAddress(Family), } #[derive(Debug, Error)] #[error("Invalid attribute type: {0}")] pub struct InvalidAttrType(u8); impl From for u8 { fn from(value: AttrType) -> Self { match value { AttrType::MappedAddress(Family::V4) => 0, AttrType::MappedAddress(Family::V6) => 1, AttrType::ResponseAddress(Family::V4) => 2, AttrType::ResponseAddress(Family::V6) => 3, AttrType::SourceAddress(Family::V4) => 4, AttrType::SourceAddress(Family::V6) => 5, AttrType::ChangedAddress(Family::V4) => 6, AttrType::ChangedAddress(Family::V6) => 7, AttrType::ChangeRequest(flag_set) => 8 | flag_set, } } } impl TryFrom for AttrType { type Error = InvalidAttrType; fn try_from(value: u8) -> Result { match value { 0 => Ok(AttrType::MappedAddress(Family::V4)), 1 => Ok(AttrType::MappedAddress(Family::V6)), 2 => Ok(AttrType::ResponseAddress(Family::V4)), 3 => Ok(AttrType::ResponseAddress(Family::V6)), 4 => Ok(AttrType::SourceAddress(Family::V4)), 5 => Ok(AttrType::SourceAddress(Family::V6)), 6 => Ok(AttrType::ChangedAddress(Family::V4)), 7 => Ok(AttrType::ChangedAddress(Family::V6)), 8..12 => Ok(AttrType::ChangeRequest(value & 0x3)), _ => Err(InvalidAttrType(value)), } } } trait WriteAttr { fn put_attr(&mut self, attr: &Attr); } impl WriteAttr for T { fn put_attr(&mut self, attr: &Attr) { let typ: u8 = attr.typ().into(); match attr { Attr::MappedAddress(socket_addr) => { self.put_u8(typ); self.put_socket_addr(socket_addr); } Attr::ResponseAddress(socket_addr) => { self.put_u8(typ); self.put_socket_addr(socket_addr); } Attr::ChangeRequest(flag) => { self.put_u8(typ | *flag); } Attr::SourceAddress(socket_addr) => { self.put_u8(typ); self.put_socket_addr(socket_addr); } Attr::ChangedAddress(socket_addr) => { self.put_u8(typ); self.put_socket_addr(socket_addr); } }; } } impl Attr { pub fn typ(&self) -> AttrType { match self { Attr::MappedAddress(socket_addr) => AttrType::MappedAddress(socket_addr.family()), Attr::ResponseAddress(socket_addr) => AttrType::ResponseAddress(socket_addr.family()), Attr::ChangeRequest(flag_set) => AttrType::ChangeRequest(*flag_set), Attr::SourceAddress(socket_addr) => AttrType::SourceAddress(socket_addr.family()), Attr::ChangedAddress(socket_addr) => AttrType::ChangedAddress(socket_addr.family()), } } fn be_attr(input: &[u8]) -> IResult<&[u8], Self> { if input.is_empty() { return Err(Err::Error(Error::new(input, ErrorKind::Eof))); } let (remain, typ) = be_u8(input)?; let typ: AttrType = typ .try_into() .map_err(|_| Err::Error(Error::new(input, ErrorKind::Alt)))?; match typ { AttrType::MappedAddress(family) => { let (remain, addr) = be_socket_addr(remain, family)?; Ok((remain, Attr::MappedAddress(addr))) } AttrType::ResponseAddress(family) => { let (remain, addr) = be_socket_addr(remain, family)?; Ok((remain, Attr::ResponseAddress(addr))) } AttrType::SourceAddress(family) => { let (remain, addr) = be_socket_addr(remain, family)?; Ok((remain, Attr::SourceAddress(addr))) } AttrType::ChangedAddress(family) => { let (remain, addr) = be_socket_addr(remain, family)?; Ok((remain, Attr::ChangedAddress(addr))) } AttrType::ChangeRequest(flags) => Ok((remain, Attr::ChangeRequest(flags))), } } } #[derive(Debug, PartialEq, Clone)] pub struct Request(Vec); /// 目前用到的Request只有3种,一种是空的默认Request;一种是变换IP、Port来响应;一种是只变换端口来响应 /// 可以看出,ChangeRequest属性不可能有超过一个,为满足这种限制,三种Request均直接构造出来,不再有其他 /// 可变操作函数。 impl Default for Request { fn default() -> Self { Self(Vec::with_capacity(0)) } } pub(crate) trait WriteRequest { fn put_request(&mut self, request: &Request); } impl WriteRequest for T { fn put_request(&mut self, request: &Request) { for attr in &request.0 { self.put_attr(attr); } } } pub fn be_request(input: &[u8]) -> IResult<&[u8], Request> { many0(Attr::be_attr).map(Request).parse(input) } pub const CHANGE_PORT: u8 = 0x01; pub const CHANGE_IP: u8 = 0x02; impl Request { pub fn change_ip_and_port() -> Self { let mut request = Request::default(); request.0.push(Attr::ChangeRequest(CHANGE_IP | CHANGE_PORT)); request } pub fn change_port() -> Self { let mut request = Request::default(); request.0.push(Attr::ChangeRequest(CHANGE_PORT)); request } pub fn add_response_address(&mut self, addr: SocketAddr) -> &mut Self { self.0.push(Attr::ResponseAddress(addr)); self } // 仅发送响应地址,移除ChangeRequest属性 pub fn with_response_addr(addr: SocketAddr) -> Self { Request(vec![Attr::ResponseAddress(addr)]) } pub fn change_request(&self) -> Option { for attr in &self.0 { if let Attr::ChangeRequest(flags) = attr { return Some(*flags); } } None } pub fn response_address(&self) -> Option<&SocketAddr> { for attr in &self.0 { if let Attr::ResponseAddress(addr) = attr { return Some(addr); } } None } } #[derive(Debug, Clone, PartialEq)] pub struct Response(pub Vec); pub(crate) trait WriteResponse { fn put_response(&mut self, response: &Response); } impl WriteResponse for T { fn put_response(&mut self, response: &Response) { for attr in &response.0 { self.put_attr(attr); } } } pub fn be_response(input: &[u8]) -> IResult<&[u8], Response> { many0(Attr::be_attr).map(Response).parse(input) } impl Response { pub fn with(attrs: Vec) -> Self { Response(attrs) } pub fn map_addr(&self) -> io::Result { for attr in &self.0 { if let Attr::MappedAddress(addr) = attr { return Ok(*addr); }; } Err(io::Error::other("No mapped address found in response")) } pub fn changed_addr(&self) -> io::Result { for attr in &self.0 { if let Attr::ChangedAddress(addr) = attr { return Ok(*addr); }; } Err(io::Error::other("No changed address found in response")) } pub fn source_addr(&self) -> io::Result { for attr in &self.0 { if let Attr::SourceAddress(addr) = attr { return Ok(*addr); }; } Err(io::Error::other("No source address found in response")) } } pub fn be_packet(input: &[u8]) -> IResult<&[u8], (TransactionId, Packet)> { let (remain, typ) = be_u16(input)?; let (txid, remain) = remain.split_at(16); let (remain, packet) = match typ { BINDING_REQUEST => map(be_request, Packet::Request).parse(remain)?, BINDING_RESPONSE => map(be_response, Packet::Response).parse(remain)?, _ => return Err(Err::Error(Error::new(input, ErrorKind::Alt))), }; Ok((remain, (TransactionId::from_slice(txid), packet))) } pub trait WritePacket { fn put_packet(&mut self, txid: &TransactionId, packet: &Packet); } impl WritePacket for T { fn put_packet(&mut self, txid: &TransactionId, packet: &Packet) { match packet { Packet::Request(request) => { self.put_u16(BINDING_REQUEST); self.put_slice(txid.as_ref()); self.put_request(request); } Packet::Response(response) => { self.put_u16(BINDING_RESPONSE); self.put_slice(txid.as_ref()); self.put_response(response); } } } } #[cfg(test)] mod tests { use super::*; #[test] fn attr_deserialize() { assert_eq!( Attr::be_attr(&[4, 78, 34, 127, 0, 0, 1][..]), Ok(( &[][..], Attr::SourceAddress("127.0.0.1:20002".parse().unwrap()) )) ); assert_eq!( Attr::be_attr(&[6, 78, 34, 127, 0, 0, 1][..]), Ok(( &[][..], Attr::ChangedAddress("127.0.0.1:20002".parse().unwrap()) )) ); assert_eq!( Attr::be_attr(&[0, 48, 57, 127, 0, 0, 1][..]), Ok(( &[][..], Attr::MappedAddress("127.0.0.1:12345".parse().unwrap()) )) ) } #[test] fn request_serialize() { let buf = [ 4, 78, 34, 127, 0, 0, 1, 0, 48, 57, 127, 0, 0, 1, 6, 78, 34, 127, 0, 0, 1, ]; let (remain, response) = be_response(&buf).unwrap(); assert_eq!(remain.len(), 0); assert_eq!( response, Response(vec![ Attr::SourceAddress("127.0.0.1:20002".parse().unwrap()), Attr::MappedAddress("127.0.0.1:12345".parse().unwrap()), Attr::ChangedAddress("127.0.0.1:20002".parse().unwrap()) ]) ); } } ================================================ FILE: qtraversal/src/nat/router.rs ================================================ use std::{ net::SocketAddr, sync::{Arc, Mutex, MutexGuard}, task::{Context, Poll}, }; use dashmap::DashMap; use qbase::{net::route::Link, util::ArcAsyncDeque}; use qinterface::{Interface, WeakInterface, component::Component}; use tokio::sync::SetOnce; use tracing::debug; use super::msg::{self, Packet, Request, Response, TransactionId}; type ResponseRouter = Arc>>>; #[derive(Default, Debug, Clone)] pub struct StunRouter { request_router: ArcAsyncDeque<(Request, TransactionId, SocketAddr)>, response_router: ResponseRouter, } impl StunRouter { pub fn new() -> Self { Self::default() } pub fn deliver_stun_packet(&self, txid: TransactionId, packet: Packet, link: Link) { match packet { msg::Packet::Request(request) => { self.request_router.push_back((request, txid, link.dst)); } msg::Packet::Response(response) => { if let Some((_id, recv_resp)) = self.response_router.remove(&txid) { let _ = recv_resp.set((response, link.dst)); } else { debug!( target: "stun", ?txid, %link, from =% link.dst, "Unknown request transaction id", ); } } } } pub async fn receive_request(&self) -> Option<(Request, TransactionId, SocketAddr)> { self.request_router.pop().await } /// Close the router, causing any pending `receive_request()` to return `None`. /// Called by `StunRouterComponent::reinit()` before replacing with a new router, /// so that a running `StunServer` task can detect the rebind and exit cleanly. pub fn close(&self) { self.request_router.close(); self.response_router.clear(); } pub(super) fn register( &self, transaction_id: TransactionId, future: Arc>, ) { self.response_router.insert(transaction_id, future); } pub(super) fn remove(&self, transaction_id: &TransactionId) { let _ = self.response_router.remove(transaction_id); } } #[derive(Debug)] struct StunRouterComponentInner { router: StunRouter, ref_iface: WeakInterface, } #[derive(Debug)] pub struct StunRouterComponent { inner: Mutex, } impl StunRouterComponent { pub fn new(ref_iface: WeakInterface) -> Self { Self { inner: Mutex::new(StunRouterComponentInner { router: StunRouter::new(), ref_iface, }), } } fn lock_inner(&self) -> MutexGuard<'_, StunRouterComponentInner> { self.inner.lock().expect("StunRouter lock poisoned") } pub fn ref_iface(&self) -> WeakInterface { self.lock_inner().ref_iface.clone() } pub fn router(&self) -> StunRouter { self.lock_inner().router.clone() } } impl Component for StunRouterComponent { fn reinit(&self, iface: &Interface) { let mut inner = self.lock_inner(); if inner.ref_iface.same_io(&iface.downgrade()) { return; } // Close the old router so any running StunServer task can detect the rebind and exit. inner.router.close(); *inner = StunRouterComponentInner { router: StunRouter::new(), ref_iface: iface.downgrade(), }; } fn poll_shutdown(&self, _cx: &mut Context<'_>) -> Poll<()> { Poll::Ready(()) } } ================================================ FILE: qtraversal/src/nat/server.rs ================================================ use std::{ io, net::SocketAddr, pin::Pin, sync::Mutex, task::{Context, Poll, ready}, }; use qinterface::{Interface, WeakInterface, component::Component, io::RefIO}; use tokio_util::task::AbortOnDropHandle; use tracing::{info, trace}; use super::{ msg::{Attr, Request, Response}, router::StunRouter, }; use crate::nat::{ iface::StunIO, msg::{CHANGE_IP, CHANGE_PORT, Packet}, router::StunRouterComponent, }; #[derive(Debug, Clone, Default)] pub struct StunServerConfig { change_port: Option, change_address: Option, } #[bon::bon] impl StunServerConfig { #[builder(finish_fn = init)] pub fn new(change_port: Option, change_address: Option) -> Self { Self { change_port, change_address, } } } #[derive(Debug)] pub struct StunServer { ref_iface: I, stun_router: StunRouter, config: StunServerConfig, } impl StunServer { pub fn new(ref_iface: I, stun_router: StunRouter, config: StunServerConfig) -> Self { info!( target: "stun", local_addr = ?ref_iface.iface().local_addr(), change_port = ?config.change_port, change_address = ?config.change_address, "new stun server", ); Self { ref_iface, stun_router, config, } } pub fn spawn(self) -> AbortOnDropHandle> { AbortOnDropHandle::new(tokio::spawn(async move { serve_loop(self.ref_iface, self.stun_router, self.config).await })) } } async fn serve_loop( ref_iface: I, stun_router: StunRouter, config: StunServerConfig, ) -> io::Result<()> { info!(target: "stun", "Server started"); let local_addr = ref_iface.iface().local_addr()?; while let Some((request, txid, src)) = stun_router.receive_request().await { trace!(target: "stun", ?request, "recv request"); match (request.change_request(), request.response_address()) { (Some(changes), _) => { let Ok(addr) = select_change_target(src, changes, local_addr, &config) else { trace!( target: "stun", changes, change_port = ?config.change_port, change_address = ?config.change_address, "drop request: server lacks requested change capability", ); continue; }; let request = Request::with_response_addr(src); trace!(target: "stun", ?request, to = %addr, "send request"); ref_iface .iface() .send_stun_packet(Packet::Request(request), txid, addr) .await?; } (None, Some(&response_addr)) => { let mut attrs = vec![ Attr::SourceAddress(local_addr), Attr::MappedAddress(response_addr), ]; if let Some(addr) = config.change_address { attrs.push(Attr::ChangedAddress(addr)); } let response = Response::with(attrs); trace!(target: "stun", ?response, to = %response_addr, "send response"); ref_iface .iface() .send_stun_packet(Packet::Response(response), txid, response_addr) .await?; } _ => { let mut attrs = vec![Attr::SourceAddress(local_addr), Attr::MappedAddress(src)]; if let Some(addr) = config.change_address { attrs.push(Attr::ChangedAddress(addr)); } let response = Response::with(attrs); trace!(target: "stun", ?response, to = %src, "send response"); ref_iface .iface() .send_stun_packet(Packet::Response(response), txid, src) .await?; } } } trace!(target: "stun", "Request handler finished - no more requests"); Ok(()) } fn select_change_target( src: SocketAddr, changes: u8, local_addr: SocketAddr, config: &StunServerConfig, ) -> io::Result { let wants_ip = changes & CHANGE_IP != 0; let wants_port = changes & CHANGE_PORT != 0; match (wants_ip, wants_port) { (false, false) => Ok(src), (true, false) => { // CHANGE_IP: respond from a different IP (complete change_address, port may differ) config.change_address.ok_or_else(|| { io::Error::new(io::ErrorKind::Unsupported, "CHANGE_IP not supported") }) } (false, true) => { let port = config.change_port.ok_or_else(|| { io::Error::new(io::ErrorKind::Unsupported, "CHANGE_PORT not supported") })?; Ok(SocketAddr::new(local_addr.ip(), port)) } (true, true) => { let addr = config.change_address.ok_or_else(|| { io::Error::new( io::ErrorKind::Unsupported, "CHANGE_IP and CHANGE_PORT not supported", ) })?; Ok(addr) } } } #[derive(Debug)] struct StunServerComponentInner { ref_iface: WeakInterface, config: StunServerConfig, task: Option>>, } #[derive(Debug)] pub struct StunServerComponent { inner: Mutex, } impl StunServerComponent { pub fn new( ref_iface: WeakInterface, stun_router: StunRouter, config: StunServerConfig, ) -> Self { let task = Some(StunServer::new(ref_iface.clone(), stun_router.clone(), config.clone()).spawn()); Self { inner: Mutex::new(StunServerComponentInner { ref_iface, config, task, }), } } fn lock_inner(&self) -> std::sync::MutexGuard<'_, StunServerComponentInner> { self.inner.lock().unwrap() } } impl Component for StunServerComponent { fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> { let mut inner = self.lock_inner(); if let Some(task) = inner.task.as_mut() { task.abort(); _ = ready!(Pin::new(task).poll(cx)); inner.task = None; } Poll::Ready(()) } fn reinit(&self, iface: &Interface) { let mut inner = self.lock_inner(); if inner.ref_iface.same_io(&iface.downgrade()) { return; } _ = iface.with_components(|components| { let Some(router) = components.with(|router: &StunRouterComponent| { router.reinit(iface); router.router() }) else { return; }; if let Some(task) = inner.task.take() { task.abort(); } inner.ref_iface = iface.downgrade(); inner.task = Some( StunServer::new(inner.ref_iface.clone(), router, inner.config.clone()).spawn(), ); }); } } ================================================ FILE: qtraversal/src/nat/tx.rs ================================================ use std::{io, net::SocketAddr, sync::Arc, time::Duration}; use qinterface::io::RefIO; use tokio::{sync::SetOnce, time::timeout}; use super::{ msg::{Packet, Request, Response, TransactionId}, router::StunRouter, }; use crate::nat::iface::StunIO; #[derive(Clone)] pub struct Transaction { stun_router: StunRouter, ref_iface: I, transaction_id: TransactionId, pending_response: Arc>, retry_times: u8, timeout: Duration, } impl Transaction { pub fn begin( ref_iface: I, stun_router: StunRouter, retry_times: u8, timeout: Duration, ) -> Self { let pending_response = Arc::new(SetOnce::new()); let transaction_id = TransactionId::random(); stun_router.register(transaction_id, pending_response.clone()); Self { stun_router, ref_iface, transaction_id, pending_response, retry_times, timeout, } } pub async fn send_request( &self, request: Request, dst: SocketAddr, ) -> io::Result> { let mut retry_times = self.retry_times; while retry_times > 0 { match timeout(self.timeout, self.do_tick(dst, request.clone())).await { Ok(result) => return result.map(Some), Err(_error) => retry_times -= 1, } } Ok(None) } async fn do_tick(&self, dst: SocketAddr, request: Request) -> io::Result { self.ref_iface .iface() .send_stun_packet(Packet::Request(request), self.transaction_id, dst) .await?; let (response, _src) = self.pending_response.wait().await.clone(); Ok(response) } } impl Drop for Transaction { fn drop(&mut self) { self.stun_router.remove(&self.transaction_id); } } ================================================ FILE: qtraversal/src/nat.rs ================================================ pub mod client; pub mod iface; pub mod msg; pub mod router; pub mod server; pub mod tx; ================================================ FILE: qtraversal/src/packet.rs ================================================ use bytes::BufMut; use qbase::net::{ Family, addr::{EndpointAddr, WriteEndpointAddr, be_endpoint_addr}, }; use crate::PathWay; const STUN_HEADER_MASK: u8 = 0b1111_1110; const STUN_HEADER_BITS: u8 = 0b1100_0010; const FORWARD_HEADER_MASK: u8 = 0b1110_0000; const FORWARD_VERSION_MASK: u8 = 0b1111_0000; const FORWARD_HEADER_BITS: u8 = 0b0110_0000; const FORWARD_BIT: u8 = 0b0000_1000; const FORWARD_FAMILY_BIT: u8 = 0b0000_0100; const FORWARD_SRC_TYPE_BIT: u8 = 0b0000_0010; const FORWARD_DST_TYPE_BIT: u8 = 0b0000_0001; #[derive(PartialEq, Eq, Debug)] pub enum HeaderType { Stun(u8), // 最后 bit Forward(u8), // 最后 5bit } // Stun Packet { // Header Form (1) = 1, // Fixed Bit (1) = 1, // Stun Hdr (6), // Request 0b000010 #Response 0b000011 // Version (32) = 0, // DDIL(8) = 0, // 伪装0长度的目标连接ID // SDIL(8) = 0, // 伪装0长度的源连接ID // Ver(16), // 2个字节,表示我们自定义的版本号,方便未来升级 // ... Stun payload // } #[derive(Clone, Copy)] pub struct StunHeader { version: u16, } impl StunHeader { pub fn new(version: u16) -> Self { Self { version } } pub fn encoding_size() -> usize { 1 + 4 + 4 } } pub fn be_stun_header(input: &[u8]) -> nom::IResult<&[u8], StunHeader> { let (remain, version) = nom::number::streaming::be_u16(input)?; Ok((remain, StunHeader { version })) } pub trait WriteStunHeader { fn put_stun_header(&mut self, stun_header: &StunHeader); } impl WriteStunHeader for T { fn put_stun_header(&mut self, stun_header: &StunHeader) { self.put_u8(STUN_HEADER_BITS); self.put_u32(0); self.put_u8(0); self.put_u8(0); self.put_u16(stun_header.version); } } // Forward Packet { // Header Form (1) = 0, // Fixed Bit (1) = 1, // Spin Bit (1) = 1, // 1表示带有转发包头 // Remain (5), // 使其等于真正QUIC包第一字节的后5bit,飘忽不定,伪装够深 // Version (4), // Forward (1) = 1, // Family (1), // 0表示IPv4,1表示IPv6 // Src type(1), // 0表示直连,1表示带agent // Dst type(1), // 0表示直连,1表示带agent // Src endpoint, // 根据src type,是Endpoint::Agent还是Direct // Dst endpoint, // 根据dst type,是Endpoint::Agent还是Direct // ... Real Quic Packet // } #[derive(Debug, Clone, Copy)] pub struct ForwardHeader { remian: u8, // 后 5bits version: u8, // 前 4bits pathway: PathWay, } impl ForwardHeader { pub fn encoding_size(pathway: &PathWay) -> usize { if matches!(pathway.remote(), EndpointAddr::Direct { .. }) { return 0; } 1 + 1 + pathway.local().encoding_size() + pathway.remote().encoding_size() } pub fn pathway(&self) -> PathWay { self.pathway } pub fn new(version: u8, pathway: &PathWay, buffer: &[u8]) -> Self { let remian = buffer[0] & 0b0001_1111; Self { remian, version, pathway: *pathway, } } } pub trait WriteForwardHeader { fn put_forward_header(&mut self, forward_header: &ForwardHeader); } impl WriteForwardHeader for T { fn put_forward_header(&mut self, forward_header: &ForwardHeader) { self.put_u8(FORWARD_HEADER_BITS | forward_header.remian); let mut flag = (forward_header.version << 4) | FORWARD_BIT; if forward_header.pathway.local().ip().is_ipv6() { flag |= FORWARD_FAMILY_BIT; } if matches!(forward_header.pathway.local(), EndpointAddr::Agent { .. }) { flag |= FORWARD_SRC_TYPE_BIT; } if matches!(forward_header.pathway.remote(), EndpointAddr::Agent { .. }) { flag |= FORWARD_DST_TYPE_BIT; } self.put_u8(flag); self.put_endpoint_addr(forward_header.pathway.local()); self.put_endpoint_addr(forward_header.pathway.remote()); } } pub fn be_forward_header(input: &[u8]) -> nom::IResult<&[u8], ForwardHeader> { let (remain, first) = nom::number::streaming::be_u8(input)?; let version = (first & FORWARD_VERSION_MASK) >> 4; let flag = first & !FORWARD_VERSION_MASK; let family = match flag & FORWARD_FAMILY_BIT { 0 => Family::V4, _ => Family::V6, }; let src_ep_typ = flag & FORWARD_SRC_TYPE_BIT; let dst_ep_typ = flag & FORWARD_DST_TYPE_BIT; let (remain, src) = be_endpoint_addr(remain, src_ep_typ, family)?; let (remain, dst) = be_endpoint_addr(remain, dst_ep_typ, family)?; let pathway = PathWay::new(src, dst); Ok(( remain, ForwardHeader { remian: first, version, pathway, }, )) } #[derive(Clone, Copy)] pub enum Header { Stun(StunHeader), Forward(ForwardHeader), } pub fn be_header_type(input: &[u8]) -> nom::IResult<&[u8], HeaderType> { let (remain, first) = nom::number::streaming::be_u8(input)?; if first & STUN_HEADER_MASK == STUN_HEADER_BITS { let (remain, version) = nom::number::streaming::be_u32(remain)?; if version == 0 { let (remain, _) = nom::number::streaming::be_u8(remain)?; let (remain, _) = nom::number::streaming::be_u8(remain)?; return Ok((remain, HeaderType::Stun(first & 1))); } } else if first & FORWARD_HEADER_MASK == FORWARD_HEADER_BITS { return Ok((remain, HeaderType::Forward(first & 0b0001_1111))); } Err(nom::Err::Error(nom::error::make_error( input, nom::error::ErrorKind::Alt, ))) } pub fn be_header(input: &[u8]) -> nom::IResult<&[u8], Header> { let (remain, ty) = be_header_type(input)?; match ty { HeaderType::Stun(_ty) => { let (remain, stun_hdr) = be_stun_header(remain)?; Ok((remain, Header::Stun(stun_hdr))) } HeaderType::Forward(_ty) => { let (remain, forward_hdr) = be_forward_header(remain)?; Ok((remain, Header::Forward(forward_hdr))) } } } pub trait WriteHeader { fn put_header(&mut self, header: &Header); } impl WriteHeader for T { fn put_header(&mut self, header: &Header) { match header { Header::Stun(stun_header) => { self.put_stun_header(stun_header); } Header::Forward(forward_header) => { self.put_forward_header(forward_header); } } } } #[cfg(test)] mod tests { use bytes::BytesMut; use super::*; #[test] fn test_stun_header() { let stun_hdr = StunHeader::new(0); let mut buf = BytesMut::with_capacity(StunHeader::encoding_size()); buf.put_stun_header(&stun_hdr); let (remain, hdr) = be_header_type(&buf[..]).unwrap(); assert_eq!(hdr, HeaderType::Stun(0)); let (remain, stun_hdr) = be_stun_header(remain).unwrap(); assert_eq!(stun_hdr.version, 0); assert_eq!(remain.len(), 0) } } ================================================ FILE: qtraversal/src/punch/predictor.rs ================================================ use std::{ collections::{HashMap, VecDeque}, future::poll_fn, io, net::SocketAddr, str::FromStr, sync::Arc, time::Duration, }; use qbase::{frame::PunchHelloFrame, net::route::Link}; use qinterface::{ Interface, bind_uri::{BindUri, Scheme}, component::route::{QuicRouter, QuicRouterComponent}, io::{IO, ProductIO}, manager::InterfaceManager, }; use crate::{ punch::{ scheduler::SCHEDULER, tx::{PunchId, Transaction}, }, route::ReceiveAndDeliverPacket, }; const MAX_CONCURRENT_SOCKETS: usize = 60; const MIN_PORT: u16 = 1024; const PACKET_TTL: u8 = 64; const FIRST_PROBE_ID: u32 = 1; const MAX_PROBES: u32 = 300; const PACING_INTERVAL: Duration = Duration::from_millis(20); pub struct PortPredictor { ifaces: Arc, factory: Arc, quic_router: Arc, bind_uri: BindUri, dst: SocketAddr, device: String, probes: ProbeTable, quota_held: u32, probes_created: u32, } #[derive(Debug)] struct PendingProbe { bind_uri: BindUri, iface: Interface, port: u16, } pub type PacketSendFn = Arc< dyn Fn( &Interface, Link, u8, PunchHelloFrame, ) -> std::pin::Pin> + Send + '_>> + Send + Sync, >; struct ProbeTable { pending: HashMap, active_ports: HashMap, order: VecDeque, next_probe_id: u32, } impl ProbeTable { fn new() -> Self { Self { pending: HashMap::new(), active_ports: HashMap::new(), order: VecDeque::new(), next_probe_id: FIRST_PROBE_ID, } } fn len(&self) -> usize { self.pending.len() } fn contains_port(&self, port: u16) -> bool { self.active_ports.contains_key(&port) } fn allocate_probe_id(&mut self) -> u32 { let probe_id = self.next_probe_id; self.next_probe_id = self.next_probe_id.wrapping_add(1); if self.next_probe_id < FIRST_PROBE_ID { self.next_probe_id = FIRST_PROBE_ID; } probe_id } fn insert(&mut self, probe_id: u32, bind_uri: BindUri, iface: Interface, port: u16) { self.active_ports.insert(port, probe_id); self.order.push_back(probe_id); self.pending.insert( probe_id, PendingProbe { bind_uri, iface, port, }, ); } fn take(&mut self, probe_id: u32) -> Option { let probe = self.pending.remove(&probe_id)?; self.active_ports.remove(&probe.port); self.order.retain(|&id| id != probe_id); Some(probe) } fn oldest_probe_id(&self) -> Option { self.order.front().copied() } fn pending_probe_ids(&self) -> Vec { self.pending.keys().copied().collect() } fn drain_bind_uris(&mut self) -> Vec { self.active_ports.clear(); self.order.clear(); self.pending .drain() .map(|(_, probe)| probe.bind_uri) .collect() } } impl PortPredictor { pub fn new( ifaces: Arc, factory: Arc, quic_router: Arc, bind_uri: BindUri, dst: SocketAddr, ) -> io::Result { let device = match bind_uri.scheme() { Scheme::Iface => bind_uri.as_iface_bind_uri().unwrap().1.to_string(), Scheme::Inet => bind_uri.as_inet_bind_uri().unwrap().ip().to_string(), _ => return Err(io::ErrorKind::Unsupported.into()), }; tracing::debug!( target: "punch", bind_uri = %bind_uri, dst = %dst, device = %device, "Created port predictor" ); Ok(Self { ifaces, factory, quic_router, bind_uri, dst, device, probes: ProbeTable::new(), quota_held: 0, probes_created: 0, }) } fn release_quota(&mut self, count: u32) -> io::Result<()> { SCHEDULER .lock() .unwrap() .release_port(count, self.dst, self.device.clone())?; self.quota_held = self.quota_held.saturating_sub(count); Ok(()) } fn port_to_bind_uri(&self, port: u16) -> BindUri { match self.bind_uri.scheme() { Scheme::Iface => { let (ip_family, device, _) = self.bind_uri.as_iface_bind_uri().unwrap(); let bind_uri = format!( "iface://{ip_family}.{device}:{port}?{}=true", BindUri::TEMPORARY_PROP ); BindUri::from_str(bind_uri.as_str()).unwrap_or_else(|e| { panic!("Constructed invalid iface bind URI {bind_uri}: {e}") }) } Scheme::Inet => { let socket_addr = self.bind_uri.as_inet_bind_uri().unwrap(); let ip = socket_addr.ip(); let bind_uri = format!("inet://{ip}:{port}?{}=true", BindUri::TEMPORARY_PROP); BindUri::from_str(bind_uri.as_str()) .unwrap_or_else(|e| panic!("Constructed invalid inet bind URI {bind_uri}: {e}")) } _ => unreachable!("Unsupported bind URI scheme for port prediction"), } } async fn release_interface(&mut self, bind_uri: BindUri) { self.ifaces.unbind(bind_uri).await; if let Err(error) = self.release_quota(1) { tracing::warn!(target: "punch", %error, "failed to release quota for interface"); } } async fn release_probe(&mut self, probe_id: u32) -> bool { let Some(probe) = self.probes.take(probe_id) else { return false; }; self.release_interface(probe.bind_uri).await; true } fn check_and_claim(&mut self, tx: &Transaction) -> Option<(BindUri, Interface)> { let (_, frame) = tx.try_punch_done()?; let probe_id = frame.probe_id(); tracing::debug!(target: "punch", probe_id, "punchDone received, attempting to claim probe"); self.claim_probe(probe_id) } async fn evict_if_needed( &mut self, tx: &Transaction, ) -> io::Result> { while self.probes.len() >= MAX_CONCURRENT_SOCKETS { if let Some(result) = self.check_and_claim(tx) { return Ok(Some(result)); } let Some(oldest_id) = self.probes.oldest_probe_id() else { break; }; self.release_probe(oldest_id).await; tracing::trace!(target: "punch", oldest_id, active_probes = self.probes.len(), "evicted oldest probe"); } Ok(None) } fn claim_probe(&mut self, probe_id: u32) -> Option<(BindUri, Interface)> { let probe = self.probes.take(probe_id)?; Some((probe.bind_uri, probe.iface)) } async fn finalize( &mut self, result: (BindUri, Interface), ) -> io::Result> { if let Err(error) = self.release_all().await { tracing::warn!(target: "punch", %error, "failed to cleanup remaining probes after success"); } Ok(Some(result)) } pub(super) async fn predict( &mut self, punch_id: PunchId, tx: Arc, packet_send_fn: PacketSendFn, ) -> io::Result> { tracing::debug!(target: "punch", %punch_id, "starting port prediction"); while self.probes_created < MAX_PROBES { // Check if PunchDone has been received for an active probe if let Some(result) = self.check_and_claim(tx.as_ref()) { if let Err(error) = self.release_quota(1) { tracing::warn!(target: "punch", %error, "failed to release quota for claimed probe"); } return self.finalize(result).await; } // Evict oldest probe if at capacity if let Some(result) = self.evict_if_needed(tx.as_ref()).await? { if let Err(error) = self.release_quota(1) { tracing::warn!(target: "punch", %error, "failed to release quota for claimed probe"); } return self.finalize(result).await; } // Create and send one probe if let Err(error) = self.create_and_send_probe(punch_id, &packet_send_fn).await { tracing::trace!(target: "punch", %punch_id, %error, "probe creation failed, continuing"); } // Pacing: wait interval or return early if PunchDone arrives if tx.try_punch_done().is_none() { tokio::time::timeout(PACING_INTERVAL, tx.wait_punch_done()) .await .ok(); } } // Final check before giving up if let Some(result) = self.check_and_claim(tx.as_ref()) { if let Err(error) = self.release_quota(1) { tracing::warn!(target: "punch", %error, "failed to release quota for claimed probe"); } return self.finalize(result).await; } if let Err(e) = self.release_all().await { tracing::error!(target: "punch", %punch_id, %e, "failed to cleanup resources"); } tracing::debug!(target: "punch", %punch_id, probes_created = self.probes_created, "port prediction finished without match"); Ok(None) } async fn create_and_send_probe( &mut self, punch_id: PunchId, packet_send_fn: &PacketSendFn, ) -> io::Result<()> { self.acquire_quota(1).await?; let (bind_uri, iface) = match self.create_interface().await { Ok(result) => result, Err(e) => { if let Err(error) = self.release_quota(1) { tracing::warn!(target: "punch", %error, "failed to release quota on interface creation failure"); } return Err(e); } }; let socket_addr = match iface.bound_addr() { Ok(addr) => addr, Err(_) => { self.release_interface(bind_uri).await; return Err(io::Error::new( io::ErrorKind::AddrNotAvailable, "failed to get bound addr", )); } }; let port = socket_addr.port(); let probe_id = self.probes.allocate_probe_id(); let link = Link::new(socket_addr, self.dst); let frame = PunchHelloFrame::new(punch_id.local_seq, punch_id.remote_seq, probe_id); if packet_send_fn(&iface, link, PACKET_TTL, frame) .await .is_ok() { self.probes.insert(probe_id, bind_uri, iface, port); self.probes_created += 1; Ok(()) } else { self.release_interface(bind_uri).await; Err(io::Error::new( io::ErrorKind::BrokenPipe, "failed to send probe", )) } } async fn create_interface(&mut self) -> io::Result<(BindUri, Interface)> { for _ in 0..10 { let port = rand::random::() % (u16::MAX - MIN_PORT) + MIN_PORT; if self.probes.contains_port(port) { continue; } let bind_addr = self.port_to_bind_uri(port); let bind_iface = self .ifaces .bind(bind_addr.clone(), self.factory.clone()) .await; bind_iface.with_components_mut(|components, iface| { components.init_with(|| QuicRouterComponent::new(self.quic_router.clone())); components.init_with(|| { ReceiveAndDeliverPacket::builder(iface.downgrade()) .quic_router(self.quic_router.clone()) .init() }); }); let iface = bind_iface.borrow(); match iface.bound_addr() { Ok(_bound_addr) => { return Ok((bind_addr, iface)); } Err(_) => { self.ifaces.unbind(bind_addr).await; continue; } } } tracing::warn!(target: "punch", bind_uri = %self.bind_uri, dst = %self.dst, "failed to create interface after 10 attempts"); Err(io::Error::new( io::ErrorKind::AddrNotAvailable, "Failed to bind port after max retries", )) } async fn release_all(&mut self) -> io::Result<()> { tracing::debug!(target: "punch", active_probes = self.probes.len(), "starting resource cleanup"); let probe_ids = self.probes.pending_probe_ids(); for probe_id in probe_ids { self.release_probe(probe_id).await; } if self.quota_held > 0 { let orphaned = self.quota_held; tracing::warn!(target: "punch", orphaned, "releasing orphaned quota without pending probes"); self.release_quota(orphaned)?; } tracing::debug!(target: "punch", "resource cleanup completed"); Ok(()) } async fn acquire_quota(&mut self, count: u32) -> io::Result { let count = count.min(MAX_PROBES - self.probes_created); if count == 0 { return Err(io::Error::new( io::ErrorKind::ResourceBusy, format!("Would exceed maximum limit of {}", MAX_PROBES), )); } let granted = poll_fn(|cx| { SCHEDULER .lock() .unwrap() .poll_allocate(cx, self.dst, self.device.clone(), count) }) .await?; self.quota_held += granted; Ok(granted) } } impl Drop for PortPredictor { fn drop(&mut self) { let quota_held = self.quota_held; self.quota_held = 0; if quota_held > 0 && let Err(error) = SCHEDULER .lock() .unwrap() .release_port(quota_held, self.dst, self.device.clone()) { tracing::warn!(target: "punch", %error, quota_held, "failed to release predictor quota during drop"); } let bind_uris = self.probes.drain_bind_uris(); let futures: Vec<_> = bind_uris .into_iter() .map(|bind_uri| self.ifaces.unbind(bind_uri)) .collect(); if !futures.is_empty() { tokio::spawn(async move { futures::future::join_all(futures).await; }); } } } ================================================ FILE: qtraversal/src/punch/puncher.rs ================================================ use std::{ collections::HashSet, io, net::SocketAddr, ops::Deref, str::FromStr, sync::{Arc, Mutex}, time::Duration, }; use dashmap::{DashMap, DashSet, Entry}; use qbase::{ frame::{ AddAddressFrame, PunchDoneFrame, PunchHelloFrame, PunchMeNowFrame, ReliableFrame, RemoveAddressFrame, io::{ReceiveFrame, SendFrame}, }, net::{ AddrFamily, NatType, addr::EndpointAddr, route::{Line, Link, Route}, tx::Signals, }, packet::{ Package, PacketSpace, ProductHeader, header::short::OneRttHeader, io::{AssemblePacket, Packages, PadTo20}, }, }; use qevent::telemetry::Instrument; use qinterface::{ Interface, WeakInterface, bind_uri::BindUri, component::route::{QuicRouter, QuicRouterComponent}, io::{IO, IoExt, ProductIO}, manager::InterfaceManager, }; use tokio::{task::AbortHandle, time::timeout}; use tracing::Instrument as _; use crate::{ PathWay, addr::AddressBook, nat::{client::StunClientComponent, router::StunRouterComponent}, punch::{ predictor::{PacketSendFn, PortPredictor}, tx::{AsPunchId, PunchId, Transaction}, }, route::ReceiveAndDeliverPacket, }; type StunClient = crate::nat::client::StunClient; // type StunProtocol = crate::nat::protocol::StunProtocol; // TTL const HELLO_TTL: u8 = 64; const DEFAULT_PROBE_ID: u32 = 0; #[cfg(any(test, feature = "test-ttl"))] pub const KNOCK_TTL: u8 = 1; #[cfg(not(any(test, feature = "test-ttl")))] pub const KNOCK_TTL: u8 = 5; // Timeout const KNOCK_TIMEOUT: Duration = Duration::from_millis(100); const PUNCH_TIMEOUT: Duration = Duration::from_secs(3); const PUNCH_ME_NOW_TIMEOUT: Duration = Duration::from_secs(1); const COLLISION_TIMEOUT: Duration = Duration::from_secs(3); // Birthday attack timeout: must exceed PortPredictor's full run time (~6s for 300 probes × 20ms) const BIRTHDAY_TIMEOUT: Duration = Duration::from_secs(8); // Quantity const MAX_RETRIES: usize = 5; const COLLISION_PORTS: u32 = 800; pub struct ArcPuncher(Arc>); impl Clone for ArcPuncher { fn clone(&self) -> Self { Self(self.0.clone()) } } impl ArcPuncher where TX: SendFrame + Send + Sync + Clone + 'static, PH: ProductHeader + Send + Sync + 'static, S: PacketSpace + Send + Sync + 'static, { pub fn new( broker: TX, product_header: PH, packet_space: Arc, ifaces: Arc, iface_factory: Arc, quic_router: Arc, stun_servers: Arc<[SocketAddr]>, ) -> Self { Self(Arc::new(Puncher::new( broker, product_header, packet_space, ifaces, iface_factory, quic_router, stun_servers, ))) } } pub struct Puncher { transaction: DashMap)>, punch_history: DashSet, product_header: PH, packet_space: Arc, ifaces: Arc, iface_factory: Arc, quic_router: Arc, stun_servers: Arc<[SocketAddr]>, address_book: Mutex, punch_ifaces: DashMap, broker: TX, } impl Puncher where TX: SendFrame + Send + Sync + Clone + 'static, PH: ProductHeader + Send + Sync + 'static, S: PacketSpace + Send + Sync + 'static, { pub fn new( broker: TX, product_header: PH, packet_space: Arc, ifaces: Arc, iface_factory: Arc, quic_router: Arc, stun_servers: Arc<[SocketAddr]>, ) -> Self { Self { transaction: DashMap::new(), punch_history: DashSet::new(), product_header, packet_space, ifaces, iface_factory, quic_router, stun_servers, address_book: Mutex::new(AddressBook::default()), punch_ifaces: DashMap::new(), broker, } } pub async fn send_packet

( &self, iface: &(impl IO + ?Sized), link: Link, ttl: u8, packages: P, ) -> io::Result<()> where P: for<'b> Package>, PadTo20: for<'b> Package>, { let mut buffer = [0; 128]; let sent_bytes = (|| { let mut packet = self .packet_space .new_packet(self.product_header.new_header()?, &mut buffer)?; packet.assemble_packet(&mut Packages((packages, PadTo20)))?; let (sent_bytes, _props) = packet.encrypt_and_protect_packet(); Result::<_, Signals>::Ok(sent_bytes) })() .map_err(|s| io::Error::other(format!("Failed to assemble packet: {s:?}")))?; let line = Line::new(link, ttl, None, sent_bytes as u16); let route = Route::new(link.into(), line); iface .sendmmsg(&[io::IoSlice::new(&buffer[..sent_bytes])], route) .await } async fn collision( &self, iface: &Interface, link: Link, punch_id: PunchId, ttl: u8, ) -> io::Result<()> where PadTo20: for<'b> Package>, PunchHelloFrame: for<'b> Package>, { tracing::debug!(target: "punch", %punch_id, %link, ttl, "starting collision attack"); let mut random_ports = HashSet::new(); let dst = link.dst; let ip = dst.ip(); while random_ports.len() < COLLISION_PORTS as usize { let port = rand::random::() % (u16::MAX - 1024) + 1024; let dst = SocketAddr::new(ip, port); if !random_ports.insert(port) { continue; } let link = Link::new(link.src, dst); let frame = PunchHelloFrame::new(punch_id.local_seq, punch_id.remote_seq, DEFAULT_PROBE_ID); self.send_packet(iface, link, ttl, frame).await?; } Ok(()) } } impl Drop for Puncher { fn drop(&mut self) { for entry in self.transaction.iter() { entry.value().0.abort(); } self.transaction.clear(); self.punch_history.clear(); let futures: Vec<_> = self .punch_ifaces .iter() .map(|entry| self.ifaces.unbind(entry.key().clone())) .collect(); if !futures.is_empty() { tokio::spawn( async move { futures::future::join_all(futures).await; } .instrument_in_current() .in_current_span(), ); } self.punch_ifaces.clear(); } } impl ArcPuncher where TX: SendFrame + Send + Sync + Clone + 'static, PH: ProductHeader + Send + Sync + 'static, S: PacketSpace + Send + Sync + 'static, for<'b> PunchDoneFrame: Package>, for<'b> PunchHelloFrame: Package>, for<'b> PadTo20: Package>, { pub fn add_local_address( &self, bind_uri: BindUri, local_addr: SocketAddr, nat_type: NatType, tire: u32, ) -> io::Result<()> { if nat_type == NatType::Dynamic { let puncher = self.clone(); let ifaces = self.0.ifaces.clone(); let iface_factory = self.0.iface_factory.clone(); let stun_servers = self.0.stun_servers.clone(); let quic_router = self.0.quic_router.clone(); tokio::spawn( async move { let (iface, stun_client) = dynamic_iface(&bind_uri, &ifaces, &iface_factory, &quic_router, &stun_servers) .await?; let dynamic_bind = iface.bind_uri(); let outer = stun_client.outer_addr().await.inspect_err(|error| { tracing::warn!(target: "punch", %error, bind_uri = %dynamic_bind, "failed to detect outer address for dynamic interface, unbinding"); let ifaces = ifaces.clone(); let dynamic_bind = dynamic_bind.clone(); tokio::spawn(async move { ifaces.unbind(dynamic_bind).await }); })?; puncher .0 .punch_ifaces .insert(dynamic_bind.clone(), iface.clone()); let mut address_book = puncher.0.address_book.lock().unwrap(); let frame = address_book.add_local_address(dynamic_bind.clone(), outer, tire, nat_type)?; tracing::trace!(target: "punch", bind_uri = %dynamic_bind, %outer, nat_type = ?nat_type, "sending AddAddress frame for dynamic"); puncher .0 .broker .send_frame([ReliableFrame::AddAddress(frame)]); Ok::<_, io::Error>(()) } .instrument_in_current() .in_current_span(), ); return Ok(()); } let mut address_book = self.0.address_book.lock().unwrap(); let frame = address_book.add_local_address(bind_uri.clone(), local_addr, tire, nat_type)?; tracing::trace!(target: "punch", bind_uri = %bind_uri, %local_addr, nat_type = ?nat_type, "sending AddAddress frame"); self.0.broker.send_frame([ReliableFrame::AddAddress(frame)]); Ok(()) } pub fn add_local_endpoint( &self, bind: BindUri, addr: EndpointAddr, ) -> io::Result> { let mut address_book = self.0.address_book.lock().unwrap(); address_book.add_local_endpoint(bind.clone(), addr)?; let mut ways = Vec::new(); for (remote_ep, source) in address_book.remote_endpoint().iter() { if let Ok(way) = self.resolve_punch_connection(&bind, &addr, remote_ep, source) { ways.push(way); } } Ok(ways) } pub fn add_peer_endpoint( &self, endpoint: EndpointAddr, source: qresolve::Source, ) -> io::Result> { let mut address_book = self.0.address_book.lock().unwrap(); address_book.add_peer_endpoint(endpoint, source.clone())?; let mut ways = Vec::new(); for (bind, local_ep) in address_book.local_endpoint().iter() { if let Ok(way) = self.resolve_punch_connection(bind, local_ep, &endpoint, &source) { ways.push(way); } } Ok(ways) } pub fn remove_local_address(&self, addr: SocketAddr) -> io::Result<()> { let mut address_book = self.0.address_book.lock().unwrap(); let frame = address_book.remove_local_address(addr)?; self.0 .broker .send_frame([ReliableFrame::RemoveAddress(frame)]); Ok(()) } fn recv_remove_address_frame(&self, remove_address_frame: RemoveAddressFrame) { let mut address_book = self.0.address_book.lock().unwrap(); address_book.remove_remote_address(remove_address_frame.deref().into_u64() as u32); } fn recv_add_address_frame(&self, add_address_frame: AddAddressFrame) -> io::Result<()> { // The lock on address_book must be released before accessing the transaction map // to avoid a deadlock with recv_punch_me_now, which holds the transaction lock // while trying to acquire the address_book lock. let (bind, local) = { let mut address_book = self.0.address_book.lock().unwrap(); address_book.add_remote_address(add_address_frame)?; let (bind, local) = address_book.pick_local_address(&add_address_frame)?; (bind.clone(), local) }; let punch_id = (&local, &add_address_frame).punch_id(); if self.0.punch_history.contains(&punch_id) { tracing::debug!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", Some(local.nat_type()), Some(add_address_frame.nat_type())), "punch already completed, skipping"); return Ok(()); } match self.0.transaction.entry(punch_id) { Entry::Occupied(_) => { tracing::debug!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", Some(local.nat_type()), Some(add_address_frame.nat_type())), "dup transaction for punch"); return Ok(()); } Entry::Vacant(entry) => { let tx = Arc::new(Transaction::new()); let task = tokio::spawn( { let puncher = self.clone(); let tx = tx.clone(); async move { let result = puncher .punch_actively(bind, &local, &add_address_frame, tx) .await; puncher.0.punch_history.insert(punch_id); puncher.0.transaction.remove(&punch_id); result } } .instrument_in_current() .in_current_span(), ) .abort_handle(); entry.insert((task, tx.clone())); } }; Ok(()) } fn recv_punch_me_now( &self, pathway: PathWay, punch_me_now_frame: PunchMeNowFrame, ) -> io::Result<()> { let punch_id = punch_me_now_frame.punch_id().flip(); if self.0.punch_history.contains(&punch_id) { tracing::debug!(target: "punch", %punch_id, "punch already completed, skipping"); return Ok(()); } let crate_punch_task = || { let tx = Arc::new(Transaction::new()); let task = tokio::spawn({ let puncher = self.clone(); let tx = tx.clone(); let address_book = self.0.address_book.lock().unwrap(); let (bind, local_address) = address_book .get_local_address(&punch_me_now_frame.remote_seq()) .ok_or_else(|| { io::Error::new(io::ErrorKind::NotFound, "local address not matched") })?; tracing::debug!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", Some(local_address.nat_type()), Some(punch_me_now_frame.nat_type())), "received punch me now frame, start passive punch"); async move { let result = puncher .punch_passively(bind, &local_address, &punch_me_now_frame, tx) .await; puncher.0.punch_history.insert(punch_id); puncher.0.transaction.remove(&punch_id); result } .instrument_in_current() .in_current_span() }) .abort_handle(); Ok::<_, io::Error>((task, tx.clone())) }; match self.0.transaction.entry(punch_id) { Entry::Occupied(mut entry) => { if pathway.local() < pathway.remote() { let (task, tx) = crate_punch_task()?; tx.store_punch_me_now(punch_me_now_frame); let old_task = entry.get().0.clone(); old_task.abort(); entry.insert((task, tx.clone())); tracing::trace!(target: "punch", %punch_id, "new passive transaction for punch"); } else { let tx = entry.get().1.clone(); tracing::trace!(target: "punch", %punch_id, "using existing active transaction to respond to PunchMeNow"); tx.store_punch_me_now(punch_me_now_frame); } } Entry::Vacant(entry) => { let (task, tx) = crate_punch_task()?; entry.insert((task, tx.clone())); tracing::trace!(target: "punch", %punch_id, "new passive transaction"); } }; Ok(()) } async fn punch_actively( &self, bind_uri: BindUri, local: &AddAddressFrame, remote: &AddAddressFrame, tx: Arc, ) -> io::Result<()> { let local_nat = local.nat_type(); let remote_nat = remote.nat_type(); let bind_addr = SocketAddr::try_from(bind_uri.clone()) .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; let link = Link::new(bind_addr, *remote.deref()); let punch_id = (local, remote).punch_id(); tracing::debug!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "starting active punch"); let mut punch_me_now = PunchMeNowFrame::new( local.seq_num(), remote.seq_num(), *local.deref(), local.tire(), local_nat, ); let ifaces = self.0.ifaces.clone(); let dynamic_iface = { let ifaces = self.0.ifaces.clone(); let iface_factory = self.0.iface_factory.clone(); let quic_router = self.0.quic_router.clone(); let stun_servers = self.0.stun_servers.clone(); async move |bind_uri: &BindUri| { dynamic_iface( bind_uri, &ifaces, &iface_factory, &quic_router, &stun_servers, ) .await } }; let broker = self.0.broker.clone(); let punch_ifaces = &self.0.punch_ifaces; // local \ remote ·FullCone RestrictedCone RestrictedPort Symmetric Dynamic // FullCone 1 6 6 6 6 // RestrictedCone 1 6 6 6 6 // RestrictedPort 1 6 6 7 6 // Symmetric 1 4 3 / 8 // Dynamic 1 5 5 2 5 // 1: Remote is FullCone // Send direct Hello to remote, expecting Hello(Done). // 2: Local Dynamic, Remote Symmetric -> New Interface & Birthday Attack // Send PunchMeNow, expect PunchMeNow. After receiving, start collision, expect Hello(Done). // 3: Local Symmetric, Remote RestrictedPort -> Birthday Attack // Send PunchMeNow, expect PunchMeNow. Use random socket collision, expect Hello(Done). // 4: Local Symmetric, Remote RestrictedCone -> Reverse Punching // Send PunchMeNow, expect remote to open hole and respond PunchMeNow. Then send direct Hello, expect Hello(Done). // 5: Local Dynamic // New Interface, detect external address. Then send PunchMeNow and Hello, expect Hello(Done). // 6: General Punching // Send Hello with TTL and PunchMeNow. Expect Hello, then respond Hello(Done). // 7: Local RestrictedPort, Remote Symmetric -> Birthday Attack (Hold Hole) // Send packets to 300 random ports, then notify with PunchMeNow. Expect Hello, then respond Hello(Done). // 8: Local Symmetric, Remote Dynamic // Hold holes on 30 random ports, send PunchMeNow. Expect Collision, then respond PunchMeNow. // Repeat until 300 sockets used. use NatType::*; let result: io::Result<()> = match (local_nat, remote_nat) { (Blocked, _) | (_, Blocked) | (Symmetric, Symmetric) => { return Err(io::Error::other("Unsupported nat type")); } // 1: Remote is FullCone // Send direct Hello to remote, expecting Hello(Done). (_, FullCone) => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "strategy: Remote FullCone, sending direct Hello"); let iface = ifaces .borrow(&bind_uri) .ok_or_else(|| io::Error::other("No interface found"))?; let time = Duration::from_millis(100); for i in 0..5 { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), %link, "sending Hello expecting Hello(Done) or receiving Hello"); self.0 .send_packet( &iface, link, HELLO_TTL, PunchHelloFrame::new( punch_id.local_seq, punch_id.remote_seq, DEFAULT_PROBE_ID, ), ) .await?; let timeout_duration = time * (1 << i); tokio::select! { _ = tokio::time::sleep(timeout_duration) => { // continue loop } Ok((_, punch_hello)) = async { Ok::<_, io::Error>(tx.wait_punch_hello().await) } => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "received Hello, sending broker PunchDone confirmation"); broker.send_frame([ReliableFrame::PunchDone(PunchDoneFrame::respond_to(&punch_hello))]); return Ok(()); } _ = tx.wait_punch_done() => { tracing::debug!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "punch success"); return Ok(()); } } } tracing::debug!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "punch failed"); return Err(io::Error::new(io::ErrorKind::TimedOut, "punch timeout")); } // 2. Local Dynamic, Remote Symmetric -> New Interface & Birthday Attack // Send PunchMeNow, expect PunchMeNow. After receiving, start collision, expect Hello(Done). (Dynamic, Symmetric) => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "strategy: Local Dynamic, Remote Symmetric, new interface & birthday attack"); // TODO: Creating a new iface is not strictly necessary; could reuse an available temporary address. let (iface, stun_client) = dynamic_iface(&bind_uri).await?; let bind_uri = iface.bind_uri(); punch_ifaces.insert(bind_uri.clone(), iface.clone()); let outer_addr = stun_client.outer_addr().await?; punch_me_now.set_addr(outer_addr); tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "sending PunchMeNow expecting PunchMeNow then collision"); broker.send_frame([ReliableFrame::PunchMeNow(punch_me_now)]); let link = Link::new(iface.bound_addr()?, link.dst); let mut collided = false; let result: io::Result<()> = loop { tokio::select! { _ = tokio::time::sleep(BIRTHDAY_TIMEOUT)=> break Err(io::Error::new(io::ErrorKind::TimedOut, "Punch timeout")), _ = tx.wait_punch_me_now(), if !collided => { collided = true; self.0.collision(&iface, link, punch_id, KNOCK_TTL).await?; } Ok((link, punch_hello)) = async { Ok::<_, io::Error>(tx.wait_punch_hello().await) } => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), %link, "received Hello, sending broker PunchDone confirmation"); broker.send_frame([ReliableFrame::PunchDone(PunchDoneFrame::respond_to(&punch_hello))]); break Ok(()); } _ = tx.wait_punch_done() => break Ok(()), }; }; // If punch failed, clean up the interface if result.is_err() { punch_ifaces.remove(&bind_uri); ifaces.unbind(bind_uri).await; } result } // 3. Local Symmetric, Remote RestrictedPort -> Birthday Attack // Send PunchMeNow, expect PunchMeNow. Use random socket collision, expect Hello(Done). (Symmetric, RestrictedPort) => { // Send PunchMeNow first tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "sending PunchMeNow expecting PunchMeNow then rush"); broker.send_frame([ReliableFrame::PunchMeNow(punch_me_now)]); if timeout(COLLISION_TIMEOUT, tx.wait_punch_me_now()) .await .is_ok() { // Use new consolidated PortPredictor birthday attack let mut predictor = PortPredictor::new( ifaces.clone(), self.0.iface_factory.clone(), self.0.quic_router.clone(), bind_uri.clone(), link.dst, )?; // Create packet send function let puncher_ref = self.0.clone(); let packet_send_fn: PacketSendFn = Arc::new(move |iface, link, ttl, frame| { let puncher = puncher_ref.clone(); Box::pin(async move { puncher.send_packet(iface, link, ttl, frame).await }) }); tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "starting consolidated birthday attack"); match predictor .predict(punch_id, tx.clone(), packet_send_fn) .await { Ok(Some((bind_uri, iface))) => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), %bind_uri, "birthday attack succeeded"); self.0.punch_ifaces.insert(bind_uri.clone(), iface); return Ok(()); } Ok(None) => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "birthday attack completed without success"); } Err(e) => { tracing::warn!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), %e, "birthday attack failed"); } } } return Err(io::Error::new(io::ErrorKind::TimedOut, "punch timeout")); } // 4. Local Symmetric, Remote RestrictedCone -> Reverse Punching // Send PunchMeNow, expect remote to open hole and respond PunchMeNow. Then send direct Hello, expect Hello(Done). (Symmetric, RestrictedCone) => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "strategy: Local Symmetric, Remote RestrictedCone, reverse punching"); tracing::trace!(target: "punch", %punch_id, "sending PunchMeNow expecting PunchMeNow then Hello"); broker.send_frame([ReliableFrame::PunchMeNow(punch_me_now)]); if timeout(PUNCH_ME_NOW_TIMEOUT, tx.wait_punch_me_now()) .await .is_err() { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "wait for PunchMeNow timeout, try to connect blindly"); } let iface = ifaces .borrow(&bind_uri) .ok_or_else(|| io::Error::other("No interface found"))?; for i in 0..5 { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), %link, "sending Hello expecting Hello(Done)"); self.0 .send_packet( &iface, link, HELLO_TTL, PunchHelloFrame::new( punch_id.local_seq, punch_id.remote_seq, DEFAULT_PROBE_ID, ), ) .await?; if (timeout(KNOCK_TIMEOUT * (1 << i), tx.wait_punch_done()).await).is_ok() { tracing::debug!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "punch success"); return Ok(()); } } tracing::debug!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "punch failed"); return Err(io::Error::new(io::ErrorKind::TimedOut, "punch timeout")); } // 5. Local Dynamic // New Interface, detect external address. Then send PunchMeNow and Hello, expect Hello(Done). (Dynamic, _) => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "strategy: Local Dynamic, new interface & send PunchMeNow + Hello"); // Use new iface, update PunchMeNow address. // TODO: Creating a new iface is not strictly necessary; could reuse an available temporary address. let (iface, stun_client) = dynamic_iface(&bind_uri).await?; let outer_addr = stun_client.outer_addr().await?; let bind_uri = iface.bind_uri(); punch_ifaces.insert(bind_uri.clone(), iface.clone()); punch_me_now.set_addr(outer_addr); tracing::trace!(target: "punch", %punch_id, "sending PunchMeNow + Hello expecting Hello(Done)"); broker.send_frame([ReliableFrame::PunchMeNow(punch_me_now)]); let link = Link::new(iface.bound_addr()?, link.dst); let time = Duration::from_millis(100); for i in 0..MAX_RETRIES { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), %link, "sending Hello expecting Hello(Done)"); self.0 .send_packet( &iface, link, HELLO_TTL, PunchHelloFrame::new( punch_id.local_seq, punch_id.remote_seq, DEFAULT_PROBE_ID, ), ) .await?; let timeout_duration = time * (1 << i); tokio::select! { _ = tokio::time::sleep(timeout_duration) => { // continue loop } Ok((_, punch_hello)) = async { Ok::<_, io::Error>(tx.wait_punch_hello().await) } => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "received Hello, sending broker PunchDone confirmation"); broker.send_frame([ReliableFrame::PunchDone(PunchDoneFrame::respond_to(&punch_hello))]); return Ok(()); } _ = tx.wait_punch_done() => { tracing::debug!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "punch success"); return Ok(()); } } } // Punch failed, remove the interface punch_ifaces.remove(&bind_uri); ifaces.unbind(bind_uri).await; Err(io::Error::new(io::ErrorKind::TimedOut, "punch timeout")) } // 6. General Punching // Send Hello with TTL and PunchMeNow. Expect Hello, then respond Hello(Done). (FullCone | RestrictedCone, Symmetric) | (FullCone | RestrictedCone | RestrictedPort, Dynamic) | (_, RestrictedCone | RestrictedPort) => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "strategy: General punching, send Hello with TTL & PunchMeNow"); tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "sending PunchMeNow + Hello expecting Hello then Hello(Done)"); broker.send_frame([ReliableFrame::PunchMeNow(punch_me_now)]); let iface = ifaces .borrow(&bind_uri) .ok_or_else(|| io::Error::other("No interface found"))?; tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), %link, "sending Hello expecting Hello"); self.0 .send_packet( &iface, link, HELLO_TTL, PunchHelloFrame::new( punch_id.local_seq, punch_id.remote_seq, DEFAULT_PROBE_ID, ), ) .await?; if let Ok((_, punch_hello)) = timeout(PUNCH_TIMEOUT, tx.wait_punch_hello()).await { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "sending broker PunchDone confirmation"); broker.send_frame([ReliableFrame::PunchDone(PunchDoneFrame::respond_to( &punch_hello, ))]); tracing::debug!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "actively punch success"); return Ok(()); } tracing::debug!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "punch failed"); return Err(io::Error::new(io::ErrorKind::TimedOut, "punch timeout")); } // 7. Local RestrictedPort, Remote Symmetric -> Birthday Attack (Hold Hole) // Send packets to 300 random ports, then notify with PunchMeNow. Expect Hello, then respond Hello(Done). (RestrictedPort, Symmetric) => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "strategy: Local RestrictedPort, Remote Symmetric, birthday attack hold hole"); let iface = ifaces .borrow(&bind_uri) .ok_or_else(|| io::Error::other("No interface found"))?; self.0.collision(&iface, link, punch_id, KNOCK_TTL).await?; tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "sending PunchMeNow expecting Hello then Hello(Done)"); broker.send_frame([ReliableFrame::PunchMeNow(punch_me_now)]); if let Ok((link, punch_hello)) = timeout(BIRTHDAY_TIMEOUT, tx.wait_punch_hello()).await { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), %link, "sending broker PunchDone confirmation"); broker.send_frame([ReliableFrame::PunchDone(PunchDoneFrame::respond_to( &punch_hello, ))]); tracing::debug!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "punch success with collision"); return Ok(()); } return Err(io::Error::new(io::ErrorKind::TimedOut, "punch timeout")); } // 8. Local Symmetric, Remote Dynamic // Hold holes on 30 random ports, send PunchMeNow. Expect Collision, then respond PunchMeNow. // Repeat until 300 sockets used. (Symmetric, Dynamic) => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "strategy: Local Symmetric, Remote Dynamic, hold holes & send PunchMeNow"); // Use new consolidated PortPredictor birthday attack let mut predictor = PortPredictor::new( ifaces.clone(), self.0.iface_factory.clone(), self.0.quic_router.clone(), bind_uri.clone(), link.dst, )?; // Create packet send function let puncher_ref = self.0.clone(); let packet_send_fn: PacketSendFn = Arc::new(move |iface, link, ttl, frame| { let puncher = puncher_ref.clone(); Box::pin(async move { puncher.send_packet(iface, link, ttl, frame).await }) }); // Send initial PunchMeNow to notify peer tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "sending initial PunchMeNow for Dynamic strategy"); broker.send_frame([ReliableFrame::PunchMeNow(punch_me_now)]); tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "starting consolidated birthday attack for Dynamic strategy"); match predictor .predict(punch_id, tx.clone(), packet_send_fn) .await { Ok(Some((bind_uri, iface))) => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), %bind_uri, "birthday attack succeeded for Dynamic strategy"); self.0.punch_ifaces.insert(bind_uri.clone(), iface); return Ok(()); } Ok(None) => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "birthday attack completed without success for Dynamic strategy"); } Err(e) => { tracing::warn!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), %e, "birthday attack failed for Dynamic strategy"); } } return Err(io::Error::new(io::ErrorKind::TimedOut, "punch timeout")); } }; result } async fn punch_passively( &self, bind: BindUri, local_address: &AddAddressFrame, remote_address: &PunchMeNowFrame, tx: Arc, ) -> io::Result<()> { use NatType::*; let remote_nat = remote_address.nat_type(); let local_nat = local_address.nat_type(); let punch_id = PunchId::new(local_address.seq_num(), remote_address.local_seq()); tracing::debug!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "starting passive punch"); let socket_addr = SocketAddr::try_from(bind.clone()) .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; if local_nat == Blocked || remote_nat == Blocked || (local_nat == Symmetric && remote_nat == Symmetric) { return Err(io::Error::other("Unsupported nat type")); } let link = Link::new(socket_addr, remote_address.address()); let ifaces = self.0.ifaces.clone(); let broker = self.0.broker.clone(); // Note: Receiving PunchMeNow implies we sent an AddAddress frame. // For Dynamic NAT, we don't need to create a new interface here; // it should have been created before sending AddAddress. // 1. Local Dynamic, Remote Symmetric // Remote has opened hole. We use new interface to collide, expecting Hello(Done). // 2. Local RestrictedPort, Remote Symmetric // We open holes on 300 random ports, send PunchMeNow. Expect Hello collision, then respond Hello(Done). // 3. Local Symmetric, Remote RestrictedPort | Dynamic // We use random socket collision to open hole, expecting Hello(Done). // 4. Local RestrictedCone, Remote Symmetric // Reflect, hello then Send PunchmeNow, wait for hello, send Hello(Done). // 5. General Punching // Received PunchMeNow implies remote has opened hole. We send direct Hello, expecting Hello(Done). match (local_nat, remote_nat) { // 1. Local Dynamic, Remote Symmetric // Remote has opened hole. We use new interface to collide, expecting Hello(Done). (Dynamic, Symmetric) => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "passive strategy: Local Dynamic, Remote Symmetric, use new interface to collide"); let iface = ifaces .borrow(&bind) .ok_or_else(|| io::Error::other("No interface found"))?; let mut collided = false; loop { tokio::select! { _ = tokio::time::sleep(BIRTHDAY_TIMEOUT)=> return Err(io::Error::new(io::ErrorKind::TimedOut, "Punch timeout")), _ = tx.wait_punch_me_now(), if !collided => { collided = true; self.0.collision(&iface, link, punch_id, KNOCK_TTL).await?; } Ok((link, punch_hello)) = async { Ok::<_, io::Error>(tx.wait_punch_hello().await) } => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), %link, "received Hello, sending broker PunchDone confirmation"); broker.send_frame([ReliableFrame::PunchDone(PunchDoneFrame::respond_to(&punch_hello))]); return Ok(()); } _ = tx.wait_punch_done() => return Ok::<(), io::Error>(()), }; } } // 2. Local RestrictedPort, Remote Symmetric // We open holes on 300 random ports, send PunchMeNow. Expect Hello collision, then respond Hello(Done). (RestrictedPort, Symmetric) => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "passive strategy: Local RestrictedPort, Remote Symmetric, open holes & send PunchMeNow"); let iface = ifaces .borrow(&bind) .ok_or_else(|| io::Error::other("No interface found"))?; self.0.collision(&iface, link, punch_id, KNOCK_TTL).await?; let punch_me_now = PunchMeNowFrame::new( punch_id.local_seq, punch_id.remote_seq, *local_address.deref(), local_address.tire(), local_nat, ); tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "sending PunchMeNow expecting Hello then Hello(Done)"); broker.send_frame([ReliableFrame::PunchMeNow(punch_me_now)]); if let Ok((link, punch_hello)) = tokio::time::timeout(BIRTHDAY_TIMEOUT, tx.wait_punch_hello()).await { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), %link, "sending broker PunchDone confirmation"); broker.send_frame([ReliableFrame::PunchDone(PunchDoneFrame::respond_to( &punch_hello, ))]); return Ok(()); } } // 3. Local Symmetric, Remote RestrictedPort // Use new consolidated PortPredictor birthday attack. Expect Hello(Done). (Symmetric, RestrictedPort | Dynamic) => { let mut predictor = PortPredictor::new( ifaces.clone(), self.0.iface_factory.clone(), self.0.quic_router.clone(), bind.clone(), link.dst, )?; // Create packet send function let puncher_ref = self.0.clone(); let packet_send_fn: PacketSendFn = Arc::new(move |iface, link, ttl, frame| { let puncher = puncher_ref.clone(); Box::pin(async move { puncher.send_packet(iface, link, ttl, frame).await }) }); tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "starting consolidated birthday attack"); match predictor .predict(punch_id, tx.clone(), packet_send_fn) .await { Ok(Some((bind_uri, iface))) => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), %bind_uri, "birthday attack succeeded"); self.0.punch_ifaces.insert(bind_uri.clone(), iface); return Ok(()); } Ok(None) => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "birthday attack completed without success"); } Err(e) => { tracing::warn!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), %e, "birthday attack failed"); } } } // 4. Local RestrictedCone, Remote Symmetric // Reflect, Hello and PunchmeNow, wait for hello, send Hello(Done) (RestrictedCone, Symmetric) => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "passive strategy: Local RestrictedCone, Remote Symmetric, reflect & send PunchMeNow"); let iface = ifaces .borrow(&bind) .ok_or_else(|| io::Error::other("No interface found"))?; let punch_me_now = PunchMeNowFrame::new( punch_id.local_seq, punch_id.remote_seq, *local_address.deref(), local_address.tire(), local_nat, ); tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "sending PunchMeNow expecting Hello then Hello(Done)"); let punch_hello_frame = PunchHelloFrame::new(punch_id.local_seq, punch_id.remote_seq, DEFAULT_PROBE_ID); self.0 .send_packet(&iface, link, HELLO_TTL, punch_hello_frame) .await?; broker.send_frame([ReliableFrame::PunchMeNow(punch_me_now)]); if let Ok((link, punch_hello)) = tokio::time::timeout(PUNCH_TIMEOUT, tx.wait_punch_hello()).await { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), %link, "sending broker PunchDone confirmation"); broker.send_frame([ReliableFrame::PunchDone(PunchDoneFrame::respond_to( &punch_hello, ))]); return Ok(()); } } // 5. General Punching // Received PunchMeNow implies remote has opened hole. We send direct Hello, expecting Hello(Done). _ => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "passive strategy: General punching, send direct Hello"); let iface = ifaces .borrow(&bind) .ok_or_else(|| io::Error::other("No interface found"))?; let time = Duration::from_millis(100); for i in 0..MAX_RETRIES { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), %link, "sending Hello expecting Hello(Done)"); self.0 .send_packet( &iface, link, HELLO_TTL, PunchHelloFrame::new( punch_id.local_seq, punch_id.remote_seq, DEFAULT_PROBE_ID, ), ) .await?; let timeout_duration = time * (1 << i); tokio::select! { _ = tokio::time::sleep(timeout_duration) => { // continue loop } Ok((_, punch_hello)) = async { Ok::<_, io::Error>(tx.wait_punch_hello().await) } => { tracing::trace!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "received Hello, sending broker PunchDone confirmation"); broker.send_frame([ReliableFrame::PunchDone(PunchDoneFrame::respond_to(&punch_hello))]); return Ok(()); } _ = tx.wait_punch_done() => { tracing::debug!(target: "punch", %punch_id, nat_pair = %format!("{:?}->{:?}", local_nat, remote_nat), "passively punch success"); return Ok(()); } } } } }; Err(io::Error::new(io::ErrorKind::TimedOut, "punch timeout")) } fn resolve_punch_connection( &self, bind: &BindUri, local: &EndpointAddr, remote: &EndpointAddr, source: &qresolve::Source, ) -> io::Result<(BindUri, Link, PathWay)> { if let qresolve::Source::Mdns { nic, family } = source { let matches_iface = bind .as_iface_bind_uri() .is_some_and(|(lf, ln, _)| lf == *family && ln == nic.as_ref()); if !matches_iface { return Err(io::Error::other( "Bind URI does not match source constraint", )); } } if local == remote { return Err(io::Error::other("Local and remote endpoints are identical")); } let (local_addr, remote_addr) = self.extract_addresses(bind, local, remote)?; if local_addr.family() != remote_addr.family() { return Err(io::Error::other( "Local and remote addresses must be in the same address family", )); } let link = Link::new(local_addr, remote_addr); let pathway = PathWay::new(*local, *remote); Ok((bind.clone(), link, pathway)) } fn extract_addresses( &self, bind: &BindUri, local: &EndpointAddr, remote: &EndpointAddr, ) -> io::Result<(SocketAddr, SocketAddr)> { use EndpointAddr::*; match (local, remote) { (Direct { addr: local_addr }, Direct { addr: remote_addr }) => { Ok((*local_addr, *remote_addr)) } ( Agent { .. }, Agent { agent: remote_agent, .. }, ) => { let iface = self.0.ifaces.borrow(bind).ok_or_else(|| { io::Error::new( io::ErrorKind::NotFound, format!("Interface not found for bind URI: {:?}", bind), ) })?; Ok((iface.bound_addr()?, *remote_agent)) } _ => Err(io::Error::other( "Unsupported endpoint type combination for punching", )), } } } impl ReceiveFrame<(BindUri, PathWay, Link, ReliableFrame)> for ArcPuncher where TX: SendFrame + Send + Sync + Clone + 'static, PH: ProductHeader + Send + Sync + 'static, S: PacketSpace + Send + Sync + 'static, for<'b> PunchDoneFrame: Package>, for<'b> PunchHelloFrame: Package>, for<'b> PadTo20: Package>, { type Output = (); fn recv_frame( &self, (_bind, pathway, link, frame): (BindUri, PathWay, Link, ReliableFrame), ) -> Result { tracing::debug!(target: "punch", %pathway, %link, frame = ?frame, "received reliable punch frame"); match frame { ReliableFrame::AddAddress(add_address_frame) => { _ = self.recv_add_address_frame(add_address_frame); } ReliableFrame::PunchMeNow(punch_me_now_frame) => { _ = self.recv_punch_me_now(pathway, punch_me_now_frame); } ReliableFrame::RemoveAddress(remove_address_frame) => { self.recv_remove_address_frame(remove_address_frame); } ReliableFrame::PunchDone(frame) => { let punch_id = frame.punch_id().flip(); match self.0.transaction.entry(punch_id) { Entry::Occupied(mut entry) => { let tx = entry.get_mut().1.clone(); _ = tx.recv_frame((link, frame)); } Entry::Vacant(_) => { tracing::debug!(target: "punch", %punch_id, frame = ?frame, %link, "received unexpected punch done frame"); } } } frame => { tracing::debug!(target: "punch", frame = ?frame, "received unexpected reliable punch frame"); } }; Ok(()) } } impl ReceiveFrame<(BindUri, PathWay, Link, PunchHelloFrame)> for ArcPuncher where TX: SendFrame + Send + Sync + Clone + 'static, PH: ProductHeader + Send + Sync + 'static, S: PacketSpace + Send + Sync + 'static, for<'b> PunchDoneFrame: Package>, for<'b> PunchHelloFrame: Package>, for<'b> PadTo20: Package>, { type Output = (); fn recv_frame( &self, (_bind, pathway, link, frame): (BindUri, PathWay, Link, PunchHelloFrame), ) -> Result { tracing::debug!(target: "punch", %pathway, %link, frame = ?frame, "received punch hello frame"); let punch_id = frame.punch_id().flip(); match self.0.transaction.entry(punch_id) { Entry::Occupied(mut entry) => { let tx = entry.get_mut().1.clone(); _ = tx.recv_frame((link, frame)); } Entry::Vacant(_) => { tracing::trace!(target: "punch", %punch_id, frame = ?frame, %link, "received unsolicited punch hello, replying with broker PunchDone"); self.0 .broker .send_frame([ReliableFrame::PunchDone(PunchDoneFrame::respond_to(&frame))]); } } Ok(()) } } #[inline] async fn dynamic_iface( bind_uri: &BindUri, ifaces: &Arc, iface_factory: &Arc, quic_router: &Arc, stun_servers: &[SocketAddr], ) -> io::Result<(Interface, StunClient)> { const MIN_PORT: u16 = 1024; const MAX_PORT: u16 = u16::MAX; let (ip_family, device, _port) = bind_uri.as_iface_bind_uri().ok_or_else(|| { let error = "Invalid bind uri, expected bind uri with iface schema"; io::Error::new(io::ErrorKind::InvalidInput, error) })?; let port = rand::random::() % (MAX_PORT - MIN_PORT) + MIN_PORT; let bind_uri = format!( "iface://{ip_family}.{device}:{port}?{}=true", BindUri::TEMPORARY_PROP ); let bind_uri = BindUri::from_str(bind_uri.as_str()) .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; ifaces .bind(bind_uri, iface_factory.clone()) .await .with_components_mut(|components, iface| { // Ensure this temporary iface can receive+deliver QUIC packets to the connection. // Must use the connection-owned router. components.init_with(|| QuicRouterComponent::new(quic_router.clone())); let local_addr = iface.bound_addr()?; let stun_server = *stun_servers .iter() .find(|addr| addr.is_ipv4() == local_addr.is_ipv4()) .ok_or_else(|| io::Error::other("No STUN server matches local address family"))?; let stun_router = components .init_with(|| { let ref_iface = iface.downgrade(); StunRouterComponent::new(ref_iface) }) .router(); let stun_client = components .init_with(|| { let client = StunClient::new(iface.downgrade(), stun_router.clone(), stun_server, None); StunClientComponent::new(client) }) .client(); components.init_with(|| { ReceiveAndDeliverPacket::builder(iface.downgrade()) .quic_router(quic_router.clone()) .stun_router(stun_router) .init() }); Ok((iface.to_owned(), stun_client)) }) } ================================================ FILE: qtraversal/src/punch/scheduler.rs ================================================ use std::{ collections::{HashMap, VecDeque}, io, net::SocketAddr, sync::{Arc, LazyLock, Mutex}, task::{Context, Poll, Waker}, time::Duration, }; use qbase::net::{AddrFamily, Family}; use tokio::time::Instant; pub static SCHEDULER: LazyLock>> = LazyLock::new(|| Arc::new(Mutex::new(Scheduler::new()))); const MAX_SOCKETS_PER_DEVICE: u32 = 300; const MAX_PORTS_PER_DEVICE: u32 = 600; const MAX_TOTAL_SOCKETS: u32 = 600; const MAX_TOTAL_PORTS: u32 = 1200; const PORT_COOLING_INTERVAL: Duration = Duration::from_secs(60); pub struct Scheduler { devices: HashMap, pub(crate) total_sockets: u32, pub(crate) total_ports: u32, cooling: VecDeque<(Instant, u32)>, waiters: VecDeque, } impl Scheduler { fn new() -> Self { Self { devices: HashMap::new(), total_sockets: 0, total_ports: 0, cooling: VecDeque::new(), waiters: VecDeque::new(), } } fn reap_cooling(&mut self) { let now = Instant::now(); self.cooling.retain(|(time, count)| { if now - *time > PORT_COOLING_INTERVAL { self.total_ports = self.total_ports.saturating_sub(*count); false } else { true } }); } fn global_available(&self) -> u32 { let by_socket = MAX_TOTAL_SOCKETS.saturating_sub(self.total_sockets); let by_port = MAX_TOTAL_PORTS.saturating_sub(self.total_ports); by_socket.min(by_port) } pub fn poll_allocate( &mut self, cx: &Context, dest: SocketAddr, device: String, count: u32, ) -> Poll> { self.reap_cooling(); let global_avail = self.global_available(); let key = DeviceKey::new(device.clone(), dest.ip().family()); let ledger = self.devices.entry(key).or_insert_with(DeviceLedger::new); ledger.reap_cooling(); let device_avail = ledger.available(); let granted = global_avail.min(device_avail).min(count); tracing::trace!(target: "punch", global_avail, device_avail, granted, count, total_sockets = self.total_sockets, total_ports = self.total_ports, "Poll allocate" ); if granted > 0 { ledger.sockets += granted; ledger.ports += granted; *ledger.per_dest.entry(dest).or_insert(0) += granted; self.total_sockets += granted; self.total_ports += granted; tracing::trace!(target: "punch", ?dest, device, granted, "Port allocated"); Poll::Ready(Ok(granted)) } else { if !self.waiters.iter().any(|w| w.will_wake(cx.waker())) { self.waiters.push_back(cx.waker().clone()); } tracing::trace!(target: "punch", ?dest, device, count, "Port allocation pending"); Poll::Pending } } pub fn release_port(&mut self, count: u32, dst: SocketAddr, device: String) -> io::Result<()> { let key = DeviceKey::new(device.clone(), dst.ip().family()); let ledger = self.devices.get_mut(&key).ok_or_else(|| { tracing::trace!(target: "punch", ?dst, device, "Device not found"); io::Error::other("device not found") })?; if ledger.sockets < count { tracing::trace!(target: "punch", sockets = ledger.sockets, count, ?dst, "Insufficient sockets"); return Err(io::Error::other("insufficient sockets")); } let dest_count = ledger.per_dest.get(&dst).copied().unwrap_or(0); if dest_count < count { tracing::trace!(target: "punch", ?dst, dest_count, count, "Socket count mismatch"); return Err(io::Error::other("socket count mismatch")); } // Device: release sockets immediately, ports enter cooling ledger.sockets -= count; let now = Instant::now(); ledger.cooling.push_back((now, count)); if dest_count > count { ledger.per_dest.insert(dst, dest_count - count); } else { ledger.per_dest.remove(&dst); } // Global: release sockets immediately, ports enter cooling self.total_sockets = self.total_sockets.saturating_sub(count); self.cooling.push_back((now, count)); tracing::trace!(target: "punch", ?dst, device, count, "Port released"); for waker in self.waiters.drain(..) { waker.wake(); } Ok(()) } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct DeviceKey { device: String, family: Family, } impl DeviceKey { fn new(device: String, family: Family) -> Self { Self { device, family } } } struct DeviceLedger { sockets: u32, ports: u32, cooling: VecDeque<(Instant, u32)>, per_dest: HashMap, } impl DeviceLedger { fn new() -> Self { Self { sockets: 0, ports: 0, cooling: VecDeque::new(), per_dest: HashMap::new(), } } fn reap_cooling(&mut self) { let now = Instant::now(); self.cooling.retain(|(time, count)| { if now - *time > PORT_COOLING_INTERVAL { self.ports = self.ports.saturating_sub(*count); false } else { true } }); } fn available(&self) -> u32 { let by_socket = MAX_SOCKETS_PER_DEVICE.saturating_sub(self.sockets); let by_port = MAX_PORTS_PER_DEVICE.saturating_sub(self.ports); by_socket.min(by_port) } } #[cfg(test)] mod tests { use std::net::{IpAddr, Ipv4Addr}; use futures::task::noop_waker_ref; use super::*; fn test_addr() -> SocketAddr { SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080) } fn test_cx() -> Context<'static> { Context::from_waker(noop_waker_ref()) } #[test] fn test_scheduler_new() { let s = Scheduler::new(); assert_eq!(s.total_sockets, 0); assert_eq!(s.total_ports, 0); assert!(s.devices.is_empty()); assert!(s.cooling.is_empty()); } #[test] fn test_allocation_success() { let mut s = Scheduler::new(); let cx = test_cx(); let dest = test_addr(); let result = s.poll_allocate(&cx, dest, "eth0".into(), 10); assert!(matches!(result, Poll::Ready(Ok(10)))); assert_eq!(s.total_sockets, 10); assert_eq!(s.total_ports, 10); let key = DeviceKey::new("eth0".into(), dest.ip().family()); let ledger = s.devices.get(&key).unwrap(); assert_eq!(ledger.sockets, 10); assert_eq!(ledger.ports, 10); assert_eq!(*ledger.per_dest.get(&dest).unwrap(), 10); } #[test] fn test_allocation_pending() { let mut s = Scheduler::new(); let cx = test_cx(); let dest = test_addr(); // Fill device to its limit let _ = s.poll_allocate(&cx, dest, "eth0".into(), MAX_SOCKETS_PER_DEVICE); let result = s.poll_allocate(&cx, dest, "eth0".into(), 1); assert!(matches!(result, Poll::Pending)); assert_eq!(s.waiters.len(), 1); } #[test] fn test_release_port() { let mut s = Scheduler::new(); let cx = test_cx(); let dest = test_addr(); let _ = s.poll_allocate(&cx, dest, "eth0".into(), 10); assert!(s.release_port(10, dest, "eth0".into()).is_ok()); assert_eq!(s.total_sockets, 0); assert_eq!(s.cooling.len(), 1); let key = DeviceKey::new("eth0".into(), dest.ip().family()); let ledger = s.devices.get(&key).unwrap(); assert_eq!(ledger.sockets, 0); assert!(!ledger.per_dest.contains_key(&dest)); } #[test] fn test_global_limits() { let mut s = Scheduler::new(); let cx = test_cx(); let dest = test_addr(); // Fill eth0 to device limit (300) let r1 = s.poll_allocate(&cx, dest, "eth0".into(), MAX_SOCKETS_PER_DEVICE); assert!(matches!(r1, Poll::Ready(Ok(c)) if c == MAX_SOCKETS_PER_DEVICE)); assert_eq!(s.total_sockets, MAX_SOCKETS_PER_DEVICE); // Fill eth1 with remaining global capacity let remain = MAX_TOTAL_SOCKETS - MAX_SOCKETS_PER_DEVICE; let r2 = s.poll_allocate(&cx, dest, "eth1".into(), remain); assert!(matches!(r2, Poll::Ready(Ok(c)) if c == remain)); assert_eq!(s.total_sockets, MAX_TOTAL_SOCKETS); // Global full → Pending let r3 = s.poll_allocate(&cx, dest, "eth0".into(), 1); assert!(matches!(r3, Poll::Pending)); } #[test] fn test_device_limits() { let mut s = Scheduler::new(); let cx = test_cx(); let dest = test_addr(); let r = s.poll_allocate(&cx, dest, "eth0".into(), MAX_SOCKETS_PER_DEVICE); assert!(matches!(r, Poll::Ready(Ok(c)) if c == MAX_SOCKETS_PER_DEVICE)); let key = DeviceKey::new("eth0".into(), dest.ip().family()); assert_eq!(s.devices.get(&key).unwrap().sockets, MAX_SOCKETS_PER_DEVICE); let r = s.poll_allocate(&cx, dest, "eth0".into(), MAX_SOCKETS_PER_DEVICE); assert!(matches!(r, Poll::Pending)); } #[test] fn test_global_not_updated_on_device_pending() { let mut s = Scheduler::new(); let cx = test_cx(); let dest = test_addr(); let _ = s.poll_allocate(&cx, dest, "eth0".into(), 10); // Manually max out the device let key = DeviceKey::new("eth0".into(), dest.ip().family()); if let Some(ledger) = s.devices.get_mut(&key) { ledger.sockets = MAX_SOCKETS_PER_DEVICE; ledger.ports = MAX_PORTS_PER_DEVICE; } // Device full → Pending, global unchanged let r = s.poll_allocate(&cx, dest, "eth0".into(), 1); assert!(matches!(r, Poll::Pending)); assert_eq!(s.total_sockets, 10); assert_eq!(s.total_ports, 10); } #[test] fn test_mutex_protection() { use std::sync::Arc; use tokio::sync::Mutex; let scheduler = Arc::new(Mutex::new(Scheduler::new())); let mut handles = vec![]; for _ in 0..5 { let s = Arc::clone(&scheduler); handles.push(std::thread::spawn(move || { let mut s = s.blocking_lock(); s.total_sockets += 1; s.total_ports += 1; true })); } for h in handles { assert!(h.join().unwrap()); } let s = scheduler.try_lock().unwrap(); assert_eq!(s.total_sockets, 5); assert_eq!(s.total_ports, 5); } } ================================================ FILE: qtraversal/src/punch/tx.rs ================================================ use std::fmt; use qbase::{ frame::{AddAddressFrame, PunchDoneFrame, PunchHelloFrame, PunchMeNowFrame, io::ReceiveFrame}, net::route::Link, }; use tokio::sync::SetOnce; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub(crate) struct PunchId { pub local_seq: u32, pub remote_seq: u32, } impl PunchId { pub fn new(local_seq: u32, remote_seq: u32) -> Self { Self { local_seq, remote_seq, } } pub fn flip(self) -> Self { Self { local_seq: self.remote_seq, remote_seq: self.local_seq, } } } impl fmt::Display for PunchId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "({}, {})", self.local_seq, self.remote_seq) } } pub(crate) trait AsPunchId { fn punch_id(&self) -> PunchId; } impl AsPunchId for PunchHelloFrame { fn punch_id(&self) -> PunchId { PunchId::new(self.local_seq(), self.remote_seq()) } } impl AsPunchId for PunchDoneFrame { fn punch_id(&self) -> PunchId { PunchId::new(self.local_seq(), self.remote_seq()) } } impl AsPunchId for PunchMeNowFrame { fn punch_id(&self) -> PunchId { PunchId::new(self.local_seq(), self.remote_seq()) } } impl AsPunchId for (&AddAddressFrame, &AddAddressFrame) { fn punch_id(&self) -> PunchId { PunchId::new(self.0.seq_num(), self.1.seq_num()) } } pub(crate) struct Transaction { punch_me_now_frame: SetOnce, punch_hello_frame: SetOnce<(Link, PunchHelloFrame)>, punch_done_frame: SetOnce<(Link, PunchDoneFrame)>, } impl Transaction { pub fn new() -> Self { Self { punch_me_now_frame: SetOnce::new(), punch_hello_frame: SetOnce::new(), punch_done_frame: SetOnce::new(), } } pub async fn wait_punch_done(&self) -> (Link, PunchDoneFrame) { *self.punch_done_frame.wait().await } pub fn try_punch_done(&self) -> Option<(Link, PunchDoneFrame)> { self.punch_done_frame.get().copied() } pub async fn wait_punch_hello(&self) -> (Link, PunchHelloFrame) { *self.punch_hello_frame.wait().await } pub async fn wait_punch_me_now(&self) -> PunchMeNowFrame { *self.punch_me_now_frame.wait().await } pub fn store_punch_me_now(&self, frame: PunchMeNowFrame) { _ = self.punch_me_now_frame.set(frame); } } impl ReceiveFrame<(Link, PunchHelloFrame)> for Transaction { type Output = (); fn recv_frame( &self, (link, frame): (Link, PunchHelloFrame), ) -> Result { _ = self.punch_hello_frame.set((link, frame)); Ok(()) } } impl ReceiveFrame<(Link, PunchDoneFrame)> for Transaction { type Output = (); fn recv_frame( &self, (link, frame): (Link, PunchDoneFrame), ) -> Result { _ = self.punch_done_frame.set((link, frame)); Ok(()) } } impl ReceiveFrame<(Link, PunchMeNowFrame)> for Transaction { type Output = (); fn recv_frame( &self, (_link, frame): (Link, PunchMeNowFrame), ) -> Result { _ = self.punch_me_now_frame.set(frame); Ok(()) } } ================================================ FILE: qtraversal/src/punch.rs ================================================ pub(super) mod predictor; pub mod puncher; pub(super) mod scheduler; pub(super) mod tx; ================================================ FILE: qtraversal/src/route.rs ================================================ use std::{ convert::identity, io, net::SocketAddr, pin::Pin, sync::{Arc, Mutex, MutexGuard}, task::{ Context, Poll::{self, Ready}, ready, }, }; use bytes::BytesMut; use qbase::{ net::{ addr::EndpointAddr, route::{Line, Link, Route}, }, util::ArcAsyncDeque, }; use qinterface::{ Interface, WeakInterface, component::{ Component, route::{QuicRouter, QuicRouterComponent}, }, io::{IO, IoExt, RefIO}, }; use smallvec::SmallVec; use tokio_util::task::AbortOnDropHandle; pub type ArcRecvQueue = ArcAsyncDeque<(BytesMut, PathWay, Link)>; use crate::{ PathWay, nat::{ client::{StunClients, StunClientsComponent}, router::{StunRouter, StunRouterComponent}, }, packet::{ForwardHeader, StunHeader}, }; #[derive(Debug, Clone)] pub enum Forwarder { Clients { stun_clients: StunClients }, Server { outer_addr: SocketAddr }, } impl Forwarder { pub fn outers(&self) -> SmallVec<[SocketAddr; 8]> { match self { Forwarder::Clients { stun_clients } => stun_clients.with_clients(|clients| { clients .values() .filter_map(|client| client.get_outer_addr()?.ok()) .collect() }), Forwarder::Server { outer_addr } => SmallVec::from_iter([*outer_addr]), } } pub fn should_forward(&self, dst: EndpointAddr) -> Option { let outers = self.outers(); if outers.is_empty() { return None; } let EndpointAddr::Agent { agent, outer: dst_outer, } = dst else { return None; }; for outer in outers { if outer == dst_outer { return None; } if outer == agent { return Some(dst_outer); } } Some(agent) } } #[derive(Debug)] pub struct ForwardersComponent { forward: Mutex>, } impl ForwardersComponent { pub fn new(forwarder: Forwarder) -> Self { Self { forward: Mutex::new(forwarder), } } pub fn new_client(stun_clients: StunClients) -> Self { Self::new(Forwarder::Clients { stun_clients }) } pub fn new_server(outer_addr: SocketAddr) -> Self { Self::new(Forwarder::Server { outer_addr }) } fn lock_forwarders(&self) -> MutexGuard<'_, Forwarder> { self.forward.lock().expect("Forwarder lock poisoned") } pub fn forwarder(&self) -> Forwarder { self.lock_forwarders().clone() } } impl Component for ForwardersComponent { fn poll_shutdown(&self, _cx: &mut Context<'_>) -> Poll<()> { Poll::Ready(()) } fn reinit(&self, iface: &Interface) { _ = iface.with_component(|clients: &StunClientsComponent| { clients.reinit(iface); *self.lock_forwarders() = Forwarder::Clients { stun_clients: clients.clone(), }; }); } } #[derive(Debug)] pub struct ReceiveAndDeliverPacket { task: Mutex>>>, quic: bool, stun: bool, forward: bool, } pub type ReceiveAndDeliverPacketComponent = ReceiveAndDeliverPacket; #[bon::bon] impl ReceiveAndDeliverPacket { #[builder(finish_fn = init)] pub fn new( #[builder(start_fn)] weak_iface: WeakInterface, quic_router: Option>, stun_router: Option, forwarder: Option>, ) -> Self { let enable_quic = quic_router.is_some(); let enable_stun = stun_router.is_some(); let enable_forward = forwarder.is_some(); let task = Self::task() .maybe_quic_router(quic_router) .maybe_stun_router(stun_router) .maybe_forwarder(forwarder) .iface_ref(weak_iface) .spawn(); Self { task: Mutex::new(Some(task)), quic: enable_quic, stun: enable_stun, forward: enable_forward, } } #[builder(finish_fn = spawn)] pub fn task( quic_router: Option>, stun_router: Option, forwarder: Option>, iface_ref: I, ) -> AbortOnDropHandle> { AbortOnDropHandle::new(tokio::spawn(async move { let iface = iface_ref.iface(); let bind_uri = iface.bind_uri(); let deliver_quic_packet = async |pkt: BytesMut, route: Route| { let Some(quic_router) = quic_router.as_ref() else { return; }; use qbase::packet::{self, Packet, PacketReader}; fn is_initial_packet(pkt: &Packet) -> bool { matches!(pkt, Packet::Data(packet) if matches!(packet.header, packet::DataHeader::Long(packet::long::DataHeader::Initial(..)))) } let size = pkt.len(); let bind_uri = bind_uri.clone(); for (packet, way) in PacketReader::new(pkt, 8) .flatten() .filter(move |pkt| !(is_initial_packet(pkt) && size < 1100)) .map(move |pkt| (pkt, (bind_uri.clone(), route.pathway(), route.link()))) { quic_router.deliver(packet, way).await; } }; let deliver_stun_packet = async |mut pkt: BytesMut, route: Route| { let Some(stun_router) = stun_router.as_ref() else { return; }; use crate::nat::msg::be_packet; let pkt = pkt.split_off(StunHeader::encoding_size()); let Ok((.., (txid, packet))) = be_packet(&pkt) else { return; }; stun_router.deliver_stun_packet(txid, packet, route.link()); }; let deliver_forward_packet = async |mut pkt: BytesMut, mut route: Route, fhdr: ForwardHeader| { if let Some(forwarder) = forwarder.as_ref() && let Some(target) = forwarder.should_forward(fhdr.pathway().remote()) { let bufs = &[io::IoSlice::new(&pkt)]; let new_link = Link::new(iface.bound_addr()?, target); let new_line = Line::new(new_link, 64, None, pkt.len() as u16); let new_route = Route::new(route.link.into(), new_line); return iface.sendmmsg(bufs, new_route).await; }; // split_off forward header, deliver the rest as quic packet let pkt = pkt.split_off(ForwardHeader::encoding_size(&fhdr.pathway())); route.seg_size = pkt.len() as _; let new_route = Route::new(fhdr.pathway().flip().map(Into::into), route.line); deliver_quic_packet(pkt, new_route).await; Ok(()) }; let (mut bufs, mut hdrs) = (vec![], vec![]); loop { use crate::packet::{Header, be_header}; for (pkt, hdr) in iface.recvmmsg(&mut bufs, &mut hdrs).await? { match be_header(&pkt) { // quic Err(_) => deliver_quic_packet(pkt, hdr).await, // stun Ok((_remain, Header::Stun(_stun_header))) => { deliver_stun_packet(pkt, hdr).await } // forward Ok((_remain, Header::Forward(forward_header))) => { deliver_forward_packet(pkt, hdr, forward_header).await? } } } } })) } } impl ReceiveAndDeliverPacket { fn lock_task(&self) -> MutexGuard<'_, Option>>> { self.task.lock().unwrap() } pub fn reinit(&self, iface: &Interface) { _ = iface.with_components(|components| { let quic_router = (self.quic) .then(|| components.with(QuicRouterComponent::router)) .and_then(identity); let stun_router = self .stun .then(|| components.with(StunRouterComponent::router)) .and_then(identity); let forwarder = self .forward .then(|| components.with(ForwardersComponent::forwarder)) .and_then(identity); *self.lock_task() = Some( Self::task() .maybe_quic_router(quic_router) .maybe_stun_router(stun_router) .maybe_forwarder(forwarder) .iface_ref(iface.downgrade()) .spawn(), ); }); } } impl Component for ReceiveAndDeliverPacket { fn poll_shutdown(&self, cx: &mut Context<'_>) -> std::task::Poll<()> { let mut task_guard = self.lock_task(); if let Some(task) = task_guard.as_mut() { task.abort(); _ = ready!(Pin::new(task).poll(cx)); *task_guard = None; } Ready(()) } fn reinit(&self, iface: &Interface) { self.reinit(iface); } } ================================================ FILE: qtraversal/tests/detect.rs ================================================ use std::{ io, sync::{Arc, LazyLock}, }; use qinterface::io::{IO, ProductIO, handy::DEFAULT_IO_FACTORY}; use qtraversal::{ nat::{ client::{NatType, StunClient}, router::StunRouter, }, route::ReceiveAndDeliverPacket, }; use tracing::{Instrument, info_span}; use tracing_subscriber::{prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt}; #[derive(Debug, Clone, Copy)] pub struct TestCase { pub bind_addr: &'static str, pub outer_addr: &'static str, pub nat_type: NatType, } pub const STUN_AGENT: &str = "10.10.0.64:20002"; pub const CASES: [TestCase; 10] = [ TestCase { bind_addr: "192.168.0.98:6001", outer_addr: "10.10.0.98:6001", nat_type: NatType::FullCone, }, TestCase { bind_addr: "192.168.0.96:6002", outer_addr: "10.10.0.96:6002", nat_type: NatType::RestrictedCone, }, TestCase { bind_addr: "192.168.0.88:6003", outer_addr: "10.10.0.88:6003", nat_type: NatType::RestrictedPort, }, TestCase { bind_addr: "192.168.0.86:6004", outer_addr: "10.10.0.86:6004", nat_type: NatType::Dynamic, }, TestCase { bind_addr: "192.168.0.84:6005", outer_addr: "10.10.0.84:6005", nat_type: NatType::Symmetric, }, TestCase { bind_addr: "172.16.0.48:6006", outer_addr: "10.10.0.48:6006", nat_type: NatType::FullCone, }, TestCase { bind_addr: "172.16.0.46:6007", outer_addr: "10.10.0.46:6007", nat_type: NatType::RestrictedCone, }, TestCase { bind_addr: "172.16.0.38:6008", outer_addr: "10.10.0.38:6008", nat_type: NatType::RestrictedPort, }, TestCase { bind_addr: "172.16.0.36:6009", outer_addr: "10.10.0.36:6009", nat_type: NatType::Dynamic, }, TestCase { bind_addr: "172.16.0.34:6010", outer_addr: "10.10.0.34:6010", nat_type: NatType::Symmetric, }, ]; pub fn init_tracing() -> io::Result<()> { let file = std::fs::OpenOptions::new() .create(true) .write(true) .truncate(true) .open("tests.log")?; let filter = tracing_subscriber::filter::filter_fn(|metadata| { !metadata.target().contains("netlink_packet_route") }); _ = tracing_subscriber::registry() .with(tracing_subscriber::Layer::with_filter( tracing_subscriber::fmt::layer() .with_target(true) .with_ansi(false) .with_file(true) .with_line_number(true), filter.clone(), )) .with(tracing_subscriber::Layer::with_filter( tracing_subscriber::fmt::layer().with_writer(file), filter, )) .try_init(); Ok(()) } fn run + Send + 'static>( test_name: &'static str, f: F, ) -> F::Output { static RT: LazyLock = LazyLock::new(|| { init_tracing().expect("failed to init tracing"); tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .unwrap() }); RT.block_on(f.instrument(info_span!("test", test_name))) } async fn test_detect_case(case: usize) { let stun_agent = STUN_AGENT.parse().unwrap(); let case = CASES[case]; let bind_uri = format!("inet://{}", case.bind_addr); let iface: Arc = Arc::from(DEFAULT_IO_FACTORY.bind(bind_uri.into())); let stun_router = StunRouter::new(); let stun_client = StunClient::new(iface.clone(), stun_router.clone(), stun_agent, None); let _route_task = ReceiveAndDeliverPacket::task() .stun_router(stun_router) .iface_ref(iface.clone()) .spawn(); let outer_addr = stun_client .outer_addr() .await .expect("failed to get outer addr"); tracing::info!("Outer addr: {} Agent addr {}", outer_addr, stun_agent); let nat_type = stun_client .nat_type() .await .expect("failed to get nat type"); tracing::info!(case.bind_addr, case.outer_addr, ?nat_type, ?case.nat_type); assert!(nat_type == case.nat_type); } macro_rules! test_detect { (async fn $test_name:ident = test_detect_case($case:expr) $($tt:tt)*) => { #[test] #[ignore] // run manually fn $test_name() { run(stringify!($test_name), async move { test_detect_case($case).await }) } test_detect!($($tt)*); }; () => {} } // ip netns exec nsa cargo test --package qtraversal test_detect -- --include-ignored --nocapture test_detect! { async fn test_detect_full_cone_client = test_detect_case(0) async fn test_detect_restricted_cone_client = test_detect_case(1) async fn test_detect_port_restricted_client = test_detect_case(2) async fn test_detect_dynamic_client = test_detect_case(3) async fn test_detect_symmetric_client = test_detect_case(4) async fn test_detect_full_cone_server = test_detect_case(5) async fn test_detect_restricted_cone_server = test_detect_case(6) async fn test_detect_port_restricted_server = test_detect_case(7) async fn test_detect_dynamic_server = test_detect_case(8) async fn test_detect_symmetric_server = test_detect_case(9) } ================================================ FILE: qtraversal/tools/build_nat.sh ================================================ #!/bin/bash # set -x set -e # 创建局域网的网桥 ip link add brlan1 type bridge ip link set dev brlan1 up iptables -A FORWARD -o brlan1 -m comment --comment "allow packets to pass from lxd lan bridge" -j ACCEPT iptables -A FORWARD -i brlan1 -m comment --comment "allow input packets to pass to lxd lan bridge" -j ACCEPT ip link add brlan2 type bridge ip link set dev brlan2 up iptables -A FORWARD -o brlan2 -m comment --comment "allow packets to pass from lxd lan bridge" -j ACCEPT iptables -A FORWARD -i brlan2 -m comment --comment "allow input packets to pass to lxd lan bridge" -j ACCEPT # 创建广域网的网桥 ip link add brwan type bridge ip link set dev brwan up iptables -A FORWARD -o brwan -m comment --comment "allow packets to pass from lxd wan bridge" -j ACCEPT iptables -A FORWARD -i brwan -m comment --comment "allow input packets to pass to lxd wan bridge" -j ACCEPT # 创建内网主机Host A,多网卡 ip netns add nsa ip netns exec nsa ip link set lo up function create_new(){ devpair=$1 # aveth0 devbr=$2 # brlan1 virtnet=$3 # nsa devhost=$4 # eth0 devaddr=$5 # 192.168.0.98 gateway=$6 # 192.168.0.1 routemap=$7 # 101 dveth0=$devpair"0" dveth1=$devpair"1" ip link add $dveth0 type veth peer name $dveth1 ip link set dev $dveth1 master $devbr ip link set dev $dveth1 up ip link set dev $dveth0 netns $virtnet ip netns exec $virtnet ip link set dev $dveth0 name $devhost ip netns exec $virtnet ip addr add $devaddr/24 dev $devhost ip netns exec $virtnet ip link set dev $devhost up ip netns exec $virtnet ip route add default via $gateway dev $devhost src $devaddr table $routemap ip netns exec $virtnet ip rule add from $devaddr table $routemap } create_new "aveth0" "brlan1" "nsa" "eth0" "192.168.0.98" "192.168.0.1" "101" create_new "aveth1" "brlan1" "nsa" "eth1" "192.168.0.96" "192.168.0.1" "102" create_new "aveth2" "brlan1" "nsa" "eth2" "192.168.0.88" "192.168.0.1" "103" create_new "aveth3" "brlan1" "nsa" "eth3" "192.168.0.86" "192.168.0.1" "104" create_new "aveth4" "brlan1" "nsa" "eth4" "192.168.0.84" "192.168.0.1" "105" # Open Internel, FullCone create_new "aveth5" "brwan" "nsa" "eth5" "10.10.0.108" "10.10.0.1" "201" # Open Internel, RestrictedCone create_new "aveth6" "brwan" "nsa" "eth6" "10.10.0.106" "10.10.0.1" "202" # Open Internet,PortRestrictedCone create_new "aveth7" "brwan" "nsa" "eth7" "10.10.0.104" "10.10.0.1" "203" # Open Internet,UDPBlocked create_new "aveth8" "brwan" "nsa" "eth8" "10.10.0.102" "10.10.0.1" "204" create_new "aveth9" "brlan2" "nsa" "eth9" "172.16.0.48" "172.16.0.1" "301" create_new "avetha" "brlan2" "nsa" "etha" "172.16.0.46" "172.16.0.1" "302" create_new "avethb" "brlan2" "nsa" "ethb" "172.16.0.38" "172.16.0.1" "303" create_new "avethc" "brlan2" "nsa" "ethc" "172.16.0.36" "172.16.0.1" "304" create_new "avethd" "brlan2" "nsa" "ethd" "172.16.0.34" "172.16.0.1" "305" ip netns exec nsa ip route add default via 192.168.0.1 ip netns exec nsa iptables -t filter -P OUTPUT DROP ip netns exec nsa iptables -t filter -P INPUT DROP ip netns exec nsa iptables -t filter -A OUTPUT ! -p udp -j ACCEPT ip netns exec nsa iptables -t filter -A INPUT ! -p udp -j ACCEPT # eth0:192.168.0.98, NAT, FullCone ip netns exec nsa iptables -t filter -A OUTPUT -p udp -o eth0 -j ACCEPT ip netns exec nsa iptables -t filter -A INPUT -p udp -i eth0 -j ACCEPT # eth1:192.168.0.96, NAT, RestrictedCone ip netns exec nsa iptables -t filter -A OUTPUT -p udp -o eth1 -m recent --rdest --set --name pubtrack1 -j ACCEPT ip netns exec nsa iptables -t filter -A INPUT -p udp -i eth1 -m recent --rsource --rcheck --seconds 300 --name pubtrack1 -j ACCEPT # eth2:192.168.0.88, NAT, PortRestrictedCone ip netns exec nsa iptables -t filter -A OUTPUT -p udp -o eth2 -j ACCEPT ip netns exec nsa iptables -t filter -A INPUT -p udp -i eth2 -m state --state ESTABLISHED,RELATED -j ACCEPT # eth3:192.168.0.86, NAT, Dynamic ip netns exec nsa iptables -t filter -A OUTPUT -p udp -o eth3 -j ACCEPT ip netns exec nsa iptables -t filter -A INPUT -p udp -i eth3 -m state --state ESTABLISHED,RELATED -j ACCEPT # eth4:192.168.0.84, NAT, Symmetric ip netns exec nsa iptables -t filter -A OUTPUT -p udp -o eth4 -j ACCEPT ip netns exec nsa iptables -t filter -A INPUT -p udp -i eth4 -m state --state ESTABLISHED,RELATED -j ACCEPT # eth5:10.10.0.108,Open Internet,FullCone ip netns exec nsa iptables -t filter -A OUTPUT -p udp -o eth5 -j ACCEPT ip netns exec nsa iptables -t filter -A INPUT -p udp -i eth5 -j ACCEPT # eth6:10.10.0.106,Open Internet,RestrictedCone ip netns exec nsa iptables -t filter -A OUTPUT -p udp -o eth6 -m recent --rdest --set --name pubtrack6 -j ACCEPT ip netns exec nsa iptables -t filter -A INPUT -p udp -i eth6 -m recent --rsource --rcheck --seconds 300 --name pubtrack6 -j ACCEPT # eth7:10.10.0.104,Open Internet,PortRestrictedCone ip netns exec nsa iptables -t filter -A OUTPUT -p udp -o eth7 -j ACCEPT ip netns exec nsa iptables -t filter -A INPUT -p udp -i eth7 -m state --state ESTABLISHED,RELATED -j ACCEPT # eth8:10.10.0.102, OpenInternel, UDPBlocked # default rule DROP # eth9:172.16.0.48, NAT, FullCone ip netns exec nsa iptables -t filter -A OUTPUT -p udp -o eth9 -j ACCEPT ip netns exec nsa iptables -t filter -A INPUT -p udp -i eth9 -j ACCEPT # etha:172.16.0.46, NAT, RestrictedCone ip netns exec nsa iptables -t filter -A OUTPUT -p udp -o etha -m recent --rdest --set --name pubtrack1 -j ACCEPT ip netns exec nsa iptables -t filter -A INPUT -p udp -i etha -m recent --rsource --rcheck --seconds 300 --name pubtrack1 -j ACCEPT # ethb:172.16.0.38, NAT, PortRestrictedCone ip netns exec nsa iptables -t filter -A OUTPUT -p udp -o ethb -j ACCEPT ip netns exec nsa iptables -t filter -A INPUT -p udp -i ethb -m state --state ESTABLISHED,RELATED -j ACCEPT # ethc:172.16.0.36, NAT, Dynamic ip netns exec nsa iptables -t filter -A OUTPUT -p udp -o ethc -j ACCEPT ip netns exec nsa iptables -t filter -A INPUT -p udp -i ethc -m state --state ESTABLISHED,RELATED -j ACCEPT # ethd:172.16.0.34, NAT, Symmetric ip netns exec nsa iptables -t filter -A OUTPUT -p udp -o ethd -j ACCEPT ip netns exec nsa iptables -t filter -A INPUT -p udp -i ethd -m state --state ESTABLISHED,RELATED -j ACCEPT # 创建内网主机B ip netns add nsb ip netns exec nsb ip link set lo up ip link add bveth0 type veth peer name bveth1 ip link set dev bveth1 master brlan1 ip link set dev bveth1 up ip link set dev bveth0 netns nsb ip netns exec nsb ip link set dev bveth0 name eth0 ip netns exec nsb ip addr add 192.168.0.100/24 dev eth0 ip netns exec nsb ip link set dev eth0 up ip netns exec nsb ip route add 192.168.0.1 dev eth0 ip netns exec nsb ip route add default via 192.168.0.1 # 创建外网主机Host O ip netns add nso ip netns exec nso ip link set lo up ip link add oveth00 type veth peer name oveth01 ip link set oveth00 netns nso ip netns exec nso ip link set dev oveth00 name eth0 ip netns exec nso ip addr add 192.168.0.1/24 dev eth0 ip netns exec nso ip link set dev eth0 up ip netns exec nso ip rule add from 192.168.0.1/24 dev eth0 ip netns exec nso sysctl -w net.ipv4.conf.eth0.proxy_arp=1 ip link set dev oveth01 master brlan1 ip link set dev oveth01 up ip link add oveth10 type veth peer name oveth11 ip link set oveth10 netns nso ip netns exec nso ip link set dev oveth10 name eth1 # ip netns exec nso ip addr add 10.10.0.1/24 dev eth1 ip netns exec nso ip addr add 10.10.0.98/24 dev eth1 ip netns exec nso ip addr add 10.10.0.96/24 dev eth1 ip netns exec nso ip addr add 10.10.0.88/24 dev eth1 ip netns exec nso ip addr add 10.10.0.86/24 dev eth1 ip netns exec nso ip addr add 10.10.0.84/24 dev eth1 ip netns exec nso ip link set dev eth1 up ip netns exec nso ip route add default dev eth1 ip link set dev oveth11 master brwan ip link set dev oveth11 up ip netns exec nso iptables -A FORWARD -j LOG --log-prefix "FORWARD:" --log-level 3 ip netns exec nso iptables -t nat -A PREROUTING -j LOG --log-prefix "DNAT:" --log-level 3 ip netns exec nso iptables -t nat -A POSTROUTING -j LOG --log-prefix "SNAT:" --log-level 3 # 192.168.0.98 nat to 10.10.0.98, 许出许进,再通过HOST A中设计iptables规则可成为FullCone ip netns exec nso iptables -t nat -A POSTROUTING -o eth1 -s 192.168.0.98 -d 10.10.0.1/24 -j SNAT --to-source 10.10.0.98 ip netns exec nso iptables -t nat -A PREROUTING -i eth1 -d 10.10.0.98 -s 10.10.0.1/24 -j DNAT --to-destination 192.168.0.98 # 192.168.0.96 nat to 10.10.0.96, 许出许进,确保映射地址无论如何不会变,再通过HOST A中设计iptables规则可成为RestrictedCone ip netns exec nso iptables -t nat -A POSTROUTING -o eth1 -s 192.168.0.96 -d 10.10.0.1/24 -j SNAT --to-source 10.10.0.96 ip netns exec nso iptables -t nat -A PREROUTING -i eth1 -d 10.10.0.96 -s 10.10.0.1/24 -j DNAT --to-destination 192.168.0.96 # 192.168.0.88 nat to 10.10.0.88, 许出许进,确保映射地址无论如何不会变,再通过HOST A中设计iptables规则可成为PortRestrictedCone ip netns exec nso iptables -t nat -A POSTROUTING -o eth1 -s 192.168.0.88 -d 10.10.0.1/24 -j SNAT --to-source 10.10.0.88 ip netns exec nso iptables -t nat -A PREROUTING -i eth1 -d 10.10.0.88 -s 10.10.0.1/24 -j DNAT --to-destination 192.168.0.88 # 192.168.0.86 nat to 10.10.0.86, 若是先进后出的,端口随机映射;否则只进行IP映射,可成为Dynamic ip netns exec nso iptables -t nat -A PREROUTING -i eth1 -d 10.10.0.86 -s 10.10.0.1/24 -m recent --rsource --set --name strangers -j DNAT --to-destination 192.168.0.1 # 注意:故意DNAT到一个错误的地址 ip netns exec nso iptables -t nat -A POSTROUTING -o eth1 -s 192.168.0.86 -d 10.10.0.1/24 -m recent --rdest --rcheck --seconds 3600 --name strangers -j SNAT --to-source 10.10.0.86 --random ip netns exec nso iptables -t nat -A POSTROUTING -o eth1 -s 192.168.0.86 -d 10.10.0.1/24 -j SNAT --to-source 10.10.0.86 # 192.168.0.84 nat to 10.10.0.84, 许出不许进,出的时候,端口随机映射,可成为Symmetric ip netns exec nso iptables -t nat -A POSTROUTING -o eth1 -s 192.168.0.84 -d 10.10.0.1/24 -j SNAT --to-source 10.10.0.84 --random # 创建外网主机Host N ip netns add nsn ip netns exec nsn ip link set lo up ip link add nveth00 type veth peer name nveth01 ip link set nveth00 netns nsn ip netns exec nsn ip link set dev nveth00 name eth0 ip netns exec nsn ip addr add 172.16.0.1/24 dev eth0 ip netns exec nsn ip link set dev eth0 up ip netns exec nsn ip rule add from 172.16.0.1/24 dev eth0 ip netns exec nsn sysctl -w net.ipv4.conf.eth0.proxy_arp=1 ip link set dev nveth01 master brlan2 ip link set dev nveth01 up ip link add nveth10 type veth peer name nveth11 ip link set nveth10 netns nsn ip netns exec nsn ip link set dev nveth10 name eth1 # ip netns exec nsn ip addr add 10.10.0.2/24 dev eth1 ip netns exec nsn ip addr add 10.10.0.48/24 dev eth1 ip netns exec nsn ip addr add 10.10.0.46/24 dev eth1 ip netns exec nsn ip addr add 10.10.0.38/24 dev eth1 ip netns exec nsn ip addr add 10.10.0.36/24 dev eth1 ip netns exec nsn ip addr add 10.10.0.34/24 dev eth1 ip netns exec nsn ip link set dev eth1 up ip netns exec nsn ip route add default dev eth1 ip link set dev nveth11 master brwan ip link set dev nveth11 up ip netns exec nsn iptables -A FORWARD -j LOG --log-prefix "FORWARD:" --log-level 3 ip netns exec nsn iptables -t nat -A PREROUTING -j LOG --log-prefix "DNAT:" --log-level 3 ip netns exec nsn iptables -t nat -A POSTROUTING -j LOG --log-prefix "SNAT:" --log-level 3 # 172.16.0.48 nat to 10.10.0.48, 许出许进,再通过HOST A中设计iptables规则可成为FullCone ip netns exec nsn iptables -t nat -A POSTROUTING -o eth1 -s 172.16.0.48 -d 10.10.0.1/24 -j SNAT --to-source 10.10.0.48 ip netns exec nsn iptables -t nat -A PREROUTING -i eth1 -d 10.10.0.48 -s 10.10.0.1/24 -j DNAT --to-destination 172.16.0.48 # 172.16.0.46 nat to 10.10.0.46, 许出许进,确保映射地址无论如何不会变,再通过HOST A中设计iptables规则可成为RestrictedCone ip netns exec nsn iptables -t nat -A POSTROUTING -o eth1 -s 172.16.0.46 -d 10.10.0.1/24 -j SNAT --to-source 10.10.0.46 ip netns exec nsn iptables -t nat -A PREROUTING -i eth1 -d 10.10.0.46 -s 10.10.0.1/24 -j DNAT --to-destination 172.16.0.46 # 172.16.0.38 nat to 10.10.0.38, 许出许进,确保映射地址无论如何不会变,再通过HOST A中设计iptables规则可成为PortRestrictedCone ip netns exec nsn iptables -t nat -A POSTROUTING -o eth1 -s 172.16.0.38 -d 10.10.0.1/24 -j SNAT --to-source 10.10.0.38 ip netns exec nsn iptables -t nat -A PREROUTING -i eth1 -d 10.10.0.38 -s 10.10.0.1/24 -j DNAT --to-destination 172.16.0.38 # 172.16.0.36 nat to 10.10.0.36, 若是先进后出的,端口随机映射;否则只进行IP映射,可成为Dynamic ip netns exec nsn iptables -t nat -A PREROUTING -i eth1 -d 10.10.0.36 -s 10.10.0.1/24 -m recent --rsource --set --name strangers -j DNAT --to-destination 172.16.0.1 # 注意:故意DNAT到一个错误的地址 ip netns exec nsn iptables -t nat -A POSTROUTING -o eth1 -s 172.16.0.36 -d 10.10.0.1/24 -m recent --rdest --rcheck --seconds 3600 --name strangers -j SNAT --to-source 10.10.0.36 --random ip netns exec nsn iptables -t nat -A POSTROUTING -o eth1 -s 172.16.0.36 -d 10.10.0.1/24 -j SNAT --to-source 10.10.0.36 # 172.16.0.34 nat to 10.10.0.34, 许出不许进,出的时候,端口随机映射,可成为Symmetric ip netns exec nsn iptables -t nat -A POSTROUTING -o eth1 -s 172.16.0.34 -d 10.10.0.1/24 -j SNAT --to-source 10.10.0.34 --random # Host S ip netns add nss ip netns exec nss ip link set lo up create_new "sveth0" "brwan" "nss" "eth0" "10.10.0.64" "10.10.0.1" "401" create_new "sveth1" "brwan" "nss" "eth1" "10.10.0.66" "10.10.0.1" "402" create_new "sveth2" "brwan" "nss" "eth2" "10.10.0.68" "10.10.0.1" "403" # 创建内网主机H ip netns add nshub ip netns exec nshub ip link set lo up ip link add hubveth0 type veth peer name hubveth1 ip link set dev hubveth1 master brwan ip link set dev hubveth1 up ip link set dev hubveth0 netns nshub ip netns exec nshub ip link set dev hubveth0 name eth0 ip netns exec nshub ip addr add 10.10.0.1/24 dev eth0 ip netns exec nshub ip link set dev eth0 up # ip netns exec nshub ip rule add from 10.10.0.1/24 dev eth0 ip netns exec nshub ip route add default dev eth0 ================================================ FILE: qtraversal/tools/clear_nat.sh ================================================ #!/bin/bash # set -x set -e ip netns exec nsa ip route flush table 101 ip netns exec nsa ip route flush table 102 ip netns exec nsa ip route flush table 103 ip netns exec nsa ip route flush table 104 ip netns exec nsa ip route flush table 105 ip netns exec nsa ip route flush table 201 ip netns exec nsa ip route flush table 202 ip netns exec nsa ip route flush table 203 ip netns exec nsa ip route flush table 204 ip netns exec nsa ip route flush table 301 ip netns exec nsa ip route flush table 302 ip netns exec nsa ip route flush table 303 ip netns exec nsa ip route flush table 304 ip netns exec nsa ip route flush table 305 ip netns exec nsa ip route flush cache ip netns exec nss ip route flush table 401 ip netns exec nss ip route flush table 402 ip netns exec nss ip route flush table 403 ip netns exec nss ip route flush cache ip netns del nsa ip netns del nsb ip netns del nso ip netns del nss ip netns del nsn ip netns del nshub ip link del brlan1 ip link del brlan2 ip link del brwan iptables -D FORWARD -o brlan1 -m comment --comment "allow packets to pass from lxd lan bridge" -j ACCEPT iptables -D FORWARD -i brlan1 -m comment --comment "allow input packets to pass to lxd lan bridge" -j ACCEPT iptables -D FORWARD -o brlan2 -m comment --comment "allow packets to pass from lxd lan bridge" -j ACCEPT iptables -D FORWARD -i brlan2 -m comment --comment "allow input packets to pass to lxd lan bridge" -j ACCEPT iptables -D FORWARD -o brwan -m comment --comment "allow packets to pass from lxd wan bridge" -j ACCEPT iptables -D FORWARD -i brwan -m comment --comment "allow input packets to pass to lxd wan bridge" -j ACCEPT # ip link del aveth1 # ip link del bveth1 # ip link del oveth1 ================================================ FILE: qtraversal/tools/dockerfile ================================================ ARG TARGETPLATFORM=linux/amd64 FROM --platform=$TARGETPLATFORM ubuntu:24.04 ENV DEBIAN_FRONTEND=noninteractive \ CARGO_HOME=/usr/local/cargo \ PATH=/usr/local/cargo/bin:$PATH # # 1. 使用阿里云APT镜像源 # RUN sed -i 's/archive.ubuntu.com/mirrors.aliyun.com/g' /etc/apt/sources.list && \ # sed -i 's/security.ubuntu.com/mirrors.aliyun.com/g' /etc/apt/sources.list # 2. 分离系统包安装层(利用Docker缓存) RUN apt-get update && apt-get install -y \ build-essential \ curl \ git \ iproute2 \ iptables \ libssl-dev \ pkg-config \ tcpdump \ && rm -rf /var/lib/apt/lists/* # 3. 安装Rust(独立层) RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly --profile minimal --no-modify-path # # 4. 配置Cargo镜像源(使用TOML格式) # RUN mkdir -p $CARGO_HOME && \ # printf '[source.crates-io]\nreplace-with = "tuna"\n\n[source.tuna]\nregistry = "sparse+https://mirrors.tuna.tsinghua.edu.cn/crates.io-index/"\n\n[net]\ngit-fetch-with-cli = true\nretry = 2\n' > $CARGO_HOME/config.toml # RUN rustup override set nightly # 5. 验证工具链 RUN rustc --version && cargo --version ================================================ FILE: qtraversal/tools/run_stun.sh ================================================ qtraversal/tools/build_nat.sh cargo build --example stun_server --release ip netns exec nss nohup target/release/examples/stun_server --bind-addr1 10.10.0.64:20002 --bind-addr2 10.10.0.64:20003 --change-addr 10.10.0.66:20002 --outer-addr1 10.10.0.64:20002 --outer-addr2 10.10.0.64:20003 & ip netns exec nss nohup target/release/examples/stun_server --bind-addr1 10.10.0.66:20002 --bind-addr2 10.10.0.66:20003 --change-addr 10.10.0.68:20002 --outer-addr1 10.10.0.66:20002 --outer-addr2 10.10.0.66:20003 & ip netns exec nss nohup target/release/examples/stun_server --bind-addr1 10.10.0.68:20002 --bind-addr2 10.10.0.68:20003 --change-addr 10.10.0.64:20002 --outer-addr1 10.10.0.68:20002 --outer-addr2 10.10.0.68:20003 & ================================================ FILE: qudp/Cargo.toml ================================================ [package] name = "qudp" version = "0.5.0" edition.workspace = true description = "High-performance UDP encapsulation for QUIC" readme.workspace = true repository.workspace = true license.workspace = true keywords = ["async", "socket", "udp", "gso", "gro"] categories.workspace = true rust-version.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] bytes = { workspace = true } cfg-if = { workspace = true } libc = "0.2" qbase = { workspace = true } tracing = { workspace = true } socket2 = { workspace = true } tokio = { workspace = true, features = ["net"] } nix = { version = "0.31", features = ["socket", "uio", "net"] } [target.'cfg(windows)'.dependencies] windows-sys = { version = "0.61", features = [ "Win32_Foundation", "Win32_System_IO", "Win32_Networking_WinSock", ] } [dev-dependencies] clap = { workspace = true } tokio = { workspace = true, features = ["test-util", "macros"] } [dev-dependencies.tracing-subscriber] workspace = true features = ["env-filter", "time"] [[example]] name = "send" path = "examples/send.rs" [[example]] name = "receive" path = "examples/receive.rs" [features] gso = [] ================================================ FILE: qudp/examples/receive.rs ================================================ use clap::Parser; use qudp::UdpSocket; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] struct Args { #[arg(short,long, default_value_t = String::from("127.0.0.1:12345"))] bind: String, } #[tokio::main(flavor = "current_thread")] async fn main() { tracing_subscriber::fmt() .with_max_level(tracing::level_filters::LevelFilter::TRACE) .init(); let args = Args::parse(); let addr = args.bind.parse().unwrap(); let socket = UdpSocket::bind(addr).expect("failed to create socket"); let mut receiver = socket.receiver(); loop { match receiver.recv().await { Ok(n) => { tracing::info!( "Received {} packets, dst {}, src {} len {}", n, receiver.lines[0].dst, receiver.lines[0].src, receiver.lines[0].seg_size ); } Err(e) => { tracing::error!("Receive failed: {}", e); } } } } ================================================ FILE: qudp/examples/send.rs ================================================ use std::io::IoSlice; use clap::Parser; use qbase::net::route::{Line, Link}; use qudp::UdpSocket; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] struct Args { #[arg(long, default_value_t = String::from("127.0.0.1:0"))] src: String, #[arg(long, default_value_t = String::from("127.0.0.1:12345"))] dst: String, #[arg(long, default_value_t = 3600)] msg_size: usize, #[arg(long, default_value_t = 100)] msg_count: usize, } #[tokio::main(flavor = "current_thread")] async fn main() { tracing_subscriber::fmt() .with_max_level(tracing::level_filters::LevelFilter::TRACE) .init(); let args = Args::parse(); let addr = args.src.parse().unwrap(); let socket = UdpSocket::bind(addr).expect("failed to create socket"); let dst = args.dst.parse().unwrap(); let send_hdr = Line::new( Link::new(socket.local_addr().expect("failed to get local addr"), dst), 64, None, args.msg_size as u16, ); let payload = vec![8u8; args.msg_size]; let payloads = vec![IoSlice::new(&payload[..]); args.msg_count]; match socket.send(&payloads, send_hdr).await { Ok(n) => tracing::info!("Sent {} packets, bytes: {}", n, n * args.msg_size), Err(e) => tracing::error!("Send failed: {}", e), } } ================================================ FILE: qudp/src/lib.rs ================================================ use std::{ future::Future, io::{self, IoSlice, IoSliceMut}, net::SocketAddr, pin::Pin, sync::atomic::AtomicI32, task::{Context, Poll, ready}, }; use bytes::BytesMut; use qbase::net::route::Line; use socket2::{Domain, Socket, Type}; use tokio::io::Interest; pub const BATCH_SIZE: usize = 64; cfg_if::cfg_if! { if #[cfg(unix)]{ #[path = "unix.rs"] mod unix; } else if #[cfg(windows)] { #[path = "windows.rs"] mod windows; } else { compile_error!("Unsupported platform"); } } #[derive(Debug)] pub struct UdpSocket { io: tokio::net::UdpSocket, ttl: AtomicI32, } impl UdpSocket { pub fn bind(addr: SocketAddr) -> io::Result { let domain = if addr.is_ipv4() { Domain::IPV4 } else { Domain::IPV6 }; let socket = Socket::new(domain, Type::DGRAM, None)?; socket.set_nonblocking(true)?; Self::config(&socket, addr)?; let io = tokio::net::UdpSocket::from_std(socket.into())?; let usc = Self { io, ttl: AtomicI32::new(Line::DEFAULT_TTL as i32), }; Ok(usc) } pub fn local_addr(&self) -> io::Result { self.io.local_addr() } pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll> { self.io.poll_send_ready(cx) } pub fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll> { self.io.poll_recv_ready(cx) } pub fn poll_send( &self, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], line: &Line, ) -> Poll> { loop { ready!(self.poll_send_ready(cx))?; self.set_ttl(line.ttl as i32)?; match self .io .try_io(Interest::WRITABLE, || self.sendmsg(bufs, line)) { Ok(n) => return Poll::Ready(Ok(n)), Err(e) if e.kind() == io::ErrorKind::WouldBlock => continue, Err(e) => return Poll::Ready(Err(e)), } } } pub fn poll_recv( &self, cx: &mut Context, bufs: &mut [IoSliceMut<'_>], lines: &mut [Line], ) -> Poll> { loop { ready!(self.poll_recv_ready(cx)?); let f = || self.recvmsg(bufs, lines); let ret = self.io.try_io(Interest::READABLE, f); if matches!(&ret, Err(e) if e.kind() == io::ErrorKind::WouldBlock) { continue; } else { return Poll::Ready(ret); } } } #[allow(unreachable_code)] pub fn bind_device(&self, _device: &str) -> io::Result<()> { // #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] // android and linux support bind_device_by_index, which is called by codes below #[cfg(target_os = "fuchsia")] { let socket = socket2::SockRef::from(&self.io); return socket.bind_device(Some(_device.as_bytes())); } #[cfg(any( target_os = "ios", target_os = "visionos", target_os = "macos", target_os = "tvos", target_os = "watchos", target_os = "illumos", target_os = "solaris", target_os = "linux", target_os = "android", ))] { let socket = socket2::SockRef::from(&self.io); let index = nix::net::if_::if_nametoindex(_device)?; let index = std::num::NonZeroU32::new(index) .expect("Already checked by nix::net::if_::if_nametoindex"); match self.io.local_addr()? { SocketAddr::V4(..) => socket.bind_device_by_index_v4(Some(index))?, SocketAddr::V6(..) => socket.bind_device_by_index_v6(Some(index))?, } return Ok(()); } Ok(()) } } pub trait Io { fn config(io: &socket2::Socket, addr: SocketAddr) -> io::Result<()>; fn sendmsg(&self, bufs: &[IoSlice<'_>], line: &Line) -> io::Result; fn recvmsg(&self, bufs: &mut [IoSliceMut<'_>], line: &mut [Line]) -> io::Result; fn set_ttl(&self, ttl: i32) -> io::Result<()>; } impl UdpSocket { pub fn send<'a>(&'a self, iovecs: &'a [IoSlice<'a>], line: Line) -> Send<'a> { Send { socket: self, iovecs, line, } } pub fn receiver(&self) -> Receiver<'_> { Receiver { socket: self, iovecs: (0..BATCH_SIZE) .map(|_| { let mut buf = BytesMut::with_capacity(1500); buf.resize(1500, 0); buf }) .collect::>(), lines: (0..BATCH_SIZE).map(|_| Line::default()).collect::>(), } } } pub struct Send<'a> { pub socket: &'a UdpSocket, pub iovecs: &'a [IoSlice<'a>], pub line: Line, } impl Future for Send<'_> { type Output = io::Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); this.socket.poll_send(cx, this.iovecs, &this.line) } } pub struct Receiver<'u> { pub socket: &'u UdpSocket, pub iovecs: Vec, pub lines: Vec, } impl Receiver<'_> { #[inline] pub fn poll_recv(&mut self, cx: &mut Context) -> Poll> { let mut bufs = self .iovecs .iter_mut() .map(|b| IoSliceMut::new(b)) .collect::>(); self.socket.poll_recv(cx, &mut bufs, &mut self.lines) } #[inline] pub async fn recv(&mut self) -> io::Result { core::future::poll_fn(|cx| self.poll_recv(cx)).await } } ================================================ FILE: qudp/src/unix.rs ================================================ use std::{ io::{self, IoSlice}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, os::fd::{AsFd, AsRawFd}, }; use nix::{ cmsg_space, sys::socket::{ ControlMessageOwned, SockaddrLike, SockaddrStorage, sockopt::{self}, }, }; use qbase::net::route::Line; use socket2::Socket; use crate::{Io, UdpSocket}; const OPTION_ON: bool = true; const OPTION_OFF: bool = false; impl Io for UdpSocket { fn config(socket: &Socket, addr: SocketAddr) -> io::Result<()> { let io = socket.as_fd(); nix::sys::socket::setsockopt(&io, sockopt::RcvBuf, &(2 * 1024 * 1024))?; match addr { SocketAddr::V4(_) => { #[cfg(any(target_os = "freebsd", target_os = "macos", target_os = "ios"))] { nix::sys::socket::setsockopt(&io, sockopt::IpDontFrag, &OPTION_ON)?; nix::sys::socket::setsockopt(&io, sockopt::Ipv4RecvDstAddr, &OPTION_ON)?; } #[cfg(any( target_os = "android", target_os = "linux", target_os = "freebsd", target_os = "netbsd" ))] nix::sys::socket::setsockopt(&io, sockopt::Ipv4Ttl, &(Line::DEFAULT_TTL as i32))?; nix::sys::socket::setsockopt(&io, sockopt::Ipv4PacketInfo, &OPTION_ON)?; } SocketAddr::V6(_) => { nix::sys::socket::setsockopt(&io, sockopt::Ipv6V6Only, &OPTION_OFF)?; nix::sys::socket::setsockopt(&io, sockopt::Ipv6RecvPacketInfo, &OPTION_ON)?; nix::sys::socket::setsockopt(&io, sockopt::Ipv6DontFrag, &OPTION_ON)?; nix::sys::socket::setsockopt(&io, sockopt::Ipv6Ttl, &(Line::DEFAULT_TTL as i32))?; } } socket.bind(&addr.into()) } #[cfg(any( target_os = "android", target_os = "linux", target_os = "freebsd", target_os = "netbsd" ))] fn sendmsg(&self, buffers: &[IoSlice<'_>], line: &Line) -> io::Result { use nix::{ errno::Errno, sys::socket::{MsgFlags, MultiHeaders, SockaddrIn, SockaddrIn6, sendmmsg}, }; use super::BATCH_SIZE; let slices: Vec<_> = buffers .iter() .take(BATCH_SIZE) .map(std::slice::from_ref) .collect(); let batch_size = slices.len(); if batch_size == 0 { return Ok(0); } #[cfg(feature = "gso")] let (cmsgs, space) = ( vec![nix::sys::socket::ControlMessage::UdpGsoSegments( &line.seg_size, )], Some(cmsg_space!(libc::c_int)), ); #[cfg(not(feature = "gso"))] let (cmsgs, space) = (Vec::new(), None); macro_rules! send_batch { ($ty:ty, $addr:expr) => {{ let sock_addr = <$ty>::from($addr); let addrs = vec![Some(sock_addr); BATCH_SIZE]; let mut data = MultiHeaders::<$ty>::preallocate(BATCH_SIZE, space); match sendmmsg( self.io.as_raw_fd(), &mut data, &slices, &addrs, &cmsgs, MsgFlags::empty(), ) { Ok(ret) => Ok(ret.count()), Err(e @ (Errno::EINTR | Errno::EAGAIN | Errno::ENOBUFS)) => { Err(io::Error::new(io::ErrorKind::WouldBlock, e)) } Err(e) => Err(e.into()), } }}; } match line.dst { SocketAddr::V4(v4) => send_batch!(SockaddrIn, v4), SocketAddr::V6(v6) => send_batch!(SockaddrIn6, v6), } } #[cfg(any( target_os = "macos", target_os = "ios", target_os = "watchos", target_os = "tvos" ))] fn sendmsg(&self, slices: &[IoSlice<'_>], send_line: &Line) -> io::Result { use nix::{ errno::Errno, sys::socket::{MsgFlags, SockaddrIn, SockaddrIn6, sendmsg}, }; let mut sent_packet = 0; for slice in slices.iter() { macro_rules! send_batch { ($ty:ty, $addr:expr) => {{ let sock_addr = <$ty>::from($addr); match sendmsg( self.io.as_raw_fd(), &[*slice], &[], MsgFlags::empty(), Some(&sock_addr), ) { Ok(_send_bytes) => sent_packet += 1, Err(_) if sent_packet > 0 => return Ok(sent_packet), Err(Errno::EINTR) => continue, Err(e @ (Errno::EAGAIN | Errno::ENOBUFS)) => { return Err(io::Error::new(io::ErrorKind::WouldBlock, e)); } Err(e) => { return Err(e.into()); } } }}; } match send_line.dst { SocketAddr::V4(v4) => send_batch!(SockaddrIn, v4), SocketAddr::V6(v6) => send_batch!(SockaddrIn6, v6), } } Ok(sent_packet) } #[cfg(any( target_os = "android", target_os = "linux", target_os = "freebsd", target_os = "netbsd" ))] fn recvmsg( &self, bufs: &mut [std::io::IoSliceMut<'_>], recv_lines: &mut [Line], ) -> io::Result { use nix::sys::socket::{MsgFlags, recvmmsg}; use super::BATCH_SIZE; let mut msgs: Vec<_> = bufs .iter_mut() .map(|buf| [std::io::IoSliceMut::new(&mut buf[..])]) .collect(); let cmsg_buffer = cmsg_space!(libc::in_pktinfo, libc::in6_pktinfo, libc::c_int); let mut data = nix::sys::socket::MultiHeaders::::preallocate( BATCH_SIZE, Some(cmsg_buffer), ); let res = match recvmmsg( self.io.as_raw_fd(), &mut data, &mut msgs, MsgFlags::MSG_DONTWAIT, None, ) { Ok(results) => results.collect::>(), Err(e) => { if matches!(e, nix::errno::Errno::EAGAIN | nix::errno::Errno::EINTR) { return Err(io::Error::new(io::ErrorKind::WouldBlock, e)); } return Err(e.into()); } }; let local_port = self.local_addr()?.port(); let mut count = 0; for recv_msg in res { let src_addr = recv_msg.address.unwrap().to_socketaddr(); let link = qbase::net::route::Link::new(src_addr, recv_lines[count].dst); let mut recv_line = Line { link, ttl: 0, ecn: None, seg_size: recv_msg.bytes as u16, }; for cmsg in recv_msg.cmsgs().unwrap() { parse_cmsg(cmsg, &mut recv_line); } recv_line.dst.set_port(local_port); recv_lines[count] = recv_line; count += 1; } Ok(count) } #[cfg(any( target_os = "macos", target_os = "ios", target_os = "watchos", target_os = "tvos" ))] fn recvmsg( &self, bufs: &mut [std::io::IoSliceMut<'_>], recv_lines: &mut [Line], ) -> io::Result { use nix::sys::socket::{MsgFlags, recvmsg}; let mut cmsg_space = cmsg_space!(libc::in_pktinfo, libc::in6_pktinfo, libc::c_int); let result = recvmsg::( self.io.as_raw_fd(), bufs, Some(&mut cmsg_space), MsgFlags::empty(), ); match result { Ok(recv_msg) => { if let Ok(cmsgs) = recv_msg.cmsgs() { for cmsg in cmsgs { parse_cmsg(cmsg, &mut recv_lines[0]); } } recv_lines[0].dst.set_port(self.local_addr()?.port()); recv_lines[0].src = recv_msg.address.unwrap().to_socketaddr(); recv_lines[0].seg_size = recv_msg.bytes as u16; Ok(1) } Err(e) => { if matches!(e, nix::errno::Errno::EAGAIN | nix::errno::Errno::EINTR) { // actually, it's not an error, just a signal to retry return Err(io::Error::new(io::ErrorKind::WouldBlock, e)); } Err(e.into()) } } } fn set_ttl(&self, ttl: i32) -> io::Result<()> { use std::sync::atomic::Ordering::{Acquire, SeqCst}; if ttl == self.ttl.load(Acquire) { return Ok(()); } let local = self.local_addr()?; let io = self.io.as_raw_fd(); let ret = match local.ip() { IpAddr::V4(_) => unsafe { libc::setsockopt( io, libc::IPPROTO_IP, libc::IP_TTL, &ttl as *const _ as *const libc::c_void, std::mem::size_of_val(&ttl) as libc::socklen_t, ) }, IpAddr::V6(_) => unsafe { libc::setsockopt( io, libc::IPPROTO_IPV6, libc::IPV6_UNICAST_HOPS, &ttl as *const _ as *const libc::c_void, std::mem::size_of_val(&ttl) as libc::socklen_t, ) }, }; if ret != 0 { return Err(io::Error::last_os_error()); } self.ttl.store(ttl, SeqCst); Ok(()) } } fn parse_cmsg(cmsg: ControlMessageOwned, line: &mut Line) { match cmsg { ControlMessageOwned::Ipv4PacketInfo(pktinfo) => { let ip = IpAddr::V4(Ipv4Addr::from(pktinfo.ipi_addr.s_addr.to_ne_bytes())); line.link.dst.set_ip(ip); } ControlMessageOwned::Ipv6PacketInfo(pktinfo6) => { let ip = IpAddr::V6(Ipv6Addr::from(pktinfo6.ipi6_addr.s6_addr)); line.link.dst.set_ip(ip); } _ => {} } } trait ToSocketAddr { fn to_socketaddr(&self) -> SocketAddr; } impl ToSocketAddr for SockaddrStorage { fn to_socketaddr(&self) -> SocketAddr { match self.family() { Some(nix::sys::socket::AddressFamily::Inet) => { let sockaddr_in = self.as_sockaddr_in().unwrap(); let v4_addr = SocketAddrV4::new(sockaddr_in.ip(), sockaddr_in.port()); SocketAddr::V4(v4_addr) } Some(nix::sys::socket::AddressFamily::Inet6) => { let sockaddr_in6 = self.as_sockaddr_in6().unwrap(); let v6_addr = SocketAddrV6::new( sockaddr_in6.ip(), sockaddr_in6.port(), sockaddr_in6.flowinfo(), sockaddr_in6.scope_id(), ); SocketAddr::V6(v6_addr) } _ => panic!("Unsupported address family"), } } } ================================================ FILE: qudp/src/windows.rs ================================================ use std::{ ffi::c_int, io, mem, net::{IpAddr, Ipv4Addr, SocketAddr}, os::windows::io::AsRawSocket, ptr, }; use libc::c_uchar; use qbase::net::route::{Line, Link}; use socket2::Socket; use windows_sys::Win32::Networking::WinSock::{self, SOCKET}; use crate::{Io, UdpSocket}; const CMSG_LEN: usize = 128; #[derive(Copy, Clone)] #[repr(align(8))] // Conservative bound for align_of pub(crate) struct Aligned(pub(crate) T); impl Io for UdpSocket { fn config(socket: &Socket, addr: SocketAddr) -> std::io::Result<()> { const OPTION_ON: c_int = 1; const OPTION_OFF: c_int = 0; let io = socket.as_raw_socket().try_into().unwrap(); setsockopt(io, WinSock::SOL_SOCKET, WinSock::SO_RCVBUF, 2 * 1024 * 1024); match addr { SocketAddr::V4(_) => { setsockopt(io, WinSock::IPPROTO_IP, WinSock::IP_RECVTOS, OPTION_ON); setsockopt(io, WinSock::IPPROTO_IP, WinSock::IP_PKTINFO, OPTION_ON); setsockopt(io, WinSock::IPPROTO_IP, WinSock::IP_RECVTTL, OPTION_ON); setsockopt(io, WinSock::IPPROTO_IP, WinSock::IP_RECVDSTADDR, OPTION_ON); setsockopt( io, WinSock::IPPROTO_IP, WinSock::IP_TTL, Line::DEFAULT_TTL as c_int, ); } SocketAddr::V6(_) => { setsockopt(io, WinSock::IPPROTO_IPV6, WinSock::IPV6_V6ONLY, OPTION_OFF); setsockopt(io, WinSock::IPPROTO_IPV6, WinSock::IPV6_HOPLIMIT, OPTION_ON); setsockopt( io, WinSock::IPPROTO_IPV6, WinSock::IPV6_RECVTCLASS, OPTION_ON, ); setsockopt(io, WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO, OPTION_ON); } } if let Err(e) = socket.bind(&addr.into()) { tracing::error!(target: "qudp", "Failed to bind socket: {}", e); return Err(io::Error::new(io::ErrorKind::AddrInUse, e)); } Ok(()) } fn sendmsg(&self, bufs: &[std::io::IoSlice<'_>], line: &Line) -> std::io::Result { let dst = socket2::SockAddr::from(line.dst); let mut count = 0; for buf in bufs { let mut ctrl_buf = Aligned([0; CMSG_LEN]); let mut data = WinSock::WSABUF { buf: buf.as_ptr() as *mut _, len: buf.len() as _, }; let ctrl = WinSock::WSABUF { buf: ctrl_buf.0.as_mut_ptr(), len: ctrl_buf.0.len() as _, }; let mut wsa_msg = WinSock::WSAMSG { name: dst.as_ptr() as *mut _, namelen: dst.len(), lpBuffers: &mut data, Control: ctrl, dwBufferCount: 1, dwFlags: 0, }; let mut cmsg = unsafe { first_cmsg(&mut wsa_msg).as_mut() }; let mut cmsg_len = 0; if !line.src.ip().is_unspecified() { let src = socket2::SockAddr::from(line.src); match src.family() { WinSock::AF_INET => { let src_ip = unsafe { ptr::read(src.as_ptr() as *const WinSock::SOCKADDR_IN) }; let pktinfo = WinSock::IN_PKTINFO { ipi_addr: src_ip.sin_addr, ipi_ifindex: 0, }; cmsg = append_cmsg( &wsa_msg, cmsg, WinSock::IPPROTO_IP, WinSock::IP_PKTINFO, pktinfo, &mut cmsg_len, ); } WinSock::AF_INET6 => { let src_ip = unsafe { ptr::read(src.as_ptr() as *const WinSock::SOCKADDR_IN6) }; let pktinfo = WinSock::IN6_PKTINFO { ipi6_addr: src_ip.sin6_addr, ipi6_ifindex: unsafe { src_ip.Anonymous.sin6_scope_id }, }; cmsg = append_cmsg( &wsa_msg, cmsg, WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO, pktinfo, &mut cmsg_len, ); } _ => { return Err(io::Error::from(io::ErrorKind::InvalidInput)); } } } if let Some(ecn) = line.ecn { let is_ipv4 = line.dst.is_ipv4() || matches!(line.dst.ip(), IpAddr::V6(addr) if addr.to_ipv4_mapped().is_some()); if is_ipv4 { _ = append_cmsg( &wsa_msg, cmsg, WinSock::IPPROTO_IP, WinSock::IP_ECN, ecn, &mut cmsg_len, ); } else { _ = append_cmsg( &wsa_msg, cmsg, WinSock::IPPROTO_IPV6, WinSock::IPV6_TCLASS, ecn, &mut cmsg_len, ); } } wsa_msg.Control.len = cmsg_len as _; if cmsg_len == 0 { wsa_msg.Control = WinSock::WSABUF { buf: ptr::null_mut(), len: 0, }; } let mut len = 0; let ret = unsafe { WinSock::WSASendMsg( self.io.as_raw_socket() as usize, &wsa_msg, 0, &mut len, ptr::null_mut(), None, ) }; if ret != 0 { let e = io::Error::last_os_error(); if e.kind() != io::ErrorKind::WouldBlock { return Err(e); } } count += 1; } Ok(count as usize) } fn recvmsg( &self, bufs: &mut [std::io::IoSliceMut<'_>], lines: &mut [Line], ) -> std::io::Result { let wsa_recvmsg_ptr = wsarecvmsg_ptr().expect("valid function pointer for WSARecvMsg"); let mut ctrl_buf = Aligned([0; CMSG_LEN]); let mut source: WinSock::SOCKADDR_INET = unsafe { mem::zeroed() }; let ctrl = WinSock::WSABUF { buf: ctrl_buf.0.as_mut_ptr(), len: ctrl_buf.0.len() as _, }; let mut wsa_msg = WinSock::WSAMSG { name: &mut source as *mut _ as *mut _, namelen: mem::size_of_val(&source) as _, lpBuffers: &mut WinSock::WSABUF { buf: bufs[0].as_mut_ptr(), len: bufs[0].len() as _, }, Control: ctrl, dwBufferCount: 1, dwFlags: 0, }; let mut len = 0; unsafe { let rc = (wsa_recvmsg_ptr)( self.io.as_raw_socket() as usize, &mut wsa_msg, &mut len, ptr::null_mut(), None, ); if rc == -1 { return Err(io::Error::last_os_error()); } } let addr = unsafe { let (_, addr) = socket2::SockAddr::try_init(|addr_storage, len| { *len = mem::size_of_val(&source) as _; ptr::copy_nonoverlapping(&source, addr_storage as _, 1); Ok(()) })?; addr.as_socket() }; let mut ecn_bits = 0; let mut dst_ip = None; let mut cmsg: Option<&mut WinSock::CMSGHDR> = unsafe { first_cmsg(&mut wsa_msg).as_mut() }; while let Some(cur_cmsg) = cmsg { // [header (len)][data][padding(len + sizeof(data))] -> [header][data][padding] match (cur_cmsg.cmsg_level, cur_cmsg.cmsg_type) { (WinSock::IPPROTO_IP, WinSock::IP_PKTINFO) => { let pktinfo = cmsg_decode::(cur_cmsg); let ip4 = Ipv4Addr::from(u32::from_be(unsafe { pktinfo.ipi_addr.S_un.S_addr })); dst_ip = Some(ip4.into()); } (WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO) => { let pktinfo = cmsg_decode::(cur_cmsg); // Addr is stored in big endian format dst_ip = Some(IpAddr::from(unsafe { pktinfo.ipi6_addr.u.Byte })); } (WinSock::IPPROTO_IP, WinSock::IP_ECN) => { ecn_bits = cmsg_decode::(cur_cmsg); } (WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN) => { ecn_bits = cmsg_decode::(cur_cmsg); } _ => {} } cmsg = unsafe { next_cmsg(&wsa_msg, cur_cmsg).as_mut() }; } let dst = if let Some(ip) = dst_ip { crate::SocketAddr::new(ip, self.local_addr()?.port()) } else { self.local_addr()? }; lines[0] = Line { link: Link::new(addr.unwrap(), dst), ttl: Line::DEFAULT_TTL, ecn: Some(ecn_bits as u8), seg_size: len as u16, }; Ok(1) } fn set_ttl(&self, ttl: i32) -> io::Result<()> { use std::sync::atomic::Ordering::{Acquire, SeqCst}; if ttl == self.ttl.load(Acquire) { return Ok(()); } let local = self.local_addr()?; let socket = self.io.as_raw_socket() as usize; match local.ip() { IpAddr::V4(_) => setsockopt(socket, WinSock::IPPROTO_IP, WinSock::IP_TTL, ttl), IpAddr::V6(_) => setsockopt( socket, WinSock::IPPROTO_IPV6, WinSock::IPV6_UNICAST_HOPS, ttl, ), }; self.ttl.store(ttl, SeqCst); Ok(()) } } fn append_cmsg<'a, V: Copy>( msg: &'a WinSock::WSAMSG, mut cmsg: Option<&'a mut WinSock::CMSGHDR>, level: libc::c_int, ty: libc::c_int, data: V, cmsg_len: &mut usize, ) -> Option<&'a mut WinSock::CMSGHDR> { let space = cmsg_space(mem::size_of_val(&data)); let next = cmsg.take().expect("no available cmsghdr"); next.cmsg_level = level as _; next.cmsg_type = ty as _; next.cmsg_len = cmsg_data_len(mem::size_of_val(&data)) as _; unsafe { ptr::write(cmsg_data(next) as *const V as *mut V, data); } *cmsg_len += space; unsafe { next_cmsg(msg, next).as_mut() } } fn cmsg_decode(cmsg: &mut WinSock::CMSGHDR) -> T { unsafe { ptr::read(cmsg_data(cmsg) as *const T) } } const fn cmsghdr_align(length: usize) -> usize { (length + mem::align_of::() - 1) & !(mem::align_of::() - 1) } fn cmsgdata_align(length: usize) -> usize { (length + mem::align_of::() - 1) & !(mem::align_of::() - 1) } fn cmsg_data_len(len: usize) -> usize { mem::size_of::() + len } fn cmsg_space(len: usize) -> usize { let total = mem::size_of::() + len; let align = mem::align_of::(); (total + align - 1) & !(align - 1) } unsafe fn first_cmsg(msg: &mut WinSock::WSAMSG) -> *mut WinSock::CMSGHDR { if msg.Control.len as usize >= mem::size_of::() { msg.Control.buf as *mut WinSock::CMSGHDR } else { ptr::null_mut::() } } fn next_cmsg(msg: &WinSock::WSAMSG, cmsg: &WinSock::CMSGHDR) -> *mut WinSock::CMSGHDR { let next = (cmsg as *const _ as usize + cmsghdr_align(cmsg.cmsg_len)) as *mut WinSock::CMSGHDR; let max = msg.Control.buf as usize + msg.Control.len as usize; if unsafe { next.offset(1) } as usize > max { ptr::null_mut() } else { next } } fn cmsg_data(cmsg: &mut WinSock::CMSGHDR) -> *mut libc::c_uchar { (cmsg as *const _ as usize + cmsgdata_align(mem::size_of::())) as *mut c_uchar } fn setsockopt(io: SOCKET, level: libc::c_int, name: libc::c_int, value: libc::c_int) { unsafe { WinSock::setsockopt( io, level, name, &value as *const _ as _, mem::size_of_val(&value) as _, ) }; } fn wsarecvmsg_ptr() -> &'static WinSock::LPFN_WSARECVMSG { static WSARECVMSG_PTR: std::sync::OnceLock = std::sync::OnceLock::new(); WSARECVMSG_PTR.get_or_init(|| { let s = unsafe { WinSock::socket(WinSock::AF_INET as _, WinSock::SOCK_DGRAM as _, 0) }; if s == WinSock::INVALID_SOCKET { tracing::warn!( target: "qudp", "Failed to create socket for WSARecvMsg function pointer: {}", io::Error::last_os_error() ); return None; } // Detect if OS expose WSARecvMsg API based on // https://github.com/Azure/mio-uds-windows/blob/a3c97df82018086add96d8821edb4aa85ec1b42b/src/stdnet/ext.rs#L601 let guid = WinSock::WSAID_WSARECVMSG; let mut wsa_recvmsg_ptr = None; let mut len = 0; // Safety: Option handles the NULL pointer with a None value let ret = unsafe { WinSock::WSAIoctl( s as _, WinSock::SIO_GET_EXTENSION_FUNCTION_POINTER, &guid as *const _ as *const _, mem::size_of_val(&guid) as u32, &mut wsa_recvmsg_ptr as *mut _ as *mut _, mem::size_of_val(&wsa_recvmsg_ptr) as u32, &mut len, ptr::null_mut(), None, ) }; if ret == -1 { tracing::warn!( target: "qudp", "Failed to get WSARecvMsg function pointer: {}", io::Error::last_os_error() ); } else if len as usize != mem::size_of::() { tracing::warn!( target: "qudp", "WSARecvMsg function pointer size mismatch: expected {}, got {}", mem::size_of::(), len ); wsa_recvmsg_ptr = None; } unsafe { WinSock::closesocket(s); } wsa_recvmsg_ptr }) } ================================================ FILE: tests/keychain/gen_key.sh ================================================ # gen root key openssl ecparam -name secp384r1 -genkey -noout -out rootCA-ECC.key # gen self-signed cert openssl req -new -x509 -days 3650 -key rootCA-ECC.key -sha384 -out rootCA-ECC.crt # gen server private key openssl ecparam -name secp384r1 -genkey -noout -out quic-test-net-ECC.key # create csr openssl req -new -key quic-test-net-ECC.key -out quic-test-net.csr # gen server cert with v3 cat < openssl.cnf [v3_req] basicConstraints = CA:FALSE keyUsage = nonRepudiation, digitalSignature, keyEncipherment subjectAltName = @alt_names [alt_names] DNS.1 = quic.test.net EOT openssl x509 -req \ -extfile openssl.cnf -extensions v3_req \ -in quic-test-net.csr \ -CA rootCA-ECC.crt -CAkey rootCA-ECC.key -CAcreateserial \ -out quic-test-net-ECC.crt -days 365 -sha384 # view info in quic-test-net-ECC.crt openssl x509 -in quic-test-net-ECC.crt -text -noout ================================================ FILE: tests/keychain/localhost/ca.cert ================================================ -----BEGIN CERTIFICATE----- MIIBkjCCATmgAwIBAgIUX2XYA8QU1FAkS19dimLJliUQEe4wCgYIKoZIzj0EAwIw FDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI1MDQxNjA3Mzk1NFoXDTM1MDQxNDA3 Mzk1NFowFDESMBAGA1UEAwwJbG9jYWxob3N0MFkwEwYHKoZIzj0CAQYIKoZIzj0D AQcDQgAEL4+GuiGFoN5syeBqmjbuciQrJfuq4NhiHw+g2K/0wDUrLOdPpNFzv4Dl oQxneGfGp1qgja+AhimYk+zeFIWqRqNpMGcwHQYDVR0OBBYEFObmop7JSIFCq/lg 20SaK4hAGLXgMB8GA1UdIwQYMBaAFObmop7JSIFCq/lg20SaK4hAGLXgMA8GA1Ud EwEB/wQFMAMBAf8wFAYDVR0RBA0wC4IJbG9jYWxob3N0MAoGCCqGSM49BAMCA0cA MEQCIFjNfmSQAaNt1wt86kfb80w8g+RNIoSHk8yHN8tNM0lqAiB95+L021D+58Uf c7z4m2eojR5BFV2lIdsbx8tMBN5RRA== -----END CERTIFICATE----- ================================================ FILE: tests/keychain/localhost/ca.key ================================================ -----BEGIN EC PARAMETERS----- BggqhkjOPQMBBw== -----END EC PARAMETERS----- -----BEGIN EC PRIVATE KEY----- MHcCAQEEIAXxESTjZZV9fAKLeBtFDoORO3H96YobgtSQDAivT9a9oAoGCCqGSM49 AwEHoUQDQgAEL4+GuiGFoN5syeBqmjbuciQrJfuq4NhiHw+g2K/0wDUrLOdPpNFz v4DloQxneGfGp1qgja+AhimYk+zeFIWqRg== -----END EC PRIVATE KEY----- ================================================ FILE: tests/keychain/localhost/ca.srl ================================================ 422450828B69F288653F12FD94827000BD65DF26 ================================================ FILE: tests/keychain/localhost/client.cert ================================================ -----BEGIN CERTIFICATE----- MIIBpDCCAUqgAwIBAgIUQiRQgotp8ohlPxL9lIJwAL1l3yUwCgYIKoZIzj0EAwIw FDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI1MDUyNzEwMTMxNloXDTM1MDUyNTEw MTMxNlowETEPMA0GA1UEAwwGY2xpZW50MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcD QgAEf8aJWKMK8lRW7L/GyEDJhHVRLC+9unFdVwN+Pjwuj2i88zjfB7aXdJC+gZ8/ 2PRc63f3twonhXV6XKjGjwUWUKN9MHswFwYDVR0RBBAwDoIGY2xpZW50hwR/AAAB MAsGA1UdDwQEAwIHgDATBgNVHSUEDDAKBggrBgEFBQcDAjAdBgNVHQ4EFgQU76ne 9bENwxHFwP9nGT9VPC1RTj4wHwYDVR0jBBgwFoAU5uainslIgUKr+WDbRJoriEAY teAwCgYIKoZIzj0EAwIDSAAwRQIgBN7Hq276bzZHijQB9vUJC7xDGyNs5/EL9Nm4 DgWKaocCIQCZcu350d5Zk55+gHuYtXwWO4dGfWS9FDZvGWR0g8db9w== -----END CERTIFICATE----- ================================================ FILE: tests/keychain/localhost/client.key ================================================ -----BEGIN PRIVATE KEY----- MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgF2YeSXeF2GnsG516 xEnZA9VL5LqJKWPZbIH6G74H8i6hRANCAAR/xolYowryVFbsv8bIQMmEdVEsL726 cV1XA34+PC6PaLzzON8Htpd0kL6Bnz/Y9Fzrd/e3CieFdXpcqMaPBRZQ -----END PRIVATE KEY----- ================================================ FILE: tests/keychain/localhost/server.cert ================================================ -----BEGIN CERTIFICATE----- MIIBgjCCASigAwIBAgIUQiRQgotp8ohlPxL9lIJwAL1l3yYwCgYIKoZIzj0EAwIw FDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDQxNzA4MjUyOFoXDTM2MDQxNDA4 MjUyOFowFDESMBAGA1UEAwwJbG9jYWxob3N0MFkwEwYHKoZIzj0CAQYIKoZIzj0D AQcDQgAEmoNwUXTOqO7yUjQfmTI+dg8lmteiIILzg8miSYraPKJsdCeMGiQrpLzM ViZyfg5VVpG3ajJYnzswe2v7dacpnqNYMFYwFAYDVR0RBA0wC4IJbG9jYWxob3N0 MB0GA1UdDgQWBBRgxdcCl/SpSR2hNzOhpReEGo0syzAfBgNVHSMEGDAWgBTm5qKe yUiBQqv5YNtEmiuIQBi14DAKBggqhkjOPQQDAgNIADBFAiBCsin9ppCSLZBgDkCn TfMn94pQ4YQ5R6SWPhv3jytAqAIhAIqAp6Q+urPUDu6xXT2zzl6xWY5+m4t26aFF exuP4i9p -----END CERTIFICATE----- ================================================ FILE: tests/keychain/localhost/server.key ================================================ -----BEGIN EC PARAMETERS----- BggqhkjOPQMBBw== -----END EC PARAMETERS----- -----BEGIN EC PRIVATE KEY----- MHcCAQEEIHgBydWpZJOQDbkIu/EWCn/2NYmF77bKjw1Xy8rGvHUdoAoGCCqGSM49 AwEHoUQDQgAEmoNwUXTOqO7yUjQfmTI+dg8lmteiIILzg8miSYraPKJsdCeMGiQr pLzMViZyfg5VVpG3ajJYnzswe2v7dacpng== -----END EC PRIVATE KEY----- ================================================ FILE: tests/keychain/quic.test.net/quic-test-net-ECC.crt ================================================ -----BEGIN CERTIFICATE----- MIIC4zCCAmmgAwIBAgIUeNy6M1upjE0Bf+BRjgbhx6fXFw8wCgYIKoZIzj0EAwMw gZMxCzAJBgNVBAYTAkNOMQswCQYDVQQIDAJISzELMAkGA1UEBwwCSEsxFTATBgNV BAoMDGdtLXF1aWMgdGVhbTEQMA4GA1UECwwHZ20tcXVpYzEbMBkGA1UEAwwSZ20t cXVpYyBtYWludGFpbmVyMSQwIgYJKoZIhvcNAQkBFhVxdWljX3RlYW1AZ2VubWV0 YS5uZXQwHhcNMjQwODI5MDgzNjAyWhcNMjUwODI5MDgzNjAyWjCBmzELMAkGA1UE BhMCQ04xEjAQBgNVBAgMCUd1YW5nZG9uZzERMA8GA1UEBwwIU2hlbnpoZW4xFTAT BgNVBAoMDGdtLXF1aWMgdGVhbTEQMA4GA1UECwwHZ20tcXVpYzEWMBQGA1UEAwwN cXVpYy50ZXN0Lm5ldDEkMCIGCSqGSIb3DQEJARYVcXVpY190ZWFtQGdlbm1ldGEu bmV0MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAEjBGFuP8QGBP5aM7ItEFzuwnG+ekJ HnzJhdJRd+FaGyaMmjBKF/KKVNas9EzI8fVmRItcrhb1mJOdg1ad8SGl+fNi3Oi1 n/6CRdHCfbUfV1cOJM9O9QnTffn9aZQaC5Noo3QwcjAJBgNVHRMEAjAAMAsGA1Ud DwQEAwIF4DAYBgNVHREEETAPgg1xdWljLnRlc3QubmV0MB0GA1UdDgQWBBRrMVbA pSCmPnSRuNVHVPo7ZCaeLTAfBgNVHSMEGDAWgBTk3utiwIFAIkmjR0g8LLc6ehdg oTAKBggqhkjOPQQDAwNoADBlAjEAiddVWk2O74NiOR+A+OActVu9ZSbeaPEUsV3V 9u1hAB8ybflgPsCb/YFLB3cZB6OVAjAdEW9SEZVXIUvuf9VK5AL2SBCumUg1G+jT 5e1IIh6HAEuCOfh4eTDXVpm2H00Fi8s= -----END CERTIFICATE----- ================================================ FILE: tests/keychain/quic.test.net/quic-test-net-ECC.key ================================================ -----BEGIN EC PRIVATE KEY----- MIGkAgEBBDD2+LMedQoNvJBnDcq1+9KFI2XfE489S9kzB/DoJW/3pzkG5Jq0Jlme 1PFoLZtfN3OgBwYFK4EEACKhZANiAASMEYW4/xAYE/lozsi0QXO7Ccb56QkefMmF 0lF34VobJoyaMEoX8opU1qz0TMjx9WZEi1yuFvWYk52DVp3xIaX582Lc6LWf/oJF 0cJ9tR9XVw4kz071CdN9+f1plBoLk2g= -----END EC PRIVATE KEY----- ================================================ FILE: tests/keychain/quic.test.net/quic-test-net.csr ================================================ -----BEGIN CERTIFICATE REQUEST----- MIIBlTCCARsCAQAwgZsxCzAJBgNVBAYTAkNOMRIwEAYDVQQIDAlHdWFuZ2Rvbmcx ETAPBgNVBAcMCFNoZW56aGVuMRUwEwYDVQQKDAxnbS1xdWljIHRlYW0xEDAOBgNV BAsMB2dtLXF1aWMxFjAUBgNVBAMMDXF1aWMudGVzdC5uZXQxJDAiBgkqhkiG9w0B CQEWFXF1aWNfdGVhbUBnZW5tZXRhLm5ldDB2MBAGByqGSM49AgEGBSuBBAAiA2IA BIwRhbj/EBgT+WjOyLRBc7sJxvnpCR58yYXSUXfhWhsmjJowShfyilTWrPRMyPH1 ZkSLXK4W9ZiTnYNWnfEhpfnzYtzotZ/+gkXRwn21H1dXDiTPTvUJ0335/WmUGguT aKAAMAoGCCqGSM49BAMCA2gAMGUCMBQlx6hnMv66mnbBZDF47v4hGdB7gxsOSEx8 EKBxsrcp7CkvL1siECJNun953MeZNQIxAJS36WwoUhhetA4YEog4lDGHeQ55f3os 4UjLeXOWKjswtxUISLB2xZMVm6kgb2vQqw== -----END CERTIFICATE REQUEST----- ================================================ FILE: tests/keychain/root/rootCA-ECC.crt ================================================ -----BEGIN CERTIFICATE----- MIICujCCAkCgAwIBAgIUfOA7KV6d4qkIqNA/6Rjb4Nf+3m4wCgYIKoZIzj0EAwMw gZMxCzAJBgNVBAYTAkNOMQswCQYDVQQIDAJISzELMAkGA1UEBwwCSEsxFTATBgNV BAoMDGdtLXF1aWMgdGVhbTEQMA4GA1UECwwHZ20tcXVpYzEbMBkGA1UEAwwSZ20t cXVpYyBtYWludGFpbmVyMSQwIgYJKoZIhvcNAQkBFhVxdWljX3RlYW1AZ2VubWV0 YS5uZXQwHhcNMjQwODI5MDczODUyWhcNMzQwODI3MDczODUyWjCBkzELMAkGA1UE BhMCQ04xCzAJBgNVBAgMAkhLMQswCQYDVQQHDAJISzEVMBMGA1UECgwMZ20tcXVp YyB0ZWFtMRAwDgYDVQQLDAdnbS1xdWljMRswGQYDVQQDDBJnbS1xdWljIG1haW50 YWluZXIxJDAiBgkqhkiG9w0BCQEWFXF1aWNfdGVhbUBnZW5tZXRhLm5ldDB2MBAG ByqGSM49AgEGBSuBBAAiA2IABO8rQjanzN5m3ZhflmnY6rx8Q4a5+CZQQPxRPt1f T6LTjK0NEdA9SnbITkU5OQo518UXsgMvrsO7zpIOH/HhwfYhVccxbMKXFzSOAYIE Ium/QtQyULy533javmOCJOogcqNTMFEwHQYDVR0OBBYEFOTe62LAgUAiSaNHSDws tzp6F2ChMB8GA1UdIwQYMBaAFOTe62LAgUAiSaNHSDwstzp6F2ChMA8GA1UdEwEB /wQFMAMBAf8wCgYIKoZIzj0EAwMDaAAwZQIxALmDdA9EIap8KjKmWAGSSXDfV5wl vwsciftrtl662l6GEu4uvI8lNpBqwEaEjvc2NAIwDkvRMnJnb8cmGScVa67dNSzU 8pM+auAM3NYjU2wRQmNKvKgtynG4Vkg974BnIwvp -----END CERTIFICATE----- ================================================ FILE: tests/keychain/root/rootCA-ECC.key ================================================ -----BEGIN EC PRIVATE KEY----- MIGkAgEBBDBa1GkhjbxHLCEC8/xuXT8uERSDGrfH+JvG2iQwz/w7voZjgEWnRZ2I jf0GKl1Q9FGgBwYFK4EEACKhZANiAATvK0I2p8zeZt2YX5Zp2Oq8fEOGufgmUED8 UT7dX0+i04ytDRHQPUp2yE5FOTkKOdfFF7IDL67Du86SDh/x4cH2IVXHMWzClxc0 jgGCBCLpv0LUMlC8ud942r5jgiTqIHI= -----END EC PRIVATE KEY----- ================================================ FILE: tests/keychain/root/rootCA-ECC.srl ================================================ 78DCBA335BA98C4D017FE0518E06E1C7A7D7170F ================================================ FILE: tests/keychain/start-quic-server.sh ================================================ cargo run --example server -- ./ \ --root --key ${path_to}/keychain/quic.test.net/quic-test-net-ECC.key \ --cert ${path_to}/keychain/quic.test.net/quic-test-net-ECC.crt \ --keylog