Repository: Hexilee/roa Branch: master Commit: 979d9783b9ca Files: 112 Total size: 340.2 KB Directory structure: gitextract_am2piau4/ ├── .github/ │ └── workflows/ │ ├── clippy.yml │ ├── code-coverage.yml │ ├── nightly-test.yml │ ├── release.yml │ ├── security-audit.yml │ └── stable-test.yml ├── .gitignore ├── .vscode/ │ └── launch.json ├── Cargo.toml ├── LICENSE ├── Makefile ├── README.md ├── assets/ │ ├── author.txt │ ├── cert.pem │ ├── css/ │ │ └── table.css │ ├── key.pem │ └── welcome.html ├── examples/ │ ├── echo.rs │ ├── hello-world.rs │ ├── https.rs │ ├── restful-api.rs │ ├── serve-file.rs │ ├── websocket-echo.rs │ └── welcome.rs ├── integration/ │ ├── diesel-example/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── src/ │ │ │ ├── bin/ │ │ │ │ └── api.rs │ │ │ ├── data_object.rs │ │ │ ├── endpoints.rs │ │ │ ├── lib.rs │ │ │ ├── models.rs │ │ │ └── schema.rs │ │ └── tests/ │ │ └── restful.rs │ ├── juniper-example/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── main.rs │ │ ├── models.rs │ │ └── schema.rs │ ├── multipart-example/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── assets/ │ │ │ └── index.html │ │ └── src/ │ │ └── main.rs │ └── websocket-example/ │ ├── Cargo.toml │ ├── README.md │ └── src/ │ └── main.rs ├── roa/ │ ├── Cargo.toml │ ├── README.md │ ├── src/ │ │ ├── body/ │ │ │ ├── file/ │ │ │ │ ├── content_disposition.rs │ │ │ │ └── help.rs │ │ │ └── file.rs │ │ ├── body.rs │ │ ├── compress.rs │ │ ├── cookie.rs │ │ ├── cors.rs │ │ ├── forward.rs │ │ ├── jsonrpc.rs │ │ ├── jwt.rs │ │ ├── lib.rs │ │ ├── logger.rs │ │ ├── query.rs │ │ ├── router/ │ │ │ ├── endpoints/ │ │ │ │ ├── dispatcher.rs │ │ │ │ └── guard.rs │ │ │ ├── endpoints.rs │ │ │ ├── err.rs │ │ │ └── path.rs │ │ ├── router.rs │ │ ├── stream.rs │ │ ├── tcp/ │ │ │ ├── incoming.rs │ │ │ └── listener.rs │ │ ├── tcp.rs │ │ ├── tls/ │ │ │ ├── incoming.rs │ │ │ └── listener.rs │ │ ├── tls.rs │ │ └── websocket.rs │ └── templates/ │ └── user.html ├── roa-async-std/ │ ├── Cargo.toml │ ├── README.md │ └── src/ │ ├── lib.rs │ ├── listener.rs │ ├── net.rs │ └── runtime.rs ├── roa-core/ │ ├── Cargo.toml │ ├── README.md │ └── src/ │ ├── app/ │ │ ├── future.rs │ │ ├── runtime.rs │ │ └── stream.rs │ ├── app.rs │ ├── body.rs │ ├── context/ │ │ └── storage.rs │ ├── context.rs │ ├── err.rs │ ├── executor.rs │ ├── group.rs │ ├── lib.rs │ ├── middleware.rs │ ├── request.rs │ ├── response.rs │ └── state.rs ├── roa-diesel/ │ ├── Cargo.toml │ ├── README.md │ └── src/ │ ├── async_ext.rs │ ├── lib.rs │ └── pool.rs ├── roa-juniper/ │ ├── Cargo.toml │ ├── README.md │ └── src/ │ └── lib.rs ├── rustfmt.toml ├── src/ │ └── lib.rs ├── templates/ │ └── directory.html └── tests/ ├── logger.rs ├── restful.rs └── serve-file.rs ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/clippy.yml ================================================ on: [push, pull_request] name: Clippy jobs: clippy_check: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: Install Toolchain uses: actions-rs/toolchain@v1 with: toolchain: nightly override: true components: clippy - uses: actions-rs/clippy-check@v1 with: token: ${{ secrets.GITHUB_TOKEN }} args: --all-targets --all-features ================================================ FILE: .github/workflows/code-coverage.yml ================================================ name: Code Coverage on: push: branches: - master jobs: check: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - uses: actions-rs/toolchain@v1 with: toolchain: nightly override: true - name: Check all uses: actions-rs/cargo@v1 with: command: check args: --all --all-features cover: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - uses: actions-rs/toolchain@v1 with: toolchain: nightly override: true - name: Install libsqlite3-dev run: | sudo apt-get update sudo apt-get install -y libsqlite3-dev - name: Run cargo-tarpaulin uses: actions-rs/tarpaulin@v0.1 with: version: '0.21.0' args: --avoid-cfg-tarpaulin --out Xml --all --all-features - name: Upload to codecov.io uses: codecov/codecov-action@v1.0.2 with: token: ${{secrets.CODECOV_TOKEN}} - name: Archive code coverage results uses: actions/upload-artifact@v1 with: name: code-coverage-report path: cobertura.xml ================================================ FILE: .github/workflows/nightly-test.yml ================================================ name: Nightly Test on: push: pull_request: schedule: - cron: '0 0 * * *' jobs: check: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: nightly override: true - name: Check all uses: actions-rs/cargo@v1 with: command: check args: --all --all-features test: runs-on: ubuntu-latest steps: - name: Install libsqlite3-dev run: | sudo apt-get update sudo apt-get -y install libsqlite3-dev - uses: actions/checkout@v2 - uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: nightly override: true - name: Run all tests uses: actions-rs/cargo@v1 with: command: test args: --all --all-features --no-fail-fast ================================================ FILE: .github/workflows/release.yml ================================================ name: Release on: push: branches: - release paths: - '**/Cargo.toml' - '.github/workflows/release.yml' jobs: publish: runs-on: ubuntu-latest strategy: fail-fast: false max-parallel: 1 matrix: package: - name: roa-core registryName: roa-core path: roa-core publishPath: /target/package - name: roa registryName: roa path: roa publishPath: /target/package - name: roa-juniper registryName: roa-juniper path: roa-juniper publishPath: /target/package - name: roa-diesel registryName: roa-diesel path: roa-diesel publishPath: /target/package - name: roa-async-std registryName: roa-async-std path: roa-async-std publishPath: /target/package steps: - uses: actions/checkout@v2 - name: Install Toolchain uses: actions-rs/toolchain@v1 with: toolchain: stable override: true - name: install libsqlite3-dev run: | sudo apt-get update sudo apt-get install -y libsqlite3-dev - name: get version working-directory: ${{ matrix.package.path }} run: echo "PACKAGE_VERSION=$(sed -nE 's/^\s*version = \"(.*?)\"/\1/p' Cargo.toml)" >> $GITHUB_ENV - name: check published version run: echo "PUBLISHED_VERSION=$(cargo search ${{ matrix.package.registryName }} --limit 1 | sed -nE 's/^[^\"]*\"//; s/\".*//1p' -)" >> $GITHUB_ENV - name: cargo login if: env.PACKAGE_VERSION != env.PUBLISHED_VERSION run: cargo login ${{ secrets.CRATE_TOKEN }} - name: cargo package if: env.PACKAGE_VERSION != env.PUBLISHED_VERSION working-directory: ${{ matrix.package.path }} run: | echo "package dir:" ls cargo package echo "We will publish:" $PACKAGE_VERSION echo "This is current latest:" $PUBLISHED_VERSION echo "post package dir:" cd ${{ matrix.publishPath }} ls - name: Publish ${{ matrix.package.name }} if: env.PACKAGE_VERSION != env.PUBLISHED_VERSION working-directory: ${{ matrix.package.path }} run: | echo "# Cargo Publish" | tee -a ${{runner.workspace }}/notes.md echo "\`\`\`" >> ${{runner.workspace }}/notes.md cargo publish --no-verify 2>&1 | tee -a ${{runner.workspace }}/notes.md echo "\`\`\`" >> ${{runner.workspace }}/notes.md - name: Create Release id: create_crate_release if: env.PACKAGE_VERSION != env.PUBLISHED_VERSION uses: jbolda/create-release@v1.1.0 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: tag_name: ${{ matrix.package.name }}-v${{ env.PACKAGE_VERSION }} release_name: "Release ${{ matrix.package.name }} v${{ env.PACKAGE_VERSION }} [crates.io]" bodyFromFile: ./../notes.md draft: false prerelease: false - name: Upload Release Asset id: upload-release-asset if: env.PACKAGE_VERSION != env.PUBLISHED_VERSION uses: actions/upload-release-asset@v1.0.1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: upload_url: ${{ steps.create_crate_release.outputs.upload_url }} asset_path: ./${{ matrix.package.publishPath }}/${{ matrix.package.registryName }}-${{ env.PACKAGE_VERSION }}.crate asset_name: ${{ matrix.package.registryName }}-${{ env.PACKAGE_VERSION }}.crate asset_content_type: application/x-gtar ================================================ FILE: .github/workflows/security-audit.yml ================================================ name: Security Audit on: schedule: - cron: '0 0 * * *' jobs: audit: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - uses: actions-rs/audit-check@v1 with: token: ${{ secrets.GITHUB_TOKEN }} ================================================ FILE: .github/workflows/stable-test.yml ================================================ on: [push, pull_request] name: Stable Test jobs: check: runs-on: ubuntu-latest strategy: matrix: rust: - stable - 1.60.0 steps: - uses: actions/checkout@v2 - uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: ${{ matrix.rust }} override: true - name: Check all uses: actions-rs/cargo@v1 with: command: check args: --all --features "roa/full" test: runs-on: ubuntu-latest strategy: matrix: rust: - stable - 1.60.0 steps: - name: Install libsqlite3-dev run: | sudo apt-get update sudo apt-get -y install libsqlite3-dev - uses: actions/checkout@v2 - uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: ${{ matrix.rust }} override: true - name: Run all tests uses: actions-rs/cargo@v1 with: command: test args: --all --features "roa/full" --no-fail-fast ================================================ FILE: .gitignore ================================================ /target **/*.rs.bk Cargo.lock **/upload/* .env node_modules ================================================ FILE: .vscode/launch.json ================================================ { // Use IntelliSense to learn about possible attributes. // Hover to view descriptions of existing attributes. // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ { "type": "lldb", "request": "launch", "name": "Debug unit tests in library 'roa'", "cargo": { "args": [ "test", "--no-run", "--lib", "--package=roa" ], "filter": { "name": "roa", "kind": "lib" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug unit tests in library 'roa-core'", "cargo": { "args": [ "test", "--no-run", "--lib", "--package=roa-core" ], "filter": { "name": "roa-core", "kind": "lib" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug unit tests in library 'roa-diesel'", "cargo": { "args": [ "test", "--no-run", "--lib", "--package=roa-diesel" ], "filter": { "name": "roa-diesel", "kind": "lib" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug unit tests in library 'roa-tokio'", "cargo": { "args": [ "+nightly", "test", "--no-run", "--lib", "--package=roa-tokio", "--all-features", ], "filter": { "name": "roa-tokio", "kind": "lib" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug unit tests in library 'roa-multipart'", "cargo": { "args": [ "test", "--no-run", "--lib", "--package=roa-multipart" ], "filter": { "name": "roa-multipart", "kind": "lib" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug unit tests in library 'roa-juniper'", "cargo": { "args": [ "test", "--no-run", "--lib", "--package=roa-juniper" ], "filter": { "name": "roa-juniper", "kind": "lib" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug unit tests in library 'roa-jsonrpc'", "cargo": { "args": [ "test", "--no-run", "--lib", "--package=roa-jsonrpc" ], "filter": { "name": "roa-jsonrpc", "kind": "lib" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug unit tests in library 'diesel-example'", "cargo": { "args": [ "test", "--no-run", "--lib", "--package=diesel-example" ], "filter": { "name": "diesel-example", "kind": "lib" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug executable 'api'", "cargo": { "args": [ "build", "--bin=api", "--package=diesel-example" ], "filter": { "name": "api", "kind": "bin" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug unit tests in executable 'api'", "cargo": { "args": [ "test", "--no-run", "--bin=api", "--package=diesel-example" ], "filter": { "name": "api", "kind": "bin" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug integration test 'restful'", "cargo": { "args": [ "test", "--no-run", "--test=restful", "--package=diesel-example" ], "filter": { "name": "restful", "kind": "test" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug executable 'multipart-example'", "cargo": { "args": [ "build", "--bin=multipart-example", "--package=multipart-example" ], "filter": { "name": "multipart-example", "kind": "bin" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug unit tests in executable 'multipart-example'", "cargo": { "args": [ "test", "--no-run", "--bin=multipart-example", "--package=multipart-example" ], "filter": { "name": "multipart-example", "kind": "bin" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug executable 'websocket-example'", "cargo": { "args": [ "build", "--bin=websocket-example", "--package=websocket-example" ], "filter": { "name": "websocket-example", "kind": "bin" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug unit tests in executable 'websocket-example'", "cargo": { "args": [ "test", "--no-run", "--bin=websocket-example", "--package=websocket-example" ], "filter": { "name": "websocket-example", "kind": "bin" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug executable 'juniper-example'", "cargo": { "args": [ "build", "--bin=juniper-example", "--package=juniper-example" ], "filter": { "name": "juniper-example", "kind": "bin" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug unit tests in executable 'juniper-example'", "cargo": { "args": [ "test", "--no-run", "--bin=juniper-example", "--package=juniper-example" ], "filter": { "name": "juniper-example", "kind": "bin" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug unit tests in library 'roa-root'", "cargo": { "args": [ "test", "--no-run", "--lib", "--package=roa-root" ], "filter": { "name": "roa-root", "kind": "lib" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug example 'echo'", "cargo": { "args": [ "build", "--example=echo", "--package=roa-root" ], "filter": { "name": "echo", "kind": "example" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug unit tests in example 'echo'", "cargo": { "args": [ "test", "--no-run", "--example=echo", "--package=roa-root" ], "filter": { "name": "echo", "kind": "example" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug example 'hello-world'", "cargo": { "args": [ "build", "--example=hello-world", "--package=roa-root" ], "filter": { "name": "hello-world", "kind": "example" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug unit tests in example 'hello-world'", "cargo": { "args": [ "test", "--no-run", "--example=hello-world", "--package=roa-root" ], "filter": { "name": "hello-world", "kind": "example" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug example 'https'", "cargo": { "args": [ "build", "--example=https", "--package=roa-root" ], "filter": { "name": "https", "kind": "example" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug unit tests in example 'https'", "cargo": { "args": [ "test", "--no-run", "--example=https", "--package=roa-root" ], "filter": { "name": "https", "kind": "example" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug example 'restful-api'", "cargo": { "args": [ "build", "--example=restful-api", "--package=roa-root" ], "filter": { "name": "restful-api", "kind": "example" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug unit tests in example 'restful-api'", "cargo": { "args": [ "test", "--no-run", "--example=restful-api", "--package=roa-root" ], "filter": { "name": "restful-api", "kind": "example" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug example 'serve-file'", "cargo": { "args": [ "build", "--example=serve-file", "--package=roa-root" ], "filter": { "name": "serve-file", "kind": "example" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug unit tests in example 'serve-file'", "cargo": { "args": [ "test", "--no-run", "--example=serve-file", "--package=roa-root" ], "filter": { "name": "serve-file", "kind": "example" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug example 'websocket-echo'", "cargo": { "args": [ "build", "--example=websocket-echo", "--package=roa-root" ], "filter": { "name": "websocket-echo", "kind": "example" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug unit tests in example 'websocket-echo'", "cargo": { "args": [ "test", "--no-run", "--example=websocket-echo", "--package=roa-root" ], "filter": { "name": "websocket-echo", "kind": "example" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug example 'welcome'", "cargo": { "args": [ "build", "--example=welcome", "--package=roa-root" ], "filter": { "name": "welcome", "kind": "example" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug unit tests in example 'welcome'", "cargo": { "args": [ "test", "--no-run", "--example=welcome", "--package=roa-root" ], "filter": { "name": "welcome", "kind": "example" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug integration test 'logger'", "cargo": { "args": [ "test", "--no-run", "--test=logger", "--package=roa-root" ], "filter": { "name": "logger", "kind": "test" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug integration test 'restful'", "cargo": { "args": [ "test", "--no-run", "--test=restful", "--package=roa-root" ], "filter": { "name": "restful", "kind": "test" } }, "args": [], "cwd": "${workspaceFolder}" }, { "type": "lldb", "request": "launch", "name": "Debug integration test 'serve-file'", "cargo": { "args": [ "test", "--no-run", "--test=serve-file", "--package=roa-root" ], "filter": { "name": "serve-file", "kind": "test" } }, "args": [], "cwd": "${workspaceFolder}" } ] } ================================================ FILE: Cargo.toml ================================================ [package] name = "roa-root" version = "0.6.0" authors = ["Hexilee "] edition = "2018" license = "MIT" publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [workspace] members = [ "roa", "roa-core", "roa-diesel", "roa-async-std", "roa-juniper", "integration/diesel-example", "integration/multipart-example", "integration/websocket-example", "integration/juniper-example" ] [dev-dependencies] tokio = { version = "1.15", features = ["full"] } reqwest = { version = "0.11", features = ["json", "cookies", "gzip"] } serde = { version = "1", features = ["derive"] } roa = { path = "./roa", features = ["full"] } test-case = "1.2" once_cell = "1.8" log = "0.4" slab = "0.4.2" multimap = "0.8.0" hyper = "0.14" chrono = "0.4" mime = "0.3" encoding = "0.2" askama = "0.10" http = "0.2" bytesize = "1.1" serde_json = "1.0" tracing = "0.1" futures = "0.3" doc-comment = "0.3.3" anyhow = "1.0" tracing-futures = "0.2" tracing-subscriber = { version = "0.3", features = ["env-filter"] } ================================================ FILE: LICENSE ================================================ Copyright (c) 2020 Hexilee Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Makefile ================================================ check: cargo check --all --features "roa/full" build: cargo build --all --features "roa/full" test: cargo test --all --features "roa/full" fmt: cargo +nightly fmt lint: cargo clippy --all-targets -- -D warnings check-all: cargo +nightly check --all --all-features test-all: cargo +nightly test --all --all-features ================================================ FILE: README.md ================================================

Roa

Roa is an async web framework inspired by koajs, lightweight but powerful.

[![Stable Test](https://github.com/Hexilee/roa/workflows/Stable%20Test/badge.svg)](https://github.com/Hexilee/roa/actions) [![codecov](https://codecov.io/gh/Hexilee/roa/branch/master/graph/badge.svg)](https://codecov.io/gh/Hexilee/roa) [![wiki](https://img.shields.io/badge/roa-wiki-purple.svg)](https://github.com/Hexilee/roa/wiki) [![Rust Docs](https://docs.rs/roa/badge.svg)](https://docs.rs/roa) [![Crate version](https://img.shields.io/crates/v/roa.svg)](https://crates.io/crates/roa) [![Download](https://img.shields.io/crates/d/roa.svg)](https://crates.io/crates/roa) [![MSRV-1.54](https://img.shields.io/badge/MSRV-1.54-blue.svg)](https://blog.rust-lang.org/2021/07/29/Rust-1.54.0.html) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/Hexilee/roa/blob/master/LICENSE)

Examples | Guide | Cookbook


#### Feature highlights - A lightweight, solid and well extensible core. - Supports HTTP/1.x and HTTP/2.0 protocols. - Full streaming. - Highly extensible middleware system. - Based on [`hyper`](https://github.com/hyperium/hyper), runtime-independent, you can chose async runtime as you like. - Many useful extensions. - Official runtime schemes: - (Default) [tokio](https://github.com/tokio-rs/tokio) runtime and TcpStream. - [async-std](https://github.com/async-rs/async-std) runtime and TcpStream. - Transparent content compression (br, gzip, deflate, zstd). - Configurable and nestable router. - Named uri parameters(query and router parameter). - Cookie and jwt support. - HTTPS support. - WebSocket support. - Asynchronous multipart form support. - Other middlewares(logger, CORS .etc). - Integrations - roa-diesel, integration with [diesel](https://github.com/diesel-rs/diesel). - roa-juniper, integration with [juniper](https://github.com/graphql-rust/juniper). - Works on stable Rust. #### Get start ```toml # Cargo.toml [dependencies] roa = "0.6" tokio = { version = "1.15", features = ["rt", "macro"] } ``` ```rust,no_run use roa::App; use roa::preload::*; #[tokio::main] async fn main() -> anyhow::Result<()> { let app = App::new().end("Hello, World"); app.listen("127.0.0.1:8000", |addr| { println!("Server is listening on {}", addr) })? .await?; Ok(()) } ``` Refer to [wiki](https://github.com/Hexilee/roa/wiki) for more details. ================================================ FILE: assets/author.txt ================================================ Hexilee ================================================ FILE: assets/cert.pem ================================================ -----BEGIN CERTIFICATE----- MIIFPjCCAyYCCQDvLYiYD+jqeTANBgkqhkiG9w0BAQsFADBhMQswCQYDVQQGEwJV UzELMAkGA1UECAwCQ0ExCzAJBgNVBAcMAlNGMRAwDgYDVQQKDAdDb21wYW55MQww CgYDVQQLDANPcmcxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xODAxMjUx NzQ2MDFaFw0xOTAxMjUxNzQ2MDFaMGExCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJD QTELMAkGA1UEBwwCU0YxEDAOBgNVBAoMB0NvbXBhbnkxDDAKBgNVBAsMA09yZzEY MBYGA1UEAwwPd3d3LmV4YW1wbGUuY29tMIICIjANBgkqhkiG9w0BAQEFAAOCAg8A MIICCgKCAgEA2WzIA2IpVR9Tb9EFhITlxuhE5rY2a3S6qzYNzQVgSFggxXEPn8k1 sQEcer5BfAP986Sck3H0FvB4Bt/I8PwOtUCmhwcc8KtB5TcGPR4fjXnrpC+MIK5U NLkwuyBDKziYzTdBj8kUFX1WxmvEHEgqToPOZfBgsS71cJAR/zOWraDLSRM54jXy voLZN4Ti9rQagQrvTQ44Vz5ycDQy7UxtbUGh1CVv69vNVr7/SOOh/Nw5FNOZWLWr odGyoec5wh9iqRZgRqiTUc6Lt7V2RWc2X2gjwST2UfI+U46Ip3oaQ7ZD4eAkoqND xdniBZAykVG3c/99ux4BAESTF8fsNch6UticBxYMuTu+ouvP0psfI9wwwNliJDmA CRUTB9AgRynbL1AzhqQoDfsb98IZfjfNOpwnwuLwpMAPhbgd5KNdZaIJ4Hb6/stI yFElOExxd3TAxF2Gshd/lq1JcNHAZ1DSXV5MvOWT/NWgXwbIzUgQ8eIi+HuDYX2U UuaB6R8tbd52H7rbUv6HrfinuSlKWqjSYLkiKHkwUpoMw8y9UycRSzs1E9nPwPTO vRXb0mNCQeBCV9FvStNVXdCUTT8LGPv87xSD2pmt7LijlE6mHLG8McfcWkzA69un CEHIFAFDimTuN7EBljc119xWFTcHMyoZAfFF+oTqwSbBGImruCxnaJECAwEAATAN BgkqhkiG9w0BAQsFAAOCAgEApavsgsn7SpPHfhDSN5iZs1ILZQRewJg0Bty0xPfk 3tynSW6bNH3nSaKbpsdmxxomthNSQgD2heOq1By9YzeOoNR+7Pk3s4FkASnf3ToI JNTUasBFFfaCG96s4Yvs8KiWS/k84yaWuU8c3Wb1jXs5Rv1qE1Uvuwat1DSGXSoD JNluuIkCsC4kWkyq5pWCGQrabWPRTWsHwC3PTcwSRBaFgYLJaR72SloHB1ot02zL d2age9dmFRFLLCBzP+D7RojBvL37qS/HR+rQ4SoQwiVc/JzaeqSe7ZbvEH9sZYEu ALowJzgbwro7oZflwTWunSeSGDSltkqKjvWvZI61pwfHKDahUTmZ5h2y67FuGEaC CIOUI8dSVSPKITxaq3JL4ze2e9/0Lt7hj19YK2uUmtMAW5Tirz4Yx5lyGH9U8Wur y/X8VPxTc4A9TMlJgkyz0hqvhbPOT/zSWB10zXh0glKAsSBryAOEDxV1UygmSir7 YV8Qaq+oyKUTMc1MFq5vZ07M51EPaietn85t8V2Y+k/8XYltRp32NxsypxAJuyxh g/ko6RVTrWa1sMvz/F9LFqAdKiK5eM96lh9IU4xiLg4ob8aS/GRAA8oIFkZFhLrt tOwjIUPmEPyHWFi8dLpNuQKYalLYhuwZftG/9xV+wqhKGZO9iPrpHSYBRTap8w2y 1QU= -----END CERTIFICATE----- ================================================ FILE: assets/css/table.css ================================================ /* spacing */ table { table-layout: fixed; width: 80%; border-collapse: collapse; } thead th { text-align: left } thead th:nth-child(1) { width: 40%; } thead th:nth-child(2) { width: 20%; } thead th:nth-child(3) { width: 40%; } th, td { padding: 10px; } ================================================ FILE: assets/key.pem ================================================ -----BEGIN RSA PRIVATE KEY----- MIIJKAIBAAKCAgEA2WzIA2IpVR9Tb9EFhITlxuhE5rY2a3S6qzYNzQVgSFggxXEP n8k1sQEcer5BfAP986Sck3H0FvB4Bt/I8PwOtUCmhwcc8KtB5TcGPR4fjXnrpC+M IK5UNLkwuyBDKziYzTdBj8kUFX1WxmvEHEgqToPOZfBgsS71cJAR/zOWraDLSRM5 4jXyvoLZN4Ti9rQagQrvTQ44Vz5ycDQy7UxtbUGh1CVv69vNVr7/SOOh/Nw5FNOZ WLWrodGyoec5wh9iqRZgRqiTUc6Lt7V2RWc2X2gjwST2UfI+U46Ip3oaQ7ZD4eAk oqNDxdniBZAykVG3c/99ux4BAESTF8fsNch6UticBxYMuTu+ouvP0psfI9wwwNli JDmACRUTB9AgRynbL1AzhqQoDfsb98IZfjfNOpwnwuLwpMAPhbgd5KNdZaIJ4Hb6 /stIyFElOExxd3TAxF2Gshd/lq1JcNHAZ1DSXV5MvOWT/NWgXwbIzUgQ8eIi+HuD YX2UUuaB6R8tbd52H7rbUv6HrfinuSlKWqjSYLkiKHkwUpoMw8y9UycRSzs1E9nP wPTOvRXb0mNCQeBCV9FvStNVXdCUTT8LGPv87xSD2pmt7LijlE6mHLG8McfcWkzA 69unCEHIFAFDimTuN7EBljc119xWFTcHMyoZAfFF+oTqwSbBGImruCxnaJECAwEA AQKCAgAME3aoeXNCPxMrSri7u4Xnnk71YXl0Tm9vwvjRQlMusXZggP8VKN/KjP0/ 9AE/GhmoxqPLrLCZ9ZE1EIjgmZ9Xgde9+C8rTtfCG2RFUL7/5J2p6NonlocmxoJm YkxYwjP6ce86RTjQWL3RF3s09u0inz9/efJk5O7M6bOWMQ9VZXDlBiRY5BYvbqUR 6FeSzD4MnMbdyMRoVBeXE88gTvZk8xhB6DJnLzYgc0tKiRoeKT0iYv5JZw25VyRM ycLzfTrFmXCPfB1ylb483d9Ly4fBlM8nkx37PzEnAuukIawDxsPOb9yZC+hfvNJI 7NFiMN+3maEqG2iC00w4Lep4skHY7eHUEUMl+Wjr+koAy2YGLWAwHZQTm7iXn9Ab L6adL53zyCKelRuEQOzbeosJAqS+5fpMK0ekXyoFIuskj7bWuIoCX7K/kg6q5IW+ vC2FrlsrbQ79GztWLVmHFO1I4J9M5r666YS0qdh8c+2yyRl4FmSiHfGxb3eOKpxQ b6uI97iZlkxPF9LYUCSc7wq0V2gGz+6LnGvTHlHrOfVXqw/5pLAKhXqxvnroDTwz 0Ay/xFF6ei/NSxBY5t8ztGCBm45wCU3l8pW0X6dXqwUipw5b4MRy1VFRu6rqlmbL OPSCuLxqyqsigiEYsBgS/icvXz9DWmCQMPd2XM9YhsHvUq+R4QKCAQEA98EuMMXI 6UKIt1kK2t/3OeJRyDd4iv/fCMUAnuPjLBvFE4cXD/SbqCxcQYqb+pue3PYkiTIC 71rN8OQAc5yKhzmmnCE5N26br/0pG4pwEjIr6mt8kZHmemOCNEzvhhT83nfKmV0g 9lNtuGEQMiwmZrpUOF51JOMC39bzcVjYX2Cmvb7cFbIq3lR0zwM+aZpQ4P8LHCIu bgHmwbdlkLyIULJcQmHIbo6nPFB3ZZE4mqmjwY+rA6Fh9rgBa8OFCfTtrgeYXrNb IgZQ5U8GoYRPNC2ot0vpTinraboa/cgm6oG4M7FW1POCJTl+/ktHEnKuO5oroSga /BSg7hCNFVaOhwKCAQEA4Kkys0HtwEbV5mY/NnvUD5KwfXX7BxoXc9lZ6seVoLEc KjgPYxqYRVrC7dB2YDwwp3qcRTi/uBAgFNm3iYlDzI4xS5SeaudUWjglj7BSgXE2 iOEa7EwcvVPluLaTgiWjlzUKeUCNNHWSeQOt+paBOT+IgwRVemGVpAgkqQzNh/nP tl3p9aNtgzEm1qVlPclY/XUCtf3bcOR+z1f1b4jBdn0leu5OhnxkC+Htik+2fTXD jt6JGrMkanN25YzsjnD3Sn+v6SO26H99wnYx5oMSdmb8SlWRrKtfJHnihphjG/YY l1cyorV6M/asSgXNQfGJm4OuJi0I4/FL2wLUHnU+JwKCAQEAzh4WipcRthYXXcoj gMKRkMOb3GFh1OpYqJgVExtudNTJmZxq8GhFU51MR27Eo7LycMwKy2UjEfTOnplh Us2qZiPtW7k8O8S2m6yXlYUQBeNdq9IuuYDTaYD94vsazscJNSAeGodjE+uGvb1q 1wLqE87yoE7dUInYa1cOA3+xy2/CaNuviBFJHtzOrSb6tqqenQEyQf6h9/12+DTW t5pSIiixHrzxHiFqOoCLRKGToQB+71rSINwTf0nITNpGBWmSj5VcC3VV3TG5/XxI fPlxV2yhD5WFDPVNGBGvwPDSh4jSMZdZMSNBZCy4XWFNSKjGEWoK4DFYed3DoSt9 5IG1YwKCAQA63ntHl64KJUWlkwNbboU583FF3uWBjee5VqoGKHhf3CkKMxhtGqnt +oN7t5VdUEhbinhqdx1dyPPvIsHCS3K1pkjqii4cyzNCVNYa2dQ00Qq+QWZBpwwc 3GAkz8rFXsGIPMDa1vxpU6mnBjzPniKMcsZ9tmQDppCEpBGfLpio2eAA5IkK8eEf cIDB3CM0Vo94EvI76CJZabaE9IJ+0HIJb2+jz9BJ00yQBIqvJIYoNy9gP5Xjpi+T qV/tdMkD5jwWjHD3AYHLWKUGkNwwkAYFeqT/gX6jpWBP+ZRPOp011X3KInJFSpKU DT5GQ1Dux7EMTCwVGtXqjO8Ym5wjwwsfAoIBAEcxlhIW1G6BiNfnWbNPWBdh3v/K 5Ln98Rcrz8UIbWyl7qNPjYb13C1KmifVG1Rym9vWMO3KuG5atK3Mz2yLVRtmWAVc fxzR57zz9MZFDun66xo+Z1wN3fVxQB4CYpOEI4Lb9ioX4v85hm3D6RpFukNtRQEc Gfr4scTjJX4jFWDp0h6ffMb8mY+quvZoJ0TJqV9L9Yj6Ksdvqez/bdSraev97bHQ 4gbQxaTZ6WjaD4HjpPQefMdWp97Metg0ZQSS8b8EzmNFgyJ3XcjirzwliKTAQtn6 I2sd0NCIooelrKRD8EJoDUwxoOctY7R97wpZ7/wEHU45cBCbRV3H4JILS5c= -----END RSA PRIVATE KEY----- ================================================ FILE: assets/welcome.html ================================================ Roa Framework

Welcome!

Go to roa for more information...

================================================ FILE: examples/echo.rs ================================================ //! RUST_LOG=info Cargo run --example echo, //! then request http://127.0.0.1:8000 with some payload. use std::error::Error as StdError; use roa::logger::logger; use roa::preload::*; use roa::{App, Context}; use tracing::info; use tracing_subscriber::EnvFilter; async fn echo(ctx: &mut Context) -> roa::Result { let stream = ctx.req.stream(); ctx.resp.write_stream(stream); Ok(()) } #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .try_init() .map_err(|err| anyhow::anyhow!("fail to init tracing subscriber: {}", err))?; let app = App::new().gate(logger).end(echo); app.listen("127.0.0.1:8000", |addr| { info!("Server is listening on {}", addr) })? .await?; Ok(()) } ================================================ FILE: examples/hello-world.rs ================================================ //! RUST_LOG=info Cargo run --example hello-world, //! then request http://127.0.0.1:8000. use log::info; use roa::logger::logger; use roa::preload::*; use roa::App; use tracing_subscriber::EnvFilter; #[tokio::main] async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .try_init() .map_err(|err| anyhow::anyhow!("fail to init tracing subscriber: {}", err))?; let app = App::new().gate(logger).end("Hello, World!"); app.listen("127.0.0.1:8000", |addr| { info!("Server is listening on {}", addr) })? .await?; Ok(()) } ================================================ FILE: examples/https.rs ================================================ //! RUST_LOG=info Cargo run --example https, //! then request https://127.0.0.1:8000. use std::error::Error as StdError; use std::fs::File; use std::io::BufReader; use log::info; use roa::body::DispositionType; use roa::logger::logger; use roa::preload::*; use roa::tls::pemfile::{certs, rsa_private_keys}; use roa::tls::{Certificate, PrivateKey, ServerConfig, TlsListener}; use roa::{App, Context}; use tracing_subscriber::EnvFilter; async fn serve_file(ctx: &mut Context) -> roa::Result { ctx.write_file("assets/welcome.html", DispositionType::Inline) .await } #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .try_init() .map_err(|err| anyhow::anyhow!("fail to init tracing subscriber: {}", err))?; let mut cert_file = BufReader::new(File::open("assets/cert.pem")?); let mut key_file = BufReader::new(File::open("assets/key.pem")?); let cert_chain = certs(&mut cert_file)? .into_iter() .map(Certificate) .collect(); let config = ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_single_cert( cert_chain, PrivateKey(rsa_private_keys(&mut key_file)?.remove(0)), )?; let app = App::new().gate(logger).end(serve_file); app.listen_tls("127.0.0.1:8000", config, |addr| { info!("Server is listening on https://localhost:{}", addr.port()) })? .await?; Ok(()) } ================================================ FILE: examples/restful-api.rs ================================================ //! RUST_LOG=info Cargo run --example restful-api, //! then: //! - `curl 127.0.0.1:8000/user/0` //! query user where id=0 //! - `curl -H "Content-type: application/json" -d '{"name":"Hexilee", "age": 20}' -X POST 127.0.0.1:8000/user` //! create a new user //! - `curl -H "Content-type: application/json" -d '{"name":"Alice", "age": 20}' -X PUT 127.0.0.1:8000/user/0` //! update user where id=0, return the old data //! - `curl 127.0.0.1:8000/user/0 -X DELETE` //! delete user where id=0 use std::result::Result as StdResult; use std::sync::Arc; use roa::http::StatusCode; use roa::preload::*; use roa::router::{get, post, Router}; use roa::{throw, App, Context, Result}; use serde::{Deserialize, Serialize}; use serde_json::json; use slab::Slab; use tokio::sync::RwLock; #[derive(Debug, Serialize, Deserialize, Clone)] struct User { name: String, age: u8, } #[derive(Clone)] struct Database { table: Arc>>, } impl Database { fn new() -> Self { Self { table: Arc::new(RwLock::new(Slab::new())), } } async fn create(&self, user: User) -> usize { self.table.write().await.insert(user) } async fn retrieve(&self, id: usize) -> Result { match self.table.read().await.get(id) { Some(user) => Ok(user.clone()), None => throw!(StatusCode::NOT_FOUND), } } async fn update(&self, id: usize, new_user: &mut User) -> Result { match self.table.write().await.get_mut(id) { Some(user) => { std::mem::swap(new_user, user); Ok(()) } None => throw!(StatusCode::NOT_FOUND), } } async fn delete(&self, id: usize) -> Result { if !self.table.read().await.contains(id) { throw!(StatusCode::NOT_FOUND) } Ok(self.table.write().await.remove(id)) } } async fn create_user(ctx: &mut Context) -> Result { let user: User = ctx.read_json().await?; let id = ctx.create(user).await; ctx.write_json(&json!({ "id": id }))?; ctx.resp.status = StatusCode::CREATED; Ok(()) } async fn get_user(ctx: &mut Context) -> Result { let id: usize = ctx.must_param("id")?.parse()?; let user = ctx.retrieve(id).await?; ctx.write_json(&user) } async fn update_user(ctx: &mut Context) -> Result { let id: usize = ctx.must_param("id")?.parse()?; let mut user: User = ctx.read_json().await?; ctx.update(id, &mut user).await?; ctx.write_json(&user) } async fn delete_user(ctx: &mut Context) -> Result { let id: usize = ctx.must_param("id")?.parse()?; let user = ctx.delete(id).await?; ctx.write_json(&user) } #[tokio::main] async fn main() -> StdResult<(), Box> { let router = Router::new() .on("/", post(create_user)) .on("/:id", get(get_user).put(update_user).delete(delete_user)); let app = App::state(Database::new()).end(router.routes("/user")?); app.listen("127.0.0.1:8000", |addr| { println!("Server is listening on {}", addr) })? .await?; Ok(()) } ================================================ FILE: examples/serve-file.rs ================================================ //! RUST_LOG=info cargo run --example serve-file, //! then request http://127.0.0.1:8000. use std::borrow::Cow; use std::path::Path; use std::result::Result as StdResult; use std::time::SystemTime; use askama::Template; use bytesize::ByteSize; use chrono::offset::Local; use chrono::DateTime; use log::info; use roa::body::DispositionType::*; use roa::compress::Compress; use roa::http::StatusCode; use roa::logger::logger; use roa::preload::*; use roa::router::{get, Router}; use roa::{throw, App, Context, Next, Result}; use tokio::fs::{metadata, read_dir}; use tracing_subscriber::EnvFilter; #[derive(Template)] #[template(path = "directory.html")] struct Dir<'a> { title: &'a str, root: &'a str, dirs: Vec, files: Vec, } struct DirInfo { link: String, name: String, modified: String, } struct FileInfo { link: String, name: String, modified: String, size: String, } impl<'a> Dir<'a> { fn new(title: &'a str, root: &'a str) -> Self { Self { title, root, dirs: Vec::new(), files: Vec::new(), } } } async fn path_checker(ctx: &mut Context, next: Next<'_>) -> Result { if ctx.must_param("path")?.contains("..") { throw!(StatusCode::BAD_REQUEST, "invalid path") } else { next.await } } async fn serve_path(ctx: &mut Context) -> Result { let path_value = ctx.must_param("path")?; let path = path_value.as_ref(); let file_path = Path::new(".").join(path); let meta = metadata(&file_path).await?; if meta.is_file() { ctx.write_file(file_path, Inline).await } else if meta.is_dir() { serve_dir(ctx, path).await } else { throw!(StatusCode::NOT_FOUND, "path not found") } } async fn serve_root(ctx: &mut Context) -> Result { serve_dir(ctx, "").await } async fn serve_dir(ctx: &mut Context, path: &str) -> Result { let uri_path = Path::new("/").join(path); let mut entries = read_dir(Path::new(".").join(path)).await?; let title = uri_path .file_name() .map(|os_str| os_str.to_string_lossy()) .unwrap_or(Cow::Borrowed("/")); let root_str = uri_path.to_string_lossy(); let mut dir = Dir::new(&title, &root_str); while let Some(entry) = entries.next_entry().await? { let metadata = entry.metadata().await?; if metadata.is_dir() { dir.dirs.push(DirInfo { link: uri_path .join(entry.file_name()) .to_string_lossy() .to_string(), name: entry.file_name().to_string_lossy().to_string(), modified: format_time(metadata.modified()?), }) } if metadata.is_file() { dir.files.push(FileInfo { link: uri_path .join(entry.file_name()) .to_string_lossy() .to_string(), name: entry.file_name().to_string_lossy().to_string(), modified: format_time(metadata.modified()?), size: ByteSize(metadata.len()).to_string(), }) } } ctx.render(&dir) } fn format_time(time: SystemTime) -> String { let datetime: DateTime = time.into(); datetime.format("%d/%m/%Y %T").to_string() } #[tokio::main] async fn main() -> StdResult<(), Box> { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .try_init() .map_err(|err| anyhow::anyhow!("fail to init tracing subscriber: {}", err))?; let wildcard_router = Router::new().gate(path_checker).on("/", get(serve_path)); let router = Router::new() .on("/", serve_root) .include("/*{path}", wildcard_router); let app = App::new() .gate(logger) .gate(Compress::default()) .end(router.routes("/")?); app.listen("127.0.0.1:8000", |addr| { info!("Server is listening on {}", addr) })? .await .map_err(Into::into) } ================================================ FILE: examples/websocket-echo.rs ================================================ //! RUST_LOG=info cargo run --example websocket-echo, //! then request ws://127.0.0.1:8000/chat. use std::error::Error as StdError; use futures::StreamExt; use http::Method; use log::{error, info}; use roa::cors::Cors; use roa::logger::logger; use roa::preload::*; use roa::router::{allow, Router}; use roa::websocket::Websocket; use roa::App; use tracing_subscriber::EnvFilter; #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .try_init() .map_err(|err| anyhow::anyhow!("fail to init tracing subscriber: {}", err))?; let router = Router::new().on( "/chat", allow( [Method::GET], Websocket::new(|_ctx, stream| async move { let (write, read) = stream.split(); if let Err(err) = read.forward(write).await { error!("{}", err); } }), ), ); let app = App::new() .gate(logger) .gate(Cors::new()) .end(router.routes("/")?); app.listen("127.0.0.1:8000", |addr| { info!("Server is listening on {}", addr) })? .await?; Ok(()) } ================================================ FILE: examples/welcome.rs ================================================ //! RUST_LOG=info Cargo run --example welcome, //! then request http://127.0.0.1:8000 with some payload. use std::error::Error as StdError; use log::info; use roa::logger::logger; use roa::preload::*; use roa::App; use tracing_subscriber::EnvFilter; #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .try_init() .map_err(|err| anyhow::anyhow!("fail to init tracing subscriber: {}", err))?; let app = App::new() .gate(logger) .end(include_str!("../assets/welcome.html")); app.listen("127.0.0.1:8000", |addr| { info!("Server is listening on {}", addr) })? .await?; Ok(()) } ================================================ FILE: integration/diesel-example/Cargo.toml ================================================ [package] name = "diesel-example" version = "0.1.0" authors = ["Hexilee "] edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] tokio = { version = "1.15", features = ["full"] } diesel = { version = "1.4", features = ["extras", "sqlite"] } roa = { path = "../../roa", features = ["router", "json"] } roa-diesel = { path = "../../roa-diesel" } tracing-futures = "0.2" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing = "0.1" serde = { version = "1", features = ["derive"] } anyhow = "1.0" [dev-dependencies] reqwest = { version = "0.11", features = ["json", "cookies", "gzip"] } ================================================ FILE: integration/diesel-example/README.md ================================================ ```bash RUST_LOG=info cargo run --bin api ``` - `curl 127.0.0.1:8000/post/1` query post where id=0 and published - `curl -H "Content-type: application/json" -d '{"title":"Hello", "body": "Hello, world", "published": false}' -X POST 127.0.0.1:8000/post` create a new post - `curl -H "Content-type: application/json" -d '{"title":"Hello", "body": "Hello, world", "published": true}' -X PUT 127.0.0.1:8000/post/1` update post where id=0, return the old data - `curl 127.0.0.1:8000/post/1 -X DELETE` delete post where id=0 ================================================ FILE: integration/diesel-example/src/bin/api.rs ================================================ use diesel_example::{create_pool, post_router}; use roa::logger::logger; use roa::preload::*; use roa::App; use tracing::info; use tracing_subscriber::EnvFilter; #[tokio::main] async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .try_init() .map_err(|err| anyhow::anyhow!("fail to init tracing subscriber: {}", err))?; let app = App::state(create_pool()?) .gate(logger) .end(post_router().routes("/post")?); app.listen("127.0.0.1:8000", |addr| { info!("Server is listening on {}", addr) })? .await?; Ok(()) } ================================================ FILE: integration/diesel-example/src/data_object.rs ================================================ use serde::Deserialize; use crate::schema::posts; // for both transfer and access #[derive(Debug, Insertable, Deserialize)] #[table_name = "posts"] pub struct NewPost { pub title: String, pub body: String, pub published: bool, } ================================================ FILE: integration/diesel-example/src/endpoints.rs ================================================ use diesel::prelude::*; use diesel::result::Error; use roa::http::StatusCode; use roa::preload::*; use roa::router::{get, post, Router}; use roa::{throw, Context, Result}; use roa_diesel::preload::*; use crate::data_object::NewPost; use crate::models::*; use crate::schema::posts::dsl::{self, posts}; use crate::State; pub fn post_router() -> Router { Router::new() .on("/", post(create_post)) .on("/:id", get(get_post).put(update_post).delete(delete_post)) } async fn create_post(ctx: &mut Context) -> Result { let data: NewPost = ctx.read_json().await?; let conn = ctx.get_conn().await?; let post = ctx .exec .spawn_blocking(move || { conn.transaction::(|| { diesel::insert_into(crate::schema::posts::table) .values(&data) .execute(&conn)?; Ok(posts.order(dsl::id.desc()).first(&conn)?) }) }) .await?; ctx.resp.status = StatusCode::CREATED; ctx.write_json(&post) } async fn get_post(ctx: &mut Context) -> Result { let id: i32 = ctx.must_param("id")?.parse()?; match ctx .first::(posts.find(id).filter(dsl::published.eq(true))) .await? { None => throw!(StatusCode::NOT_FOUND, &format!("post({}) not found", id)), Some(post) => ctx.write_json(&post), } } async fn update_post(ctx: &mut Context) -> Result { let id: i32 = ctx.must_param("id")?.parse()?; let NewPost { title, body, published, } = ctx.read_json().await?; match ctx.first::(posts.find(id)).await? { None => throw!(StatusCode::NOT_FOUND, &format!("post({}) not found", id)), Some(post) => { ctx.execute(diesel::update(posts.find(id)).set(( dsl::title.eq(title), dsl::body.eq(body), dsl::published.eq(published), ))) .await?; ctx.write_json(&post) } } } async fn delete_post(ctx: &mut Context) -> Result { let id: i32 = ctx.must_param("id")?.parse()?; match ctx.first::(posts.find(id)).await? { None => throw!(StatusCode::NOT_FOUND, &format!("post({}) not found", id)), Some(post) => { ctx.execute(diesel::delete(posts.find(id))).await?; ctx.write_json(&post) } } } ================================================ FILE: integration/diesel-example/src/lib.rs ================================================ #[macro_use] extern crate diesel; mod data_object; mod endpoints; pub mod models; pub mod schema; use diesel::prelude::*; use diesel::sqlite::SqliteConnection; use roa_diesel::{make_pool, Pool}; #[derive(Clone)] pub struct State(pub Pool); impl AsRef> for State { fn as_ref(&self) -> &Pool { &self.0 } } pub fn create_pool() -> anyhow::Result { let pool = make_pool(":memory:")?; diesel::sql_query( r" CREATE TABLE posts ( id INTEGER PRIMARY KEY, title VARCHAR NOT NULL, body TEXT NOT NULL, published BOOLEAN NOT NULL DEFAULT 'f' ) ", ) .execute(&*pool.get()?)?; Ok(State(pool)) } pub use endpoints::post_router; ================================================ FILE: integration/diesel-example/src/models.rs ================================================ use diesel::Queryable; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Queryable, Serialize, Deserialize)] pub struct Post { pub id: i32, pub title: String, pub body: String, pub published: bool, } ================================================ FILE: integration/diesel-example/src/schema.rs ================================================ table! { posts (id) { id -> Integer, title -> Text, body -> Text, published -> Bool, } } ================================================ FILE: integration/diesel-example/tests/restful.rs ================================================ use diesel_example::models::Post; use diesel_example::{create_pool, post_router}; use roa::http::StatusCode; use roa::preload::*; use roa::App; use serde::Serialize; use tracing::{debug, info}; use tracing_subscriber::EnvFilter; #[derive(Debug, Serialize, Copy, Clone)] pub struct NewPost<'a> { pub title: &'a str, pub body: &'a str, pub published: bool, } impl PartialEq for NewPost<'_> { fn eq(&self, other: &Post) -> bool { self.title == other.title && self.body == other.body && self.published == other.published } } #[tokio::test] async fn test() -> anyhow::Result<()> { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .try_init() .map_err(|err| anyhow::anyhow!("fail to init tracing subscriber: {}", err))?; let app = App::state(create_pool()?).end(post_router().routes("/post")?); let (addr, server) = app.run()?; tokio::task::spawn(server); info!("server is running on {}", addr); let base_url = format!("http://{}/post", addr); let client = reqwest::Client::new(); // Not Found let resp = client.get(&format!("{}/{}", &base_url, 0)).send().await?; assert_eq!(StatusCode::NOT_FOUND, resp.status()); debug!("{}/{} not found", &base_url, 0); // Create let first_post = NewPost { title: "Hello", body: "Welcome to roa-diesel", published: false, }; let resp = client.post(&base_url).json(&first_post).send().await?; assert_eq!(StatusCode::CREATED, resp.status()); let created_post: Post = resp.json().await?; let id = created_post.id; assert_eq!(&first_post, &created_post); // Post isn't published, get nothing let resp = client.get(&format!("{}/{}", &base_url, id)).send().await?; assert_eq!(StatusCode::NOT_FOUND, resp.status()); // Update let second_post = NewPost { published: true, ..first_post }; let resp = client .put(&format!("{}/{}", &base_url, id)) .json(&second_post) .send() .await?; assert_eq!(StatusCode::OK, resp.status()); // Return old post let updated_post: Post = resp.json().await?; assert_eq!(&first_post, &updated_post); // Get it let resp = client.get(&format!("{}/{}", &base_url, id)).send().await?; assert_eq!(StatusCode::OK, resp.status()); let query_post: Post = resp.json().await?; assert_eq!(&second_post, &query_post); // Delete let resp = client .delete(&format!("{}/{}", &base_url, id)) .send() .await?; assert_eq!(StatusCode::OK, resp.status()); let deleted_post: Post = resp.json().await?; assert_eq!(&second_post, &deleted_post); // Post is deleted, get nothing let resp = client.get(&format!("{}/{}", &base_url, id)).send().await?; assert_eq!(StatusCode::NOT_FOUND, resp.status()); Ok(()) } ================================================ FILE: integration/juniper-example/Cargo.toml ================================================ [package] name = "juniper-example" version = "0.1.0" authors = ["Hexilee "] edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] diesel = "1.4" roa = { path = "../../roa", features = ["router"] } roa-diesel = { path = "../../roa-diesel" } roa-juniper = { path = "../../roa-juniper" } diesel-example = { path = "../diesel-example" } tokio = { version = "1.15", features = ["full"] } tracing = "0.1" serde = { version = "1", features = ["derive"] } futures = "0.3" juniper = { version = "0.15", default-features = false } tracing-futures = "0.2" tracing-subscriber = { version = "0.3", features = ["env-filter"] } anyhow = "1.0" ================================================ FILE: integration/juniper-example/README.md ================================================ ```bash RUST_LOG=info cargo run ``` Then request http://127.0.0.1:8000, play with the GraphQL playground! ================================================ FILE: integration/juniper-example/src/main.rs ================================================ #[macro_use] extern crate diesel; mod models; mod schema; use std::error::Error as StdError; use diesel::prelude::*; use diesel::result::Error; use diesel_example::{create_pool, State}; use juniper::http::playground::playground_source; use juniper::{ graphql_value, EmptySubscription, FieldError, FieldResult, GraphQLInputObject, RootNode, }; use roa::http::Method; use roa::logger::logger; use roa::preload::*; use roa::router::{allow, get, Router}; use roa::App; use roa_diesel::preload::*; use roa_juniper::{GraphQL, JuniperContext}; use serde::Serialize; use tracing::info; use tracing_subscriber::EnvFilter; use crate::models::Post; use crate::schema::posts; #[derive(Debug, Insertable, Serialize, GraphQLInputObject)] #[table_name = "posts"] #[graphql(description = "A new post")] struct NewPost { title: String, body: String, published: bool, } struct Query; #[juniper::graphql_object( Context = JuniperContext, )] impl Query { async fn post( &self, ctx: &JuniperContext, id: i32, published: bool, ) -> FieldResult { use crate::schema::posts::dsl::{self, posts}; match ctx .first(posts.find(id).filter(dsl::published.eq(published))) .await? { Some(post) => Ok(post), None => Err(FieldError::new( "post not found", graphql_value!({ "status": 404, "id": id }), )), } } } struct Mutation; #[juniper::graphql_object( Context = JuniperContext, )] impl Mutation { async fn create_post( &self, ctx: &JuniperContext, new_post: NewPost, ) -> FieldResult { use crate::schema::posts::dsl::{self, posts}; let conn = ctx.get_conn().await?; let post = ctx .exec .spawn_blocking(move || { conn.transaction::(|| { diesel::insert_into(crate::schema::posts::table) .values(&new_post) .execute(&conn)?; Ok(posts.order(dsl::id.desc()).first(&conn)?) }) }) .await?; Ok(post) } async fn update_post( &self, id: i32, ctx: &JuniperContext, new_post: NewPost, ) -> FieldResult { use crate::schema::posts::dsl::{self, posts}; match ctx.first(posts.find(id)).await? { None => Err(FieldError::new( "post not found", graphql_value!({ "status": 404, "id": id }), )), Some(old_post) => { let NewPost { title, body, published, } = new_post; ctx.execute(diesel::update(posts.find(id)).set(( dsl::title.eq(title), dsl::body.eq(body), dsl::published.eq(published), ))) .await?; Ok(old_post) } } } async fn delete_post(&self, ctx: &JuniperContext, id: i32) -> FieldResult { use crate::schema::posts::dsl::posts; match ctx.first(posts.find(id)).await? { None => Err(FieldError::new( "post not found", graphql_value!({ "status": 404, "id": id }), )), Some(old_post) => { ctx.execute(diesel::delete(posts.find(id))).await?; Ok(old_post) } } } } #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .try_init() .map_err(|err| anyhow::anyhow!("fail to init tracing subscriber: {}", err))?; let router = Router::new() .on("/", get(playground_source("/api", None))) .on( "/api", allow( [Method::GET, Method::POST], GraphQL(RootNode::new(Query, Mutation, EmptySubscription::new())), ), ); let app = App::state(create_pool()?) .gate(logger) .end(router.routes("/")?); app.listen("127.0.0.1:8000", |addr| { info!("Server is listening on {}", addr) })? .await?; Ok(()) } ================================================ FILE: integration/juniper-example/src/models.rs ================================================ use diesel::Queryable; use juniper::GraphQLObject; use serde::Deserialize; #[derive(Debug, Clone, Queryable, Deserialize, GraphQLObject)] #[graphql(description = "A post")] pub struct Post { pub id: i32, pub title: String, pub body: String, pub published: bool, } ================================================ FILE: integration/juniper-example/src/schema.rs ================================================ table! { posts (id) { id -> Integer, title -> Text, body -> Text, published -> Bool, } } ================================================ FILE: integration/multipart-example/Cargo.toml ================================================ [package] name = "multipart-example" version = "0.1.0" authors = ["Hexilee "] edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] roa = { path = "../../roa", features = ["router", "file", "multipart"] } tokio = { version = "1.15", features = ["full"] } tracing = "0.1" futures = "0.3" tracing-futures = "0.2" tracing-subscriber = { version = "0.3", features = ["env-filter"] } anyhow = "1.0" ================================================ FILE: integration/multipart-example/README.md ================================================ ```bash RUST_LOG=info cargo run ``` Then visit `http://127.0.0.1:8000`, files will be stored in `./upload`. ================================================ FILE: integration/multipart-example/assets/index.html ================================================ Upload Test
================================================ FILE: integration/multipart-example/src/main.rs ================================================ use std::error::Error as StdError; use std::path::Path; use roa::body::{DispositionType, PowerBody}; use roa::logger::logger; use roa::preload::*; use roa::router::{get, post, Router}; use roa::{App, Context}; use tokio::fs::File; use tokio::io::AsyncWriteExt; use tracing::info; use tracing_subscriber::EnvFilter; async fn get_form(ctx: &mut Context) -> roa::Result { ctx.write_file("./assets/index.html", DispositionType::Inline) .await } async fn post_file(ctx: &mut Context) -> roa::Result { let mut form = ctx.read_multipart().await?; while let Some(mut field) = form.next_field().await? { info!("{:?}", field.content_type()); match field.file_name() { None => continue, // ignore non-file field Some(filename) => { let path = Path::new("./upload"); let mut file = File::create(path.join(filename)).await?; while let Some(c) = field.chunk().await? { file.write_all(&c).await?; } } } } Ok(()) } #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .try_init() .map_err(|err| anyhow::anyhow!("fail to init tracing subscriber: {}", err))?; let router = Router::new() .on("/", get(get_form)) .on("/file", post(post_file)); let app = App::new().gate(logger).end(router.routes("/")?); app.listen("127.0.0.1:8000", |addr| { info!("Server is listening on {}", addr) })? .await?; Ok(()) } ================================================ FILE: integration/websocket-example/Cargo.toml ================================================ [package] name = "websocket-example" version = "0.1.0" authors = ["Hexilee "] edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] roa = { path = "../../roa", features = ["router", "file", "websocket"] } tokio = { version = "1.15", features = ["full"] } tracing = "0.1" futures = "0.3" http = "0.2" slab = "0.4" tracing-futures = "0.2" tracing-subscriber = { version = "0.3", features = ["env-filter"] } anyhow = "1.0" [dev-dependencies] tokio-tungstenite = { version = "0.15", features = ["connect"] } ================================================ FILE: integration/websocket-example/README.md ================================================ WIP... ================================================ FILE: integration/websocket-example/src/main.rs ================================================ use std::borrow::Cow; use std::error::Error as StdError; use std::sync::Arc; use futures::stream::{SplitSink, SplitStream}; use futures::{SinkExt, StreamExt}; use http::Method; use roa::logger::logger; use roa::preload::*; use roa::router::{allow, RouteTable, Router, RouterError}; use roa::websocket::tungstenite::protocol::frame::coding::CloseCode; use roa::websocket::tungstenite::protocol::frame::CloseFrame; use roa::websocket::tungstenite::Error as WsError; use roa::websocket::{Message, SocketStream, Websocket}; use roa::{App, Context}; use slab::Slab; use tokio::sync::{Mutex, RwLock}; use tracing::{debug, error, info, warn}; use tracing_subscriber::EnvFilter; type Sender = SplitSink; type Channel = Slab>; #[derive(Clone)] struct SyncChannel(Arc>); impl SyncChannel { fn new() -> Self { Self(Arc::new(RwLock::new(Slab::new()))) } async fn broadcast(&self, message: Message) { let channel = self.0.read().await; for (_, sender) in channel.iter() { if let Err(err) = sender.lock().await.send(message.clone()).await { error!("broadcast error: {}", err); } } } async fn send(&self, index: usize, message: Message) { if let Err(err) = self.0.read().await[index].lock().await.send(message).await { error!("message send error: {}", err) } } async fn register(&self, sender: Sender) -> usize { self.0.write().await.insert(Mutex::new(sender)) } async fn deregister(&self, index: usize) -> Sender { self.0.write().await.remove(index).into_inner() } } async fn handle_message( ctx: &Context, index: usize, mut receiver: SplitStream, ) -> Result<(), WsError> { while let Some(message) = receiver.next().await { let message = message?; match message { Message::Close(frame) => { debug!("websocket connection close: {:?}", frame); break; } Message::Ping(data) => ctx.send(index, Message::Pong(data)).await, Message::Pong(data) => warn!("ignored pong: {:?}", data), msg => ctx.broadcast(msg).await, } } Ok(()) } fn route(prefix: &'static str) -> Result, RouterError> { Router::new() .on( "/chat", allow( [Method::GET], Websocket::new(|ctx: Context, stream| async move { let (sender, receiver) = stream.split(); let index = ctx.register(sender).await; let result = handle_message(&ctx, index, receiver).await; let mut sender = ctx.deregister(index).await; if let Err(err) = result { let result = sender .send(Message::Close(Some(CloseFrame { code: CloseCode::Invalid, reason: Cow::Owned(err.to_string()), }))) .await; if let Err(err) = result { warn!("send close message error: {}", err) } } }), ), ) .routes(prefix) } #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .try_init() .map_err(|err| anyhow::anyhow!("fail to init tracing subscriber: {}", err))?; let app = App::state(SyncChannel::new()).gate(logger).end(route("/")?); app.listen("127.0.0.1:8000", |addr| { info!("Server is listening on {}", addr) })? .await?; Ok(()) } #[cfg(test)] mod tests { use std::time::Duration; use roa::preload::*; use tokio_tungstenite::connect_async; use super::{route, App, Message, SinkExt, StdError, StreamExt, SyncChannel}; #[tokio::test] async fn echo() -> Result<(), Box> { let channel = SyncChannel::new(); let app = App::state(channel.clone()).end(route("/")?); let (addr, server) = app.run()?; tokio::task::spawn(server); let (ws_stream, _) = connect_async(format!("ws://{}/chat", addr)).await?; let (mut sender, mut recv) = ws_stream.split(); tokio::time::sleep(Duration::from_secs(1)).await; assert_eq!(1, channel.0.read().await.len()); // ping sender .send(Message::Ping(b"Hello, World!".to_vec())) .await?; let msg = recv.next().await.unwrap()?; assert!(msg.is_pong()); assert_eq!(b"Hello, World!".as_ref(), msg.into_data().as_slice()); // close sender.send(Message::Close(None)).await?; tokio::time::sleep(Duration::from_secs(1)).await; assert_eq!(0, channel.0.read().await.len()); Ok(()) } #[tokio::test] async fn broadcast() -> Result<(), Box> { let channel = SyncChannel::new(); let app = App::state(channel.clone()).end(route("/")?); let (addr, server) = app.run()?; tokio::task::spawn(server); let url = format!("ws://{}/chat", addr); for _ in 0..100 { let url = url.clone(); tokio::task::spawn(async move { if let Ok((ws_stream, _)) = connect_async(url).await { let (mut sender, mut recv) = ws_stream.split(); if let Some(Ok(message)) = recv.next().await { assert!(sender.send(message).await.is_ok()); } tokio::time::sleep(Duration::from_secs(1)).await; assert!(sender.send(Message::Close(None)).await.is_ok()); } }); } tokio::time::sleep(Duration::from_secs(1)).await; assert_eq!(100, channel.0.read().await.len()); let (ws_stream, _) = connect_async(url).await?; let (mut sender, mut recv) = ws_stream.split(); assert!(sender .send(Message::Text("Hello, World!".to_string())) .await .is_ok()); tokio::time::sleep(Duration::from_secs(2)).await; assert_eq!(1, channel.0.read().await.len()); let mut counter = 0i32; while let Some(item) = recv.next().await { if let Ok(Message::Text(message)) = item { assert_eq!("Hello, World!", message); } counter += 1; if counter == 101 { break; } } Ok(()) } } ================================================ FILE: roa/Cargo.toml ================================================ [package] name = "roa" version = "0.6.1" authors = ["Hexilee "] edition = "2018" license = "MIT" readme = "./README.md" repository = "https://github.com/Hexilee/roa" documentation = "https://docs.rs/roa" homepage = "https://github.com/Hexilee/roa/wiki" description = """ async web framework inspired by koajs, lightweight but powerful. """ keywords = ["http", "web", "framework", "async"] categories = ["network-programming", "asynchronous", "web-programming::http-server"] [package.metadata.docs.rs] features = ["docs"] rustdoc-args = ["--cfg", "feature=\"docs\""] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [badges] codecov = { repository = "Hexilee/roa" } [dependencies] tracing = { version = "0.1", features = ["log"] } futures = "0.3" bytesize = "1.0" async-trait = "0.1.51" url = "2.2" percent-encoding = "2.1" bytes = "1.1" headers = "0.3" tokio = "1.15" tokio-util = { version = "0.6.9", features = ["io"] } once_cell = "1.8" hyper = { version = "0.14", default-features = false, features = ["stream", "server", "http1", "http2"] } roa-core = { path = "../roa-core", version = "0.6" } cookie = { version = "0.15", features = ["percent-encode"], optional = true } jsonwebtoken = { version = "7.2", optional = true } serde = { version = "1", optional = true } serde_json = { version = "1.0", optional = true } async-compression = { version = "0.3.8", features = ["all-algorithms", "futures-io"], optional = true } # router radix_trie = { version = "0.2.1", optional = true } regex = { version = "1.5", optional = true } # body askama = { version = "0.10", optional = true } doc-comment = { version = "0.3.3", optional = true } serde_urlencoded = { version = "0.7", optional = true } mime_guess = { version = "2.0", optional = true } multer = { version = "2.0", optional = true } mime = { version = "0.3", optional = true } # websocket tokio-tungstenite = { version = "0.15.0", default-features = false, optional = true } # tls rustls = { version = "0.20", optional = true } tokio-rustls = { version = "0.23", optional = true } rustls-pemfile = { version = "0.2", optional = true } # jsonrpc jsonrpc-v2 = { version = "0.10", default-features = false, features = ["bytes-v10"], optional = true } [dev-dependencies] tokio = { version = "1.15", features = ["full"] } tokio-native-tls = "0.3" hyper-tls = "0.5" reqwest = { version = "0.11", features = ["json", "cookies", "gzip", "multipart"] } pretty_env_logger = "0.4" serde = { version = "1", features = ["derive"] } test-case = "1.2" slab = "0.4.5" multimap = "0.8" hyper = "0.14" mime = "0.3" encoding = "0.2" askama = "0.10" anyhow = "1.0" [features] default = ["async_rt"] full = [ "default", "json", "urlencoded", "file", "multipart", "template", "tls", "router", "jwt", "cookies", "compress", "websocket", "jsonrpc", ] docs = ["full", "roa-core/docs"] runtime = ["roa-core/runtime"] json = ["serde", "serde_json"] multipart = ["multer", "mime"] urlencoded = ["serde", "serde_urlencoded"] file = ["mime_guess", "tokio/fs"] template = ["askama"] tcp = ["tokio/net", "tokio/time"] tls = ["rustls", "tokio-rustls", "rustls-pemfile"] cookies = ["cookie"] jwt = ["jsonwebtoken", "serde", "serde_json"] router = ["radix_trie", "regex", "doc-comment"] websocket = ["tokio-tungstenite"] compress = ["async-compression"] async_rt = ["runtime", "tcp"] jsonrpc = ["jsonrpc-v2"] ================================================ FILE: roa/README.md ================================================ [![Build status](https://img.shields.io/travis/Hexilee/roa/master.svg)](https://travis-ci.org/Hexilee/roa) [![codecov](https://codecov.io/gh/Hexilee/roa/branch/master/graph/badge.svg)](https://codecov.io/gh/Hexilee/roa) [![Rust Docs](https://docs.rs/roa/badge.svg)](https://docs.rs/roa) [![Crate version](https://img.shields.io/crates/v/roa.svg)](https://crates.io/crates/roa) [![Download](https://img.shields.io/crates/d/roa.svg)](https://crates.io/crates/roa) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/Hexilee/roa/blob/master/LICENSE) ### Introduction Roa is an async web framework inspired by koajs, lightweight but powerful. ### Application A Roa application is a structure composing and executing middlewares and an endpoint in a stack-like manner. The obligatory hello world application: ```rust,no_run use roa::App; use roa::preload::*; use tracing::info; use std::error::Error as StdError; #[tokio::main] async fn main() -> Result<(), Box> { let app = App::new().end("Hello, World"); app.listen("127.0.0.1:8000", |addr| { info!("Server is listening on {}", addr) })? .await?; Ok(()) } ``` #### Endpoint An endpoint is a request handler. There are some build-in endpoints in roa. - Functional endpoint A normal functional endpoint is an async function with signature: `async fn(&mut Context) -> Result`. ```rust use roa::{App, Context, Result}; async fn endpoint(ctx: &mut Context) -> Result { Ok(()) } let app = App::new().end(endpoint); ``` - Ok endpoint `()` is an endpoint always return `Ok(())` ```rust let app = roa::App::new().end(()); ``` - Status endpoint `Status` is an endpoint always return `Err(Status)` ```rust use roa::{App, status}; use roa::http::StatusCode; let app = App::new().end(status!(StatusCode::BAD_REQUEST)); ``` - String endpoint Write string to body. ```rust use roa::App; let app = App::new().end("Hello, world"); // static slice let app = App::new().end("Hello, world".to_owned()); // string ``` - Redirect endpoint Redirect to an uri. ```rust use roa::App; use roa::http::Uri; let app = App::new().end("/target".parse::().unwrap()); ``` #### Cascading Like koajs, middleware suspends and passes control to "downstream" by invoking `next.await`. Then control flows back "upstream" when `next.await` returns. The following example responds with "Hello World", however first the request flows through the x-response-time and logging middleware to mark when the request started, then continue to yield control through the endpoint. When a middleware invokes next the function suspends and passes control to the next middleware or endpoint. After the endpoint is called, the stack will unwind and each middleware is resumed to perform its upstream behaviour. ```rust,no_run use roa::{App, Context, Next}; use roa::preload::*; use tracing::info; use std::error::Error as StdError; use std::time::Instant; #[tokio::main] async fn main() -> Result<(), Box> { let app = App::new() .gate(logger) .gate(x_response_time) .end("Hello, World"); app.listen("127.0.0.1:8000", |addr| { info!("Server is listening on {}", addr) })? .await?; Ok(()) } async fn logger(ctx: &mut Context, next: Next<'_>) -> roa::Result { next.await?; let rt = ctx.load::("x-response-time").unwrap(); info!("{} {} - {}", ctx.method(), ctx.uri(), rt.as_str()); Ok(()) } async fn x_response_time(ctx: &mut Context, next: Next<'_>) -> roa::Result { let start = Instant::now(); next.await?; let ms = start.elapsed().as_millis(); ctx.store("x-response-time", format!("{}ms", ms)); Ok(()) } ``` ### Status Handling You can catch or straightly throw a status returned by next. ```rust,no_run use roa::{App, Context, Next, status}; use roa::preload::*; use roa::http::StatusCode; use tokio::task::spawn; use tracing::info; #[tokio::main] async fn main() -> Result<(), Box> { let app = App::new() .gate(catch) .gate(not_catch) .end(status!(StatusCode::IM_A_TEAPOT, "I'm a teapot!")); app.listen("127.0.0.1:8000", |addr| { info!("Server is listening on {}", addr) })? .await?; Ok(()) } async fn catch(_ctx: &mut Context, next: Next<'_>) -> roa::Result { // catch if let Err(status) = next.await { // teapot is ok if status.status_code != StatusCode::IM_A_TEAPOT { return Err(status); } } Ok(()) } async fn not_catch(ctx: &mut Context, next: Next<'_>) -> roa::Result { next.await?; // just throw unreachable!() } ``` #### status_handler App has an status_handler to handle status thrown by the top middleware. This is the status_handler: ```rust,no_run use roa::{Context, Status}; pub fn status_handler(ctx: &mut Context, status: Status) { ctx.resp.status = status.status_code; if status.expose { ctx.resp.write(status.message); } else { tracing::error!("{}", status); } } ``` ### Router. Roa provides a configurable and nestable router. ```rust,no_run use roa::preload::*; use roa::router::{Router, get}; use roa::{App, Context}; use tokio::task::spawn; use tracing::info; #[tokio::main] async fn main() -> Result<(), Box> { let router = Router::new() .on("/:id", get(end)); // get dynamic "/:id" let app = App::new() .end(router.routes("/user")?); // route with prefix "/user" app.listen("127.0.0.1:8000", |addr| { info!("Server is listening on {}", addr) })? .await?; Ok(()) } async fn end(ctx: &mut Context) -> roa::Result { // get "/user/1", then id == 1. let id: u64 = ctx.must_param("id")?.parse()?; // do something Ok(()) } ``` ### Query Roa provides a middleware `query_parser`. ```rust,no_run use roa::preload::*; use roa::query::query_parser; use roa::{App, Context}; use tokio::task::spawn; use tracing::info; async fn must(ctx: &mut Context) -> roa::Result { // request "/?id=1", then id == 1. let id: u64 = ctx.must_query("id")?.parse()?; Ok(()) } #[tokio::main] async fn main() -> Result<(), Box> { let app = App::new() .gate(query_parser) .end(must); app.listen("127.0.0.1:8080", |addr| { info!("Server is listening on {}", addr) })? .await?; Ok(()) } ``` ### Other modules - body: dealing with body more conveniently. - compress: supports transparent content compression. - cookie: cookies getter or setter. - cors: CORS support. - forward: "X-Forwarded-*" parser. - jwt: json web token support. - logger: a logger middleware. - tls: https supports. - websocket: websocket supports. ================================================ FILE: roa/src/body/file/content_disposition.rs ================================================ use std::convert::{TryFrom, TryInto}; use std::fmt::{self, Display, Formatter}; use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS}; use super::help::bug_report; use crate::http::header::HeaderValue; use crate::Status; // This encode set is used for HTTP header values and is defined at // https://tools.ietf.org/html/rfc5987#section-3.2 const HTTP_VALUE: &AsciiSet = &CONTROLS .add(b' ') .add(b'"') .add(b'%') .add(b'\'') .add(b'(') .add(b')') .add(b'*') .add(b',') .add(b'/') .add(b':') .add(b';') .add(b'<') .add(b'-') .add(b'>') .add(b'?') .add(b'[') .add(b'\\') .add(b']') .add(b'{') .add(b'}'); /// Type of content-disposition, inline or attachment #[derive(Clone, Debug, PartialEq)] pub enum DispositionType { /// Inline implies default processing Inline, /// Attachment implies that the recipient should prompt the user to save the response locally, /// rather than process it normally (as per its media type). Attachment, } /// A structure to generate value of "Content-Disposition" pub struct ContentDisposition { typ: DispositionType, encoded_filename: Option, } impl ContentDisposition { /// Construct by disposition type and optional filename. #[inline] pub(crate) fn new(typ: DispositionType, filename: Option<&str>) -> Self { Self { typ, encoded_filename: filename .map(|name| utf8_percent_encode(name, HTTP_VALUE).to_string()), } } } impl TryFrom for HeaderValue { type Error = Status; #[inline] fn try_from(value: ContentDisposition) -> Result { value .to_string() .try_into() .map_err(|err| bug_report(format!("{}\nNot a valid header value", err))) } } impl Display for ContentDisposition { #[inline] fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match &self.encoded_filename { None => f.write_fmt(format_args!("{}", self.typ)), Some(name) => f.write_fmt(format_args!( "{}; filename={}; filename*=UTF-8''{}", self.typ, name, name )), } } } impl Display for DispositionType { #[inline] fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { DispositionType::Inline => f.write_str("inline"), DispositionType::Attachment => f.write_str("attachment"), } } } ================================================ FILE: roa/src/body/file/help.rs ================================================ use crate::http::StatusCode; use crate::Status; const BUG_HELP: &str = r"This is a bug of roa::body::file, please report it to https://github.com/Hexilee/roa."; #[inline] pub fn bug_report(message: impl ToString) -> Status { Status::new( StatusCode::INTERNAL_SERVER_ERROR, format!("{}\n{}", message.to_string(), BUG_HELP), false, ) } ================================================ FILE: roa/src/body/file.rs ================================================ mod content_disposition; mod help; use std::convert::TryInto; use std::path::Path; use content_disposition::ContentDisposition; pub use content_disposition::DispositionType; use tokio::fs::File; use crate::{http, Context, Result, State}; /// Write file to response body then set "Content-Type" and "Context-Disposition". #[inline] pub async fn write_file( ctx: &mut Context, path: impl AsRef, typ: DispositionType, ) -> Result { let path = path.as_ref(); ctx.resp.write_reader(File::open(path).await?); if let Some(filename) = path.file_name() { ctx.resp.headers.insert( http::header::CONTENT_TYPE, mime_guess::from_path(&filename) .first_or_octet_stream() .as_ref() .parse() .map_err(help::bug_report)?, ); let name = filename.to_string_lossy(); let content_disposition = ContentDisposition::new(typ, Some(&name)); ctx.resp.headers.insert( http::header::CONTENT_DISPOSITION, content_disposition.try_into()?, ); } Ok(()) } ================================================ FILE: roa/src/body.rs ================================================ //! This module provides a context extension `PowerBody`. //! //! ### Read/write body in a easier way. //! //! The `roa_core` provides several methods to read/write body. //! //! ```rust //! use roa::{Context, Result}; //! use tokio::io::AsyncReadExt; //! use tokio::fs::File; //! //! async fn get(ctx: &mut Context) -> Result { //! let mut data = String::new(); //! // implements futures::AsyncRead. //! ctx.req.reader().read_to_string(&mut data).await?; //! println!("data: {}", data); //! //! // although body is empty now... //! let stream = ctx.req.stream(); //! ctx.resp //! // echo //! .write_stream(stream) //! // write object implementing futures::AsyncRead //! .write_reader(File::open("assets/author.txt").await?) //! // write reader with specific chunk size //! .write_chunk(File::open("assets/author.txt").await?, 1024) //! // write text //! .write("I am Roa.") //! .write(b"I am Roa.".as_ref()); //! Ok(()) //! } //! ``` //! //! These methods are useful, but they do not deal with headers and (de)serialization. //! //! The `PowerBody` provides more powerful methods to handle it. //! //! ```rust //! use roa::{Context, Result}; //! use roa::body::{PowerBody, DispositionType::*}; //! use serde::{Serialize, Deserialize}; //! use askama::Template; //! use tokio::fs::File; //! //! #[derive(Debug, Serialize, Deserialize, Template)] //! #[template(path = "user.html")] //! struct User { //! id: u64, //! name: String, //! } //! //! async fn get(ctx: &mut Context) -> Result { //! // read as bytes. //! let data = ctx.read().await?; //! //! // deserialize as json. //! let user: User = ctx.read_json().await?; //! //! // deserialize as x-form-urlencoded. //! let user: User = ctx.read_form().await?; //! //! // serialize object and write it to body, //! // set "Content-Type" //! ctx.write_json(&user)?; //! //! // open file and write it to body, //! // set "Content-Type" and "Content-Disposition" //! ctx.write_file("assets/welcome.html", Inline).await?; //! //! // write text, //! // set "Content-Type" //! ctx.write("Hello, World!"); //! //! // write object implementing AsyncRead, //! // set "Content-Type" //! ctx.write_reader(File::open("assets/author.txt").await?); //! //! // render html template, based on [askama](https://github.com/djc/askama). //! // set "Content-Type" //! ctx.render(&user)?; //! Ok(()) //! } //! ``` #[cfg(feature = "template")] use askama::Template; use bytes::Bytes; use headers::{ContentLength, ContentType, HeaderMapExt}; use tokio::io::{AsyncRead, AsyncReadExt}; use crate::{async_trait, Context, Result, State}; #[cfg(feature = "file")] mod file; #[cfg(feature = "file")] use file::write_file; #[cfg(feature = "file")] pub use file::DispositionType; #[cfg(feature = "multipart")] pub use multer::Multipart; #[cfg(any(feature = "json", feature = "urlencoded"))] use serde::de::DeserializeOwned; #[cfg(feature = "json")] use serde::Serialize; /// A context extension to read/write body more simply. #[async_trait] pub trait PowerBody { /// read request body as Bytes. async fn read(&mut self) -> Result>; /// read request body as "json". #[cfg(feature = "json")] #[cfg_attr(feature = "docs", doc(cfg(feature = "json")))] async fn read_json(&mut self) -> Result where B: DeserializeOwned; /// read request body as "urlencoded form". #[cfg(feature = "urlencoded")] #[cfg_attr(feature = "docs", doc(cfg(feature = "urlencoded")))] async fn read_form(&mut self) -> Result where B: DeserializeOwned; /// read request body as "multipart form". #[cfg(feature = "multipart")] #[cfg_attr(feature = "docs", doc(cfg(feature = "multipart")))] async fn read_multipart(&mut self) -> Result; /// write object to response body as "application/json" #[cfg(feature = "json")] #[cfg_attr(feature = "docs", doc(cfg(feature = "json")))] fn write_json(&mut self, data: &B) -> Result where B: Serialize; /// write object to response body as "text/html; charset=utf-8" #[cfg(feature = "template")] #[cfg_attr(feature = "docs", doc(cfg(feature = "template")))] fn render(&mut self, data: &B) -> Result where B: Template; /// write object to response body as "text/plain" fn write(&mut self, data: B) where B: Into; /// write object to response body as "application/octet-stream" fn write_reader(&mut self, reader: B) where B: 'static + AsyncRead + Unpin + Sync + Send; /// write object to response body as extension name of file #[cfg(feature = "file")] #[cfg_attr(feature = "docs", doc(cfg(feature = "file")))] async fn write_file

(&mut self, path: P, typ: DispositionType) -> Result where P: Send + AsRef; } #[async_trait] impl PowerBody for Context { #[inline] async fn read(&mut self) -> Result> { let mut data = match self.req.headers.typed_get::() { Some(hint) => Vec::with_capacity(hint.0 as usize), None => Vec::new(), }; self.req.reader().read_to_end(&mut data).await?; Ok(data) } #[cfg(feature = "json")] #[inline] async fn read_json(&mut self) -> Result where B: DeserializeOwned, { use crate::http::StatusCode; use crate::status; let data = self.read().await?; serde_json::from_slice(&data).map_err(|err| status!(StatusCode::BAD_REQUEST, err)) } #[cfg(feature = "urlencoded")] #[inline] async fn read_form(&mut self) -> Result where B: DeserializeOwned, { use crate::http::StatusCode; use crate::status; let data = self.read().await?; serde_urlencoded::from_bytes(&data).map_err(|err| status!(StatusCode::BAD_REQUEST, err)) } #[cfg(feature = "multipart")] async fn read_multipart(&mut self) -> Result { use headers::{ContentType, HeaderMapExt}; use crate::http::StatusCode; // Verify that the request is 'Content-Type: multipart/*'. let typ: mime::Mime = self .req .headers .typed_get::() .ok_or_else(|| crate::status!(StatusCode::BAD_REQUEST, "fail to get content-type"))? .into(); let boundary = typ .get_param(mime::BOUNDARY) .ok_or_else(|| crate::status!(StatusCode::BAD_REQUEST, "fail to get boundary"))? .as_str(); Ok(Multipart::new(self.req.stream(), boundary)) } #[cfg(feature = "json")] #[inline] fn write_json(&mut self, data: &B) -> Result where B: Serialize, { self.resp.write(serde_json::to_vec(data)?); self.resp.headers.typed_insert(ContentType::json()); Ok(()) } #[cfg(feature = "template")] #[inline] fn render(&mut self, data: &B) -> Result where B: Template, { self.resp.write(data.render()?); self.resp .headers .typed_insert::(mime::TEXT_HTML_UTF_8.into()); Ok(()) } #[inline] fn write(&mut self, data: B) where B: Into, { self.resp.write(data); self.resp.headers.typed_insert(ContentType::text()); } #[inline] fn write_reader(&mut self, reader: B) where B: 'static + AsyncRead + Unpin + Sync + Send, { self.resp.write_reader(reader); self.resp.headers.typed_insert(ContentType::octet_stream()); } #[cfg(feature = "file")] #[inline] async fn write_file

(&mut self, path: P, typ: DispositionType) -> Result where P: Send + AsRef, { write_file(self, path, typ).await } } #[cfg(all(test, feature = "tcp"))] mod tests { use std::error::Error; use askama::Template; use http::header::CONTENT_TYPE; use http::StatusCode; use serde::{Deserialize, Serialize}; use tokio::fs::File; use tokio::task::spawn; use super::PowerBody; use crate::tcp::Listener; use crate::{http, App, Context}; #[derive(Debug, Deserialize)] struct UserDto { id: u64, name: String, } #[derive(Debug, Serialize, Hash, Eq, PartialEq, Clone, Template)] #[template(path = "user.html")] struct User<'a> { id: u64, name: &'a str, } impl PartialEq for User<'_> { fn eq(&self, other: &UserDto) -> bool { self.id == other.id && self.name == other.name } } #[allow(dead_code)] const USER: User = User { id: 0, name: "Hexilee", }; #[cfg(feature = "json")] #[tokio::test] async fn read_json() -> Result<(), Box> { async fn test(ctx: &mut Context) -> crate::Result { let user: UserDto = ctx.read_json().await?; assert_eq!(USER, user); Ok(()) } let (addr, server) = App::new().end(test).run()?; spawn(server); let client = reqwest::Client::new(); let resp = client .get(&format!("http://{}", addr)) .json(&USER) .send() .await?; assert_eq!(StatusCode::OK, resp.status()); Ok(()) } #[cfg(feature = "urlencoded")] #[tokio::test] async fn read_form() -> Result<(), Box> { async fn test(ctx: &mut Context) -> crate::Result { let user: UserDto = ctx.read_form().await?; assert_eq!(USER, user); Ok(()) } let (addr, server) = App::new().end(test).run()?; spawn(server); let client = reqwest::Client::new(); let resp = client .get(&format!("http://{}", addr)) .form(&USER) .send() .await?; assert_eq!(StatusCode::OK, resp.status()); Ok(()) } #[cfg(feature = "template")] #[tokio::test] async fn render() -> Result<(), Box> { async fn test(ctx: &mut Context) -> crate::Result { ctx.render(&USER) } let (addr, server) = App::new().end(test).run()?; spawn(server); let resp = reqwest::get(&format!("http://{}", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); assert_eq!("text/html; charset=utf-8", resp.headers()[CONTENT_TYPE]); Ok(()) } #[tokio::test] async fn write() -> Result<(), Box> { async fn test(ctx: &mut Context) -> crate::Result { ctx.write("Hello, World!"); Ok(()) } let (addr, server) = App::new().end(test).run()?; spawn(server); let resp = reqwest::get(&format!("http://{}", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); assert_eq!("text/plain", resp.headers()[CONTENT_TYPE]); assert_eq!("Hello, World!", resp.text().await?); Ok(()) } #[tokio::test] async fn write_octet() -> Result<(), Box> { async fn test(ctx: &mut Context) -> crate::Result { ctx.write_reader(File::open("../assets/author.txt").await?); Ok(()) } let (addr, server) = App::new().end(test).run()?; spawn(server); let resp = reqwest::get(&format!("http://{}", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); assert_eq!( mime::APPLICATION_OCTET_STREAM.as_ref(), resp.headers()[CONTENT_TYPE] ); assert_eq!("Hexilee", resp.text().await?); Ok(()) } #[cfg(feature = "multipart")] mod multipart { use std::error::Error as StdError; use reqwest::multipart::{Form, Part}; use reqwest::Client; use tokio::fs::read; use crate::body::PowerBody; use crate::http::header::CONTENT_TYPE; use crate::http::StatusCode; use crate::router::{post, Router}; use crate::tcp::Listener; use crate::{throw, App, Context}; const FILE_PATH: &str = "../assets/author.txt"; const FILE_NAME: &str = "author.txt"; const FIELD_NAME: &str = "file"; async fn post_file(ctx: &mut Context) -> crate::Result { let mut form = ctx.read_multipart().await?; while let Some(field) = form.next_field().await? { match (field.file_name(), field.name()) { (Some(filename), Some(name)) => { assert_eq!(FIELD_NAME, name); assert_eq!(FILE_NAME, filename); let content = field.bytes().await?; let expected_content = read(FILE_PATH).await?; assert_eq!(&expected_content, &content); } _ => throw!( StatusCode::BAD_REQUEST, format!("invalid field: {:?}", field) ), } } Ok(()) } #[tokio::test] async fn upload() -> Result<(), Box> { let router = Router::new().on("/file", post(post_file)); let app = App::new().end(router.routes("/")?); let (addr, server) = app.run()?; tokio::task::spawn(server); // client let url = format!("http://{}/file", addr); let client = Client::new(); let form = Form::new().part( FIELD_NAME, Part::bytes(read(FILE_PATH).await?).file_name(FILE_NAME), ); let boundary = form.boundary().to_string(); let resp = client .post(&url) .multipart(form) .header( CONTENT_TYPE, format!(r#"multipart/form-data; boundary="{}""#, boundary), ) .send() .await?; assert_eq!(StatusCode::OK, resp.status()); Ok(()) } } } ================================================ FILE: roa/src/compress.rs ================================================ //! This module provides a middleware `Compress`. //! //! ### Example //! //! ```rust //! use roa::compress::{Compress, Level}; //! use roa::body::DispositionType::*; //! use roa::{App, Context}; //! use roa::preload::*; //! use std::error::Error; //! //! async fn end(ctx: &mut Context) -> roa::Result { //! ctx.write_file("../assets/welcome.html", Inline).await //! } //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let mut app = App::new().gate(Compress(Level::Fastest)).end(end); //! let (addr, server) = app.run()?; //! // server.await //! Ok(()) //! # } //! ``` use async_compression::tokio::bufread::{BrotliEncoder, GzipEncoder, ZlibEncoder, ZstdEncoder}; pub use async_compression::Level; use tokio_util::io::StreamReader; use crate::http::header::{HeaderMap, ACCEPT_ENCODING, CONTENT_ENCODING}; use crate::http::{HeaderValue, StatusCode}; use crate::{async_trait, status, Context, Middleware, Next, Result}; /// A middleware to negotiate with client and compress response body automatically, /// supports gzip, deflate, brotli, zstd and identity. #[derive(Debug, Copy, Clone)] pub struct Compress(pub Level); /// Encodings to use. #[derive(Debug, Clone, Copy, Eq, PartialEq)] enum Encoding { /// The Gzip encoding. Gzip, /// The Deflate encoding. Deflate, /// The Brotli encoding. Brotli, /// The Zstd encoding. Zstd, /// No encoding. Identity, } impl Encoding { /// Parses a given string into its corresponding encoding. fn parse(s: &str) -> Result> { match s { "gzip" => Ok(Some(Encoding::Gzip)), "deflate" => Ok(Some(Encoding::Deflate)), "br" => Ok(Some(Encoding::Brotli)), "zstd" => Ok(Some(Encoding::Zstd)), "identity" => Ok(Some(Encoding::Identity)), "*" => Ok(None), _ => Err(status!( StatusCode::BAD_REQUEST, format!("unknown encoding: {}", s), true )), } } /// Converts the encoding into its' corresponding header value. fn to_header_value(self) -> HeaderValue { match self { Encoding::Gzip => HeaderValue::from_str("gzip").unwrap(), Encoding::Deflate => HeaderValue::from_str("deflate").unwrap(), Encoding::Brotli => HeaderValue::from_str("br").unwrap(), Encoding::Zstd => HeaderValue::from_str("zstd").unwrap(), Encoding::Identity => HeaderValue::from_str("identity").unwrap(), } } } fn select_encoding(headers: &HeaderMap) -> Result> { let mut preferred_encoding = None; let mut max_qval = 0.0; for (encoding, qval) in accept_encodings(headers)? { if qval > max_qval { preferred_encoding = encoding; max_qval = qval; } } Ok(preferred_encoding) } /// Parse a set of HTTP headers into a vector containing tuples of options containing encodings and their corresponding q-values. /// /// If you're looking for more fine-grained control over what encoding to choose for the client, or if you don't support every [`Encoding`] listed, this is likely what you want. /// /// Note that a result of `None` indicates there preference is expressed on which encoding to use. /// Either the `Accept-Encoding` header is not present, or `*` is set as the most preferred encoding. fn accept_encodings(headers: &HeaderMap) -> Result, f32)>> { headers .get_all(ACCEPT_ENCODING) .iter() .map(|hval| { hval.to_str() .map_err(|err| status!(StatusCode::BAD_REQUEST, err, true)) }) .collect::>>()? .iter() .flat_map(|s| s.split(',').map(str::trim)) .filter_map(|v| { let pair: Vec<&str> = v.splitn(2, ";q=").collect(); if pair.is_empty() { return None; } let encoding = match Encoding::parse(pair[0]) { Ok(encoding) => encoding, Err(_) => return None, // ignore unknown encodings }; let qval = if pair.len() == 1 { 1.0 } else { match pair[1].parse::() { Ok(f) => f, Err(err) => return Some(Err(status!(StatusCode::BAD_REQUEST, err, true))), } }; Some(Ok((encoding, qval))) }) .collect::, f32)>>>() } impl Default for Compress { fn default() -> Self { Self(Level::Default) } } #[async_trait(?Send)] impl<'a, S> Middleware<'a, S> for Compress { #[allow(clippy::trivially_copy_pass_by_ref)] #[inline] async fn handle(&'a self, ctx: &'a mut Context, next: Next<'a>) -> Result { next.await?; let level = self.0; let best_encoding = select_encoding(&ctx.req.headers)?; let body = std::mem::take(&mut ctx.resp.body); let content_encoding = match best_encoding { None | Some(Encoding::Gzip) => { ctx.resp .write_reader(GzipEncoder::with_quality(StreamReader::new(body), level)); Encoding::Gzip.to_header_value() } Some(Encoding::Deflate) => { ctx.resp .write_reader(ZlibEncoder::with_quality(StreamReader::new(body), level)); Encoding::Deflate.to_header_value() } Some(Encoding::Brotli) => { ctx.resp .write_reader(BrotliEncoder::with_quality(StreamReader::new(body), level)); Encoding::Brotli.to_header_value() } Some(Encoding::Zstd) => { ctx.resp .write_reader(ZstdEncoder::with_quality(StreamReader::new(body), level)); Encoding::Zstd.to_header_value() } Some(Encoding::Identity) => { ctx.resp.body = body; Encoding::Identity.to_header_value() } }; ctx.resp.headers.append(CONTENT_ENCODING, content_encoding); Ok(()) } } #[cfg(all(test, feature = "tcp", feature = "file"))] mod tests { use std::io; use std::pin::Pin; use std::task::{self, Poll}; use bytes::Bytes; use futures::Stream; use tokio::task::spawn; use crate::body::DispositionType::*; use crate::compress::{Compress, Level}; use crate::http::header::ACCEPT_ENCODING; use crate::http::StatusCode; use crate::preload::*; use crate::{async_trait, App, Context, Middleware, Next}; struct Consumer { counter: usize, stream: S, assert_counter: usize, } impl Stream for Consumer where S: 'static + Send + Send + Unpin + Stream>, { type Item = io::Result; fn poll_next( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll> { match Pin::new(&mut self.stream).poll_next(cx) { Poll::Ready(Some(Ok(bytes))) => { self.counter += bytes.len(); Poll::Ready(Some(Ok(bytes))) } Poll::Ready(None) => { assert_eq!(self.assert_counter, self.counter); Poll::Ready(None) } poll => poll, } } } struct Assert(usize); #[async_trait(?Send)] impl<'a, S> Middleware<'a, S> for Assert { async fn handle(&'a self, ctx: &'a mut Context, next: Next<'a>) -> crate::Result { next.await?; let body = std::mem::take(&mut ctx.resp.body); ctx.resp.write_stream(Consumer { counter: 0, stream: body, assert_counter: self.0, }); Ok(()) } } async fn end(ctx: &mut Context) -> crate::Result { ctx.write_file("../assets/welcome.html", Inline).await } #[tokio::test] async fn compress() -> Result<(), Box> { let app = App::new() .gate(Assert(202)) // compressed to 202 bytes .gate(Compress(Level::Fastest)) .gate(Assert(236)) // the size of assets/welcome.html is 236 bytes. .end(end); let (addr, server) = app.run()?; spawn(server); let client = reqwest::Client::builder().gzip(true).build()?; let resp = client .get(&format!("http://{}", addr)) .header(ACCEPT_ENCODING, "gzip") .send() .await?; assert_eq!(StatusCode::OK, resp.status()); assert_eq!(236, resp.text().await?.len()); Ok(()) } } ================================================ FILE: roa/src/cookie.rs ================================================ //! This module provides a middleware `cookie_parser` and context extensions `CookieGetter` and `CookieSetter`. //! //! ### Example //! //! ```rust //! use roa::cookie::cookie_parser; //! use roa::preload::*; //! use roa::{App, Context}; //! use std::error::Error; //! //! async fn end(ctx: &mut Context) -> roa::Result { //! assert_eq!("Hexilee", ctx.must_cookie("name")?.value()); //! Ok(()) //! } //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let app = App::new().gate(cookie_parser).end(end); //! let (addr, server) = app.run()?; //! // server.await //! Ok(()) //! # } //! ``` use std::sync::Arc; pub use cookie::Cookie; use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC}; use crate::http::{header, StatusCode}; use crate::{throw, Context, Next, Result}; /// A scope to store and load variables in Context::storage. struct CookieScope; /// A context extension. /// This extension must be used in downstream of middleware `cookier_parser`, /// otherwise you cannot get expected cookie. /// /// Percent-encoded cookies will be decoded. /// ### Example /// /// ```rust /// use roa::cookie::cookie_parser; /// use roa::preload::*; /// use roa::{App, Context}; /// use std::error::Error; /// /// async fn end(ctx: &mut Context) -> roa::Result { /// assert_eq!("Hexilee", ctx.must_cookie("name")?.value()); /// Ok(()) /// } /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// let app = App::new().gate(cookie_parser).end(end); /// let (addr, server) = app.run()?; /// // server.await /// Ok(()) /// # } /// ``` pub trait CookieGetter { /// Must get a cookie, throw 401 UNAUTHORIZED if it not exists. fn must_cookie(&mut self, name: &str) -> Result>>; /// Try to get a cookie, return `None` if it not exists. /// /// ### Example /// /// ```rust /// use roa::cookie::cookie_parser; /// use roa::preload::*; /// use roa::{App, Context}; /// use std::error::Error; /// /// async fn end(ctx: &mut Context) -> roa::Result { /// assert!(ctx.cookie("name").is_none()); /// Ok(()) /// } /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// let app = App::new().gate(cookie_parser).end(end); /// let (addr, server) = app.run()?; /// // server.await /// Ok(()) /// # } /// ``` fn cookie(&self, name: &str) -> Option>>; } /// An extension to set cookie. pub trait CookieSetter { /// Set a cookie in pecent encoding, should not return Err. /// ### Example /// /// ```rust /// use roa::cookie::{cookie_parser, Cookie}; /// use roa::preload::*; /// use roa::{App, Context}; /// use std::error::Error; /// /// async fn end(ctx: &mut Context) -> roa::Result { /// ctx.set_cookie(Cookie::new("name", "Hexilee")); /// Ok(()) /// } /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// let app = App::new().gate(cookie_parser).end(end); /// let (addr, server) = app.run()?; /// // server.await /// Ok(()) /// # } /// ``` fn set_cookie(&mut self, cookie: Cookie<'_>) -> Result; } /// A middleware to parse cookie. #[inline] pub async fn cookie_parser(ctx: &mut Context, next: Next<'_>) -> Result { if let Some(cookies) = ctx.get(header::COOKIE) { for cookie in cookies .split(';') .map(|cookie| cookie.trim()) .map(Cookie::parse_encoded) .filter_map(|cookie| cookie.ok()) .map(|cookie| cookie.into_owned()) .collect::>() .into_iter() { let name = cookie.name().to_string(); ctx.store_scoped(CookieScope, name, cookie); } } next.await } impl CookieGetter for Context { #[inline] fn must_cookie(&mut self, name: &str) -> Result>> { match self.cookie(name) { Some(value) => Ok(value), None => { self.resp.headers.insert( header::WWW_AUTHENTICATE, format!( r#"Cookie name="{}""#, utf8_percent_encode(name, NON_ALPHANUMERIC) ) .parse()?, ); throw!(StatusCode::UNAUTHORIZED) } } } #[inline] fn cookie(&self, name: &str) -> Option>> { Some(self.load_scoped::(name)?.value()) } } impl CookieSetter for Context { #[inline] fn set_cookie(&mut self, cookie: Cookie<'_>) -> Result { let cookie_value = cookie.encoded().to_string(); self.resp .headers .append(header::SET_COOKIE, cookie_value.parse()?); Ok(()) } } #[cfg(all(test, feature = "tcp"))] mod tests { use tokio::task::spawn; use crate::cookie::{cookie_parser, Cookie}; use crate::http::header::{COOKIE, WWW_AUTHENTICATE}; use crate::http::StatusCode; use crate::preload::*; use crate::{App, Context}; async fn must(ctx: &mut Context) -> crate::Result { assert_eq!("Hexi Lee", ctx.must_cookie("nick name")?.value()); Ok(()) } async fn none(ctx: &mut Context) -> crate::Result { assert!(ctx.cookie("nick name").is_none()); Ok(()) } #[tokio::test] async fn parser() -> Result<(), Box> { // downstream of `cookie_parser` let (addr, server) = App::new().gate(cookie_parser).end(must).run()?; spawn(server); let client = reqwest::Client::new(); let resp = client .get(&format!("http://{}", addr)) .header(COOKIE, "nick%20name=Hexi%20Lee") .send() .await?; assert_eq!(StatusCode::OK, resp.status()); // miss `cookie_parser` let (addr, server) = App::new().end(must).run()?; spawn(server); let resp = client .get(&format!("http://{}", addr)) .header(COOKIE, "nick%20name=Hexi%20Lee") .send() .await?; assert_eq!(StatusCode::UNAUTHORIZED, resp.status()); Ok(()) } #[tokio::test] async fn cookie() -> Result<(), Box> { // miss cookie let (addr, server) = App::new().end(none).run()?; spawn(server); let resp = reqwest::get(&format!("http://{}", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); let (addr, server) = App::new().gate(cookie_parser).end(must).run()?; spawn(server); let resp = reqwest::get(&format!("http://{}", addr)).await?; assert_eq!(StatusCode::UNAUTHORIZED, resp.status()); assert_eq!( r#"Cookie name="nick%20name""#, resp.headers() .get(WWW_AUTHENTICATE) .unwrap() .to_str() .unwrap() ); // string value let (addr, server) = App::new().gate(cookie_parser).end(must).run()?; spawn(server); let client = reqwest::Client::new(); let resp = client .get(&format!("http://{}", addr)) .header(COOKIE, "nick%20name=Hexi%20Lee") .send() .await?; assert_eq!(StatusCode::OK, resp.status()); Ok(()) } #[tokio::test] async fn cookie_action() -> Result<(), Box> { async fn test(ctx: &mut Context) -> crate::Result { assert_eq!("bar baz", ctx.must_cookie("bar baz")?.value()); assert_eq!("bar foo", ctx.must_cookie("foo baz")?.value()); Ok(()) } let (addr, server) = App::new().gate(cookie_parser).end(test).run()?; spawn(server); let client = reqwest::Client::new(); let resp = client .get(&format!("http://{}", addr)) .header(COOKIE, "bar%20baz=bar%20baz; foo%20baz=bar%20foo") .send() .await?; assert_eq!(StatusCode::OK, resp.status()); Ok(()) } #[tokio::test] async fn set_cookie() -> Result<(), Box> { async fn test(ctx: &mut Context) -> crate::Result { ctx.set_cookie(Cookie::new("bar baz", "bar baz"))?; ctx.set_cookie(Cookie::new("bar foo", "foo baz"))?; Ok(()) } let (addr, server) = App::new().end(test).run()?; spawn(server); let resp = reqwest::get(&format!("http://{}", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); let cookies: Vec = resp.cookies().collect(); assert_eq!(2, cookies.len()); assert_eq!(("bar%20baz"), cookies[0].name()); assert_eq!(("bar%20baz"), cookies[0].value()); assert_eq!(("bar%20foo"), cookies[1].name()); assert_eq!(("foo%20baz"), cookies[1].value()); Ok(()) } } ================================================ FILE: roa/src/cors.rs ================================================ //! This module provides a middleware `Cors`. use std::collections::HashSet; use std::convert::TryInto; use std::fmt::Debug; use std::iter::FromIterator; use std::time::Duration; use headers::{ AccessControlAllowCredentials, AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlAllowOrigin, AccessControlExposeHeaders, AccessControlMaxAge, AccessControlRequestHeaders, AccessControlRequestMethod, Header, HeaderMapExt, }; use roa_core::Status; use crate::http::header::{HeaderName, HeaderValue, ORIGIN, VARY}; use crate::http::{Method, StatusCode}; use crate::{async_trait, Context, Middleware, Next, Result}; /// A middleware to deal with Cross-Origin Resource Sharing (CORS). /// /// ### Default /// /// The default Cors middleware will satisfy all needs of a request. /// /// Build a default Cors middleware: /// /// ```rust /// use roa::cors::Cors; /// /// let default_cors = Cors::new(); /// ``` /// /// ### Config /// /// You can also configure it: /// /// ```rust /// use roa::cors::Cors; /// use roa::http::header::{CONTENT_DISPOSITION, AUTHORIZATION, WWW_AUTHENTICATE}; /// use roa::http::Method; /// /// let cors = Cors::builder() /// .allow_credentials(true) /// .max_age(86400) /// .allow_origin("https://github.com") /// .allow_methods(vec![Method::GET, Method::POST]) /// .allow_method(Method::PUT) /// .expose_headers(vec![CONTENT_DISPOSITION]) /// .expose_header(WWW_AUTHENTICATE) /// .allow_headers(vec![AUTHORIZATION]) /// .allow_header(CONTENT_DISPOSITION) /// .build(); /// ``` #[derive(Debug, Default)] pub struct Cors { allow_origin: Option, allow_methods: Option, expose_headers: Option, allow_headers: Option, max_age: Option, credentials: Option, } /// Builder of Cors. #[derive(Clone, Debug, Default)] pub struct Builder { credentials: bool, allowed_headers: HashSet, exposed_headers: HashSet, max_age: Option, methods: HashSet, origins: Option, } impl Cors { /// Construct default Cors. pub fn new() -> Self { Self::default() } /// Get builder. pub fn builder() -> Builder { Builder::default() } } impl Builder { /// Sets whether to add the `Access-Control-Allow-Credentials` header. pub fn allow_credentials(mut self, allow: bool) -> Self { self.credentials = allow; self } /// Adds a method to the existing list of allowed request methods. pub fn allow_method(mut self, method: Method) -> Self { self.methods.insert(method); self } /// Adds multiple methods to the existing list of allowed request methods. pub fn allow_methods(mut self, methods: impl IntoIterator) -> Self { self.methods.extend(methods); self } /// Adds a header to the list of allowed request headers. /// /// # Panics /// /// Panics if header is not a valid `http::header::HeaderName`. pub fn allow_header(mut self, header: H) -> Self where H: TryInto, H::Error: Debug, { self.allowed_headers .insert(header.try_into().expect("invalid header")); self } /// Adds multiple headers to the list of allowed request headers. /// /// # Panics /// /// Panics if any of the headers are not a valid `http::header::HeaderName`. pub fn allow_headers(mut self, headers: I) -> Self where I: IntoIterator, I::Item: TryInto, >::Error: Debug, { let iter = headers .into_iter() .map(|h| h.try_into().expect("invalid header")); self.allowed_headers.extend(iter); self } /// Adds a header to the list of exposed headers. /// /// # Panics /// /// Panics if the provided argument is not a valid `http::header::HeaderName`. pub fn expose_header(mut self, header: H) -> Self where H: TryInto, H::Error: Debug, { self.exposed_headers .insert(header.try_into().expect("illegal Header")); self } /// Adds multiple headers to the list of exposed headers. /// /// # Panics /// /// Panics if any of the headers are not a valid `http::header::HeaderName`. pub fn expose_headers(mut self, headers: I) -> Self where I: IntoIterator, I::Item: TryInto, >::Error: Debug, { let iter = headers .into_iter() .map(|h| h.try_into().expect("illegal Header")); self.exposed_headers.extend(iter); self } /// Add an origin to the existing list of allowed `Origin`s. /// /// # Panics /// /// Panics if the provided argument is not a valid `HeaderValue`. pub fn allow_origin(mut self, origin: H) -> Self where H: TryInto, H::Error: Debug, { self.origins = Some(origin.try_into().expect("invalid origin")); self } /// Sets the `Access-Control-Max-Age` header. pub fn max_age(mut self, seconds: u64) -> Self { self.max_age = Some(seconds); self } /// Builds the `Cors` wrapper from the configured settings. /// /// This step isn't *required*, as the `Builder` itself can be passed /// to `Filter::with`. This just allows constructing once, thus not needing /// to pay the cost of "building" every time. pub fn build(self) -> Cors { let Builder { allowed_headers, credentials, exposed_headers, max_age, origins, methods, } = self; let mut cors = Cors::default(); if !allowed_headers.is_empty() { cors.allow_headers = Some(AccessControlAllowHeaders::from_iter(allowed_headers)) } if credentials { cors.credentials = Some(AccessControlAllowCredentials) } if !exposed_headers.is_empty() { cors.expose_headers = Some(AccessControlExposeHeaders::from_iter(exposed_headers)) } if let Some(age) = max_age { cors.max_age = Some(Duration::from_secs(age).into()) } if origins.is_some() { cors.allow_origin = Some( AccessControlAllowOrigin::decode(&mut origins.iter()).expect("invalid origins"), ); } if !methods.is_empty() { cors.allow_methods = Some(AccessControlAllowMethods::from_iter(methods)) } cors } } #[async_trait(?Send)] impl<'a, S> Middleware<'a, S> for Cors { #[inline] async fn handle(&'a self, ctx: &'a mut Context, next: Next<'a>) -> Result { // Always set Vary header // https://github.com/rs/cors/issues/10 ctx.resp.headers.append(VARY, ORIGIN.into()); let origin = match ctx.req.headers.get(ORIGIN) { // If there is no Origin header, skip this middleware. None => return next.await, Some(origin) => AccessControlAllowOrigin::decode(&mut Some(origin).into_iter()) .map_err(|err| { Status::new( StatusCode::BAD_REQUEST, format!("invalid origin: {}", err), true, ) })?, }; // If Options::allow_origin is None, `Access-Control-Allow-Origin` will be set to `Origin`. let allow_origin = self.allow_origin.clone().unwrap_or(origin); let credentials = self.credentials.clone(); let insert_origin_and_credentials = move |ctx: &mut Context| { // Set "Access-Control-Allow-Origin" ctx.resp.headers.typed_insert(allow_origin); // Try to set "Access-Control-Allow-Credentials" if let Some(credentials) = credentials { ctx.resp.headers.typed_insert(credentials); } }; if ctx.method() != Method::OPTIONS { // Simple Request insert_origin_and_credentials(ctx); // Set "Access-Control-Expose-Headers" if let Some(ref exposed_headers) = self.expose_headers { ctx.resp.headers.typed_insert(exposed_headers.clone()); } next.await } else { // Preflight Request let request_method = match ctx.req.headers.typed_get::() { // If there is no Origin header or if parsing failed, skip this middleware. None => return next.await, Some(request_method) => request_method, }; // If Options::allow_methods is None, `Access-Control-Allow-Methods` will be set to `Access-Control-Request-Method`. let allow_methods = match self.allow_methods { Some(ref origin) => origin.clone(), None => AccessControlAllowMethods::from_iter(Some(request_method.into())), }; // Try to set "Access-Control-Allow-Methods" ctx.resp.headers.typed_insert(allow_methods); insert_origin_and_credentials(ctx); // Set "Access-Control-Max-Age" if let Some(ref max_age) = self.max_age { ctx.resp.headers.typed_insert(max_age.clone()); } // If allow_headers is None, try to assign `Access-Control-Request-Headers` to `Access-Control-Allow-Headers`. let allow_headers = self.allow_headers.clone().or_else(|| { ctx.req .headers .typed_get::() .map(|headers| headers.iter().collect()) }); if let Some(headers) = allow_headers { ctx.resp.headers.typed_insert(headers); }; ctx.resp.status = StatusCode::NO_CONTENT; Ok(()) } } } #[cfg(all(test, feature = "tcp"))] mod tests { use headers::{ AccessControlAllowCredentials, AccessControlAllowOrigin, AccessControlExposeHeaders, HeaderMapExt, HeaderName, }; use tokio::task::spawn; use super::Cors; use crate::http::header::{ ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, AUTHORIZATION, CONTENT_DISPOSITION, CONTENT_TYPE, ORIGIN, VARY, WWW_AUTHENTICATE, }; use crate::http::{HeaderValue, Method, StatusCode}; use crate::preload::*; use crate::{App, Context}; async fn end(ctx: &mut Context) -> crate::Result { ctx.resp.write("Hello, World"); Ok(()) } #[tokio::test] async fn default_cors() -> Result<(), Box> { let (addr, server) = App::new().gate(Cors::new()).end(end).run()?; spawn(server); let client = reqwest::Client::new(); // No origin let resp = client.get(&format!("http://{}", addr)).send().await?; assert_eq!(StatusCode::OK, resp.status()); assert!(resp .headers() .typed_get::() .is_none()); assert_eq!( HeaderValue::from_name(ORIGIN), resp.headers().get(VARY).unwrap() ); assert_eq!("Hello, World", resp.text().await?); // invalid origin let resp = client .get(&format!("http://{}", addr)) .header(ORIGIN, "github.com") .send() .await?; assert_eq!(StatusCode::BAD_REQUEST, resp.status()); // simple request let resp = client .get(&format!("http://{}", addr)) .header(ORIGIN, "http://github.com") .send() .await?; assert_eq!(StatusCode::OK, resp.status()); let allow_origin = resp .headers() .typed_get::() .unwrap(); let origin = allow_origin.origin().unwrap(); assert_eq!("http", origin.scheme()); assert_eq!("github.com", origin.hostname()); assert!(origin.port().is_none()); assert!(resp .headers() .typed_get::() .is_none()); assert!(resp .headers() .typed_get::() .is_none()); assert_eq!("Hello, World", resp.text().await?); // options, no Access-Control-Request-Method let resp = client .request(Method::OPTIONS, &format!("http://{}", addr)) .header(ORIGIN, "http://github.com") .send() .await?; assert_eq!(StatusCode::OK, resp.status()); assert!(resp.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none()); assert_eq!( HeaderValue::from_name(ORIGIN), resp.headers().get(VARY).unwrap() ); assert_eq!("Hello, World", resp.text().await?); // options, contains Access-Control-Request-Method let resp = client .request(Method::OPTIONS, &format!("http://{}", addr)) .header(ORIGIN, "http://github.com") .header(ACCESS_CONTROL_REQUEST_METHOD, "POST") .header( ACCESS_CONTROL_REQUEST_HEADERS, HeaderValue::from_name(CONTENT_TYPE), ) .send() .await?; assert_eq!(StatusCode::NO_CONTENT, resp.status()); assert_eq!( "http://github.com", resp.headers() .get(ACCESS_CONTROL_ALLOW_ORIGIN) .unwrap() .to_str()? ); assert!(resp .headers() .get(ACCESS_CONTROL_ALLOW_CREDENTIALS) .is_none()); assert!(resp.headers().get(ACCESS_CONTROL_MAX_AGE).is_none()); assert_eq!( "POST", resp.headers() .get(ACCESS_CONTROL_ALLOW_METHODS) .unwrap() .to_str()? ); assert_eq!( HeaderValue::from_name(CONTENT_TYPE), resp.headers().get(ACCESS_CONTROL_ALLOW_HEADERS).unwrap() ); assert_eq!("", resp.text().await?); // Ok(()) } #[tokio::test] async fn configured_cors() -> Result<(), Box> { let configured_cors = Cors::builder() .allow_credentials(true) .max_age(86400) .allow_origin("https://github.com") .allow_methods(vec![Method::GET, Method::POST]) .allow_method(Method::PUT) .expose_headers(vec![CONTENT_DISPOSITION]) .expose_header(WWW_AUTHENTICATE) .allow_headers(vec![AUTHORIZATION]) .allow_header(CONTENT_TYPE) .build(); let (addr, server) = App::new().gate(configured_cors).end(end).run()?; spawn(server); let client = reqwest::Client::new(); // No origin let resp = client.get(&format!("http://{}", addr)).send().await?; assert_eq!(StatusCode::OK, resp.status()); assert!(resp .headers() .typed_get::() .is_none()); assert_eq!( HeaderValue::from_name(ORIGIN), resp.headers().get(VARY).unwrap() ); assert_eq!("Hello, World", resp.text().await?); // invalid origin let resp = client .get(&format!("http://{}", addr)) .header(ORIGIN, "github.com") .send() .await?; assert_eq!(StatusCode::BAD_REQUEST, resp.status()); // simple request let resp = client .get(&format!("http://{}", addr)) .header(ORIGIN, "http://github.io") .send() .await?; assert_eq!(StatusCode::OK, resp.status()); let allow_origin = resp .headers() .typed_get::() .unwrap(); let origin = allow_origin.origin().unwrap(); assert_eq!("https", origin.scheme()); assert_eq!("github.com", origin.hostname()); assert!(origin.port().is_none()); assert!(resp .headers() .typed_get::() .is_some()); let expose_headers = resp .headers() .typed_get::() .unwrap(); let headers = expose_headers.iter().collect::>(); assert!(headers.contains(&CONTENT_DISPOSITION)); assert!(headers.contains(&WWW_AUTHENTICATE)); assert_eq!("Hello, World", resp.text().await?); // options, no Access-Control-Request-Method let resp = client .request(Method::OPTIONS, &format!("http://{}", addr)) .header(ORIGIN, "http://github.com") .send() .await?; assert_eq!(StatusCode::OK, resp.status()); assert!(resp.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none()); assert_eq!( HeaderValue::from_name(ORIGIN), resp.headers().get(VARY).unwrap() ); assert_eq!("Hello, World", resp.text().await?); // options, contains Access-Control-Request-Method let resp = client .request(Method::OPTIONS, &format!("http://{}", addr)) .header(ORIGIN, "http://github.io") .header(ACCESS_CONTROL_REQUEST_METHOD, "POST") .header( ACCESS_CONTROL_REQUEST_HEADERS, HeaderValue::from_name(CONTENT_TYPE), ) .send() .await?; assert_eq!(StatusCode::NO_CONTENT, resp.status()); assert_eq!( "https://github.com", resp.headers() .get(ACCESS_CONTROL_ALLOW_ORIGIN) .unwrap() .to_str()? ); assert_eq!( "true", resp.headers() .get(ACCESS_CONTROL_ALLOW_CREDENTIALS) .unwrap() .to_str()? ); assert_eq!("86400", resp.headers().get(ACCESS_CONTROL_MAX_AGE).unwrap()); let allow_methods = resp .headers() .get(ACCESS_CONTROL_ALLOW_METHODS) .unwrap() .to_str()?; assert!(allow_methods.contains("POST")); assert!(allow_methods.contains("GET")); assert!(allow_methods.contains("PUT")); let allow_headers = resp .headers() .get(ACCESS_CONTROL_ALLOW_HEADERS) .unwrap() .to_str()?; assert!(allow_headers.contains(CONTENT_TYPE.as_str())); assert!(allow_headers.contains(AUTHORIZATION.as_str())); assert_eq!("", resp.text().await?); // Ok(()) } } ================================================ FILE: roa/src/forward.rs ================================================ //! This module provides a context extension `Forward`, //! which is used to parse `X-Forwarded-*` headers. use std::net::IpAddr; use crate::http::header::HOST; use crate::{Context, State}; /// A context extension `Forward` used to parse `X-Forwarded-*` request headers. pub trait Forward { /// Get true host. /// - If "x-forwarded-host" is set and valid, use it. /// - Else if "host" is set and valid, use it. /// - Else throw Err(400 BAD REQUEST). /// /// ### Example /// ```rust /// use roa::{Context, Result}; /// use roa::forward::Forward; /// /// async fn get(ctx: &mut Context) -> Result { /// if let Some(host) = ctx.host() { /// println!("host: {}", host); /// } /// Ok(()) /// } /// ``` fn host(&self) -> Option<&str>; /// Get true client ip. /// - If "x-forwarded-for" is set and valid, use the first ip. /// - Else use the ip of `Context::remote_addr()`. /// /// ### Example /// ```rust /// use roa::{Context, Result}; /// use roa::forward::Forward; /// /// async fn get(ctx: &mut Context) -> Result { /// println!("client ip: {}", ctx.client_ip()); /// Ok(()) /// } /// ``` fn client_ip(&self) -> IpAddr; /// Get true forwarded ips. /// - If "x-forwarded-for" is set and valid, use it. /// - Else return an empty vector. /// /// ### Example /// ```rust /// use roa::{Context, Result}; /// use roa::forward::Forward; /// /// async fn get(ctx: &mut Context) -> Result { /// println!("forwarded ips: {:?}", ctx.forwarded_ips()); /// Ok(()) /// } /// ``` fn forwarded_ips(&self) -> Vec; /// Try to get forwarded proto. /// - If "x-forwarded-proto" is not set, return None. /// - If "x-forwarded-proto" is set but fails to string, return Some(Err(400 BAD REQUEST)). /// /// ### Example /// ```rust /// use roa::{Context, Result}; /// use roa::forward::Forward; /// /// async fn get(ctx: &mut Context) -> Result { /// if let Some(proto) = ctx.forwarded_proto() { /// println!("forwarded proto: {}", proto); /// } /// Ok(()) /// } /// ``` fn forwarded_proto(&self) -> Option<&str>; } impl Forward for Context { #[inline] fn host(&self) -> Option<&str> { self.get("x-forwarded-host").or_else(|| self.get(HOST)) } #[inline] fn client_ip(&self) -> IpAddr { let addrs = self.forwarded_ips(); if addrs.is_empty() { self.remote_addr.ip() } else { addrs[0] } } #[inline] fn forwarded_ips(&self) -> Vec { let mut addrs = Vec::new(); if let Some(value) = self.get("x-forwarded-for") { for addr_str in value.split(',') { if let Ok(addr) = addr_str.trim().parse() { addrs.push(addr) } } } addrs } #[inline] fn forwarded_proto(&self) -> Option<&str> { self.get("x-forwarded-proto") } } #[cfg(all(test, feature = "tcp"))] mod tests { use tokio::task::spawn; use super::Forward; use crate::http::header::HOST; use crate::http::{HeaderValue, StatusCode}; use crate::preload::*; use crate::{App, Context}; #[tokio::test] async fn host() -> Result<(), Box> { async fn test(ctx: &mut Context) -> crate::Result { assert_eq!(Some("github.com"), ctx.host()); Ok(()) } let (addr, server) = App::new().end(test).run()?; spawn(server); let client = reqwest::Client::new(); let resp = client .get(&format!("http://{}", addr)) .header(HOST, HeaderValue::from_static("github.com")) .send() .await?; assert_eq!(StatusCode::OK, resp.status()); let resp = client .get(&format!("http://{}", addr)) .header(HOST, "google.com") .header("x-forwarded-host", "github.com") .send() .await?; assert_eq!(StatusCode::OK, resp.status()); Ok(()) } #[tokio::test] async fn host_err() -> Result<(), Box> { async fn test(ctx: &mut Context) -> crate::Result { ctx.req.headers.remove(HOST); assert_eq!(None, ctx.host()); Ok(()) } let (addr, server) = App::new().end(test).run()?; spawn(server); let resp = reqwest::get(&format!("http://{}", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); Ok(()) } #[tokio::test] async fn client_ip() -> Result<(), Box> { async fn remote_addr(ctx: &mut Context) -> crate::Result { assert_eq!(ctx.remote_addr.ip(), ctx.client_ip()); Ok(()) } let (addr, server) = App::new().end(remote_addr).run()?; spawn(server); reqwest::get(&format!("http://{}", addr)).await?; async fn forward_addr(ctx: &mut Context) -> crate::Result { assert_eq!("192.168.0.1", ctx.client_ip().to_string()); Ok(()) } let (addr, server) = App::new().end(forward_addr).run()?; spawn(server); let client = reqwest::Client::new(); client .get(&format!("http://{}", addr)) .header("x-forwarded-for", "192.168.0.1, 8.8.8.8") .send() .await?; Ok(()) } #[tokio::test] async fn forwarded_proto() -> Result<(), Box> { async fn test(ctx: &mut Context) -> crate::Result { assert_eq!(Some("https"), ctx.forwarded_proto()); Ok(()) } let (addr, server) = App::new().end(test).run()?; spawn(server); let client = reqwest::Client::new(); client .get(&format!("http://{}", addr)) .header("x-forwarded-proto", "https") .send() .await?; Ok(()) } } ================================================ FILE: roa/src/jsonrpc.rs ================================================ //! //! ## roa::jsonrpc //! //! This module provides a json rpc endpoint. //! //! ### Example //! //! ```rust,no_run //! use roa::App; //! use roa::jsonrpc::{RpcEndpoint, Data, Error, Params, Server}; //! use roa::tcp::Listener; //! use tracing::info; //! //! #[derive(serde::Deserialize)] //! struct TwoNums { //! a: usize, //! b: usize, //! } //! //! async fn add(Params(params): Params) -> Result { //! Ok(params.a + params.b) //! } //! //! async fn sub(Params(params): Params<(usize, usize)>) -> Result { //! Ok(params.0 - params.1) //! } //! //! async fn message(data: Data) -> Result { //! Ok(String::from(&*data)) //! } //! //! #[tokio::main] //! async fn main() -> anyhow::Result<()> { //! let rpc = Server::new() //! .with_data(Data::new(String::from("Hello!"))) //! .with_method("sub", sub) //! .with_method("message", message) //! .finish_unwrapped(); //! //! let app = App::new().end(RpcEndpoint(rpc)); //! app.listen("127.0.0.1:8000", |addr| { //! info!("Server is listening on {}", addr) //! })? //! .await?; //! Ok(()) //! } //! ``` use bytes::Bytes; #[doc(no_inline)] pub use jsonrpc_v2::*; use crate::body::PowerBody; use crate::{async_trait, Context, Endpoint, Result, State}; /// A wrapper for [`jsonrpc_v2::Server`], implemented [`roa::Endpoint`]. /// /// [`jsonrpc_v2::Server`]: https://docs.rs/jsonrpc-v2/0.10.1/jsonrpc_v2/struct.Server.html /// [`roa::Endpoint`]: https://docs.rs/roa/0.6.0/roa/trait.Endpoint.html pub struct RpcEndpoint(pub Server); #[async_trait(? Send)] impl<'a, S, R> Endpoint<'a, S> for RpcEndpoint where S: State, R: Router + Sync + Send + 'static, { #[inline] async fn call(&'a self, ctx: &'a mut Context) -> Result { let data = ctx.read().await?; let resp = self.0.handle(Bytes::from(data)).await; ctx.write_json(&resp) } } ================================================ FILE: roa/src/jwt.rs ================================================ //! This module provides middleware `JwtGuard` and a context extension `JwtVerifier`. //! //! ### Example //! //! ```rust //! use roa::jwt::{guard, DecodingKey}; //! use roa::{App, Context}; //! use roa::http::header::AUTHORIZATION; //! use roa::http::StatusCode; //! use roa::preload::*; //! use tokio::task::spawn; //! use jsonwebtoken::{encode, Header, EncodingKey}; //! use serde::{Deserialize, Serialize}; //! use std::time::{Duration, SystemTime, UNIX_EPOCH}; //! //! #[derive(Debug, Serialize, Deserialize)] //! struct User { //! sub: String, //! company: String, //! exp: u64, //! id: u64, //! name: String, //! } //! //! const SECRET: &[u8] = b"123456"; //! //! async fn test(ctx: &mut Context) -> roa::Result { //! let user: User = ctx.claims()?; //! assert_eq!(0, user.id); //! assert_eq!("Hexilee", &user.name); //! Ok(()) //! } //! //! #[tokio::main] //! async fn main() -> Result<(), Box> { //! let (addr, server) = App::new() //! .gate(guard(DecodingKey::from_secret(SECRET))) //! .end(test).run()?; //! spawn(server); //! let mut user = User { //! sub: "user".to_string(), //! company: "None".to_string(), //! exp: (SystemTime::now() + Duration::from_secs(86400)) //! .duration_since(UNIX_EPOCH)? //! .as_secs(), //! id: 0, //! name: "Hexilee".to_string(), //! }; //! //! let client = reqwest::Client::new(); //! let resp = client //! .get(&format!("http://{}", addr)) //! .header( //! AUTHORIZATION, //! format!( //! "Bearer {}", //! encode( //! &Header::default(), //! &user, //! &EncodingKey::from_secret(SECRET) //! )? //! ), //! ) //! .send() //! .await?; //! assert_eq!(StatusCode::OK, resp.status()); //! Ok(()) //! } //! ``` use headers::authorization::Bearer; use headers::{Authorization, HeaderMapExt}; use jsonwebtoken::decode; pub use jsonwebtoken::{DecodingKey, Validation}; use serde::de::DeserializeOwned; use serde_json::Value; use crate::http::header::{HeaderValue, WWW_AUTHENTICATE}; use crate::http::StatusCode; use crate::{async_trait, throw, Context, Middleware, Next, Result, Status}; /// A private scope. struct JwtScope; static INVALID_TOKEN: HeaderValue = HeaderValue::from_static(r#"Bearer realm="", error="invalid_token""#); /// A function to set value of WWW_AUTHENTICATE. #[inline] fn set_www_authenticate(ctx: &mut Context) { ctx.resp .headers .insert(WWW_AUTHENTICATE, INVALID_TOKEN.clone()); } /// Throw a internal server error. #[inline] fn guard_not_set() -> Status { Status::new( StatusCode::INTERNAL_SERVER_ERROR, "middleware `JwtGuard` is not set correctly", false, ) } /// A context extension. /// This extension must be used in downstream of middleware `guard` or `guard_by`, /// otherwise you cannot get expected claims. /// /// ### Example /// /// ```rust /// use roa::{Context, Result}; /// use roa::jwt::JwtVerifier; /// use serde_json::Value; /// /// async fn get(ctx: &mut Context) -> Result { /// let claims: Value = ctx.claims()?; /// Ok(()) /// } /// ``` pub trait JwtVerifier { /// Deserialize claims from token. fn claims(&self) -> Result where C: 'static + DeserializeOwned; /// Verify token and deserialize claims with a validation. /// Use this method if this validation is different from that one of `JwtGuard`. fn verify(&mut self, validation: &Validation) -> Result where C: 'static + DeserializeOwned; } /// Guard by default validation. pub fn guard(secret: DecodingKey) -> JwtGuard { JwtGuard::new(secret, Validation::default()) } /// A middleware to deny unauthorized requests. /// /// The json web token should be deliver by request header "authorization", /// in format of `Authorization: Bearer `. /// /// If request fails to pass verification, return 401 UNAUTHORIZED and set response header "WWW-Authenticate". #[derive(Debug, Clone, PartialEq)] pub struct JwtGuard { secret: DecodingKey<'static>, validation: Validation, } impl JwtGuard { /// Construct guard. pub fn new(secret: DecodingKey, validation: Validation) -> Self { Self { secret: secret.into_static(), validation, } } /// Verify token. #[inline] fn verify(&self, ctx: &Context) -> Option<(Bearer, Value)> { let bearer = ctx.req.headers.typed_get::>()?.0; let value = decode::(bearer.token(), &self.secret, &self.validation) .ok()? .claims; Some((bearer, value)) } } #[async_trait(? Send)] impl<'a, S> Middleware<'a, S> for JwtGuard { #[inline] async fn handle(&'a self, ctx: &'a mut Context, next: Next<'a>) -> Result { match self.verify(ctx) { None => { set_www_authenticate(ctx); throw!(StatusCode::UNAUTHORIZED) } Some((bearer, value)) => { ctx.store_scoped(JwtScope, "secret", self.secret.clone()); ctx.store_scoped(JwtScope, "token", bearer); ctx.store_scoped(JwtScope, "value", value); next.await } } } } impl JwtVerifier for Context { #[inline] fn claims(&self) -> Result where C: 'static + DeserializeOwned, { let value = self.load_scoped::("value"); match value { Some(claims) => Ok(serde_json::from_value((*claims).clone())?), None => Err(guard_not_set()), } } #[inline] fn verify(&mut self, validation: &Validation) -> Result where C: 'static + DeserializeOwned, { let secret = self.load_scoped::>("secret"); let token = self.load_scoped::("token"); match (secret, token) { (Some(secret), Some(token)) => match decode(token.token(), &secret, validation) { Ok(data) => Ok(data.claims), Err(_) => { set_www_authenticate(self); throw!(StatusCode::UNAUTHORIZED) } }, _ => Err(guard_not_set()), } } } #[cfg(all(test, feature = "tcp"))] mod tests { use std::time::{Duration, SystemTime, UNIX_EPOCH}; use jsonwebtoken::{encode, EncodingKey, Header}; use serde::{Deserialize, Serialize}; use tokio::task::spawn; use super::{guard, DecodingKey, INVALID_TOKEN}; use crate::http::header::{AUTHORIZATION, WWW_AUTHENTICATE}; use crate::http::StatusCode; use crate::preload::*; use crate::{App, Context}; #[derive(Debug, Serialize, Deserialize)] struct User { sub: String, company: String, exp: u64, id: u64, name: String, } const SECRET: &[u8] = b"123456"; #[tokio::test] async fn claims() -> Result<(), Box> { async fn test(ctx: &mut Context) -> crate::Result { let user: User = ctx.claims()?; assert_eq!(0, user.id); assert_eq!("Hexilee", &user.name); Ok(()) } let (addr, server) = App::new() .gate(guard(DecodingKey::from_secret(SECRET))) .end(test) .run()?; spawn(server); let resp = reqwest::get(&format!("http://{}", addr)).await?; assert_eq!(StatusCode::UNAUTHORIZED, resp.status()); assert_eq!(&INVALID_TOKEN, &resp.headers()[WWW_AUTHENTICATE]); // non-string header value let client = reqwest::Client::new(); let resp = client .get(&format!("http://{}", addr)) .header(AUTHORIZATION, [255].as_ref()) .send() .await?; assert_eq!(StatusCode::UNAUTHORIZED, resp.status()); assert_eq!(&INVALID_TOKEN, &resp.headers()[WWW_AUTHENTICATE]); // non-Bearer header value let resp = client .get(&format!("http://{}", addr)) .header(AUTHORIZATION, "Basic hahaha") .send() .await?; assert_eq!(StatusCode::UNAUTHORIZED, resp.status()); assert_eq!(&INVALID_TOKEN, &resp.headers()[WWW_AUTHENTICATE]); // invalid token let resp = client .get(&format!("http://{}", addr)) .header(AUTHORIZATION, "Bearer hahaha") .send() .await?; assert_eq!(StatusCode::UNAUTHORIZED, resp.status()); assert_eq!(&INVALID_TOKEN, &resp.headers()[WWW_AUTHENTICATE]); // expired token let mut user = User { sub: "user".to_string(), company: "None".to_string(), exp: (SystemTime::now() - Duration::from_secs(1)) .duration_since(UNIX_EPOCH)? .as_secs(), // one second ago id: 0, name: "Hexilee".to_string(), }; let resp = client .get(&format!("http://{}", addr)) .header( AUTHORIZATION, format!( "Bearer {}", encode(&Header::default(), &user, &EncodingKey::from_secret(SECRET),)? ), ) .send() .await?; assert_eq!(StatusCode::UNAUTHORIZED, resp.status()); assert_eq!(&INVALID_TOKEN, &resp.headers()[WWW_AUTHENTICATE]); user.exp = (SystemTime::now() + Duration::from_millis(60)) .duration_since(UNIX_EPOCH)? .as_secs(); // one hour later let resp = client .get(&format!("http://{}", addr)) .header( AUTHORIZATION, format!( "Bearer {}", encode(&Header::default(), &user, &EncodingKey::from_secret(SECRET),)? ), ) .send() .await?; assert_eq!(StatusCode::OK, resp.status()); Ok(()) } #[tokio::test] async fn jwt_verify_not_set() -> Result<(), Box> { async fn test(ctx: &mut Context) -> crate::Result { let _: User = ctx.claims()?; Ok(()) } let (addr, server) = App::new().end(test).run()?; spawn(server); let resp = reqwest::get(&format!("http://{}", addr)).await?; assert_eq!(StatusCode::INTERNAL_SERVER_ERROR, resp.status()); Ok(()) } } ================================================ FILE: roa/src/lib.rs ================================================ #![cfg_attr(feature = "docs", feature(doc_cfg))] #![cfg_attr(feature = "docs", doc = include_str!("../README.md"))] #![cfg_attr(feature = "docs", warn(missing_docs))] pub use roa_core::*; #[cfg(feature = "router")] #[cfg_attr(feature = "docs", doc(cfg(feature = "router")))] pub mod router; #[cfg(feature = "tcp")] #[cfg_attr(feature = "docs", doc(cfg(feature = "tcp")))] pub mod tcp; #[cfg(feature = "tls")] #[cfg_attr(feature = "docs", doc(cfg(feature = "tls")))] pub mod tls; #[cfg(feature = "websocket")] #[cfg_attr(feature = "docs", doc(cfg(feature = "websocket")))] pub mod websocket; #[cfg(feature = "cookies")] #[cfg_attr(feature = "docs", doc(cfg(feature = "cookies")))] pub mod cookie; #[cfg(feature = "jwt")] #[cfg_attr(feature = "docs", doc(cfg(feature = "jwt")))] pub mod jwt; #[cfg(feature = "compress")] #[cfg_attr(feature = "docs", doc(cfg(feature = "compress")))] pub mod compress; #[cfg(feature = "jsonrpc")] #[cfg_attr(feature = "docs", doc(cfg(feature = "jsonrpc")))] pub mod jsonrpc; pub mod body; pub mod cors; pub mod forward; pub mod logger; pub mod query; pub mod stream; /// Reexport all extension traits. pub mod preload { pub use crate::body::PowerBody; #[cfg(feature = "cookies")] pub use crate::cookie::{CookieGetter, CookieSetter}; pub use crate::forward::Forward; #[cfg(feature = "jwt")] pub use crate::jwt::JwtVerifier; pub use crate::query::Query; #[cfg(feature = "router")] pub use crate::router::RouterParam; #[cfg(feature = "tcp")] #[doc(no_inline)] pub use crate::tcp::Listener; #[cfg(all(feature = "tcp", feature = "tls"))] #[doc(no_inline)] pub use crate::tls::TlsListener; } ================================================ FILE: roa/src/logger.rs ================================================ //! This module provides a middleware `logger`. //! //! ### Example //! //! ```rust //! use roa::logger::logger; //! use roa::preload::*; //! use roa::App; //! use roa::http::StatusCode; //! use tokio::task::spawn; //! //! #[tokio::main] //! async fn main() -> Result<(), Box> { //! pretty_env_logger::init(); //! let app = App::new() //! .gate(logger) //! .end("Hello, World"); //! let (addr, server) = app.run()?; //! spawn(server); //! let resp = reqwest::get(&format!("http://{}", addr)).await?; //! assert_eq!(StatusCode::OK, resp.status()); //! Ok(()) //! } //! ``` use std::pin::Pin; use std::time::Instant; use std::{io, mem}; use bytes::Bytes; use bytesize::ByteSize; use futures::task::{self, Poll}; use futures::{Future, Stream}; use roa_core::http::{Method, StatusCode}; use tracing::{error, info}; use crate::http::Uri; use crate::{Context, Executor, JoinHandle, Next, Result}; /// A finite-state machine to log success information in each successful response. enum StreamLogger { /// Polling state, as a body stream. Polling { stream: S, task: LogTask }, /// Logging state, as a logger future. Logging(JoinHandle<()>), /// Complete, as a empty stream. Complete, } /// A task structure to log when polling is complete. #[derive(Clone)] struct LogTask { counter: u64, method: Method, status_code: StatusCode, uri: Uri, start: Instant, exec: Executor, } impl LogTask { #[inline] fn log(&self) -> JoinHandle<()> { let LogTask { counter, method, status_code, uri, start, exec, } = self.clone(); exec.spawn_blocking(move || { info!( "<-- {} {} {}ms {} {}", method, uri, start.elapsed().as_millis(), ByteSize(counter), status_code, ) }) } } impl Stream for StreamLogger where S: 'static + Send + Send + Unpin + Stream>, { type Item = io::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { match &mut *self { StreamLogger::Polling { stream, task } => { match futures::ready!(Pin::new(stream).poll_next(cx)) { Some(Ok(bytes)) => { task.counter += bytes.len() as u64; Poll::Ready(Some(Ok(bytes))) } None => { let handler = task.log(); *self = StreamLogger::Logging(handler); self.poll_next(cx) } err => Poll::Ready(err), } } StreamLogger::Logging(handler) => { futures::ready!(Pin::new(handler).poll(cx)); *self = StreamLogger::Complete; self.poll_next(cx) } StreamLogger::Complete => Poll::Ready(None), } } } /// A middleware to log information about request and response. /// /// Based on crate `log`, the log level must be greater than `INFO` to log all information, /// and should be greater than `ERROR` when you need error information only. pub async fn logger(ctx: &mut Context, next: Next<'_>) -> Result { info!("--> {} {}", ctx.method(), ctx.uri().path()); let start = Instant::now(); let mut result = next.await; let method = ctx.method().clone(); let uri = ctx.uri().clone(); let exec = ctx.exec.clone(); match &mut result { Err(status) => { let status_code = status.status_code; let message = if status.expose { status.message.clone() } else { // set expose to true; then root status_handler won't log this status. status.expose = true; // take unexposed message mem::take(&mut status.message) }; ctx.exec .spawn_blocking(move || { error!("<-- {} {} {}\n{}", method, uri, status_code, message,); }) .await } Ok(_) => { let status_code = ctx.status(); // logging when body polling complete. let logger = StreamLogger::Polling { stream: mem::take(&mut ctx.resp.body), task: LogTask { counter: 0, method, uri, status_code, start, exec, }, }; ctx.resp.write_stream(logger); } } result } ================================================ FILE: roa/src/query.rs ================================================ //! This module provides a middleware `query_parser` and a context extension `Query`. //! //! ### Example //! //! ```rust //! use roa::query::query_parser; //! use roa::{App, Context}; //! use roa::http::StatusCode; //! use roa::preload::*; //! use tokio::task::spawn; //! //! async fn must(ctx: &mut Context) -> roa::Result { //! assert_eq!("Hexilee", &*ctx.must_query("name")?); //! Ok(()) //! } //! //! #[tokio::main] //! async fn main() -> Result<(), Box> { //! let app = App::new() //! .gate(query_parser) //! .end(must); //! let (addr, server) = app.run()?; //! spawn(server); //! let resp = reqwest::get(&format!("http://{}?name=Hexilee", addr)).await?; //! assert_eq!(StatusCode::OK, resp.status()); //! Ok(()) //! } //! ``` use url::form_urlencoded::parse; use crate::http::StatusCode; use crate::{Context, Next, Result, Status, Variable}; /// A scope to store and load variables in Context::storage. struct QueryScope; /// A context extension. /// This extension must be used in downstream of middleware `query_parser`, /// otherwise you cannot get expected query variable. /// /// ### Example /// /// ```rust /// use roa::query::query_parser; /// use roa::{App, Context}; /// use roa::http::StatusCode; /// use roa::preload::*; /// use tokio::task::spawn; /// /// async fn must(ctx: &mut Context) -> roa::Result { /// assert_eq!("Hexilee", &*ctx.must_query("name")?); /// Ok(()) /// } /// /// #[tokio::main] /// async fn main() -> Result<(), Box> { /// // downstream of `query_parser` /// let app = App::new() /// .gate(query_parser) /// .end(must); /// let (addr, server) = app.run()?; /// spawn(server); /// let resp = reqwest::get(&format!("http://{}?name=Hexilee", addr)).await?; /// assert_eq!(StatusCode::OK, resp.status()); /// /// // miss `query_parser` /// let app = App::new().end(must); /// let (addr, server) = app.run()?; /// spawn(server); /// let resp = reqwest::get(&format!("http://{}?name=Hexilee", addr)).await?; /// assert_eq!(StatusCode::BAD_REQUEST, resp.status()); /// Ok(()) /// } /// ``` pub trait Query { /// Must get a variable, throw 400 BAD_REQUEST if it not exists. /// ### Example /// /// ```rust /// use roa::query::query_parser; /// use roa::{App, Context}; /// use roa::http::StatusCode; /// use roa::preload::*; /// use tokio::task::spawn; /// /// async fn must(ctx: &mut Context) -> roa::Result { /// assert_eq!("Hexilee", &*ctx.must_query("name")?); /// Ok(()) /// } /// /// #[tokio::main] /// async fn main() -> Result<(), Box> { /// // downstream of `query_parser` /// let app = App::new() /// .gate(query_parser) /// .end(must); /// let (addr, server) = app.run()?; /// spawn(server); /// let resp = reqwest::get(&format!("http://{}", addr)).await?; /// assert_eq!(StatusCode::BAD_REQUEST, resp.status()); /// Ok(()) /// } /// ``` fn must_query<'a>(&self, name: &'a str) -> Result>; /// Query a variable, return `None` if it not exists. /// ### Example /// /// ```rust /// use roa::query::query_parser; /// use roa::{App, Context}; /// use roa::http::StatusCode; /// use roa::preload::*; /// use tokio::task::spawn; /// /// async fn test(ctx: &mut Context) -> roa::Result { /// assert!(ctx.query("name").is_none()); /// Ok(()) /// } /// /// #[tokio::main] /// async fn main() -> Result<(), Box> { /// // downstream of `query_parser` /// let app = App::new() /// .gate(query_parser) /// .end(test); /// let (addr, server) = app.run()?; /// spawn(server); /// let resp = reqwest::get(&format!("http://{}", addr)).await?; /// assert_eq!(StatusCode::OK, resp.status()); /// Ok(()) /// } /// ``` fn query<'a>(&self, name: &'a str) -> Option>; } /// A middleware to parse query. #[inline] pub async fn query_parser(ctx: &mut Context, next: Next<'_>) -> Result { let query_string = ctx.uri().query().unwrap_or(""); let pairs: Vec<(String, String)> = parse(query_string.as_bytes()).into_owned().collect(); for (key, value) in pairs { ctx.store_scoped(QueryScope, key, value); } next.await } impl Query for Context { #[inline] fn must_query<'a>(&self, name: &'a str) -> Result> { self.query(name).ok_or_else(|| { Status::new( StatusCode::BAD_REQUEST, format!("query `{}` is required", name), true, ) }) } #[inline] fn query<'a>(&self, name: &'a str) -> Option> { self.load_scoped::(name) } } #[cfg(all(test, feature = "tcp"))] mod tests { use tokio::task::spawn; use crate::http::StatusCode; use crate::preload::*; use crate::query::query_parser; use crate::{App, Context}; #[tokio::test] async fn query() -> Result<(), Box> { async fn test(ctx: &mut Context) -> crate::Result { assert_eq!("Hexilee", &*ctx.must_query("name")?); Ok(()) } // miss key let (addr, server) = App::new().gate(query_parser).end(test).run()?; spawn(server); let resp = reqwest::get(&format!("http://{}/", addr)).await?; assert_eq!(StatusCode::BAD_REQUEST, resp.status()); // string value let (addr, server) = App::new().gate(query_parser).end(test).run()?; spawn(server); let resp = reqwest::get(&format!("http://{}?name=Hexilee", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); Ok(()) } #[tokio::test] async fn query_parse() -> Result<(), Box> { async fn test(ctx: &mut Context) -> crate::Result { assert_eq!(120, ctx.must_query("age")?.parse::()?); Ok(()) } // invalid int value let (addr, server) = App::new().gate(query_parser).end(test).run()?; spawn(server); let resp = reqwest::get(&format!("http://{}?age=Hexilee", addr)).await?; assert_eq!(StatusCode::BAD_REQUEST, resp.status()); let (addr, server) = App::new().gate(query_parser).end(test).run()?; spawn(server); let resp = reqwest::get(&format!("http://{}?age=120", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); Ok(()) } #[tokio::test] async fn query_action() -> Result<(), Box> { async fn test(ctx: &mut Context) -> crate::Result { assert_eq!("Hexilee", &*ctx.must_query("name")?); assert_eq!("rust", &*ctx.must_query("lang")?); Ok(()) } let (addr, server) = App::new().gate(query_parser).end(test).run()?; spawn(server); let resp = reqwest::get(&format!("http://{}?name=Hexilee&lang=rust", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); Ok(()) } } ================================================ FILE: roa/src/router/endpoints/dispatcher.rs ================================================ use std::collections::HashMap; use doc_comment::doc_comment; use super::method_not_allowed; use crate::http::Method; use crate::{async_trait, Context, Endpoint, Result}; macro_rules! impl_http_methods { ($end:ident, $method:expr) => { doc_comment! { concat!("Method to add or override endpoint on ", stringify!($method), ". You can use it as follow: ```rust use roa::{App, Context, Result}; use roa::router::get; async fn foo(ctx: &mut Context) -> Result { Ok(()) } async fn bar(ctx: &mut Context) -> Result { Ok(()) } let app = App::new().end(get(foo).", stringify!($end), "(bar)); ```"), pub fn $end(mut self, endpoint: impl for<'a> Endpoint<'a, S>) -> Self { self.0.insert($method, Box::new(endpoint)); self } } }; } macro_rules! impl_http_functions { ($end:ident, $method:expr) => { doc_comment! { concat!("Function to construct dispatcher with ", stringify!($method), " and an endpoint. You can use it as follow: ```rust use roa::{App, Context, Result}; use roa::router::", stringify!($end), "; async fn end(ctx: &mut Context) -> Result { Ok(()) } let app = App::new().end(", stringify!($end), "(end)); ```"), pub fn $end(endpoint: impl for<'a> Endpoint<'a, S>) -> Dispatcher { Dispatcher::::default().$end(endpoint) } } }; } /// An endpoint wrapper to dispatch requests by http method. pub struct Dispatcher(HashMap Endpoint<'a, S>>>); impl_http_functions!(get, Method::GET); impl_http_functions!(post, Method::POST); impl_http_functions!(put, Method::PUT); impl_http_functions!(patch, Method::PATCH); impl_http_functions!(options, Method::OPTIONS); impl_http_functions!(delete, Method::DELETE); impl_http_functions!(head, Method::HEAD); impl_http_functions!(trace, Method::TRACE); impl_http_functions!(connect, Method::CONNECT); impl Dispatcher { impl_http_methods!(get, Method::GET); impl_http_methods!(post, Method::POST); impl_http_methods!(put, Method::PUT); impl_http_methods!(patch, Method::PATCH); impl_http_methods!(options, Method::OPTIONS); impl_http_methods!(delete, Method::DELETE); impl_http_methods!(head, Method::HEAD); impl_http_methods!(trace, Method::TRACE); impl_http_methods!(connect, Method::CONNECT); } /// Empty dispatcher. impl Default for Dispatcher { fn default() -> Self { Self(HashMap::new()) } } #[async_trait(?Send)] impl<'a, S> Endpoint<'a, S> for Dispatcher where S: 'static, { #[inline] async fn call(&'a self, ctx: &'a mut Context) -> Result<()> { match self.0.get(ctx.method()) { Some(endpoint) => endpoint.call(ctx).await, None => method_not_allowed(ctx.method()), } } } ================================================ FILE: roa/src/router/endpoints/guard.rs ================================================ use std::collections::HashSet; use std::iter::FromIterator; use super::method_not_allowed; use crate::http::Method; use crate::{async_trait, Context, Endpoint, Result}; /// Methods allowed in `Guard`. const ALL_METHODS: [Method; 9] = [ Method::GET, Method::POST, Method::PUT, Method::PATCH, Method::OPTIONS, Method::DELETE, Method::HEAD, Method::TRACE, Method::CONNECT, ]; /// An endpoint wrapper to guard endpoint by http method. pub struct Guard { white_list: HashSet, endpoint: E, } /// Initialize hash set. fn hash_set(methods: impl AsRef<[Method]>) -> HashSet { HashSet::from_iter(methods.as_ref().to_vec()) } /// A function to construct guard by white list. /// /// Only requests with http method in list can access this endpoint, otherwise will get a 405 METHOD NOT ALLOWED. /// /// ``` /// use roa::{App, Context, Result}; /// use roa::http::Method; /// use roa::router::allow; /// /// async fn foo(ctx: &mut Context) -> Result { /// Ok(()) /// } /// /// let app = App::new().end(allow([Method::GET, Method::POST], foo)); /// ``` pub fn allow(methods: impl AsRef<[Method]>, endpoint: E) -> Guard { Guard { endpoint, white_list: hash_set(methods), } } /// A function to construct guard by black list. /// /// Only requests with http method not in list can access this endpoint, otherwise will get a 405 METHOD NOT ALLOWED. /// /// ``` /// use roa::{App, Context, Result}; /// use roa::http::Method; /// use roa::router::deny; /// /// async fn foo(ctx: &mut Context) -> Result { /// Ok(()) /// } /// /// let app = App::new().end(deny([Method::PUT, Method::DELETE], foo)); /// ``` pub fn deny(methods: impl AsRef<[Method]>, endpoint: E) -> Guard { let white_list = hash_set(ALL_METHODS); let black_list = &white_list & &hash_set(methods); Guard { endpoint, white_list: &white_list ^ &black_list, } } #[async_trait(?Send)] impl<'a, S, E> Endpoint<'a, S> for Guard where E: Endpoint<'a, S>, { #[inline] async fn call(&'a self, ctx: &'a mut Context) -> Result { if self.white_list.contains(ctx.method()) { self.endpoint.call(ctx).await } else { method_not_allowed(ctx.method()) } } } ================================================ FILE: roa/src/router/endpoints.rs ================================================ mod dispatcher; mod guard; use crate::http::{Method, StatusCode}; use crate::{throw, Result}; #[inline] fn method_not_allowed(method: &Method) -> Result { throw!( StatusCode::METHOD_NOT_ALLOWED, format!("Method {} not allowed", method) ) } pub use dispatcher::{connect, delete, get, head, options, patch, post, put, trace, Dispatcher}; pub use guard::{allow, deny, Guard}; ================================================ FILE: roa/src/router/err.rs ================================================ use std::fmt::{self, Display, Formatter}; use roa_core::http; /// Error occurring in building route table. #[derive(Debug)] pub enum RouterError { /// Dynamic paths miss variable. MissingVariable(String), /// Variables, methods or paths conflict. Conflict(Conflict), } /// Router conflict. #[derive(Debug, Eq, PartialEq)] pub enum Conflict { Path(String), Method(String, http::Method), Variable { paths: (String, String), var_name: String, }, } impl Display for Conflict { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> { match self { Conflict::Path(path) => f.write_str(&format!("conflict path: `{}`", path)), Conflict::Method(path, method) => f.write_str(&format!( "conflict method: `{}` on `{}` is already set", method, path )), Conflict::Variable { paths, var_name } => f.write_str(&format!( "conflict variable `{}`: between `{}` and `{}`", var_name, paths.0, paths.1 )), } } } impl Display for RouterError { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> { match self { RouterError::Conflict(conflict) => f.write_str(&format!("Conflict! {}", conflict)), RouterError::MissingVariable(path) => { f.write_str(&format!("missing variable on path {}", path)) } } } } impl From for RouterError { fn from(conflict: Conflict) -> Self { RouterError::Conflict(conflict) } } impl std::error::Error for Conflict {} impl std::error::Error for RouterError {} #[cfg(test)] mod tests { use roa_core::http; use super::{Conflict, RouterError}; #[test] fn conflict_to_string() { assert_eq!( "conflict path: `/`", Conflict::Path("/".to_string()).to_string() ); assert_eq!( "conflict method: `GET` on `/` is already set", Conflict::Method("/".to_string(), http::Method::GET).to_string() ); assert_eq!( "conflict variable `id`: between `/:id` and `/user/:id`", Conflict::Variable { paths: ("/:id".to_string(), "/user/:id".to_string()), var_name: "id".to_string() } .to_string() ); } #[test] fn err_to_string() { assert_eq!( "Conflict! conflict path: `/`", RouterError::Conflict(Conflict::Path("/".to_string())).to_string() ); assert_eq!( "missing variable on path /:", RouterError::MissingVariable("/:".to_string()).to_string() ); } } ================================================ FILE: roa/src/router/path.rs ================================================ use std::collections::HashSet; use std::convert::AsRef; use std::str::FromStr; use regex::{escape, Captures, Regex}; use super::{Conflict, RouterError}; /// Match pattern *{variable} const WILDCARD: &str = r"\*\{(?P\w*)\}"; /// Match pattern /:variable/ const VARIABLE: &str = r"/:(?P\w*)/"; /// {/path path/ /path/} => /path/ pub fn standardize_path(raw_path: &str) -> String { format!("/{}/", raw_path.trim_matches('/')) } /// Join multiple segments. pub fn join_path<'a>(paths: impl 'a + AsRef<[&'a str]>) -> String { paths .as_ref() .iter() .map(|path| path.trim_matches('/')) .filter(|path| !path.is_empty()) .collect::>() .join("/") } /// Build pattern. fn must_build(pattern: &str) -> Regex { Regex::new(pattern).unwrap_or_else(|err| { panic!( r#"{} regex pattern {} is invalid, this is a bug of roa::router::path. please report it to https://github.com/Hexilee/roa"#, err, pattern ) }) } /// Parsed path. #[derive(Clone)] pub enum Path { Static(String), Dynamic(RegexPath), } /// Dynamic path. #[derive(Clone)] pub struct RegexPath { pub raw: String, pub vars: HashSet, pub re: Regex, } impl FromStr for Path { type Err = RouterError; fn from_str(raw_path: &str) -> Result { let path = standardize_path(raw_path); Ok(match path_to_regexp(&path)? { None => Path::Static(path), Some((pattern, vars)) => Path::Dynamic(RegexPath { raw: path, vars, re: must_build(&format!(r"^{}$", pattern)), }), }) } } fn path_to_regexp(path: &str) -> Result)>, RouterError> { let mut pattern = escape(path); let mut vars = HashSet::new(); let wildcard_re = must_build(WILDCARD); let variable_re = must_build(VARIABLE); let wildcards: Vec = wildcard_re.captures_iter(path).collect(); let variable_template = path.replace('/', "//"); // to match continuous variables like /:year/:month/:day/ let variables: Vec = variable_re.captures_iter(&variable_template).collect(); if wildcards.is_empty() && variables.is_empty() { Ok(None) } else { // detect variable conflicts. let try_add_variable = |set: &mut HashSet, variable: String| { if set.insert(variable.clone()) { Ok(()) } else { Err(Conflict::Variable { paths: (path.to_string(), path.to_string()), var_name: variable, }) } }; // match wildcard patterns for cap in wildcards { let variable = &cap["var"]; if variable.is_empty() { return Err(RouterError::MissingVariable(path.to_string())); } let var = escape(variable); pattern = pattern.replace( &escape(&format!(r"*{{{}}}", variable)), &format!(r"(?P<{}>\S+)", &var), ); try_add_variable(&mut vars, var)?; } // match segment variable patterns for cap in variables { let variable = &cap["var"]; if variable.is_empty() { return Err(RouterError::MissingVariable(path.to_string())); } let var = escape(variable); pattern = pattern.replace( &escape(&format!(r":{}", variable)), &format!(r"(?P<{}>[^\s/]+)", &var), ); try_add_variable(&mut vars, var)?; } Ok(Some((pattern, vars))) } } #[cfg(test)] mod tests { use test_case::test_case; use super::{must_build, path_to_regexp, Path, VARIABLE, WILDCARD}; #[test_case("/:id/"; "pure dynamic")] #[test_case("/user/:id/"; "static prefix")] #[test_case("/user/:id/name"; "static prefix and suffix")] fn var_regex_match(path: &str) { let re = must_build(VARIABLE); let cap = re.captures(path); assert!(cap.is_some()); assert_eq!("id", &cap.unwrap()["var"]); } #[test_case("/-:id/"; "invalid prefix")] #[test_case("/:i-d/"; "invalid variable name")] #[test_case("/:id-/"; "invalid suffix")] fn var_regex_mismatch(path: &str) { let re = must_build(VARIABLE); let cap = re.captures(path); assert!(cap.is_none()); } #[test_case("*{id}"; "pure dynamic")] #[test_case("user-*{id}"; "static prefix")] #[test_case("user-*{id}-name"; "static prefix and suffix")] fn wildcard_regex_match(path: &str) { let re = must_build(WILDCARD); let cap = re.captures(path); assert!(cap.is_some()); assert_eq!("id", &cap.unwrap()["var"]); } #[test_case("*"; "no variable")] #[test_case("*{-id}"; "invalid variable name")] fn wildcard_regex_mismatch(path: &str) { let re = must_build(WILDCARD); let cap = re.captures(path); assert!(cap.is_none()); } #[test_case(r"/:id/" => r"/(?P[^\s/]+)/"; "single variable")] #[test_case(r"/:year/:month/:day/" => r"/(?P[^\s/]+)/(?P[^\s/]+)/(?P[^\s/]+)/"; "multiple variable")] #[test_case(r"*{id}" => r"(?P\S+)"; "single wildcard")] #[test_case(r"*{year}_*{month}_*{day}" => r"(?P\S+)_(?P\S+)_(?P\S+)"; "multiple wildcard")] fn path_to_regexp_dynamic_pattern(path: &str) -> String { path_to_regexp(path).unwrap().unwrap().0 } #[test_case(r"/id/")] #[test_case(r"/user/post/")] fn path_to_regexp_static(path: &str) { assert!(path_to_regexp(path).unwrap().is_none()) } #[test_case(r"/:/"; "missing variable name")] #[test_case(r"*{}"; "wildcard missing variable name")] #[test_case(r"/:id/:id/"; "conflict variable")] #[test_case(r"*{id}-*{id}"; "wildcard conflict variable")] #[test_case(r"/:id/*{id}"; "mix conflict variable")] fn path_to_regexp_err(path: &str) { assert!(path_to_regexp(path).is_err()) } fn path_match(pattern: &str, path: &str) { let pattern: Path = pattern.parse().unwrap(); match pattern { Path::Static(pattern) => panic!("`{}` should be dynamic", pattern), Path::Dynamic(re) => assert!(re.re.is_match(path)), } } fn path_not_match(pattern: &str, path: &str) { let pattern: Path = pattern.parse().unwrap(); match pattern { Path::Static(pattern) => panic!("`{}` should be dynamic", pattern), Path::Dynamic(re) => { println!("regex: {}", re.re.to_string()); assert!(!re.re.is_match(path)) } } } #[test_case(r"/user/1/")] #[test_case(r"/user/65535/")] fn single_variable_path_match(path: &str) { path_match(r"/user/:id", path) } #[test_case(r"/2000/01/01/")] #[test_case(r"/2020/02/20/")] fn multiple_variable_path_match(path: &str) { path_match(r"/:year/:month/:day", path) } #[test_case(r"/usr/include/boost/boost.h/")] #[test_case(r"/usr/include/uv/uv.h/")] fn segment_wildcard_path_match(path: &str) { path_match(r"/usr/include/*{dir}/*{file}.h", path) } #[test_case(r"/srv/static/app/index.html/")] #[test_case(r"/srv/static/../../index.html/")] fn full_wildcard_path_match(path: &str) { path_match(r"/srv/static/*{path}/", path) } #[test_case(r"/srv/app/index.html/")] #[test_case(r"/srv/../../index.html/")] fn variable_path_not_match(path: &str) { path_not_match(r"/srv/:path/", path) } #[should_panic] #[test] fn must_build_fails() { must_build(r"{"); } } ================================================ FILE: roa/src/router.rs ================================================ //! This module provides a context extension `RouterParam` and //! many endpoint wrappers like `Router`, `Dispatcher` and `Guard`. //! //! ### Example //! //! ```rust //! use roa::router::{Router, RouterParam, get, allow}; //! use roa::{App, Context, Status, MiddlewareExt, Next}; //! use roa::http::{StatusCode, Method}; //! use roa::tcp::Listener; //! use tokio::task::spawn; //! //! //! async fn gate(_ctx: &mut Context, next: Next<'_>) -> Result<(), Status> { //! next.await //! } //! //! async fn query(ctx: &mut Context) -> Result<(), Status> { //! Ok(()) //! } //! //! async fn create(ctx: &mut Context) -> Result<(), Status> { //! Ok(()) //! } //! //! async fn graphql(ctx: &mut Context) -> Result<(), Status> { //! Ok(()) //! } //! //! #[tokio::main] //! async fn main() -> Result<(), Box> { //! let router = Router::new() //! .gate(gate) //! .on("/restful", get(query).post(create)) //! .on("/graphql", allow([Method::GET, Method::POST], graphql)); //! let app = App::new() //! .end(router.routes("/api")?); //! let (addr, server) = app.run()?; //! spawn(server); //! let resp = reqwest::get(&format!("http://{}/api/restful", addr)).await?; //! assert_eq!(StatusCode::OK, resp.status()); //! //! let resp = reqwest::get(&format!("http://{}/restful", addr)).await?; //! assert_eq!(StatusCode::NOT_FOUND, resp.status()); //! Ok(()) //! } //! ``` //! mod endpoints; mod err; mod path; use std::convert::AsRef; use std::result::Result as StdResult; #[doc(inline)] pub use endpoints::*; use err::Conflict; #[doc(inline)] pub use err::RouterError; use path::{join_path, standardize_path, Path, RegexPath}; use percent_encoding::percent_decode_str; use radix_trie::Trie; use crate::http::StatusCode; use crate::{ async_trait, throw, Boxed, Context, Endpoint, EndpointExt, Middleware, MiddlewareExt, Result, Shared, Status, Variable, }; /// A private scope to store and load variables in Context::storage. struct RouterScope; /// A context extension. /// This extension must be used in `Router`, /// otherwise you cannot get expected router parameters. /// /// ### Example /// /// ```rust /// use roa::router::{Router, RouterParam}; /// use roa::{App, Context, Status}; /// use roa::http::StatusCode; /// use roa::tcp::Listener; /// use tokio::task::spawn; /// /// async fn test(ctx: &mut Context) -> Result<(), Status> { /// let id: u64 = ctx.must_param("id")?.parse()?; /// assert_eq!(0, id); /// Ok(()) /// } /// /// #[tokio::main] /// async fn main() -> Result<(), Box> { /// let router = Router::new().on("/:id", test); /// let app = App::new().end(router.routes("/user")?); /// let (addr, server) = app.run()?; /// spawn(server); /// let resp = reqwest::get(&format!("http://{}/user/0", addr)).await?; /// assert_eq!(StatusCode::OK, resp.status()); /// Ok(()) /// } /// /// /// ``` pub trait RouterParam { /// Must get a router parameter, throw 500 INTERNAL SERVER ERROR if it not exists. fn must_param<'a>(&self, name: &'a str) -> Result>; /// Try to get a router parameter, return `None` if it not exists. /// ### Example /// /// ```rust /// use roa::router::{Router, RouterParam}; /// use roa::{App, Context, Status}; /// use roa::http::StatusCode; /// use roa::tcp::Listener; /// use tokio::task::spawn; /// /// async fn test(ctx: &mut Context) -> Result<(), Status> { /// assert!(ctx.param("name").is_none()); /// Ok(()) /// } /// /// #[tokio::main] /// async fn main() -> Result<(), Box> { /// let router = Router::new().on("/:id", test); /// let app = App::new().end(router.routes("/user")?); /// let (addr, server) = app.run()?; /// spawn(server); /// let resp = reqwest::get(&format!("http://{}/user/0", addr)).await?; /// assert_eq!(StatusCode::OK, resp.status()); /// Ok(()) /// } /// /// /// ``` fn param<'a>(&self, name: &'a str) -> Option>; } /// A builder of `RouteTable`. pub struct Router { middleware: Shared, endpoints: Vec<(String, Boxed)>, } /// An endpoint to route request by uri path. pub struct RouteTable { static_route: Trie>, dynamic_route: Vec<(RegexPath, Boxed)>, } impl Router where S: 'static, { /// Construct a new router. pub fn new() -> Self { Self { middleware: ().shared(), endpoints: Vec::new(), } } /// Register a new endpoint. pub fn on(mut self, path: &'static str, endpoint: impl for<'a> Endpoint<'a, S>) -> Self { self.endpoints .push((path.to_string(), self.register(endpoint))); self } /// Chain an endpoint to Router::middleware. fn register(&self, endpoint: impl for<'a> Endpoint<'a, S>) -> Boxed { self.middleware.clone().end(endpoint).boxed() } /// Include another router with prefix. pub fn include(mut self, prefix: &'static str, router: Router) -> Self { for (path, endpoint) in router.endpoints { self.endpoints .push((join_path([prefix, path.as_str()]), self.register(endpoint))) } self } /// Chain a middleware to Router::middleware. pub fn gate(self, next: impl for<'a> Middleware<'a, S>) -> Router { let Self { middleware, endpoints, } = self; Self { middleware: middleware.chain(next).shared(), endpoints, } } /// Build RouteTable with path prefix. pub fn routes(self, prefix: &'static str) -> StdResult, RouterError> { let mut route_table = RouteTable::default(); for (raw_path, endpoint) in self.endpoints { route_table.insert(join_path([prefix, raw_path.as_str()]), endpoint)?; } Ok(route_table) } } impl RouteTable where S: 'static, { fn new() -> Self { Self { static_route: Trie::new(), dynamic_route: Vec::new(), } } /// Insert endpoint to table. fn insert( &mut self, raw_path: impl AsRef, endpoint: Boxed, ) -> StdResult<(), RouterError> { match raw_path.as_ref().parse()? { Path::Static(path) => { if self.static_route.insert(path.clone(), endpoint).is_some() { return Err(Conflict::Path(path).into()); } } Path::Dynamic(regex_path) => self.dynamic_route.push((regex_path, endpoint)), } Ok(()) } } impl Default for Router where S: 'static, { fn default() -> Self { Self::new() } } impl Default for RouteTable where S: 'static, { fn default() -> Self { Self::new() } } #[async_trait(?Send)] impl<'a, S> Endpoint<'a, S> for RouteTable where S: 'static, { #[inline] async fn call(&'a self, ctx: &'a mut Context) -> Result { let uri = ctx.uri(); // standardize path let path = standardize_path(&percent_decode_str(uri.path()).decode_utf8().map_err( |err| { Status::new( StatusCode::BAD_REQUEST, format!("{}\npath `{}` is not a valid utf-8 string", err, uri.path()), true, ) }, )?); // search static routes if let Some(end) = self.static_route.get(&path) { return end.call(ctx).await; } // search dynamic routes for (regexp_path, end) in self.dynamic_route.iter() { if let Some(cap) = regexp_path.re.captures(&path) { for var in regexp_path.vars.iter() { ctx.store_scoped(RouterScope, var.to_string(), cap[var.as_str()].to_string()); } return end.call(ctx).await; } } // 404 NOT FOUND throw!(StatusCode::NOT_FOUND) } } impl RouterParam for Context { #[inline] fn must_param<'a>(&self, name: &'a str) -> Result> { self.param(name).ok_or_else(|| { Status::new( StatusCode::INTERNAL_SERVER_ERROR, format!("router variable `{}` is required", name), false, ) }) } #[inline] fn param<'a>(&self, name: &'a str) -> Option> { self.load_scoped::(name) } } #[cfg(all(test, feature = "tcp"))] mod tests { use encoding::EncoderTrap; use percent_encoding::NON_ALPHANUMERIC; use tokio::task::spawn; use super::Router; use crate::http::StatusCode; use crate::tcp::Listener; use crate::{App, Context, Next, Status}; async fn gate(ctx: &mut Context, next: Next<'_>) -> Result<(), Status> { ctx.store("id", "0".to_string()); next.await } async fn test(ctx: &mut Context) -> Result<(), Status> { let id: u64 = ctx.load::("id").unwrap().parse()?; assert_eq!(0, id); Ok(()) } #[tokio::test] async fn gate_test() -> Result<(), Box> { let router = Router::new().gate(gate).on("/", test); let app = App::new().end(router.routes("/route")?); let (addr, server) = app.run()?; spawn(server); let resp = reqwest::get(&format!("http://{}/route", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); Ok(()) } #[tokio::test] async fn route() -> Result<(), Box> { let user_router = Router::new().on("/", test); let router = Router::new().gate(gate).include("/user", user_router); let app = App::new().end(router.routes("/route")?); let (addr, server) = app.run()?; spawn(server); let resp = reqwest::get(&format!("http://{}/route/user", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); Ok(()) } #[test] fn conflict_path() -> Result<(), Box> { let evil_router = Router::new().on("/endpoint", test); let router = Router::new() .on("/route/endpoint", test) .include("/route", evil_router); let ret = router.routes("/"); assert!(ret.is_err()); Ok(()) } #[tokio::test] async fn route_not_found() -> Result<(), Box> { let app = App::new().end(Router::default().routes("/")?); let (addr, server) = app.run()?; spawn(server); let resp = reqwest::get(&format!("http://{}", addr)).await?; assert_eq!(StatusCode::NOT_FOUND, resp.status()); Ok(()) } #[tokio::test] async fn non_utf8_uri() -> Result<(), Box> { let app = App::new().end(Router::default().routes("/")?); let (addr, server) = app.run()?; spawn(server); let gbk_path = encoding::label::encoding_from_whatwg_label("gbk") .unwrap() .encode("路由", EncoderTrap::Strict) .unwrap(); let encoded_path = percent_encoding::percent_encode(&gbk_path, NON_ALPHANUMERIC).to_string(); let uri = format!("http://{}/{}", addr, encoded_path); let resp = reqwest::get(&uri).await?; assert_eq!(StatusCode::BAD_REQUEST, resp.status()); assert!(resp .text() .await? .ends_with("path `/%C2%B7%D3%C9` is not a valid utf-8 string")); Ok(()) } } ================================================ FILE: roa/src/stream.rs ================================================ //! this module provides a stream adaptor `AsyncStream` use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use futures::io::{AsyncRead as Read, AsyncWrite as Write}; use futures::ready; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tracing::{instrument, trace}; /// A adaptor between futures::io::{AsyncRead, AsyncWrite} and tokio::io::{AsyncRead, AsyncWrite}. pub struct AsyncStream(pub IO); impl AsyncRead for AsyncStream where IO: Unpin + Read, { #[inline] fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { let read_size = ready!(Pin::new(&mut self.0).poll_read(cx, buf.initialize_unfilled()))?; buf.advance(read_size); Poll::Ready(Ok(())) } } impl AsyncWrite for AsyncStream where IO: Unpin + Write, { #[inline] fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { Pin::new(&mut self.0).poll_write(cx, buf) } #[inline] fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.0).poll_flush(cx) } #[inline] fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.0).poll_close(cx) } } impl Read for AsyncStream where IO: Unpin + AsyncRead, { #[inline] #[instrument(skip(self, cx, buf))] fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { let mut read_buf = ReadBuf::new(buf); ready!(Pin::new(&mut self.0).poll_read(cx, &mut read_buf))?; trace!("read {} bytes", read_buf.filled().len()); Poll::Ready(Ok(read_buf.filled().len())) } } impl Write for AsyncStream where IO: Unpin + AsyncWrite, { #[inline] #[instrument(skip(self, cx, buf))] fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { let size = ready!(Pin::new(&mut self.0).poll_write(cx, buf))?; trace!("wrote {} bytes", size); Poll::Ready(Ok(size)) } #[inline] fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.0).poll_flush(cx) } #[inline] fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.0).poll_shutdown(cx) } } ================================================ FILE: roa/src/tcp/incoming.rs ================================================ use std::convert::TryInto; use std::future::Future; use std::mem::transmute; use std::net::{SocketAddr, TcpListener as StdListener, ToSocketAddrs}; use std::pin::Pin; use std::task::{self, Poll}; use std::time::Duration; use std::{fmt, io, matches}; use roa_core::{Accept, AddrStream}; use tokio::net::{TcpListener, TcpStream}; use tokio::time::{sleep, Sleep}; use tracing::{debug, error, trace}; /// A stream of connections from binding to an address. /// As an implementation of roa_core::Accept. #[must_use = "streams do nothing unless polled"] pub struct TcpIncoming { addr: SocketAddr, listener: Box, sleep_on_errors: bool, tcp_nodelay: bool, timeout: Option>>, accept: Option>>, } type BoxedAccept<'a> = Box> + Send + Sync>; impl TcpIncoming { /// Creates a new `TcpIncoming` binding to provided socket address. pub fn bind(addr: impl ToSocketAddrs) -> io::Result { let listener = StdListener::bind(addr)?; TcpIncoming::from_std(listener) } /// Creates a new `TcpIncoming` from std TcpListener. pub fn from_std(listener: StdListener) -> io::Result { let addr = listener.local_addr()?; listener.set_nonblocking(true)?; Ok(TcpIncoming { listener: Box::new(listener.try_into()?), addr, sleep_on_errors: true, tcp_nodelay: false, timeout: None, accept: None, }) } /// Get the local address bound to this listener. pub fn local_addr(&self) -> SocketAddr { self.addr } /// Set the value of `TCP_NODELAY` option for accepted connections. pub fn set_nodelay(&mut self, enabled: bool) -> &mut Self { self.tcp_nodelay = enabled; self } /// Set whether to sleep on accept errors. /// /// A possible scenario is that the process has hit the max open files /// allowed, and so trying to accept a new connection will fail with /// `EMFILE`. In some cases, it's preferable to just wait for some time, if /// the application will likely close some files (or connections), and try /// to accept the connection again. If this option is `true`, the error /// will be logged at the `error` level, since it is still a big deal, /// and then the listener will sleep for 1 second. /// /// In other cases, hitting the max open files should be treat similarly /// to being out-of-memory, and simply error (and shutdown). Setting /// this option to `false` will allow that. /// /// Default is `true`. pub fn set_sleep_on_errors(&mut self, val: bool) { self.sleep_on_errors = val; } /// Poll TcpStream. fn poll_stream( &mut self, cx: &mut task::Context<'_>, ) -> Poll> { // Check if a previous timeout is active that was set by IO errors. if let Some(ref mut to) = self.timeout { futures::ready!(Pin::new(to).poll(cx)); } self.timeout = None; loop { if self.accept.is_none() { let accept: Pin> = Box::pin(self.listener.accept()); self.accept = Some(unsafe { transmute(accept) }); } if let Some(f) = &mut self.accept { match futures::ready!(f.as_mut().poll(cx)) { Ok((socket, addr)) => { if let Err(e) = socket.set_nodelay(self.tcp_nodelay) { trace!("error trying to set TCP nodelay: {}", e); } self.accept = None; return Poll::Ready(Ok((socket, addr))); } Err(e) => { // Connection errors can be ignored directly, continue by // accepting the next request. if is_connection_error(&e) { debug!("accepted connection already errored: {}", e); continue; } if self.sleep_on_errors { error!("accept error: {}", e); // Sleep 1s. let mut timeout = Box::pin(sleep(Duration::from_secs(1))); match timeout.as_mut().poll(cx) { Poll::Ready(()) => { // Wow, it's been a second already? Ok then... continue; } Poll::Pending => { self.timeout = Some(timeout); return Poll::Pending; } } } else { return Poll::Ready(Err(e)); } } } } } } } impl Accept for TcpIncoming { type Conn = AddrStream; type Error = io::Error; #[inline] fn poll_accept( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll>> { let (stream, addr) = futures::ready!(self.poll_stream(cx))?; Poll::Ready(Some(Ok(AddrStream::new(addr, stream)))) } } impl Drop for TcpIncoming { fn drop(&mut self) { self.accept = None; } } /// This function defines errors that are per-connection. Which basically /// means that if we get this error from `accept()` system call it means /// next connection might be ready to be accepted. /// /// All other errors will incur a timeout before next `accept()` is performed. /// The timeout is useful to handle resource exhaustion errors like ENFILE /// and EMFILE. Otherwise, could enter into tight loop. fn is_connection_error(e: &io::Error) -> bool { matches!( e.kind(), io::ErrorKind::ConnectionRefused | io::ErrorKind::ConnectionAborted | io::ErrorKind::ConnectionReset ) } impl fmt::Debug for TcpIncoming { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("TcpIncoming") .field("addr", &self.addr) .field("sleep_on_errors", &self.sleep_on_errors) .field("tcp_nodelay", &self.tcp_nodelay) .finish() } } ================================================ FILE: roa/src/tcp/listener.rs ================================================ use std::net::{SocketAddr, ToSocketAddrs}; use std::sync::Arc; use roa_core::{App, Endpoint, Executor, Server, State}; use super::TcpIncoming; /// An app extension. pub trait Listener { /// http server type Server; /// Listen on a socket addr, return a server and the real addr it binds. fn bind(self, addr: impl ToSocketAddrs) -> std::io::Result<(SocketAddr, Self::Server)>; /// Listen on a socket addr, return a server, and pass real addr to the callback. fn listen( self, addr: impl ToSocketAddrs, callback: impl Fn(SocketAddr), ) -> std::io::Result; /// Listen on an unused port of 127.0.0.1, return a server and the real addr it binds. /// ### Example /// ```rust /// use roa::{App, Context, Status}; /// use roa::tcp::Listener; /// use roa::http::StatusCode; /// use tokio::task::spawn; /// use std::time::Instant; /// /// async fn end(_ctx: &mut Context) -> Result<(), Status> { /// Ok(()) /// } /// /// #[tokio::main] /// async fn main() -> Result<(), Box> { /// let (addr, server) = App::new().end(end).run()?; /// spawn(server); /// let resp = reqwest::get(&format!("http://{}", addr)).await?; /// assert_eq!(StatusCode::OK, resp.status()); /// Ok(()) /// } /// ``` fn run(self) -> std::io::Result<(SocketAddr, Self::Server)>; } impl Listener for App> where S: State, E: for<'a> Endpoint<'a, S>, { type Server = Server; fn bind(self, addr: impl ToSocketAddrs) -> std::io::Result<(SocketAddr, Self::Server)> { let incoming = TcpIncoming::bind(addr)?; let local_addr = incoming.local_addr(); Ok((local_addr, self.accept(incoming))) } fn listen( self, addr: impl ToSocketAddrs, callback: impl Fn(SocketAddr), ) -> std::io::Result { let (addr, server) = self.bind(addr)?; callback(addr); Ok(server) } fn run(self) -> std::io::Result<(SocketAddr, Self::Server)> { self.bind("127.0.0.1:0") } } ================================================ FILE: roa/src/tcp.rs ================================================ //! This module provides an acceptor implementing `roa_core::Accept` and an app extension. //! //! ### TcpIncoming //! //! ``` //! use roa::{App, Context, Result}; //! use roa::tcp::TcpIncoming; //! use std::io; //! //! async fn end(_ctx: &mut Context) -> Result { //! Ok(()) //! } //! # #[tokio::main] //! # async fn main() -> io::Result<()> { //! let app = App::new().end(end); //! let incoming = TcpIncoming::bind("127.0.0.1:0")?; //! let server = app.accept(incoming); //! // server.await //! Ok(()) //! # } //! ``` //! //! ### Listener //! //! ``` //! use roa::{App, Context, Result}; //! use roa::tcp::Listener; //! use std::io; //! //! async fn end(_ctx: &mut Context) -> Result { //! Ok(()) //! } //! # #[tokio::main] //! # async fn main() -> io::Result<()> { //! let app = App::new().end(end); //! let (addr, server) = app.bind("127.0.0.1:0")?; //! // server.await //! Ok(()) //! # } //! ``` mod incoming; mod listener; #[doc(inline)] pub use incoming::TcpIncoming; #[doc(inline)] pub use listener::Listener; ================================================ FILE: roa/src/tls/incoming.rs ================================================ use std::io; use std::ops::{Deref, DerefMut}; use std::pin::Pin; use std::sync::Arc; use std::task::{self, Context, Poll}; use futures::Future; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_rustls::server::TlsStream; use tokio_rustls::TlsAcceptor; use super::ServerConfig; use crate::{Accept, AddrStream}; /// A stream of connections based on another stream. /// As an implementation of roa_core::Accept. pub struct TlsIncoming { incoming: I, acceptor: TlsAcceptor, } type AcceptFuture = dyn 'static + Sync + Send + Unpin + Future>>; /// A finite-state machine to do tls handshake. pub enum WrapTlsStream { /// Handshaking state. Handshaking(Box>), /// Streaming state. Streaming(Box>), } use WrapTlsStream::*; impl WrapTlsStream { #[inline] fn poll_handshake( handshake: &mut AcceptFuture, cx: &mut Context<'_>, ) -> Poll> { let stream = futures::ready!(Pin::new(handshake).poll(cx))?; Poll::Ready(Ok(Streaming(Box::new(stream)))) } } impl AsyncRead for WrapTlsStream where IO: 'static + Unpin + AsyncRead + AsyncWrite, { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { match &mut *self { Streaming(stream) => Pin::new(stream).poll_read(cx, buf), Handshaking(handshake) => { *self = futures::ready!(Self::poll_handshake(handshake, cx))?; self.poll_read(cx, buf) } } } } impl AsyncWrite for WrapTlsStream where IO: 'static + Unpin + AsyncRead + AsyncWrite, { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { match &mut *self { Streaming(stream) => Pin::new(stream).poll_write(cx, buf), Handshaking(handshake) => { *self = futures::ready!(Self::poll_handshake(handshake, cx))?; self.poll_write(cx, buf) } } } fn poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll> { match &mut *self { Streaming(stream) => Pin::new(stream).poll_write_vectored(cx, bufs), Handshaking(handshake) => { *self = futures::ready!(Self::poll_handshake(handshake, cx))?; self.poll_write_vectored(cx, bufs) } } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { Streaming(stream) => Pin::new(stream).poll_flush(cx), Handshaking(handshake) => { *self = futures::ready!(Self::poll_handshake(handshake, cx))?; self.poll_flush(cx) } } } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { Streaming(stream) => Pin::new(stream).poll_shutdown(cx), Handshaking(handshake) => { *self = futures::ready!(Self::poll_handshake(handshake, cx))?; self.poll_shutdown(cx) } } } } impl TlsIncoming { /// Construct from inner incoming. pub fn new(incoming: I, config: ServerConfig) -> Self { Self { incoming, acceptor: Arc::new(config).into(), } } } impl Deref for TlsIncoming { type Target = I; fn deref(&self) -> &Self::Target { &self.incoming } } impl DerefMut for TlsIncoming { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.incoming } } impl Accept for TlsIncoming where IO: 'static + Send + Sync + Unpin + AsyncRead + AsyncWrite, I: Unpin + Accept>, { type Conn = AddrStream>; type Error = I::Error; #[inline] fn poll_accept( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll>> { Poll::Ready( match futures::ready!(Pin::new(&mut self.incoming).poll_accept(cx)) { Some(Ok(AddrStream { stream, remote_addr, })) => { let accept_future = self.acceptor.accept(stream); Some(Ok(AddrStream::new( remote_addr, Handshaking(Box::new(accept_future)), ))) } Some(Err(err)) => Some(Err(err)), None => None, }, ) } } ================================================ FILE: roa/src/tls/listener.rs ================================================ use std::io; use std::net::{SocketAddr, ToSocketAddrs}; use std::sync::Arc; use super::{ServerConfig, TlsIncoming}; use crate::tcp::TcpIncoming; use crate::{App, Endpoint, Executor, Server, State}; impl TlsIncoming { /// Bind a socket addr. #[cfg_attr(feature = "docs", doc(cfg(feature = "tls")))] pub fn bind(addr: impl ToSocketAddrs, config: ServerConfig) -> io::Result { Ok(Self::new(TcpIncoming::bind(addr)?, config)) } } /// An app extension. #[cfg_attr(feature = "docs", doc(cfg(feature = "tls")))] pub trait TlsListener { /// http server type Server; /// Listen on a socket addr, return a server and the real addr it binds. fn bind_tls( self, addr: impl ToSocketAddrs, config: ServerConfig, ) -> std::io::Result<(SocketAddr, Self::Server)>; /// Listen on a socket addr, return a server, and pass real addr to the callback. fn listen_tls( self, addr: impl ToSocketAddrs, config: ServerConfig, callback: impl Fn(SocketAddr), ) -> std::io::Result; /// Listen on an unused port of 127.0.0.1, return a server and the real addr it binds. /// ### Example /// ```rust /// use roa::{App, Context, Status}; /// use roa::tls::{TlsIncoming, ServerConfig, TlsListener, Certificate, PrivateKey}; /// use roa::tls::pemfile::{certs, rsa_private_keys}; /// use roa_core::http::StatusCode; /// use tokio::task::spawn; /// use std::time::Instant; /// use std::fs::File; /// use std::io::BufReader; /// /// async fn end(_ctx: &mut Context) -> Result<(), Status> { /// Ok(()) /// } /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// let mut cert_file = BufReader::new(File::open("../assets/cert.pem")?); /// let mut key_file = BufReader::new(File::open("../assets/key.pem")?); /// let cert_chain = certs(&mut cert_file)?.into_iter().map(Certificate).collect(); /// /// let config = ServerConfig::builder() /// .with_safe_defaults() /// .with_no_client_auth() /// .with_single_cert(cert_chain, PrivateKey(rsa_private_keys(&mut key_file)?.remove(0)))?; /// /// let server = App::new().end(end).listen_tls("127.0.0.1:8000", config, |addr| { /// println!("Server is listening on https://localhost:{}", addr.port()); /// })?; /// // server.await /// Ok(()) /// # } /// ``` fn run_tls(self, config: ServerConfig) -> std::io::Result<(SocketAddr, Self::Server)>; } impl TlsListener for App> where S: State, E: for<'a> Endpoint<'a, S>, { type Server = Server, Self, Executor>; fn bind_tls( self, addr: impl ToSocketAddrs, config: ServerConfig, ) -> std::io::Result<(SocketAddr, Self::Server)> { let incoming = TlsIncoming::bind(addr, config)?; let local_addr = incoming.local_addr(); Ok((local_addr, self.accept(incoming))) } fn listen_tls( self, addr: impl ToSocketAddrs, config: ServerConfig, callback: impl Fn(SocketAddr), ) -> std::io::Result { let (addr, server) = self.bind_tls(addr, config)?; callback(addr); Ok(server) } fn run_tls(self, config: ServerConfig) -> std::io::Result<(SocketAddr, Self::Server)> { self.bind_tls("127.0.0.1:0", config) } } #[cfg(test)] mod tests { use std::fs::File; use std::io::{self, BufReader}; use futures::{AsyncReadExt, TryStreamExt}; use hyper::client::{Client, HttpConnector}; use hyper::Body; use hyper_tls::{native_tls, HttpsConnector}; use tokio::task::spawn; use tokio_native_tls::TlsConnector; use crate::http::StatusCode; use crate::tls::pemfile::{certs, rsa_private_keys}; use crate::tls::{Certificate, PrivateKey, ServerConfig, TlsListener}; use crate::{App, Context, Status}; async fn end(ctx: &mut Context) -> Result<(), Status> { ctx.resp.write("Hello, World!"); Ok(()) } #[tokio::test] async fn run_tls() -> Result<(), Box> { let mut cert_file = BufReader::new(File::open("../assets/cert.pem")?); let mut key_file = BufReader::new(File::open("../assets/key.pem")?); let cert_chain = certs(&mut cert_file)? .into_iter() .map(Certificate) .collect(); let config = ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_single_cert( cert_chain, PrivateKey(rsa_private_keys(&mut key_file)?.remove(0)), )?; let app = App::new().end(end); let (addr, server) = app.run_tls(config)?; spawn(server); let native_tls_connector = native_tls::TlsConnector::builder() .danger_accept_invalid_hostnames(true) .danger_accept_invalid_certs(true) .build()?; let tls_connector = TlsConnector::from(native_tls_connector); let mut http_connector = HttpConnector::new(); http_connector.enforce_http(false); let https_connector = HttpsConnector::from((http_connector, tls_connector)); let client = Client::builder().build::<_, Body>(https_connector); let resp = client .get(format!("https://localhost:{}", addr.port()).parse()?) .await?; assert_eq!(StatusCode::OK, resp.status()); let mut text = String::new(); resp.into_body() .map_err(|err| io::Error::new(io::ErrorKind::Other, err)) .into_async_read() .read_to_string(&mut text) .await?; assert_eq!("Hello, World!", text); Ok(()) } } ================================================ FILE: roa/src/tls.rs ================================================ //! This module provides an acceptor implementing `roa_core::Accept` and an app extension. //! //! ### TlsIncoming //! //! ```rust //! use roa::{App, Context, Status}; //! use roa::tls::{TlsIncoming, ServerConfig, Certificate, PrivateKey}; //! use roa::tls::pemfile::{certs, rsa_private_keys}; //! use std::fs::File; //! use std::io::BufReader; //! //! async fn end(_ctx: &mut Context) -> Result<(), Status> { //! Ok(()) //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let mut cert_file = BufReader::new(File::open("../assets/cert.pem")?); //! let mut key_file = BufReader::new(File::open("../assets/key.pem")?); //! let cert_chain = certs(&mut cert_file)?.into_iter().map(Certificate).collect(); //! //! let config = ServerConfig::builder() //! .with_safe_defaults() //! .with_no_client_auth() //! .with_single_cert(cert_chain, PrivateKey(rsa_private_keys(&mut key_file)?.remove(0)))?; //! //! let incoming = TlsIncoming::bind("127.0.0.1:0", config)?; //! let server = App::new().end(end).accept(incoming); //! // server.await //! Ok(()) //! # } //! ``` //! //! ### TlsListener //! //! ```rust //! use roa::{App, Context, Status}; //! use roa::tls::{ServerConfig, TlsListener, Certificate, PrivateKey}; //! use roa::tls::pemfile::{certs, rsa_private_keys}; //! use std::fs::File; //! use std::io::BufReader; //! //! async fn end(_ctx: &mut Context) -> Result<(), Status> { //! Ok(()) //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let mut cert_file = BufReader::new(File::open("../assets/cert.pem")?); //! let mut key_file = BufReader::new(File::open("../assets/key.pem")?); //! let cert_chain = certs(&mut cert_file)?.into_iter().map(Certificate).collect(); //! //! let config = ServerConfig::builder() //! .with_safe_defaults() //! .with_no_client_auth() //! .with_single_cert(cert_chain, PrivateKey(rsa_private_keys(&mut key_file)?.remove(0)))?; //! let (addr, server) = App::new().end(end).bind_tls("127.0.0.1:0", config)?; //! // server.await //! Ok(()) //! # } //! ``` #[doc(no_inline)] pub use rustls::*; #[doc(no_inline)] pub use rustls_pemfile as pemfile; mod incoming; #[cfg(feature = "tcp")] mod listener; #[doc(inline)] pub use incoming::TlsIncoming; #[doc(inline)] #[cfg(feature = "tcp")] pub use listener::TlsListener; ================================================ FILE: roa/src/websocket.rs ================================================ //! This module provides a websocket endpoint. //! //! ### Example //! ``` //! use futures::StreamExt; //! use roa::router::{Router, RouterError}; //! use roa::websocket::Websocket; //! use roa::{App, Context}; //! use roa::http::Method; //! //! # fn main() -> Result<(), RouterError> { //! let router = Router::new().on("/chat", Websocket::new(|_ctx, stream| async move { //! let (write, read) = stream.split(); //! // echo //! if let Err(err) = read.forward(write).await { //! println!("forward err: {}", err); //! } //! })); //! let app = App::new().end(router.routes("/")?); //! Ok(()) //! # } //! ``` use std::future::Future; use std::marker::PhantomData; use std::sync::Arc; use headers::{ Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, SecWebsocketVersion, Upgrade, }; use hyper::upgrade::{self, Upgraded}; pub use tokio_tungstenite::tungstenite; pub use tokio_tungstenite::tungstenite::protocol::{Message, WebSocketConfig}; use tokio_tungstenite::WebSocketStream; use crate::http::header::UPGRADE; use crate::http::StatusCode; use crate::{async_trait, throw, Context, Endpoint, State, Status}; /// An alias for WebSocketStream. pub type SocketStream = WebSocketStream; /// The Websocket middleware. /// /// ### Example /// ``` /// use futures::StreamExt; /// use roa::router::{Router, RouterError}; /// use roa::websocket::Websocket; /// use roa::{App, Context}; /// use roa::http::Method; /// /// # fn main() -> Result<(), RouterError> { /// let router = Router::new().on("/chat", Websocket::new(|_ctx, stream| async move { /// let (write, read) = stream.split(); /// // echo /// if let Err(err) = read.forward(write).await { /// println!("forward err: {}", err); /// } /// })); /// let app = App::new().end(router.routes("/")?); /// Ok(()) /// # } /// ``` /// /// ### Parameter /// /// - Context /// /// The context is the same with roa context, /// however, neither read body from request or write anything to response is unavailing. /// /// - SocketStream /// /// The websocket stream, implementing `Stream` and `Sink`. /// /// ### Return /// /// Must be `()`, as roa cannot deal with errors occurring in websocket. pub struct Websocket where F: Fn(Context, SocketStream) -> Fut, { task: Arc, config: Option, _s: PhantomData, _fut: PhantomData, } unsafe impl Send for Websocket where F: Sync + Send + Fn(Context, SocketStream) -> Fut { } unsafe impl Sync for Websocket where F: Sync + Send + Fn(Context, SocketStream) -> Fut { } impl Websocket where F: Fn(Context, SocketStream) -> Fut, { fn config(config: Option, task: F) -> Self { Self { task: Arc::new(task), config, _s: PhantomData::default(), _fut: PhantomData::default(), } } /// Construct a websocket middleware by task closure. pub fn new(task: F) -> Self { Self::config(None, task) } /// Construct a websocket middleware with config. /// ### Example /// ``` /// use futures::StreamExt; /// use roa::router::{Router, RouterError}; /// use roa::websocket::{Websocket, WebSocketConfig}; /// use roa::{App, Context}; /// use roa::http::Method; /// /// # fn main() -> Result<(), RouterError> { /// let router = Router::new().on("/chat", Websocket::with_config( /// WebSocketConfig::default(), /// |_ctx, stream| async move { /// let (write, read) = stream.split(); /// // echo /// if let Err(err) = read.forward(write).await { /// println!("forward err: {}", err); /// } /// }) /// ); /// let app = App::new().end(router.routes("/")?); /// # Ok(()) /// # } /// ``` pub fn with_config(config: WebSocketConfig, task: F) -> Self { Self::config(Some(config), task) } } #[async_trait(?Send)] impl<'a, F, S, Fut> Endpoint<'a, S> for Websocket where S: State, F: 'static + Sync + Send + Fn(Context, SocketStream) -> Fut, Fut: 'static + Send + Future, { #[inline] async fn call(&'a self, ctx: &'a mut Context) -> Result<(), Status> { let header_map = &ctx.req.headers; let key = header_map .typed_get::() .filter(|upgrade| upgrade == &Upgrade::websocket()) .and(header_map.typed_get::()) .filter(|connection| connection.contains(UPGRADE)) .and(header_map.typed_get::()) .filter(|version| version == &SecWebsocketVersion::V13) .and(header_map.typed_get::()); match key { None => throw!(StatusCode::BAD_REQUEST, "invalid websocket upgrade request"), Some(key) => { let raw_req = ctx.req.take_raw(); let context = ctx.clone(); let task = self.task.clone(); let config = self.config; // Setup a future that will eventually receive the upgraded // connection and talk a new protocol, and spawn the future // into the runtime. // // Note: This can't possibly be fulfilled until the 101 response // is returned below, so it's better to spawn this future instead // waiting for it to complete to then return a response. ctx.exec.spawn(async move { match upgrade::on(raw_req).await { Err(err) => tracing::error!("websocket upgrade error: {}", err), Ok(upgraded) => { let websocket = WebSocketStream::from_raw_socket( upgraded, tungstenite::protocol::Role::Server, config, ) .await; task(context, websocket).await } } }); ctx.resp.status = StatusCode::SWITCHING_PROTOCOLS; ctx.resp.headers.typed_insert(Connection::upgrade()); ctx.resp.headers.typed_insert(Upgrade::websocket()); ctx.resp.headers.typed_insert(SecWebsocketAccept::from(key)); Ok(()) } } } } ================================================ FILE: roa/templates/user.html ================================================ User Homepage

{{name}}

{{id}}

================================================ FILE: roa-async-std/Cargo.toml ================================================ [package] authors = ["Hexilee "] categories = [ "network-programming", "asynchronous", "web-programming::http-server", ] description = "tokio-based runtime and acceptor" documentation = "https://docs.rs/roa-tokio" edition = "2018" homepage = "https://github.com/Hexilee/roa/wiki" keywords = ["http", "web", "framework", "async"] license = "MIT" name = "roa-async-std" readme = "./README.md" repository = "https://github.com/Hexilee/roa" version = "0.6.0" [package.metadata.docs.rs] features = ["docs"] rustdoc-args = ["--cfg", "feature=\"docs\""] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] futures = "0.3" tracing = "0.1" roa = {path = "../roa", version = "0.6.0", default-features = false} async-std = {version = "1.10", features = ["unstable"]} futures-timer = "3.0" [dev-dependencies] reqwest = "0.11" roa = {path = "../roa", version = "0.6.0"} tracing-subscriber = { version = "0.3", features = ["env-filter"]} tokio = { version = "1.15", features = ["full"] } async-std = {version = "1.10", features = ["attributes", "unstable"]} [features] docs = ["roa/docs"] ================================================ FILE: roa-async-std/README.md ================================================ [![Stable Test](https://github.com/Hexilee/roa/workflows/Stable%20Test/badge.svg)](https://github.com/Hexilee/roa/actions) [![codecov](https://codecov.io/gh/Hexilee/roa/branch/master/graph/badge.svg)](https://codecov.io/gh/Hexilee/roa) [![Rust Docs](https://docs.rs/roa-async-std/badge.svg)](https://docs.rs/roa-async-std) [![Crate version](https://img.shields.io/crates/v/roa-async-std.svg)](https://crates.io/crates/roa-async-std) [![Download](https://img.shields.io/crates/d/roa-async-std.svg)](https://crates.io/crates/roa-async-std) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/Hexilee/roa/blob/master/LICENSE) This crate provides async-std-based runtime and acceptor for roa. ```rust,no_run use roa::http::StatusCode; use roa::{App, Context}; use roa_async_std::{Listener, Exec}; use std::error::Error; async fn end(_ctx: &mut Context) -> roa::Result { Ok(()) } #[async_std::main] async fn main() -> Result<(), Box> { let (addr, server) = App::with_exec((), Exec).end(end).run()?; println!("server is listening on {}", addr); server.await?; Ok(()) } ``` ================================================ FILE: roa-async-std/src/lib.rs ================================================ #![cfg_attr(feature = "docs", doc = include_str!("../README.md"))] #![cfg_attr(feature = "docs", warn(missing_docs))] mod listener; mod net; mod runtime; #[doc(inline)] pub use listener::Listener; #[doc(inline)] pub use net::TcpIncoming; #[doc(inline)] pub use runtime::Exec; ================================================ FILE: roa-async-std/src/listener.rs ================================================ use std::net::{SocketAddr, ToSocketAddrs}; use std::sync::Arc; use roa::{App, Endpoint, Executor, Server, State}; use super::TcpIncoming; /// An app extension. pub trait Listener { /// http server type Server; /// Listen on a socket addr, return a server and the real addr it binds. fn bind(self, addr: impl ToSocketAddrs) -> std::io::Result<(SocketAddr, Self::Server)>; /// Listen on a socket addr, return a server, and pass real addr to the callback. fn listen( self, addr: impl ToSocketAddrs, callback: impl Fn(SocketAddr), ) -> std::io::Result; /// Listen on an unused port of 127.0.0.1, return a server and the real addr it binds. /// ### Example /// ```rust,no_run /// use roa::{App, Context, Status}; /// use roa_async_std::{Exec, Listener}; /// use roa::http::StatusCode; /// use async_std::task::spawn; /// use std::time::Instant; /// /// async fn end(_ctx: &mut Context) -> Result<(), Status> { /// Ok(()) /// } /// /// #[async_std::main] /// async fn main() -> Result<(), Box> { /// let (_, server) = App::with_exec((), Exec).end(end).run()?; /// server.await?; /// Ok(()) /// } /// ``` fn run(self) -> std::io::Result<(SocketAddr, Self::Server)>; } impl Listener for App> where S: State, E: for<'a> Endpoint<'a, S>, { type Server = Server; fn bind(self, addr: impl ToSocketAddrs) -> std::io::Result<(SocketAddr, Self::Server)> { let incoming = TcpIncoming::bind(addr)?; let local_addr = incoming.local_addr(); Ok((local_addr, self.accept(incoming))) } fn listen( self, addr: impl ToSocketAddrs, callback: impl Fn(SocketAddr), ) -> std::io::Result { let (addr, server) = self.bind(addr)?; callback(addr); Ok(server) } fn run(self) -> std::io::Result<(SocketAddr, Self::Server)> { self.bind("127.0.0.1:0") } } #[cfg(test)] mod tests { use std::error::Error; use roa::http::StatusCode; use roa::App; use super::Listener; use crate::Exec; #[tokio::test] async fn incoming() -> Result<(), Box> { let (addr, server) = App::with_exec((), Exec).end(()).run()?; tokio::task::spawn(server); let resp = reqwest::get(&format!("http://{}", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); Ok(()) } } ================================================ FILE: roa-async-std/src/net.rs ================================================ use std::future::Future; use std::mem::transmute; use std::net::{TcpListener as StdListener, ToSocketAddrs}; use std::pin::Pin; use std::task::{self, Poll}; use std::time::Duration; use std::{fmt, io, matches}; use async_std::net::{SocketAddr, TcpListener, TcpStream}; use futures_timer::Delay; use roa::stream::AsyncStream; use roa::{Accept, AddrStream}; use tracing::{debug, error, trace}; /// A stream of connections from binding to an address. /// As an implementation of roa_core::Accept. #[must_use = "streams do nothing unless polled"] pub struct TcpIncoming { addr: SocketAddr, listener: Box, sleep_on_errors: bool, tcp_nodelay: bool, timeout: Option>>, accept: Option>>, } type BoxedAccept<'a> = Box> + Send + Sync>; impl TcpIncoming { /// Creates a new `TcpIncoming` binding to provided socket address. pub fn bind(addr: impl ToSocketAddrs) -> io::Result { let listener = StdListener::bind(addr)?; TcpIncoming::from_std(listener) } /// Creates a new `TcpIncoming` from std TcpListener. pub fn from_std(listener: StdListener) -> io::Result { let addr = listener.local_addr()?; Ok(TcpIncoming { listener: Box::new(listener.into()), addr, sleep_on_errors: true, tcp_nodelay: false, timeout: None, accept: None, }) } /// Get the local address bound to this listener. pub fn local_addr(&self) -> SocketAddr { self.addr } /// Set the value of `TCP_NODELAY` option for accepted connections. pub fn set_nodelay(&mut self, enabled: bool) -> &mut Self { self.tcp_nodelay = enabled; self } /// Set whether to sleep on accept errors. /// /// A possible scenario is that the process has hit the max open files /// allowed, and so trying to accept a new connection will fail with /// `EMFILE`. In some cases, it's preferable to just wait for some time, if /// the application will likely close some files (or connections), and try /// to accept the connection again. If this option is `true`, the error /// will be logged at the `error` level, since it is still a big deal, /// and then the listener will sleep for 1 second. /// /// In other cases, hitting the max open files should be treat similarly /// to being out-of-memory, and simply error (and shutdown). Setting /// this option to `false` will allow that. /// /// Default is `true`. pub fn set_sleep_on_errors(&mut self, val: bool) { self.sleep_on_errors = val; } /// Poll TcpStream. fn poll_stream( &mut self, cx: &mut task::Context<'_>, ) -> Poll> { // Check if a previous timeout is active that was set by IO errors. if let Some(ref mut to) = self.timeout { futures::ready!(Pin::new(to).poll(cx)); } self.timeout = None; loop { if self.accept.is_none() { let accept: Pin> = Box::pin(self.listener.accept()); self.accept = Some(unsafe { transmute(accept) }); } if let Some(f) = &mut self.accept { match futures::ready!(f.as_mut().poll(cx)) { Ok((socket, addr)) => { if let Err(e) = socket.set_nodelay(self.tcp_nodelay) { trace!("error trying to set TCP nodelay: {}", e); } self.accept = None; return Poll::Ready(Ok((socket, addr))); } Err(e) => { // Connection errors can be ignored directly, continue by // accepting the next request. if is_connection_error(&e) { debug!("accepted connection already errored: {}", e); continue; } if self.sleep_on_errors { error!("accept error: {}", e); // Sleep 1s. let mut timeout = Box::pin(Delay::new(Duration::from_secs(1))); match timeout.as_mut().poll(cx) { Poll::Ready(()) => { // Wow, it's been a second already? Ok then... continue; } Poll::Pending => { self.timeout = Some(timeout); return Poll::Pending; } } } else { return Poll::Ready(Err(e)); } } } } } } } impl Accept for TcpIncoming { type Conn = AddrStream>; type Error = io::Error; #[inline] fn poll_accept( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll>> { let (stream, addr) = futures::ready!(self.poll_stream(cx))?; Poll::Ready(Some(Ok(AddrStream::new(addr, AsyncStream(stream))))) } } impl Drop for TcpIncoming { fn drop(&mut self) { self.accept = None; } } /// This function defines errors that are per-connection. Which basically /// means that if we get this error from `accept()` system call it means /// next connection might be ready to be accepted. /// /// All other errors will incur a timeout before next `accept()` is performed. /// The timeout is useful to handle resource exhaustion errors like ENFILE /// and EMFILE. Otherwise, could enter into tight loop. fn is_connection_error(e: &io::Error) -> bool { matches!( e.kind(), io::ErrorKind::ConnectionRefused | io::ErrorKind::ConnectionAborted | io::ErrorKind::ConnectionReset ) } impl fmt::Debug for TcpIncoming { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("TcpIncoming") .field("addr", &self.addr) .field("sleep_on_errors", &self.sleep_on_errors) .field("tcp_nodelay", &self.tcp_nodelay) .finish() } } #[cfg(test)] mod tests { use std::error::Error; use roa::http::StatusCode; use roa::App; use tracing_subscriber::{fmt, EnvFilter}; use super::TcpIncoming; use crate::Exec; #[tokio::test] async fn incoming() -> Result<(), Box> { fmt().with_env_filter(EnvFilter::from_default_env()).init(); let app = App::with_exec((), Exec).end(()); let incoming = TcpIncoming::bind("127.0.0.1:0")?; let addr = incoming.local_addr(); tokio::task::spawn(app.accept(incoming)); let resp = reqwest::get(&format!("http://{}", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); Ok(()) } } ================================================ FILE: roa-async-std/src/runtime.rs ================================================ use std::future::Future; use std::pin::Pin; use roa::Spawn; /// Future Object pub type FutureObj = Pin>>; /// Blocking task Object pub type BlockingObj = Box; /// Tokio-based executor. /// /// ``` /// use roa::App; /// use roa_async_std::Exec; /// /// let app = App::with_exec((), Exec); /// ``` pub struct Exec; impl Spawn for Exec { #[inline] fn spawn(&self, fut: FutureObj) { async_std::task::spawn(fut); } #[inline] fn spawn_blocking(&self, task: BlockingObj) { async_std::task::spawn_blocking(task); } } #[cfg(test)] mod tests { use std::error::Error; use roa::http::StatusCode; use roa::tcp::Listener; use roa::App; use super::Exec; #[tokio::test] async fn exec() -> Result<(), Box> { let app = App::with_exec((), Exec).end(()); let (addr, server) = app.bind("127.0.0.1:0")?; tokio::spawn(server); let resp = reqwest::get(&format!("http://{}", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); Ok(()) } } ================================================ FILE: roa-core/Cargo.toml ================================================ [package] name = "roa-core" version = "0.6.1" authors = ["Hexilee "] edition = "2018" license = "MIT" readme = "./README.md" repository = "https://github.com/Hexilee/roa" documentation = "https://docs.rs/roa-core" homepage = "https://github.com/Hexilee/roa/wiki" description = "core components of roa web framework" keywords = ["http", "web", "framework", "async"] categories = ["network-programming", "asynchronous", "web-programming::http-server"] [package.metadata.docs.rs] features = ["docs"] rustdoc-args = ["--cfg", "feature=\"docs\""] [badges] codecov = { repository = "Hexilee/roa" } # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] futures = "0.3" bytes = "1.1" http = "0.2" hyper = { version = "0.14", default-features = false, features = ["stream", "server", "http1", "http2"] } tracing = "0.1" tokio = "1.15" tokio-util = { version = "0.6.9", features = ["io"] } async-trait = "0.1.51" crossbeam-queue = "0.3" [dev-dependencies] tokio = { version = "1.15", features = ["fs", "macros", "rt"] } [features] runtime = ["tokio/rt"] docs = ["runtime"] ================================================ FILE: roa-core/README.md ================================================ [![Stable Test](https://github.com/Hexilee/roa/workflows/Stable%20Test/badge.svg)](https://github.com/Hexilee/roa/actions) [![codecov](https://codecov.io/gh/Hexilee/roa/branch/master/graph/badge.svg)](https://codecov.io/gh/Hexilee/roa) [![Rust Docs](https://docs.rs/roa-core/badge.svg)](https://docs.rs/roa-core) [![Crate version](https://img.shields.io/crates/v/roa-core.svg)](https://crates.io/crates/roa-core) [![Download](https://img.shields.io/crates/d/roa-core.svg)](https://crates.io/crates/roa-core) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/Hexilee/roa/blob/master/LICENSE) ### Introduction Core components of Roa framework. If you are new to roa, please go to the documentation of [roa framework](https://docs.rs/roa). ### Application A Roa application is a structure composing and executing middlewares and an endpoint in a stack-like manner. The obligatory hello world application: ```rust use roa_core::App; let app = App::new().end("Hello, World"); ``` #### Endpoint An endpoint is a request handler. There are some build-in endpoints in roa_core. - Functional endpoint A normal functional endpoint is an async function with signature: `async fn(&mut Context) -> Result`. ```rust use roa_core::{App, Context, Result}; async fn endpoint(ctx: &mut Context) -> Result { Ok(()) } let app = App::new().end(endpoint); ``` - Ok endpoint `()` is an endpoint always return `Ok(())` ```rust let app = roa_core::App::new().end(()); ``` - Status endpoint `Status` is an endpoint always return `Err(Status)` ```rust use roa_core::{App, status}; use roa_core::http::StatusCode; let app = App::new().end(status!(StatusCode::BAD_REQUEST)); ``` - String endpoint Write string to body. ```rust use roa_core::App; let app = App::new().end("Hello, world"); // static slice let app = App::new().end("Hello, world".to_owned()); // string ``` - Redirect endpoint Redirect to an uri. ```rust use roa_core::App; use roa_core::http::Uri; let app = App::new().end("/target".parse::().unwrap()); ``` #### Cascading The following example responds with "Hello World", however, the request flows through the `logging` middleware to mark when the request started, then continue to yield control through the endpoint. When a middleware invokes `next.await` the function suspends and passes control to the next middleware or endpoint. After the endpoint is called, the stack will unwind and each middleware is resumed to perform its upstream behaviour. ```rust use roa_core::{App, Context, Result, Status, MiddlewareExt, Next}; use std::time::Instant; use tracing::info; let app = App::new().gate(logging).end("Hello, World"); async fn logging(ctx: &mut Context, next: Next<'_>) -> Result { let inbound = Instant::now(); next.await?; info!("time elapsed: {} ms", inbound.elapsed().as_millis()); Ok(()) } ``` ### Status Handling You can catch or straightly throw a status returned by next. ```rust use roa_core::{App, Context, Result, Status, MiddlewareExt, Next, throw}; use roa_core::http::StatusCode; let app = App::new().gate(catch).gate(gate).end(end); async fn catch(ctx: &mut Context, next: Next<'_>) -> Result { // catch if let Err(status) = next.await { // teapot is ok if status.status_code != StatusCode::IM_A_TEAPOT { return Err(status); } } Ok(()) } async fn gate(ctx: &mut Context, next: Next<'_>) -> Result { next.await?; // just throw unreachable!() } async fn end(ctx: &mut Context) -> Result { throw!(StatusCode::IM_A_TEAPOT, "I'm a teapot!") } ``` #### status_handler App has an status_handler to handle `Status` thrown by the top middleware. This is the status_handler: ```rust use roa_core::{Context, Status, Result, State}; pub fn status_handler(ctx: &mut Context, status: Status) { ctx.resp.status = status.status_code; if status.expose { ctx.resp.write(status.message); } else { tracing::error!("{}", status); } } ``` ### HTTP Server. Use `roa_core::accept` to construct a http server. Please refer to `roa::tcp` for more information. ================================================ FILE: roa-core/src/app/future.rs ================================================ use std::future::Future; use std::pin::Pin; use futures::task::{Context, Poll}; /// A wrapper to make future `Send`. It's used to wrap future returned by top middleware. /// So the future returned by each middleware or endpoint can be `?Send`. /// /// But how to ensure thread safety? Because the middleware and the context must be `Sync + Send`, /// which means the only factor causing future `!Send` is the variables generated in `Future::poll`. /// And these variable mustn't be accessed from other threads. #[allow(clippy::non_send_fields_in_send_ty)] pub struct SendFuture(pub F); impl Future for SendFuture where F: 'static + Future + Unpin, { type Output = F::Output; #[inline] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { Pin::new(&mut self.0).poll(cx) } } unsafe impl Send for SendFuture {} ================================================ FILE: roa-core/src/app/runtime.rs ================================================ use crate::executor::{BlockingObj, FutureObj}; use crate::{App, Spawn}; impl App { /// Construct app with default runtime. #[cfg_attr(feature = "docs", doc(cfg(feature = "runtime")))] #[inline] pub fn state(state: S) -> Self { Self::with_exec(state, Exec) } } impl App<(), ()> { /// Construct app with default runtime. #[cfg_attr(feature = "docs", doc(cfg(feature = "runtime")))] #[inline] pub fn new() -> Self { Self::state(()) } } impl Default for App<(), ()> { /// Construct app with default runtime. fn default() -> Self { Self::new() } } pub struct Exec; impl Spawn for Exec { #[inline] fn spawn(&self, fut: FutureObj) { tokio::task::spawn(fut); } #[inline] fn spawn_blocking(&self, task: BlockingObj) { tokio::task::spawn_blocking(task); } } ================================================ FILE: roa-core/src/app/stream.rs ================================================ use std::fmt::Debug; use std::io; use std::net::SocketAddr; use std::pin::Pin; use std::task::{self, Poll}; use futures::ready; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tracing::{instrument, trace}; /// A transport returned yieled by `AddrIncoming`. pub struct AddrStream { /// The remote address of this stream. pub remote_addr: SocketAddr, /// The inner stream. pub stream: IO, } impl AddrStream { /// Construct an AddrStream from an addr and a AsyncReadWriter. #[inline] pub fn new(remote_addr: SocketAddr, stream: IO) -> AddrStream { AddrStream { remote_addr, stream, } } } impl AsyncRead for AddrStream where IO: Unpin + AsyncRead, { #[inline] #[instrument(skip(cx, buf))] fn poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { let poll = Pin::new(&mut self.stream).poll_read(cx, buf); trace!("poll read: {:?}", poll); ready!(poll)?; trace!("read {} bytes", buf.filled().len()); Poll::Ready(Ok(())) } } impl AsyncWrite for AddrStream where IO: Unpin + AsyncWrite, { #[inline] #[instrument(skip(cx, buf))] fn poll_write( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll> { let write_size = ready!(Pin::new(&mut self.stream).poll_write(cx, buf))?; trace!("wrote {} bytes", write_size); Poll::Ready(Ok(write_size)) } #[inline] fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { Pin::new(&mut self.stream).poll_flush(cx) } #[inline] fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { Pin::new(&mut self.stream).poll_shutdown(cx) } } impl Debug for AddrStream { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("AddrStream") .field("remote_addr", &self.remote_addr) .finish() } } ================================================ FILE: roa-core/src/app.rs ================================================ mod future; #[cfg(feature = "runtime")] mod runtime; mod stream; use std::convert::Infallible; use std::error::Error; use std::future::Future; use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; use std::task::Poll; use future::SendFuture; use http::{Request as HttpRequest, Response as HttpResponse}; use hyper::service::Service; use hyper::{Body as HyperBody, Server}; use tokio::io::{AsyncRead, AsyncWrite}; pub use self::stream::AddrStream; use crate::{ Accept, Chain, Context, Endpoint, Executor, Middleware, MiddlewareExt, Request, Response, Spawn, State, }; /// The Application of roa. /// ### Example /// ```rust,no_run /// use roa_core::{App, Context, Next, Result, MiddlewareExt}; /// use tracing::info; /// use tokio::fs::File; /// /// let app = App::new().gate(gate).end(end); /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result { /// info!("{} {}", ctx.method(), ctx.uri()); /// next.await /// } /// /// async fn end(ctx: &mut Context) -> Result { /// ctx.resp.write_reader(File::open("assets/welcome.html").await?); /// Ok(()) /// } /// ``` /// /// ### State /// The `State` is designed to share data or handler between middlewares. /// The only one type implemented `State` by this crate is `()`, you can implement your custom state if neccassary. /// /// ```rust /// use roa_core::{App, Context, Next, Result}; /// use tracing::info; /// use futures::lock::Mutex; /// /// use std::sync::Arc; /// use std::collections::HashMap; /// /// #[derive(Clone)] /// struct State { /// id: u64, /// database: Arc>>, /// } /// /// impl State { /// fn new() -> Self { /// Self { /// id: 0, /// database: Arc::new(Mutex::new(HashMap::new())) /// } /// } /// } /// /// let app = App::state(State::new()).gate(gate).end(end); /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result { /// ctx.id = 1; /// next.await /// } /// /// async fn end(ctx: &mut Context) -> Result { /// let id = ctx.id; /// ctx.database.lock().await.get(&id); /// Ok(()) /// } /// ``` /// pub struct App { service: T, exec: Executor, state: S, } /// An implementation of hyper HttpService. pub struct HttpService { endpoint: Arc, remote_addr: SocketAddr, exec: Executor, pub(crate) state: S, } impl App { /// Map app::service fn map_service(self, mapper: impl FnOnce(T) -> U) -> App { let Self { exec, state, service, } = self; App { service: mapper(service), exec, state, } } } impl App { /// Construct an application with custom runtime. pub fn with_exec(state: S, exec: impl 'static + Send + Sync + Spawn) -> Self { Self { service: (), exec: Executor(Arc::new(exec)), state, } } } impl App where T: for<'a> Middleware<'a, S>, { /// Use a middleware. pub fn gate(self, middleware: M) -> App> where M: for<'a> Middleware<'a, S>, { self.map_service(move |service| service.chain(middleware)) } /// Set endpoint, then app can only be used to serve http request. pub fn end(self, endpoint: E) -> App>> where E: for<'a> Endpoint<'a, S>, { self.map_service(move |service| Arc::new(service.end(endpoint))) } } impl App> where E: for<'a> Endpoint<'a, S>, { /// Construct a hyper server by an incoming. pub fn accept(self, incoming: I) -> Server where S: State, IO: 'static + Send + Sync + Unpin + AsyncRead + AsyncWrite, I: Accept>, I::Error: Into>, { Server::builder(incoming) .executor(self.exec.clone()) .serve(self) } /// Make a fake http service for test. #[cfg(test)] pub fn http_service(&self) -> HttpService where S: Clone, { let endpoint = self.service.clone(); let addr = ([127, 0, 0, 1], 0); let state = self.state.clone(); let exec = self.exec.clone(); HttpService::new(endpoint, addr.into(), exec, state) } } macro_rules! impl_poll_ready { () => { #[inline] fn poll_ready( &mut self, _cx: &mut std::task::Context<'_>, ) -> Poll> { Poll::Ready(Ok(())) } }; } type AppFuture = Pin>> + Send>>; impl Service<&AddrStream> for App> where S: State, E: for<'a> Endpoint<'a, S>, IO: 'static + Send + Sync + Unpin + AsyncRead + AsyncWrite, { type Response = HttpService; type Error = std::io::Error; type Future = AppFuture; impl_poll_ready!(); #[inline] fn call(&mut self, stream: &AddrStream) -> Self::Future { let endpoint = self.service.clone(); let addr = stream.remote_addr; let state = self.state.clone(); let exec = self.exec.clone(); Box::pin(async move { Ok(HttpService::new(endpoint, addr, exec, state)) }) } } type HttpFuture = Pin, Infallible>> + Send>>; impl Service> for HttpService where S: State, E: for<'a> Endpoint<'a, S>, { type Response = HttpResponse; type Error = Infallible; type Future = HttpFuture; impl_poll_ready!(); #[inline] fn call(&mut self, req: HttpRequest) -> Self::Future { let service = self.clone(); Box::pin(async move { let serve_future = SendFuture(Box::pin(service.serve(req.into()))); Ok(serve_future.await.into()) }) } } impl HttpService { pub fn new(endpoint: Arc, remote_addr: SocketAddr, exec: Executor, state: S) -> Self { Self { endpoint, remote_addr, exec, state, } } /// Receive a request then return a response. /// The entry point of http service. pub async fn serve(self, req: Request) -> Response where S: 'static, E: for<'a> Endpoint<'a, S>, { let Self { endpoint, remote_addr, exec, state, } = self; let mut ctx = Context::new(req, state, exec, remote_addr); if let Err(status) = endpoint.call(&mut ctx).await { ctx.resp.status = status.status_code; if status.expose { ctx.resp.write(status.message); } else { ctx.exec .spawn_blocking(move || tracing::error!("Uncaught status: {}", status)) .await; } } ctx.resp } } impl Clone for HttpService { fn clone(&self) -> Self { Self { endpoint: self.endpoint.clone(), state: self.state.clone(), exec: self.exec.clone(), remote_addr: self.remote_addr, } } } #[cfg(all(test, feature = "runtime"))] mod tests { use http::StatusCode; use crate::{App, Request}; #[tokio::test] async fn gate_simple() -> Result<(), Box> { let service = App::new().end(()).http_service(); let resp = service.serve(Request::default()).await; assert_eq!(StatusCode::OK, resp.status); Ok(()) } } ================================================ FILE: roa-core/src/body.rs ================================================ use std::mem; use std::pin::Pin; use std::task::{Context, Poll}; use bytes::{Bytes, BytesMut}; use futures::future::ok; use futures::stream::{once, Stream, StreamExt}; use tokio::io::{self, AsyncRead, ReadBuf}; const DEFAULT_CHUNK_SIZE: usize = 4096; /// The body of response. /// /// ### Example /// /// ```rust /// use roa_core::Body; /// use futures::StreamExt; /// use std::io; /// use bytes::Bytes; /// /// async fn read_body(body: Body) -> io::Result { /// Ok(match body { /// Body::Empty => Bytes::new(), /// Body::Once(bytes) => bytes, /// Body::Stream(mut stream) => { /// let mut bytes = Vec::new(); /// while let Some(item) = stream.next().await { /// bytes.extend_from_slice(&*item?); /// } /// bytes.into() /// } /// }) /// } /// ``` pub enum Body { /// Empty kind Empty, /// Bytes kind. Once(Bytes), /// Stream kind. Stream(Segment), } /// A boxed stream. #[derive(Default)] pub struct Segment(Option> + Sync + Send + 'static>>>); impl Body { /// Construct an empty body. #[inline] pub fn empty() -> Self { Body::Empty } /// Construct a once body. #[inline] pub fn once(bytes: impl Into) -> Self { Body::Once(bytes.into()) } /// Construct an empty body of stream kind. #[inline] pub fn stream(stream: S) -> Self where S: Stream> + Sync + Send + 'static, { Body::Stream(Segment::new(stream)) } /// Write stream. #[inline] pub fn write_stream( &mut self, stream: impl Stream> + Sync + Send + 'static, ) -> &mut Self { match self { Body::Empty => { *self = Self::stream(stream); } Body::Once(bytes) => { let stream = once(ok(mem::take(bytes))).chain(stream); *self = Self::stream(stream); } Body::Stream(segment) => { *self = Self::stream(mem::take(segment).chain(stream)); } } self } /// Write reader with default chunk size. #[inline] pub fn write_reader( &mut self, reader: impl AsyncRead + Sync + Send + Unpin + 'static, ) -> &mut Self { self.write_chunk(reader, DEFAULT_CHUNK_SIZE) } /// Write reader with chunk size. #[inline] pub fn write_chunk( &mut self, reader: impl AsyncRead + Sync + Send + Unpin + 'static, chunk_size: usize, ) -> &mut Self { self.write_stream(ReaderStream::new(reader, chunk_size)) } /// Write `Bytes`. #[inline] pub fn write(&mut self, data: impl Into) -> &mut Self { match self { Body::Empty => { *self = Self::once(data.into()); self } body => body.write_stream(once(ok(data.into()))), } } } impl Segment { #[inline] fn new(stream: impl Stream> + Sync + Send + 'static) -> Self { Self(Some(Box::pin(stream))) } } impl From for hyper::Body { #[inline] fn from(body: Body) -> Self { match body { Body::Empty => hyper::Body::empty(), Body::Once(bytes) => hyper::Body::from(bytes), Body::Stream(stream) => hyper::Body::wrap_stream(stream), } } } impl Default for Body { #[inline] fn default() -> Self { Self::empty() } } pub struct ReaderStream { chunk_size: usize, reader: R, } impl ReaderStream { #[inline] fn new(reader: R, chunk_size: usize) -> Self { Self { reader, chunk_size } } } impl Stream for ReaderStream where R: AsyncRead + Unpin, { type Item = io::Result; #[inline] fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let chunk_size = self.chunk_size; let mut chunk = BytesMut::with_capacity(chunk_size); unsafe { chunk.set_len(chunk_size) }; let mut buf = ReadBuf::new(&mut *chunk); futures::ready!(Pin::new(&mut self.reader).poll_read(cx, &mut buf))?; let filled_len = buf.filled().len(); if filled_len == 0 { Poll::Ready(None) } else { Poll::Ready(Some(Ok(chunk.freeze().slice(0..filled_len)))) } } } impl Stream for Body { type Item = io::Result; #[inline] fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { Body::Empty => Poll::Ready(None), Body::Once(bytes) => { let data = mem::take(bytes); *self = Body::empty(); Poll::Ready(Some(Ok(data))) } Body::Stream(stream) => Pin::new(stream).poll_next(cx), } } } impl Stream for Segment { type Item = io::Result; #[inline] fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.0 { None => Poll::Ready(None), Some(ref mut stream) => stream.as_mut().poll_next(cx), } } } #[cfg(test)] mod tests { use std::io; use futures::{AsyncReadExt, TryStreamExt}; use tokio::fs::File; use super::Body; async fn read_body(body: Body) -> io::Result { let mut data = String::new(); body.into_async_read().read_to_string(&mut data).await?; Ok(data) } #[tokio::test] async fn body_empty() -> std::io::Result<()> { let body = Body::default(); assert_eq!("", read_body(body).await?); Ok(()) } #[tokio::test] async fn body_single() -> std::io::Result<()> { let mut body = Body::default(); body.write("Hello, World"); assert_eq!("Hello, World", read_body(body).await?); Ok(()) } #[tokio::test] async fn body_multiple() -> std::io::Result<()> { let mut body = Body::default(); body.write("He").write("llo, ").write("World"); assert_eq!("Hello, World", read_body(body).await?); Ok(()) } #[tokio::test] async fn body_composed() -> std::io::Result<()> { let mut body = Body::empty(); body.write("He") .write("llo, ") .write_reader(File::open("../assets/author.txt").await?) .write_reader(File::open("../assets/author.txt").await?) .write("."); assert_eq!("Hello, HexileeHexilee.", read_body(body).await?); Ok(()) } } ================================================ FILE: roa-core/src/context/storage.rs ================================================ use std::any::{Any, TypeId}; use std::borrow::Cow; use std::collections::HashMap; use std::fmt::Display; use std::ops::Deref; use std::str::FromStr; use std::sync::Arc; use http::StatusCode; use crate::Status; pub trait Value: Any + Send + Sync {} impl Value for V where V: Any + Send + Sync {} /// A context scoped storage. #[derive(Clone)] pub struct Storage(HashMap, Arc>>); /// A wrapper of Arc. /// /// ### Deref /// /// ```rust /// use roa_core::Variable; /// /// fn consume(var: Variable) { /// let value: &V = &var; /// } /// ``` /// /// ### Parse /// /// ```rust /// use roa_core::{Variable, Result}; /// /// fn consume>(var: Variable) -> Result { /// let value: i32 = var.parse()?; /// Ok(()) /// } /// ``` #[derive(Debug, Clone)] pub struct Variable<'a, V> { key: &'a str, value: Arc, } impl Deref for Variable<'_, V> { type Target = V; #[inline] fn deref(&self) -> &Self::Target { &self.value } } impl<'a, V> Variable<'a, V> { /// Construct a variable from name and value. #[inline] fn new(key: &'a str, value: Arc) -> Self { Self { key, value } } /// Consume self and get inner Arc. #[inline] pub fn value(self) -> Arc { self.value } } impl Variable<'_, V> where V: AsRef, { /// A wrapper of `str::parse`. Converts `T::FromStr::Err` to `roa_core::Error` automatically. #[inline] pub fn parse(&self) -> Result where T: FromStr, T::Err: Display, { self.as_ref().parse().map_err(|err| { Status::new( StatusCode::BAD_REQUEST, format!( "{}\ntype of variable `{}` should be {}", err, self.key, std::any::type_name::() ), true, ) }) } } impl Storage { /// Construct an empty Bucket. #[inline] pub fn new() -> Self { Self(HashMap::new()) } /// Inserts a key-value pair into the storage. /// /// If the storage did not have this key present, [`None`] is returned. /// /// If the storage did have this key present, the value is updated, and the old /// value is returned. pub fn insert(&mut self, scope: S, key: K, value: V) -> Option> where S: Any, K: Into>, V: Value, { let id = TypeId::of::(); match self.0.get_mut(&id) { Some(bucket) => bucket .insert(key.into(), Arc::new(value)) .and_then(|value| value.downcast().ok()), None => { self.0.insert(id, HashMap::new()); self.insert(scope, key, value) } } } /// If the storage did not have this key present, [`None`] is returned. /// /// If the storage did have this key present, the key-value pair will be returned as a `Variable` #[inline] pub fn get<'a, S, V>(&self, key: &'a str) -> Option> where S: Any, V: Value, { let value = self.0.get(&TypeId::of::())?.get(key)?.clone(); Some(Variable::new(key, value.clone().downcast().ok()?)) } } impl Default for Storage { #[inline] fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use std::sync::Arc; use http::StatusCode; use super::{Storage, Variable}; #[test] fn storage() { struct Scope; let mut storage = Storage::default(); assert!(storage.get::("id").is_none()); assert!(storage.insert(Scope, "id", "1").is_none()); let id: i32 = storage .get::("id") .unwrap() .parse() .unwrap(); assert_eq!(1, id); assert_eq!( 1, storage .insert(Scope, "id", "2") .unwrap() .parse::() .unwrap() ); } #[test] fn variable() { assert_eq!( 1, Variable::new("id", Arc::new("1")).parse::().unwrap() ); let result = Variable::new("id", Arc::new("x")).parse::(); assert!(result.is_err()); let status = result.unwrap_err(); assert_eq!(StatusCode::BAD_REQUEST, status.status_code); assert!(status .message .ends_with("type of variable `id` should be usize")); } } ================================================ FILE: roa-core/src/context.rs ================================================ mod storage; use std::any::Any; use std::borrow::Cow; use std::net::SocketAddr; use std::ops::{Deref, DerefMut}; use std::sync::Arc; use http::header::AsHeaderName; use http::{Method, StatusCode, Uri, Version}; pub use storage::Variable; use storage::{Storage, Value}; use crate::{status, Executor, Request, Response}; /// A structure to share request, response and other data between middlewares. /// /// ### Example /// /// ```rust /// use roa_core::{App, Context, Next, Result}; /// use tracing::info; /// use tokio::fs::File; /// /// let app = App::new().gate(gate).end(end); /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result { /// info!("{} {}", ctx.method(), ctx.uri()); /// next.await /// } /// /// async fn end(ctx: &mut Context) -> Result { /// ctx.resp.write_reader(File::open("assets/welcome.html").await?); /// Ok(()) /// } /// ``` pub struct Context { /// The request, to read http method, uri, version, headers and body. pub req: Request, /// The response, to set http status, version, headers and body. pub resp: Response, /// The executor, to spawn futures or blocking works. pub exec: Executor, /// Socket addr of last client or proxy. pub remote_addr: SocketAddr, storage: Storage, state: S, } impl Context { /// Construct a context from a request, an app and a addr_stream. #[inline] pub(crate) fn new(request: Request, state: S, exec: Executor, remote_addr: SocketAddr) -> Self { Self { req: request, resp: Response::default(), state, exec, storage: Storage::default(), remote_addr, } } /// Clone URI. /// /// ### Example /// ```rust /// use roa_core::{App, Context, Result}; /// /// let app = App::new().end(get); /// /// async fn get(ctx: &mut Context) -> Result { /// assert_eq!("/", ctx.uri().to_string()); /// Ok(()) /// } /// ``` #[inline] pub fn uri(&self) -> &Uri { &self.req.uri } /// Clone request::method. /// /// ### Example /// ```rust /// use roa_core::{App, Context, Result}; /// use roa_core::http::Method; /// /// let app = App::new().end(get); /// /// async fn get(ctx: &mut Context) -> Result { /// assert_eq!(Method::GET, ctx.method()); /// Ok(()) /// } /// ``` #[inline] pub fn method(&self) -> &Method { &self.req.method } /// Search for a header value and try to get its string reference. /// /// ### Example /// ```rust /// use roa_core::{App, Context, Result}; /// use roa_core::http::header::CONTENT_TYPE; /// /// let app = App::new().end(get); /// /// async fn get(ctx: &mut Context) -> Result { /// assert_eq!( /// Some("text/plain"), /// ctx.get(CONTENT_TYPE), /// ); /// Ok(()) /// } /// ``` #[inline] pub fn get(&self, name: impl AsHeaderName) -> Option<&str> { self.req .headers .get(name) .and_then(|value| value.to_str().ok()) } /// Search for a header value and get its string reference. /// /// Otherwise return a 400 BAD REQUEST. /// /// ### Example /// ```rust /// use roa_core::{App, Context, Result}; /// use roa_core::http::header::CONTENT_TYPE; /// /// let app = App::new().end(get); /// /// async fn get(ctx: &mut Context) -> Result { /// assert_eq!( /// "text/plain", /// ctx.must_get(CONTENT_TYPE)?, /// ); /// Ok(()) /// } /// ``` #[inline] pub fn must_get(&self, name: impl AsHeaderName) -> crate::Result<&str> { let value = self .req .headers .get(name) .ok_or_else(|| status!(StatusCode::BAD_REQUEST))?; value .to_str() .map_err(|err| status!(StatusCode::BAD_REQUEST, err)) } /// Clone response::status. /// /// ### Example /// ```rust /// use roa_core::{App, Context, Result}; /// use roa_core::http::StatusCode; /// /// let app = App::new().end(get); /// /// async fn get(ctx: &mut Context) -> Result { /// assert_eq!(StatusCode::OK, ctx.status()); /// Ok(()) /// } /// ``` #[inline] pub fn status(&self) -> StatusCode { self.resp.status } /// Clone request::version. /// /// ### Example /// ```rust /// use roa_core::{App, Context, Result}; /// use roa_core::http::Version; /// /// let app = App::new().end(get); /// /// async fn get(ctx: &mut Context) -> Result { /// assert_eq!(Version::HTTP_11, ctx.version()); /// Ok(()) /// } /// ``` #[inline] pub fn version(&self) -> Version { self.req.version } /// Store key-value pair in specific scope. /// /// ### Example /// ```rust /// use roa_core::{App, Context, Result, Next}; /// /// struct Scope; /// struct AnotherScope; /// /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result { /// ctx.store_scoped(Scope, "id", "1".to_string()); /// next.await /// } /// /// async fn end(ctx: &mut Context) -> Result { /// assert_eq!(1, ctx.load_scoped::("id").unwrap().parse::()?); /// assert!(ctx.load_scoped::("id").is_none()); /// Ok(()) /// } /// /// let app = App::new().gate(gate).end(end); /// ``` #[inline] pub fn store_scoped(&mut self, scope: SC, key: K, value: V) -> Option> where SC: Any, K: Into>, V: Value, { self.storage.insert(scope, key, value) } /// Store key-value pair in public scope. /// /// ### Example /// ```rust /// use roa_core::{App, Context, Result, Next}; /// /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result { /// ctx.store("id", "1".to_string()); /// next.await /// } /// /// async fn end(ctx: &mut Context) -> Result { /// assert_eq!(1, ctx.load::("id").unwrap().parse::()?); /// Ok(()) /// } /// /// let app = App::new().gate(gate).end(end); /// ``` #[inline] pub fn store(&mut self, key: K, value: V) -> Option> where K: Into>, V: Value, { self.store_scoped(PublicScope, key, value) } /// Search for value by key in specific scope. /// /// ### Example /// /// ```rust /// use roa_core::{App, Context, Result, Next}; /// /// struct Scope; /// /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result { /// ctx.store_scoped(Scope, "id", "1".to_owned()); /// next.await /// } /// /// async fn end(ctx: &mut Context) -> Result { /// assert_eq!(1, ctx.load_scoped::("id").unwrap().parse::()?); /// Ok(()) /// } /// /// let app = App::new().gate(gate).end(end); /// ``` #[inline] pub fn load_scoped<'a, SC, V>(&self, key: &'a str) -> Option> where SC: Any, V: Value, { self.storage.get::(key) } /// Search for value by key in public scope. /// /// ### Example /// ```rust /// use roa_core::{App, Context, Result, Next}; /// /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result { /// ctx.store("id", "1".to_string()); /// next.await /// } /// /// async fn end(ctx: &mut Context) -> Result { /// assert_eq!(1, ctx.load::("id").unwrap().parse::()?); /// Ok(()) /// } /// /// let app = App::new().gate(gate).end(end); /// ``` #[inline] pub fn load<'a, V>(&self, key: &'a str) -> Option> where V: Value, { self.load_scoped::(key) } } /// Public storage scope. struct PublicScope; impl Deref for Context { type Target = S; #[inline] fn deref(&self) -> &Self::Target { &self.state } } impl DerefMut for Context { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { &mut self.state } } impl Clone for Context { #[inline] fn clone(&self) -> Self { Self { req: Request::default(), resp: Response::new(), state: self.state.clone(), exec: self.exec.clone(), storage: self.storage.clone(), remote_addr: self.remote_addr, } } } #[cfg(all(test, feature = "runtime"))] mod tests_with_runtime { use std::error::Error; use http::{HeaderValue, StatusCode, Version}; use crate::{App, Context, Next, Request, Status}; #[tokio::test] async fn status_and_version() -> Result<(), Box> { async fn test(ctx: &mut Context) -> Result<(), Status> { assert_eq!(Version::HTTP_11, ctx.version()); assert_eq!(StatusCode::OK, ctx.status()); Ok(()) } let service = App::new().end(test).http_service(); service.serve(Request::default()).await; Ok(()) } #[derive(Clone)] struct State { data: usize, } #[tokio::test] async fn state() -> Result<(), Box> { async fn gate(ctx: &mut Context, next: Next<'_>) -> Result<(), Status> { ctx.data = 1; next.await } async fn test(ctx: &mut Context) -> Result<(), Status> { assert_eq!(1, ctx.data); Ok(()) } let service = App::state(State { data: 1 }) .gate(gate) .end(test) .http_service(); service.serve(Request::default()).await; Ok(()) } #[tokio::test] async fn must_get() -> Result<(), Box> { use http::header::{CONTENT_TYPE, HOST}; async fn test(ctx: &mut Context) -> Result<(), Status> { assert_eq!(Ok("github.com"), ctx.must_get(HOST)); ctx.must_get(CONTENT_TYPE)?; unreachable!() } let service = App::new().end(test).http_service(); let mut req = Request::default(); req.headers .insert(HOST, HeaderValue::from_static("github.com")); let resp = service.serve(req).await; assert_eq!(StatusCode::BAD_REQUEST, resp.status); Ok(()) } } ================================================ FILE: roa-core/src/err.rs ================================================ use std::fmt::{Display, Formatter}; use std::result::Result as StdResult; pub use http::StatusCode; /// Type alias for `StdResult`. pub type Result = StdResult; /// Construct a `Status`. /// /// - `status!(status_code)` will be expanded to `status!(status_code, "")` /// - `status!(status_code, message)` will be expanded to `status!(status_code, message, true)` /// - `status!(status_code, message, expose)` will be expanded to `Status::new(status_code, message, expose)` /// /// ### Example /// ```rust /// use roa_core::{App, Context, Next, Result, status}; /// use roa_core::http::StatusCode; /// /// let app = App::new() /// .gate(gate) /// .end(status!(StatusCode::IM_A_TEAPOT, "I'm a teapot!")); /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result { /// next.await?; // throw /// unreachable!(); /// ctx.resp.status = StatusCode::OK; /// Ok(()) /// } /// ``` #[macro_export] macro_rules! status { ($status_code:expr) => { $crate::status!($status_code, "") }; ($status_code:expr, $message:expr) => { $crate::status!($status_code, $message, true) }; ($status_code:expr, $message:expr, $expose:expr) => { $crate::Status::new($status_code, $message, $expose) }; } /// Throw an `Err(Status)`. /// /// - `throw!(status_code)` will be expanded to `throw!(status_code, "")` /// - `throw!(status_code, message)` will be expanded to `throw!(status_code, message, true)` /// - `throw!(status_code, message, expose)` will be expanded to `return Err(Status::new(status_code, message, expose));` /// /// ### Example /// ```rust /// use roa_core::{App, Context, Next, Result, throw}; /// use roa_core::http::StatusCode; /// /// let app = App::new().gate(gate).end(end); /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result { /// next.await?; // throw /// unreachable!(); /// ctx.resp.status = StatusCode::OK; /// Ok(()) /// } /// /// async fn end(ctx: &mut Context) -> Result { /// throw!(StatusCode::IM_A_TEAPOT, "I'm a teapot!"); // throw /// unreachable!() /// } /// ``` #[macro_export] macro_rules! throw { ($status_code:expr) => { return core::result::Result::Err($crate::status!($status_code)) }; ($status_code:expr, $message:expr) => { return core::result::Result::Err($crate::status!($status_code, $message)) }; ($status_code:expr, $message:expr, $expose:expr) => { return core::result::Result::Err($crate::status!($status_code, $message, $expose)) }; } /// The `Status` of roa. #[derive(Debug, Clone, Eq, PartialEq)] pub struct Status { /// StatusCode will be responded to client if Error is thrown by the top middleware. /// /// ### Example /// ```rust /// use roa_core::{App, Context, Next, Result, MiddlewareExt, throw}; /// use roa_core::http::StatusCode; /// /// let app = App::new().gate(gate).end(end); /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result { /// ctx.resp.status = StatusCode::OK; /// next.await // not caught /// } /// /// async fn end(ctx: &mut Context) -> Result { /// throw!(StatusCode::IM_A_TEAPOT, "I'm a teapot!"); // throw /// unreachable!() /// } /// ``` pub status_code: StatusCode, /// Data will be written to response body if self.expose is true. /// StatusCode will be responded to client if Error is thrown by the top middleware. /// /// ### Example /// ```rust /// use roa_core::{App, Context, Result, Status}; /// use roa_core::http::StatusCode; /// /// let app = App::new().end(end); /// /// async fn end(ctx: &mut Context) -> Result { /// Err(Status::new(StatusCode::IM_A_TEAPOT, "I'm a teapot!", false)) // message won't be exposed to user. /// } /// /// ``` pub message: String, /// if message exposed. pub expose: bool, } impl Status { /// Construct an error. #[inline] pub fn new(status_code: StatusCode, message: impl ToString, expose: bool) -> Self { Self { status_code, message: message.to_string(), expose, } } } impl From for Status where E: std::error::Error, { #[inline] fn from(err: E) -> Self { Self::new(StatusCode::INTERNAL_SERVER_ERROR, err, false) } } impl Display for Status { #[inline] fn fmt(&self, f: &mut Formatter<'_>) -> StdResult<(), std::fmt::Error> { f.write_str(&format!("{}: {}", self.status_code, self.message)) } } ================================================ FILE: roa-core/src/executor.rs ================================================ use std::future::Future; use std::pin::Pin; use std::sync::Arc; use futures::channel::oneshot::{channel, Receiver}; use futures::task::{Context, Poll}; use hyper::rt; /// Future Object pub type FutureObj = Pin>>; /// Blocking task Object pub type BlockingObj = Box; /// Executor constraint. pub trait Spawn { /// Spawn a future object fn spawn(&self, fut: FutureObj); /// Spawn a blocking task object fn spawn_blocking(&self, task: BlockingObj); } /// A type implementing hyper::rt::Executor #[derive(Clone)] pub struct Executor(pub(crate) Arc); /// A handle that awaits the result of a task. pub struct JoinHandle(Receiver); impl Executor { /// Spawn a future by app runtime #[inline] pub fn spawn(&self, fut: Fut) -> JoinHandle where Fut: 'static + Send + Future, Fut::Output: 'static + Send, { let (sender, recv) = channel(); self.0.spawn(Box::pin(async move { if sender.send(fut.await).is_err() { // handler is dropped, do nothing. }; })); JoinHandle(recv) } /// Spawn a blocking task by app runtime #[inline] pub fn spawn_blocking(&self, task: T) -> JoinHandle where T: 'static + Send + FnOnce() -> R, R: 'static + Send, { let (sender, recv) = channel(); self.0.spawn_blocking(Box::new(|| { if sender.send(task()).is_err() { // handler is dropped, do nothing. }; })); JoinHandle(recv) } } impl Future for JoinHandle { type Output = T; #[inline] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let ready = futures::ready!(Pin::new(&mut self.0).poll(cx)); Poll::Ready(ready.expect("receiver in JoinHandle shouldn't be canceled")) } } impl rt::Executor for Executor where F: 'static + Send + Future, F::Output: 'static + Send, { #[inline] fn execute(&self, fut: F) { self.0.spawn(Box::pin(async move { let _ = fut.await; })); } } #[cfg(test)] mod tests { use std::sync::Arc; use super::{BlockingObj, Executor, FutureObj, Spawn}; pub struct Exec; impl Spawn for Exec { fn spawn(&self, fut: FutureObj) { tokio::task::spawn(fut); } fn spawn_blocking(&self, task: BlockingObj) { tokio::task::spawn_blocking(task); } } #[tokio::test] async fn spawn() { let exec = Executor(Arc::new(Exec)); assert_eq!(1, exec.spawn(async { 1 }).await); } #[tokio::test] async fn spawn_blocking() { let exec = Executor(Arc::new(Exec)); assert_eq!(1, exec.spawn_blocking(|| 1).await); } } ================================================ FILE: roa-core/src/group.rs ================================================ use std::sync::Arc; use crate::{async_trait, Context, Endpoint, Middleware, Next, Result}; /// A set of method to chain middleware/endpoint to middleware /// or make middleware shared. pub trait MiddlewareExt: Sized + for<'a> Middleware<'a, S> { /// Chain two middlewares. fn chain(self, next: M) -> Chain where M: for<'a> Middleware<'a, S>, { Chain(self, next) } /// Chain an endpoint to a middleware. fn end(self, next: E) -> Chain where E: for<'a> Endpoint<'a, S>, { Chain(self, next) } /// Make middleware shared. fn shared(self) -> Shared where S: 'static, { Shared(Arc::new(self)) } } /// Extra methods of endpoint. pub trait EndpointExt: Sized + for<'a> Endpoint<'a, S> { /// Box an endpoint. fn boxed(self) -> Boxed where S: 'static, { Boxed(Box::new(self)) } } impl MiddlewareExt for T where T: for<'a> Middleware<'a, S> {} impl EndpointExt for T where T: for<'a> Endpoint<'a, S> {} /// A middleware composing and executing other middlewares in a stack-like manner. pub struct Chain(T, U); /// Shared middleware. pub struct Shared(Arc Middleware<'a, S>>); /// Boxed endpoint. pub struct Boxed(Box Endpoint<'a, S>>); #[async_trait(?Send)] impl<'a, S, T, U> Middleware<'a, S> for Chain where U: Middleware<'a, S>, T: for<'b> Middleware<'b, S>, { #[inline] async fn handle(&'a self, ctx: &'a mut Context, next: Next<'a>) -> Result { let ptr = ctx as *mut Context; let mut next = self.1.handle(unsafe { &mut *ptr }, next); self.0.handle(ctx, &mut next).await } } #[async_trait(?Send)] impl<'a, S> Middleware<'a, S> for Shared where S: 'static, { #[inline] async fn handle(&'a self, ctx: &'a mut Context, next: Next<'a>) -> Result { self.0.handle(ctx, next).await } } impl Clone for Shared { #[inline] fn clone(&self) -> Self { Self(self.0.clone()) } } #[async_trait(?Send)] impl<'a, S> Endpoint<'a, S> for Boxed where S: 'static, { #[inline] async fn call(&'a self, ctx: &'a mut Context) -> Result { self.0.call(ctx).await } } #[async_trait(?Send)] impl<'a, S, T, U> Endpoint<'a, S> for Chain where U: Endpoint<'a, S>, T: for<'b> Middleware<'b, S>, { #[inline] async fn call(&'a self, ctx: &'a mut Context) -> Result { let ptr = ctx as *mut Context; let mut next = self.1.call(unsafe { &mut *ptr }); self.0.handle(ctx, &mut next).await } } #[cfg(all(test, feature = "runtime"))] mod tests { use std::sync::Arc; use futures::lock::Mutex; use http::StatusCode; use crate::{async_trait, App, Context, Middleware, Next, Request, Status}; struct Pusher { data: usize, vector: Arc>>, } impl Pusher { fn new(data: usize, vector: Arc>>) -> Self { Self { data, vector } } } #[async_trait(?Send)] impl<'a> Middleware<'a, ()> for Pusher { async fn handle(&'a self, _ctx: &'a mut Context, next: Next<'a>) -> Result<(), Status> { self.vector.lock().await.push(self.data); next.await?; self.vector.lock().await.push(self.data); Ok(()) } } #[tokio::test] async fn middleware_order() -> Result<(), Box> { let vector = Arc::new(Mutex::new(Vec::new())); let service = App::new() .gate(Pusher::new(0, vector.clone())) .gate(Pusher::new(1, vector.clone())) .gate(Pusher::new(2, vector.clone())) .gate(Pusher::new(3, vector.clone())) .gate(Pusher::new(4, vector.clone())) .gate(Pusher::new(5, vector.clone())) .gate(Pusher::new(6, vector.clone())) .gate(Pusher::new(7, vector.clone())) .gate(Pusher::new(8, vector.clone())) .gate(Pusher::new(9, vector.clone())) .end(()) .http_service(); let resp = service.serve(Request::default()).await; assert_eq!(StatusCode::OK, resp.status); for i in 0..10 { assert_eq!(i, vector.lock().await[i]); assert_eq!(i, vector.lock().await[19 - i]); } Ok(()) } } ================================================ FILE: roa-core/src/lib.rs ================================================ #![cfg_attr(feature = "docs", feature(doc_cfg))] #![cfg_attr(feature = "docs", doc = include_str!("../README.md"))] #![cfg_attr(feature = "docs", warn(missing_docs))] mod app; mod body; mod context; mod err; mod executor; mod group; mod middleware; mod request; mod response; mod state; #[doc(inline)] pub use app::{AddrStream, App}; pub use async_trait::async_trait; #[doc(inline)] pub use body::Body; #[doc(inline)] pub use context::{Context, Variable}; #[doc(inline)] pub use err::{Result, Status}; #[doc(inline)] pub use executor::{Executor, JoinHandle, Spawn}; #[doc(inline)] pub use group::{Boxed, Chain, EndpointExt, MiddlewareExt, Shared}; pub use http; pub use hyper::server::accept::Accept; pub use hyper::server::Server; #[doc(inline)] pub use middleware::{Endpoint, Middleware, Next}; #[doc(inline)] pub use request::Request; #[doc(inline)] pub use response::Response; #[doc(inline)] pub use state::State; ================================================ FILE: roa-core/src/middleware.rs ================================================ use std::future::Future; use http::header::LOCATION; use http::{StatusCode, Uri}; use crate::{async_trait, throw, Context, Result, Status}; /// ### Middleware /// /// #### Build-in middlewares /// /// - Functional middleware /// /// A functional middleware is an async function with signature: /// `async fn(&mut Context, Next<'_>) -> Result`. /// /// ```rust /// use roa_core::{App, Context, Next, Result}; /// /// async fn middleware(ctx: &mut Context, next: Next<'_>) -> Result { /// next.await /// } /// /// let app = App::new().gate(middleware); /// ``` /// /// - Blank middleware /// /// `()` is a blank middleware, it just calls the next middleware or endpoint. /// /// ```rust /// let app = roa_core::App::new().gate(()); /// ``` /// /// #### Custom middleware /// /// You can implement custom `Middleware` for other types. /// /// ```rust /// use roa_core::{App, Middleware, Context, Next, Result, async_trait}; /// use std::sync::Arc; /// use std::time::Instant; /// /// /// struct Logger; /// /// #[async_trait(?Send)] /// impl <'a> Middleware<'a> for Logger { /// async fn handle(&'a self, ctx: &'a mut Context, next: Next<'a>) -> Result { /// let start = Instant::now(); /// let result = next.await; /// println!("time elapsed: {}ms", start.elapsed().as_millis()); /// result /// } /// } /// /// let app = App::new().gate(Logger); /// ``` #[async_trait(?Send)] pub trait Middleware<'a, S = ()>: 'static + Sync + Send { /// Handle context and next, return status. async fn handle(&'a self, ctx: &'a mut Context, next: Next<'a>) -> Result; } #[async_trait(?Send)] impl<'a, S, T, F> Middleware<'a, S> for T where S: 'a, T: 'static + Send + Sync + Fn(&'a mut Context, Next<'a>) -> F, F: 'a + Future, { #[inline] async fn handle(&'a self, ctx: &'a mut Context, next: Next<'a>) -> Result { (self)(ctx, next).await } } /// ### Endpoint /// /// An endpoint is a request handler. /// /// #### Build-in endpoint /// /// There are some build-in endpoints. /// /// - Functional endpoint /// /// A normal functional endpoint is an async function with signature: /// `async fn(&mut Context) -> Result`. /// /// ```rust /// use roa_core::{App, Context, Result}; /// /// async fn endpoint(ctx: &mut Context) -> Result { /// Ok(()) /// } /// /// let app = App::new().end(endpoint); /// ``` /// - Ok endpoint /// /// `()` is an endpoint always return `Ok(())` /// /// ```rust /// let app = roa_core::App::new().end(()); /// ``` /// /// - Status endpoint /// /// `Status` is an endpoint always return `Err(Status)` /// /// ```rust /// use roa_core::{App, status}; /// use roa_core::http::StatusCode; /// let app = App::new().end(status!(StatusCode::BAD_REQUEST)); /// ``` /// /// - String endpoint /// /// Write string to body. /// /// ```rust /// use roa_core::App; /// /// let app = App::new().end("Hello, world"); // static slice /// let app = App::new().end("Hello, world".to_owned()); // string /// ``` /// /// - Redirect endpoint /// /// Redirect to an uri. /// /// ```rust /// use roa_core::App; /// use roa_core::http::Uri; /// /// let app = App::new().end("/target".parse::().unwrap()); /// ``` /// /// #### Custom endpoint /// /// You can implement custom `Endpoint` for your types. /// /// ```rust /// use roa_core::{App, Endpoint, Context, Next, Result, async_trait}; /// /// fn is_endpoint(endpoint: impl for<'a> Endpoint<'a>) { /// } /// /// struct Service; /// /// #[async_trait(?Send)] /// impl <'a> Endpoint<'a> for Service { /// async fn call(&'a self, ctx: &'a mut Context) -> Result { /// Ok(()) /// } /// } /// /// let app = App::new().end(Service); /// ``` #[async_trait(?Send)] pub trait Endpoint<'a, S = ()>: 'static + Sync + Send { /// Call this endpoint. async fn call(&'a self, ctx: &'a mut Context) -> Result; } #[async_trait(?Send)] impl<'a, S, T, F> Endpoint<'a, S> for T where S: 'a, T: 'static + Send + Sync + Fn(&'a mut Context) -> F, F: 'a + Future, { #[inline] async fn call(&'a self, ctx: &'a mut Context) -> Result { (self)(ctx).await } } /// blank middleware. #[async_trait(?Send)] impl<'a, S> Middleware<'a, S> for () { #[allow(clippy::trivially_copy_pass_by_ref)] #[inline] async fn handle(&'a self, _ctx: &'a mut Context, next: Next<'a>) -> Result { next.await } } /// ok endpoint, always return Ok(()) #[async_trait(?Send)] impl<'a, S> Endpoint<'a, S> for () { #[allow(clippy::trivially_copy_pass_by_ref)] #[inline] async fn call(&'a self, _ctx: &'a mut Context) -> Result { Ok(()) } } /// status endpoint. #[async_trait(?Send)] impl<'a, S> Endpoint<'a, S> for Status { #[inline] async fn call(&'a self, _ctx: &'a mut Context) -> Result { Err(self.clone()) } } /// String endpoint. #[async_trait(?Send)] impl<'a, S> Endpoint<'a, S> for String { #[inline] #[allow(clippy::ptr_arg)] async fn call(&'a self, ctx: &'a mut Context) -> Result { ctx.resp.write(self.clone()); Ok(()) } } /// Static slice endpoint. #[async_trait(?Send)] impl<'a, S> Endpoint<'a, S> for &'static str { #[inline] async fn call(&'a self, ctx: &'a mut Context) -> Result { ctx.resp.write(*self); Ok(()) } } /// Redirect endpoint. #[async_trait(?Send)] impl<'a, S> Endpoint<'a, S> for Uri { #[inline] async fn call(&'a self, ctx: &'a mut Context) -> Result { ctx.resp.headers.insert(LOCATION, self.to_string().parse()?); throw!(StatusCode::PERMANENT_REDIRECT) } } /// Type of the second parameter in a middleware, /// an alias for `&mut (dyn Unpin + Future)` /// /// Developer of middleware can jump to next middleware by calling `next.await`. /// /// ### Example /// /// ```rust /// use roa_core::{App, Context, Result, Status, MiddlewareExt, Next}; /// use roa_core::http::StatusCode; /// /// let app = App::new() /// .gate(first) /// .gate(second) /// .gate(third) /// .end(end); /// async fn first(ctx: &mut Context, next: Next<'_>) -> Result { /// assert!(ctx.store("id", "1").is_none()); /// next.await?; /// assert_eq!("5", *ctx.load::<&'static str>("id").unwrap()); /// Ok(()) /// } /// async fn second(ctx: &mut Context, next: Next<'_>) -> Result { /// assert_eq!("1", *ctx.load::<&'static str>("id").unwrap()); /// assert_eq!("1", *ctx.store("id", "2").unwrap()); /// next.await?; /// assert_eq!("4", *ctx.store("id", "5").unwrap()); /// Ok(()) /// } /// async fn third(ctx: &mut Context, next: Next<'_>) -> Result { /// assert_eq!("2", *ctx.store("id", "3").unwrap()); /// next.await?; // next is none; do nothing /// assert_eq!("3", *ctx.store("id", "4").unwrap()); /// Ok(()) /// } /// /// async fn end(ctx: &mut Context) -> Result { /// assert_eq!("3", *ctx.load::<&'static str>("id").unwrap()); /// Ok(()) /// } /// ``` /// /// ### Error Handling /// /// You can catch or straightly throw a Error returned by next. /// /// ```rust /// use roa_core::{App, Context, Result, Status, MiddlewareExt, Next, status}; /// use roa_core::http::StatusCode; /// /// let app = App::new() /// .gate(catch) /// .gate(gate) /// .end(status!(StatusCode::IM_A_TEAPOT, "I'm a teapot!")); /// /// async fn catch(ctx: &mut Context, next: Next<'_>) -> Result { /// // catch /// if let Err(err) = next.await { /// // teapot is ok /// if err.status_code != StatusCode::IM_A_TEAPOT { /// return Err(err); /// } /// } /// Ok(()) /// } /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result { /// next.await?; // just throw /// unreachable!() /// } /// ``` /// pub type Next<'a> = &'a mut (dyn Unpin + Future); #[cfg(test)] mod tests { use futures::{AsyncReadExt, TryStreamExt}; use http::header::LOCATION; use http::{StatusCode, Uri}; use crate::{status, App, Request}; const HELLO: &str = "Hello, world"; #[tokio::test] async fn status_endpoint() { let app = App::new().end(status!(StatusCode::BAD_REQUEST)); let service = app.http_service(); let resp = service.serve(Request::default()).await; assert_eq!(StatusCode::BAD_REQUEST, resp.status); } #[tokio::test] async fn string_endpoint() { let app = App::new().end(HELLO.to_owned()); let service = app.http_service(); let mut data = String::new(); service .serve(Request::default()) .await .body .into_async_read() .read_to_string(&mut data) .await .unwrap(); assert_eq!(HELLO, data); } #[tokio::test] async fn static_slice_endpoint() { let app = App::new().end(HELLO); let service = app.http_service(); let mut data = String::new(); service .serve(Request::default()) .await .body .into_async_read() .read_to_string(&mut data) .await .unwrap(); assert_eq!(HELLO, data); } #[tokio::test] async fn redirect_endpoint() { let app = App::new().end("/target".parse::().unwrap()); let service = app.http_service(); let resp = service.serve(Request::default()).await; assert_eq!(StatusCode::PERMANENT_REDIRECT, resp.status); assert_eq!("/target", resp.headers[LOCATION].to_str().unwrap()) } } ================================================ FILE: roa-core/src/request.rs ================================================ use std::io; use bytes::Bytes; use futures::stream::{Stream, TryStreamExt}; use http::{Extensions, HeaderMap, HeaderValue, Method, Uri, Version}; use hyper::Body; use tokio::io::AsyncRead; use tokio_util::io::StreamReader; /// Http request type of roa. pub struct Request { /// The request's method pub method: Method, /// The request's URI pub uri: Uri, /// The request's version pub version: Version, /// The request's headers pub headers: HeaderMap, extensions: Extensions, body: Body, } impl Request { /// Take raw hyper request. /// This method will consume inner body and extensions. #[inline] pub fn take_raw(&mut self) -> http::Request { let mut builder = http::Request::builder() .method(self.method.clone()) .uri(self.uri.clone()); *builder.extensions_mut().expect("fail to get extensions") = std::mem::take(&mut self.extensions); *builder.headers_mut().expect("fail to get headers") = self.headers.clone(); builder .body(self.raw_body()) .expect("fail to build raw body") } /// Gake raw hyper body. /// This method will consume inner body. #[inline] pub fn raw_body(&mut self) -> Body { std::mem::take(&mut self.body) } /// Get body as Stream. /// This method will consume inner body. #[inline] pub fn stream( &mut self, ) -> impl Stream> + Sync + Send + Unpin + 'static { self.raw_body() .map_err(|err| io::Error::new(io::ErrorKind::Other, err)) } /// Get body as AsyncRead. /// This method will consume inner body. #[inline] pub fn reader(&mut self) -> impl AsyncRead + Sync + Send + Unpin + 'static { StreamReader::new(self.stream()) } } impl From> for Request { #[inline] fn from(req: http::Request) -> Self { let (parts, body) = req.into_parts(); Self { method: parts.method, uri: parts.uri, version: parts.version, headers: parts.headers, extensions: parts.extensions, body, } } } impl Default for Request { #[inline] fn default() -> Self { http::Request::new(Body::empty()).into() } } #[cfg(all(test, feature = "runtime"))] mod tests { use http::StatusCode; use hyper::Body; use tokio::io::AsyncReadExt; use crate::{App, Context, Request, Status}; async fn test(ctx: &mut Context) -> Result<(), Status> { let mut data = String::new(); ctx.req.reader().read_to_string(&mut data).await?; assert_eq!("Hello, World!", data); Ok(()) } #[tokio::test] async fn body_read() -> Result<(), Box> { let app = App::new().end(test); let service = app.http_service(); let req = Request::from(http::Request::new(Body::from("Hello, World!"))); let resp = service.serve(req).await; assert_eq!(StatusCode::OK, resp.status); Ok(()) } } ================================================ FILE: roa-core/src/response.rs ================================================ //! A module for Response and its body use std::ops::{Deref, DerefMut}; use http::{HeaderMap, HeaderValue, StatusCode, Version}; pub use crate::Body; /// Http response type of roa. pub struct Response { /// Status code. pub status: StatusCode, /// Version of HTTP protocol. pub version: Version, /// Raw header map. pub headers: HeaderMap, /// Response body. pub body: Body, } impl Response { #[inline] pub(crate) fn new() -> Self { Self { status: StatusCode::default(), version: Version::default(), headers: HeaderMap::default(), body: Body::default(), } } #[inline] fn into_resp(self) -> http::Response { let (mut parts, _) = http::Response::new(()).into_parts(); let Response { status, version, headers, body, } = self; parts.status = status; parts.version = version; parts.headers = headers; http::Response::from_parts(parts, body.into()) } } impl Deref for Response { type Target = Body; #[inline] fn deref(&self) -> &Self::Target { &self.body } } impl DerefMut for Response { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { &mut self.body } } impl From for http::Response { #[inline] fn from(value: Response) -> Self { value.into_resp() } } impl Default for Response { #[inline] fn default() -> Self { Self::new() } } ================================================ FILE: roa-core/src/state.rs ================================================ /// The `State` trait, should be replace with trait alias. /// The `App::state` will be cloned when a request inbounds. /// /// `State` is designed to share data or handler between middlewares. /// /// ### Example /// ```rust /// use roa_core::{App, Context, Next, Result}; /// use roa_core::http::StatusCode; /// /// #[derive(Clone)] /// struct State { /// id: u64, /// } /// /// let app = App::state(State { id: 0 }).gate(gate).end(end); /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result { /// ctx.id = 1; /// next.await /// } /// /// async fn end(ctx: &mut Context) -> Result { /// let id = ctx.id; /// assert_eq!(1, id); /// Ok(()) /// } /// ``` pub trait State: 'static + Clone + Send + Sync + Sized {} impl State for T {} ================================================ FILE: roa-diesel/Cargo.toml ================================================ [package] name = "roa-diesel" version = "0.6.0" authors = ["Hexilee "] edition = "2018" license = "MIT" readme = "./README.md" repository = "https://github.com/Hexilee/roa" documentation = "https://docs.rs/roa-diesel" homepage = "https://github.com/Hexilee/roa/wiki" description = "diesel integration with roa framework" keywords = ["http", "web", "framework", "orm"] categories = ["database"] [package.metadata.docs.rs] features = ["docs"] rustdoc-args = ["--cfg", "feature=\"docs\""] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] roa = { path = "../roa", version = "0.6.0", default-features = false } diesel = { version = "1.4", features = ["extras"] } r2d2 = "0.8" [dev-dependencies] diesel = { version = "1.4", features = ["extras", "sqlite"] } [features] docs = ["roa/docs"] ================================================ FILE: roa-diesel/README.md ================================================ [![Stable Test](https://github.com/Hexilee/roa/workflows/Stable%20Test/badge.svg)](https://github.com/Hexilee/roa/actions) [![codecov](https://codecov.io/gh/Hexilee/roa/branch/master/graph/badge.svg)](https://codecov.io/gh/Hexilee/roa) [![Rust Docs](https://docs.rs/roa-diesel/badge.svg)](https://docs.rs/roa-diesel) [![Crate version](https://img.shields.io/crates/v/roa-diesel.svg)](https://crates.io/crates/roa-diesel) [![Download](https://img.shields.io/crates/d/roa-diesel.svg)](https://crates.io/crates/roa-diesel) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/Hexilee/roa/blob/master/LICENSE) This crate provides diesel integration with roa framework. ### AsyncPool A context extension to access r2d2 pool asynchronously. ```rust use roa::{Context, Result}; use diesel::sqlite::SqliteConnection; use roa_diesel::Pool; use roa_diesel::preload::*; use diesel::r2d2::ConnectionManager; #[derive(Clone)] struct State(Pool); impl AsRef> for State { fn as_ref(&self) -> &Pool { &self.0 } } async fn get(ctx: Context) -> Result { let conn = ctx.get_conn().await?; // handle conn Ok(()) } ``` ### SqlQuery A context extension to execute diesel query asynchronously. Refer to [integration example](https://github.com/Hexilee/roa/tree/master/integration/diesel-example) for more use cases. ================================================ FILE: roa-diesel/src/async_ext.rs ================================================ use diesel::connection::Connection; use diesel::helper_types::Limit; use diesel::query_dsl::methods::{ExecuteDsl, LimitDsl, LoadQuery}; use diesel::query_dsl::RunQueryDsl; use diesel::result::{Error as DieselError, OptionalExtension}; use roa::{async_trait, Context, Result, State}; use crate::pool::{AsyncPool, Pool}; /// A context extension to execute diesel dsl asynchronously. #[async_trait] pub trait SqlQuery { /// Executes the given command, returning the number of rows affected. /// /// `execute` is usually used in conjunction with [`insert_into`](../fn.insert_into.html), /// [`update`](../fn.update.html) and [`delete`](../fn.delete.html) where the number of /// affected rows is often enough information. /// /// When asking the database to return data from a query, [`load`](#method.load) should /// probably be used instead. async fn execute(&self, exec: E) -> Result where E: 'static + Send + ExecuteDsl; /// Executes the given query, returning a `Vec` with the returned rows. /// /// When using the query builder, /// the return type can be /// a tuple of the values, /// or a struct which implements [`Queryable`]. /// /// When this method is called on [`sql_query`], /// the return type can only be a struct which implements [`QueryableByName`] /// /// For insert, update, and delete operations where only a count of affected is needed, /// [`execute`] should be used instead. /// /// [`Queryable`]: ../deserialize/trait.Queryable.html /// [`QueryableByName`]: ../deserialize/trait.QueryableByName.html /// [`execute`]: fn.execute.html /// [`sql_query`]: ../fn.sql_query.html /// async fn load_data(&self, query: Q) -> Result> where U: 'static + Send, Q: 'static + Send + LoadQuery; /// Runs the command, and returns the affected row. /// /// `Err(NotFound)` will be returned if the query affected 0 rows. You can /// call `.optional()` on the result of this if the command was optional to /// get back a `Result>` /// /// When this method is called on an insert, update, or delete statement, /// it will implicitly add a `RETURNING *` to the query, /// unless a returning clause was already specified. async fn get_result(&self, query: Q) -> Result> where U: 'static + Send, Q: 'static + Send + LoadQuery; /// Runs the command, returning an `Vec` with the affected rows. /// /// This method is an alias for [`load`], but with a name that makes more /// sense for insert, update, and delete statements. /// /// [`load`]: #method.load async fn get_results(&self, query: Q) -> Result> where U: 'static + Send, Q: 'static + Send + LoadQuery; /// Attempts to load a single record. /// /// This method is equivalent to `.limit(1).get_result()` /// /// Returns `Ok(record)` if found, and `Err(NotFound)` if no results are /// returned. If the query truly is optional, you can call `.optional()` on /// the result of this to get a `Result>`. /// async fn first(&self, query: Q) -> Result> where U: 'static + Send, Q: 'static + Send + LimitDsl, Limit: LoadQuery; } #[async_trait] impl SqlQuery for Context where S: State + AsRef>, Conn: 'static + Connection, { #[inline] async fn execute(&self, exec: E) -> Result where E: 'static + Send + ExecuteDsl, { let conn = self.get_conn().await?; Ok(self .exec .spawn_blocking(move || ExecuteDsl::::execute(exec, &*conn)) .await?) } /// Executes the given query, returning a `Vec` with the returned rows. /// /// When using the query builder, /// the return type can be /// a tuple of the values, /// or a struct which implements [`Queryable`]. /// /// When this method is called on [`sql_query`], /// the return type can only be a struct which implements [`QueryableByName`] /// /// For insert, update, and delete operations where only a count of affected is needed, /// [`execute`] should be used instead. /// /// [`Queryable`]: ../deserialize/trait.Queryable.html /// [`QueryableByName`]: ../deserialize/trait.QueryableByName.html /// [`execute`]: fn.execute.html /// [`sql_query`]: ../fn.sql_query.html /// #[inline] async fn load_data(&self, query: Q) -> Result> where U: 'static + Send, Q: 'static + Send + LoadQuery, { let conn = self.get_conn().await?; match self.exec.spawn_blocking(move || query.load(&*conn)).await { Ok(data) => Ok(data), Err(DieselError::NotFound) => Ok(Vec::new()), Err(err) => Err(err.into()), } } /// Runs the command, and returns the affected row. /// /// `Err(NotFound)` will be returned if the query affected 0 rows. You can /// call `.optional()` on the result of this if the command was optional to /// get back a `Result>` /// /// When this method is called on an insert, update, or delete statement, /// it will implicitly add a `RETURNING *` to the query, /// unless a returning clause was already specified. #[inline] async fn get_result(&self, query: Q) -> Result> where U: 'static + Send, Q: 'static + Send + LoadQuery, { let conn = self.get_conn().await?; Ok(self .exec .spawn_blocking(move || query.get_result(&*conn)) .await .optional()?) } /// Runs the command, returning an `Vec` with the affected rows. /// /// This method is an alias for [`load`], but with a name that makes more /// sense for insert, update, and delete statements. /// /// [`load`]: #method.load #[inline] async fn get_results(&self, query: Q) -> Result> where U: 'static + Send, Q: 'static + Send + LoadQuery, { self.load_data(query).await } /// Attempts to load a single record. /// /// This method is equivalent to `.limit(1).get_result()` /// /// Returns `Ok(record)` if found, and `Err(NotFound)` if no results are /// returned. If the query truly is optional, you can call `.optional()` on /// the result of this to get a `Result>`. /// #[inline] async fn first(&self, query: Q) -> Result> where U: 'static + Send, Q: 'static + Send + LimitDsl, Limit: LoadQuery, { let conn = self.get_conn().await?; Ok(self .exec .spawn_blocking(move || query.limit(1).get_result(&*conn)) .await .optional()?) } } ================================================ FILE: roa-diesel/src/lib.rs ================================================ #![cfg_attr(feature = "docs", doc = include_str!("../README.md"))] #![cfg_attr(feature = "docs", warn(missing_docs))] mod async_ext; mod pool; #[doc(inline)] pub use diesel::r2d2::ConnectionManager; #[doc(inline)] pub use pool::{builder, make_pool, Pool, WrapConnection}; /// preload ext traits. pub mod preload { #[doc(inline)] pub use crate::async_ext::SqlQuery; #[doc(inline)] pub use crate::pool::AsyncPool; } ================================================ FILE: roa-diesel/src/pool.rs ================================================ use std::time::Duration; use diesel::r2d2::{ConnectionManager, PoolError}; use diesel::Connection; use r2d2::{Builder, PooledConnection}; use roa::{async_trait, Context, State, Status}; /// An alias for r2d2::Pool>. pub type Pool = r2d2::Pool>; /// An alias for r2d2::PooledConnection>. pub type WrapConnection = PooledConnection>; /// Create a connection pool. /// /// ### Example /// /// ``` /// use roa_diesel::{make_pool, Pool}; /// use diesel::sqlite::SqliteConnection; /// use std::error::Error; /// /// # fn main() -> Result<(), Box> { /// let pool: Pool = make_pool(":memory:")?; /// Ok(()) /// # } /// ``` pub fn make_pool(url: impl Into) -> Result, PoolError> where Conn: Connection + 'static, { r2d2::Pool::new(ConnectionManager::::new(url)) } /// Create a pool builder. pub fn builder() -> Builder> where Conn: Connection + 'static, { r2d2::Pool::builder() } /// A context extension to access r2d2 pool asynchronously. #[async_trait] pub trait AsyncPool where Conn: Connection + 'static, { /// Retrieves a connection from the pool. /// /// Waits for at most the configured connection timeout before returning an /// error. /// /// ``` /// use roa::{Context, Result}; /// use diesel::sqlite::SqliteConnection; /// use roa_diesel::preload::AsyncPool; /// use roa_diesel::Pool; /// use diesel::r2d2::ConnectionManager; /// /// #[derive(Clone)] /// struct State(Pool); /// /// impl AsRef> for State { /// fn as_ref(&self) -> &Pool { /// &self.0 /// } /// } /// /// async fn get(ctx: Context) -> Result { /// let conn = ctx.get_conn().await?; /// // handle conn /// Ok(()) /// } /// ``` async fn get_conn(&self) -> Result, Status>; /// Retrieves a connection from the pool, waiting for at most `timeout` /// /// The given timeout will be used instead of the configured connection /// timeout. async fn get_timeout(&self, timeout: Duration) -> Result, Status>; /// Returns information about the current state of the pool. async fn pool_state(&self) -> r2d2::State; } #[async_trait] impl AsyncPool for Context where S: State + AsRef>, Conn: Connection + 'static, { #[inline] async fn get_conn(&self) -> Result, Status> { let pool = self.as_ref().clone(); Ok(self.exec.spawn_blocking(move || pool.get()).await?) } #[inline] async fn get_timeout(&self, timeout: Duration) -> Result, Status> { let pool = self.as_ref().clone(); Ok(self .exec .spawn_blocking(move || pool.get_timeout(timeout)) .await?) } #[inline] async fn pool_state(&self) -> r2d2::State { let pool = self.as_ref().clone(); self.exec.spawn_blocking(move || pool.state()).await } } ================================================ FILE: roa-juniper/Cargo.toml ================================================ [package] name = "roa-juniper" version = "0.6.0" authors = ["Hexilee "] edition = "2018" readme = "./README.md" repository = "https://github.com/Hexilee/roa" documentation = "https://docs.rs/roa-juniper" homepage = "https://github.com/Hexilee/roa/wiki" description = "juniper integration for roa" keywords = ["http", "web", "framework", "async"] categories = ["network-programming", "asynchronous", "web-programming::http-server"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] roa = { path = "../roa", version = "0.6.0", default-features = false, features = ["json"] } futures = "0.3" juniper = { version = "0.15", default-features = false } ================================================ FILE: roa-juniper/README.md ================================================ [![Stable Test](https://github.com/Hexilee/roa/workflows/Stable%20Test/badge.svg)](https://github.com/Hexilee/roa/actions) [![codecov](https://codecov.io/gh/Hexilee/roa/branch/master/graph/badge.svg)](https://codecov.io/gh/Hexilee/roa) [![Rust Docs](https://docs.rs/roa-juniper/badge.svg)](https://docs.rs/roa-juniper) [![Crate version](https://img.shields.io/crates/v/roa-juniper.svg)](https://crates.io/crates/roa-juniper) [![Download](https://img.shields.io/crates/d/roa-juniper.svg)](https://crates.io/crates/roa-juniper) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/Hexilee/roa/blob/master/LICENSE) ## Roa-juniper This crate provides a juniper context and a graphql endpoint. ### Example Refer to [integration-example](https://github.com/Hexilee/roa/tree/master/integration/juniper-example). ================================================ FILE: roa-juniper/src/lib.rs ================================================ //! This crate provides a juniper context and a graphql endpoint. //! //! ### Example //! //! Refer to [integration-example](https://github.com/Hexilee/roa/tree/master/integration/juniper-example) #![warn(missing_docs)] use std::ops::{Deref, DerefMut}; use juniper::http::GraphQLRequest; use juniper::{GraphQLType, GraphQLTypeAsync, RootNode, ScalarValue}; use roa::preload::*; use roa::{async_trait, Context, Endpoint, Result, State}; /// A wrapper for `roa_core::SyncContext`. /// As an implementation of `juniper::Context`. pub struct JuniperContext(Context); impl juniper::Context for JuniperContext {} impl Deref for JuniperContext { type Target = Context; #[inline] fn deref(&self) -> &Self::Target { &self.0 } } impl DerefMut for JuniperContext { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } /// An endpoint. pub struct GraphQL( pub RootNode<'static, QueryT, MutationT, SubscriptionT, Sca>, ) where QueryT: GraphQLType, MutationT: GraphQLType, SubscriptionT: GraphQLType, Sca: ScalarValue; #[async_trait(?Send)] impl<'a, S, QueryT, MutationT, SubscriptionT, Sca> Endpoint<'a, S> for GraphQL where S: State, QueryT: GraphQLTypeAsync> + Send + Sync + 'static, QueryT::TypeInfo: Send + Sync, MutationT: GraphQLTypeAsync + Send + Sync + 'static, MutationT::TypeInfo: Send + Sync, SubscriptionT: GraphQLType + Send + Sync + 'static, SubscriptionT::TypeInfo: Send + Sync, Sca: ScalarValue + Send + Sync + 'static, { #[inline] async fn call(&'a self, ctx: &'a mut Context) -> Result { let request: GraphQLRequest = ctx.read_json().await?; let juniper_ctx = JuniperContext(ctx.clone()); let resp = request.execute(&self.0, &juniper_ctx).await; ctx.write_json(&resp) } } ================================================ FILE: rustfmt.toml ================================================ group_imports = "StdExternalCrate" imports_granularity = "Module" reorder_imports = true unstable_features = true ================================================ FILE: src/lib.rs ================================================ #[cfg(doctest)] doc_comment::doctest!("../README.md"); ================================================ FILE: templates/directory.html ================================================ {{title}}

{{root}}


{% for dir in dirs %} {% endfor %} {% for file in files %} {% endfor %}
Name Size Modified
{{dir.name}} - {{dir.modified}}
{{file.name}} {{file.size}} {{file.modified}}
================================================ FILE: tests/logger.rs ================================================ use std::sync::RwLock; use log::{Level, LevelFilter, Metadata, Record}; use once_cell::sync::Lazy; use roa::http::StatusCode; use roa::logger::logger; use roa::preload::*; use roa::{throw, App, Context}; use tokio::fs::File; use tokio::task::spawn; struct TestLogger { records: RwLock>, } impl log::Log for TestLogger { fn enabled(&self, metadata: &Metadata) -> bool { metadata.level() <= Level::Info } fn log(&self, record: &Record) { self.records .write() .unwrap() .push((record.level().to_string(), record.args().to_string())) } fn flush(&self) {} } static LOGGER: Lazy = Lazy::new(|| TestLogger { records: RwLock::new(Vec::new()), }); fn init() -> anyhow::Result<()> { log::set_logger(&*LOGGER) .map(|()| log::set_max_level(LevelFilter::Info)) .map_err(|err| anyhow::anyhow!("fail to init logger: {}", err)) } #[tokio::test] async fn log() -> anyhow::Result<()> { init()?; async fn bytes_info(ctx: &mut Context) -> roa::Result { ctx.resp.write("Hello, World."); Ok(()) } // bytes info let (addr, server) = App::new().gate(logger).end(bytes_info).run()?; spawn(server); let resp = reqwest::get(&format!("http://{}", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); assert_eq!("Hello, World.", resp.text().await?); let records = LOGGER.records.read().unwrap().clone(); assert_eq!(2, records.len()); assert_eq!("INFO", records[0].0); assert_eq!("--> GET /", records[0].1.trim_end()); assert_eq!("INFO", records[1].0); assert!(records[1].1.starts_with("<-- GET /")); assert!(records[1].1.contains("13 B")); assert!(records[1].1.trim_end().ends_with("200 OK")); // error async fn err(_ctx: &mut Context) -> roa::Result { throw!(StatusCode::BAD_REQUEST, "Hello, World!") } let (addr, server) = App::new().gate(logger).end(err).run()?; spawn(server); let resp = reqwest::get(&format!("http://{}", addr)).await?; assert_eq!(StatusCode::BAD_REQUEST, resp.status()); assert_eq!("Hello, World!", resp.text().await?); let records = LOGGER.records.read().unwrap().clone(); assert_eq!(4, records.len()); assert_eq!("INFO", records[2].0); assert_eq!("--> GET /", records[2].1.trim_end()); assert_eq!("ERROR", records[3].0); assert!(records[3].1.starts_with("<-- GET /")); assert!(records[3].1.contains(&StatusCode::BAD_REQUEST.to_string())); assert!(records[3].1.trim_end().ends_with("Hello, World!")); // stream info async fn stream_info(ctx: &mut Context) -> roa::Result { ctx.resp .write_reader(File::open("assets/welcome.html").await?); Ok(()) } // bytes info let (addr, server) = App::new().gate(logger).end(stream_info).run()?; spawn(server); let resp = reqwest::get(&format!("http://{}", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); assert_eq!(236, resp.text().await?.len()); let records = LOGGER.records.read().unwrap().clone(); assert_eq!(6, records.len()); assert_eq!("INFO", records[4].0); assert_eq!("--> GET /", records[4].1.trim_end()); assert_eq!("INFO", records[5].0); assert!(records[5].1.starts_with("<-- GET /")); assert!(records[5].1.contains("236 B")); assert!(records[5].1.trim_end().ends_with("200 OK")); Ok(()) } ================================================ FILE: tests/restful.rs ================================================ use std::collections::HashMap; use std::sync::Arc; use http::StatusCode; use multimap::MultiMap; use roa::preload::*; use roa::query::query_parser; use roa::router::{get, post, Router}; use roa::{throw, App, Context}; use serde::{Deserialize, Serialize}; use serde_json::json; use slab::Slab; use tokio::sync::RwLock; use tokio::task::spawn; #[derive(Debug, Clone, Deserialize, Serialize, Hash, Eq, PartialEq)] struct User { name: String, age: u8, favorite_fruit: String, } struct DB { main_table: Slab, name_index: MultiMap, } impl DB { fn new() -> Self { Self { main_table: Slab::new(), name_index: MultiMap::new(), } } fn add(&mut self, user: User) -> usize { let name = user.name.clone(); let id = self.main_table.insert(user); self.name_index.insert(name, id); id } fn delete_index(&mut self, name: &str, id: usize) { if let Some(id_set) = self.name_index.get_vec_mut(name) { let uids = id_set.clone(); for (index, uid) in uids.into_iter().enumerate() { if uid == id { id_set.remove(index); } } } } fn delete(&mut self, id: usize) -> Option { if !self.main_table.contains(id) { None } else { let user = self.main_table.remove(id); self.delete_index(&user.name, id); Some(user) } } fn get(&self, id: usize) -> Option<&User> { self.main_table.get(id) } fn get_by_name(&self, name: &str) -> Vec<(usize, &User)> { match self.name_index.get_vec(name) { None => Vec::new(), Some(ids) => ids .iter() .filter_map(|id| self.get(*id).map(|user| (*id, user))) .collect(), } } fn update(&mut self, id: usize, new_user: &mut User) -> bool { let new_name = new_user.name.clone(); let swapped = self .main_table .get_mut(id) .map(|user| std::mem::swap(user, new_user)) .is_some(); if swapped { self.delete_index(&new_user.name, id); self.name_index.insert(new_name, id); } swapped } } #[derive(Clone)] struct State(Arc>); impl State { fn new(db: DB) -> Self { Self(Arc::new(RwLock::new(db))) } async fn add(&mut self, user: User) -> usize { self.0.write().await.add(user) } async fn delete(&mut self, id: usize) -> Option { self.0.write().await.delete(id) } async fn get_user(&self, id: usize) -> Option { self.0.read().await.get(id).cloned() } async fn get_by_name(&self, name: &str) -> Vec<(usize, User)> { self.0 .read() .await .get_by_name(name) .into_iter() .map(|(id, user)| (id, user.clone())) .collect() } async fn get_all(&self) -> Vec<(usize, User)> { self.0 .read() .await .main_table .iter() .map(|(id, user)| (id, user.clone())) .collect() } async fn update(&mut self, id: usize, new_user: &mut User) -> bool { self.0.write().await.update(id, new_user) } } async fn create_user(ctx: &mut Context) -> roa::Result { let user = ctx.read_json().await?; let id = ctx.add(user).await; ctx.resp.status = StatusCode::CREATED; ctx.write_json(&json!({ "id": id })) } async fn query_user(ctx: &mut Context) -> roa::Result { let id = ctx.must_param("id")?.parse()?; match ctx.get_user(id).await { Some(user) => ctx.write_json(&user), None => throw!(StatusCode::NOT_FOUND, format!("id({}) not found", id)), } } async fn update_user(ctx: &mut Context) -> roa::Result { let id = ctx.must_param("id")?.parse()?; let mut user = ctx.read_json().await?; if ctx.update(id, &mut user).await { ctx.write_json(&user) } else { throw!(StatusCode::NOT_FOUND, format!("id({}) not found", id)) } } async fn delete_user(ctx: &mut Context) -> roa::Result { let id = ctx.must_param("id")?.parse()?; match ctx.delete(id).await { Some(user) => ctx.write_json(&user), None => throw!(StatusCode::NOT_FOUND, format!("id({}) not found", id)), } } fn crud_router() -> Router { Router::new() .on("/", post(create_user)) .on("/:id", get(query_user).put(update_user).delete(delete_user)) } #[tokio::test] async fn restful_crud() -> Result<(), Box> { let app = App::state(State::new(DB::new())).end(crud_router().routes("/user")?); let (addr, server) = app.run()?; spawn(server); // first get, 404 Not Found let resp = reqwest::get(&format!("http://{}/user/0", addr)).await?; assert_eq!(StatusCode::NOT_FOUND, resp.status()); let user = User { name: "Hexilee".to_string(), age: 20, favorite_fruit: "Apple".to_string(), }; // post let client = reqwest::Client::new(); let resp = client .post(&format!("http://{}/user", addr)) .json(&user) .send() .await?; assert_eq!(StatusCode::CREATED, resp.status()); let data: HashMap = serde_json::from_str(&resp.text().await?)?; assert_eq!(0, data["id"]); // get let resp = reqwest::get(&format!("http://{}/user/0", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); assert_eq!(&user, &resp.json().await?); // put let another = User { name: "Bob".to_string(), age: 120, favorite_fruit: "Lemon".to_string(), }; let resp = client .put(&format!("http://{}/user/0", addr)) .json(&another) .send() .await?; assert_eq!(StatusCode::OK, resp.status()); // return first user assert_eq!(&user, &resp.json().await?); // updated, get new user let resp = reqwest::get(&format!("http://{}/user/0", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); assert_eq!(&another, &resp.json().await?); // delete, get deleted user let resp = client .delete(&format!("http://{}/user/0", addr)) .send() .await?; assert_eq!(StatusCode::OK, resp.status()); assert_eq!(&another, &resp.json().await?); // delete again, 404 Not Found let resp = client .delete(&format!("http://{}/user/0", addr)) .send() .await?; assert_eq!(StatusCode::NOT_FOUND, resp.status()); // put again, 404 Not Found let resp = client .put(&format!("http://{}/user/0", addr)) .json(&another) .send() .await?; assert_eq!(StatusCode::NOT_FOUND, resp.status()); Ok(()) } async fn create_batch(ctx: &mut Context) -> roa::Result { let users: Vec = ctx.read_json().await?; let mut ids = Vec::new(); for user in users { ids.push(ctx.add(user).await) } ctx.resp.status = StatusCode::CREATED; ctx.write_json(&ids) } async fn query_batch(ctx: &mut Context) -> roa::Result { let users = match ctx.query("name") { Some(name) => ctx.get_by_name(&name).await, None => ctx.get_all().await, }; ctx.write_json(&users) } fn batch_router() -> Router { Router::new().on("/user", get(query_batch).post(create_batch)) } #[tokio::test] async fn batch() -> Result<(), Box> { let app = App::state(State::new(DB::new())) .gate(query_parser) .end(batch_router().routes("/")?); let (addr, server) = app.run()?; spawn(server); // first get, list empty let resp = reqwest::get(&format!("http://{}/user", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); let data: Vec<(usize, User)> = resp.json().await?; assert!(data.is_empty()); // post batch let client = reqwest::Client::new(); let users = vec![ User { name: "Hexilee".to_string(), age: 20, favorite_fruit: "Apple".to_string(), }, User { name: "Bob".to_string(), age: 120, favorite_fruit: "Lemon".to_string(), }, User { name: "Hexilee".to_string(), age: 40, favorite_fruit: "Orange".to_string(), }, ]; let resp = client .post(&format!("http://{}/user", addr)) .json(&users) .send() .await?; assert_eq!(StatusCode::CREATED, resp.status()); let ids: Vec = resp.json().await?; assert_eq!(vec![0, 1, 2], ids); // get all let resp = reqwest::get(&format!("http://{}/user", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); let data: Vec<(usize, User)> = resp.json().await?; assert_eq!(3, data.len()); for (index, user) in data.iter() { assert_eq!(&users[*index], user); } // get by name let resp = reqwest::get(&format!("http://{}/user?name=Alice", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); let data: Vec<(usize, User)> = resp.json().await?; assert!(data.is_empty()); let resp = reqwest::get(&format!("http://{}/user?name=Hexilee", addr)).await?; assert_eq!(StatusCode::OK, resp.status()); let data: Vec<(usize, User)> = resp.json().await?; assert_eq!(2, data.len()); assert_eq!(0, data[0].0); assert_eq!(&users[0], &data[0].1); assert_eq!(2, data[1].0); assert_eq!(&users[2], &data[1].1); Ok(()) } ================================================ FILE: tests/serve-file.rs ================================================ use http::header::ACCEPT_ENCODING; use roa::body::DispositionType; use roa::compress::Compress; use roa::preload::*; use roa::router::{get, Router}; use roa::{App, Context}; use tokio::fs::read_to_string; use tokio::task::spawn; #[tokio::test] async fn serve_static_file() -> Result<(), Box> { async fn test(ctx: &mut Context) -> roa::Result { ctx.write_file("assets/author.txt", DispositionType::Inline) .await } let app = App::new().end(get(test)); let (addr, server) = app.run()?; spawn(server); let resp = reqwest::get(&format!("http://{}", addr)).await?; assert_eq!("Hexilee", resp.text().await?); Ok(()) } #[tokio::test] async fn serve_router_variable() -> Result<(), Box> { async fn test(ctx: &mut Context) -> roa::Result { let filename = ctx.must_param("filename")?; ctx.write_file(format!("assets/{}", &*filename), DispositionType::Inline) .await } let router = Router::new().on("/:filename", get(test)); let app = App::new().end(router.routes("/")?); let (addr, server) = app.run()?; spawn(server); let resp = reqwest::get(&format!("http://{}/author.txt", addr)).await?; assert_eq!("Hexilee", resp.text().await?); Ok(()) } #[tokio::test] async fn serve_router_wildcard() -> Result<(), Box> { async fn test(ctx: &mut Context) -> roa::Result { let path = ctx.must_param("path")?; ctx.write_file(format!("./{}", &*path), DispositionType::Inline) .await } let router = Router::new().on("/*{path}", get(test)); let app = App::new().end(router.routes("/")?); let (addr, server) = app.run()?; spawn(server); let resp = reqwest::get(&format!("http://{}/assets/author.txt", addr)).await?; assert_eq!("Hexilee", resp.text().await?); Ok(()) } #[tokio::test] async fn serve_gzip() -> Result<(), Box> { async fn test(ctx: &mut Context) -> roa::Result { ctx.write_file("assets/welcome.html", DispositionType::Inline) .await } let app = App::new().gate(Compress::default()).end(get(test)); let (addr, server) = app.run()?; spawn(server); let client = reqwest::Client::builder().gzip(true).build()?; let resp = client .get(&format!("http://{}", addr)) .header(ACCEPT_ENCODING, "gzip") .send() .await?; assert_eq!( read_to_string("assets/welcome.html").await?, resp.text().await? ); Ok(()) }