Repository: xaynetwork/xaynet Branch: master Commit: 3289a3003288 Files: 247 Total size: 1.2 MB Directory structure: gitextract_ecf7igk4/ ├── .dockerignore ├── .github/ │ ├── codecov.yml │ ├── dependabot.yml │ └── workflows/ │ ├── dockercompose-validation.yml │ ├── dockerfile-validation.yml │ ├── dockerhub-cleanup.yml │ ├── dockerhub-master.yml │ ├── dockerhub-pr-with-parameters.yml │ ├── dockerhub-release.yml │ ├── kubernetes-manifests.yml │ ├── rust-audit-cron.yml │ ├── rust-next.yml │ └── rust.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── README.tpl ├── ROADMAP.md ├── bindings/ │ └── python/ │ ├── .gitignore │ ├── .isort.cfg │ ├── .pylintrc │ ├── Cargo.toml │ ├── README.md │ ├── examples/ │ │ ├── README.md │ │ ├── download_global_model.py │ │ ├── download_global_model_async.py │ │ ├── hello_world.py │ │ ├── hello_world_async.py │ │ ├── keras_house_prices/ │ │ │ ├── .gitignore │ │ │ ├── README.md │ │ │ ├── keras_house_prices/ │ │ │ │ ├── __init__.py │ │ │ │ ├── data_handlers/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── data_handler.py │ │ │ │ │ └── regression_data.py │ │ │ │ ├── participant.py │ │ │ │ └── regressor.py │ │ │ └── setup.py │ │ ├── multiple_participants.py │ │ ├── participate_in_update.py │ │ └── restore.py │ ├── migration_guide.md │ ├── src/ │ │ ├── lib.rs │ │ └── python_ffi.rs │ └── xaynet_sdk/ │ ├── __init__.py │ ├── async_participant.py │ └── participant.py ├── configs/ │ ├── config.toml │ └── docker-dev.toml ├── docker/ │ ├── .dev.env │ ├── Dockerfile │ └── docker-compose.yml ├── k8s/ │ └── coordinator/ │ ├── base/ │ │ ├── deployment.yaml │ │ ├── kustomization.yaml │ │ └── service.yaml │ └── development/ │ ├── cert-volume-mount.yaml │ ├── config-volume-mount.yaml │ ├── config.toml │ ├── history-limit.yaml │ ├── ingress.yaml │ └── kustomization.yaml ├── rust/ │ ├── .gitignore │ ├── Cargo.toml │ ├── benches/ │ │ ├── Cargo.toml │ │ ├── messages/ │ │ │ ├── sum.rs │ │ │ └── update.rs │ │ └── models/ │ │ ├── from_primitives.rs │ │ └── to_primitives.rs │ ├── examples/ │ │ ├── Cargo.toml │ │ └── test-drive/ │ │ ├── main.rs │ │ ├── participant.rs │ │ └── settings.rs │ ├── rustfmt.toml │ ├── xaynet/ │ │ ├── Cargo.toml │ │ └── src/ │ │ └── lib.rs │ ├── xaynet-analytics/ │ │ ├── Cargo.toml │ │ └── src/ │ │ ├── controller.rs │ │ ├── data_combination/ │ │ │ ├── data_combiner.rs │ │ │ ├── data_points/ │ │ │ │ ├── data_point.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── screen_active_time.rs │ │ │ │ ├── screen_enter_count.rs │ │ │ │ ├── was_active_each_past_period.rs │ │ │ │ └── was_active_past_n_days.rs │ │ │ └── mod.rs │ │ ├── database/ │ │ │ ├── analytics_event/ │ │ │ │ ├── adapter.rs │ │ │ │ ├── data_model.rs │ │ │ │ ├── mod.rs │ │ │ │ └── repo.rs │ │ │ ├── common.rs │ │ │ ├── controller_data/ │ │ │ │ ├── adapter.rs │ │ │ │ ├── data_model.rs │ │ │ │ ├── mod.rs │ │ │ │ └── repo.rs │ │ │ ├── isar.rs │ │ │ ├── mod.rs │ │ │ └── screen_route/ │ │ │ ├── adapter.rs │ │ │ ├── data_model.rs │ │ │ ├── mod.rs │ │ │ └── repo.rs │ │ ├── lib.rs │ │ └── sender.rs │ ├── xaynet-core/ │ │ ├── Cargo.toml │ │ └── src/ │ │ ├── common.rs │ │ ├── crypto/ │ │ │ ├── encrypt.rs │ │ │ ├── hash.rs │ │ │ ├── mod.rs │ │ │ ├── prng.rs │ │ │ └── sign.rs │ │ ├── lib.rs │ │ ├── mask/ │ │ │ ├── config/ │ │ │ │ ├── mod.rs │ │ │ │ └── serialization.rs │ │ │ ├── masking.rs │ │ │ ├── mod.rs │ │ │ ├── model.rs │ │ │ ├── object/ │ │ │ │ ├── mod.rs │ │ │ │ └── serialization/ │ │ │ │ ├── mod.rs │ │ │ │ ├── unit.rs │ │ │ │ └── vect.rs │ │ │ ├── scalar.rs │ │ │ └── seed.rs │ │ ├── message/ │ │ │ ├── message.rs │ │ │ ├── mod.rs │ │ │ ├── payload/ │ │ │ │ ├── chunk.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── sum.rs │ │ │ │ ├── sum2.rs │ │ │ │ └── update.rs │ │ │ ├── traits.rs │ │ │ └── utils/ │ │ │ ├── chunkable_iterator.rs │ │ │ └── mod.rs │ │ └── testutils/ │ │ ├── messages.rs │ │ ├── mod.rs │ │ └── multipart.rs │ ├── xaynet-mobile/ │ │ ├── .cargo/ │ │ │ └── config.toml │ │ ├── .gitignore │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── build.rs │ │ ├── cbindgen.toml │ │ ├── src/ │ │ │ ├── ffi/ │ │ │ │ ├── config.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── participant.rs │ │ │ │ └── settings.rs │ │ │ ├── lib.rs │ │ │ ├── participant.rs │ │ │ ├── reqwest_client.rs │ │ │ └── settings.rs │ │ ├── tests/ │ │ │ ├── ffi_test.c │ │ │ └── minunit.h │ │ └── xaynet_ffi.h │ ├── xaynet-sdk/ │ │ ├── Cargo.toml │ │ └── src/ │ │ ├── client.rs │ │ ├── lib.rs │ │ ├── message_encoder/ │ │ │ ├── chunker.rs │ │ │ ├── encoder.rs │ │ │ └── mod.rs │ │ ├── settings/ │ │ │ ├── max_message_size.rs │ │ │ └── mod.rs │ │ ├── state_machine/ │ │ │ ├── io.rs │ │ │ ├── mod.rs │ │ │ ├── phase.rs │ │ │ ├── phases/ │ │ │ │ ├── awaiting.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── new_round.rs │ │ │ │ ├── sending.rs │ │ │ │ ├── sum.rs │ │ │ │ ├── sum2.rs │ │ │ │ └── update.rs │ │ │ ├── state_machine.rs │ │ │ └── tests/ │ │ │ ├── mod.rs │ │ │ ├── phases/ │ │ │ │ ├── mod.rs │ │ │ │ ├── new_round.rs │ │ │ │ ├── sum.rs │ │ │ │ ├── sum2.rs │ │ │ │ └── update.rs │ │ │ └── utils.rs │ │ ├── traits.rs │ │ └── utils/ │ │ ├── concurrent_futures.rs │ │ └── mod.rs │ └── xaynet-server/ │ ├── Cargo.toml │ └── src/ │ ├── bin/ │ │ └── main.rs │ ├── examples.rs │ ├── lib.rs │ ├── metrics/ │ │ ├── mod.rs │ │ └── recorders/ │ │ ├── influxdb/ │ │ │ ├── dispatcher.rs │ │ │ ├── mod.rs │ │ │ ├── models.rs │ │ │ ├── recorder.rs │ │ │ └── service.rs │ │ └── mod.rs │ ├── rest.rs │ ├── services/ │ │ ├── fetchers/ │ │ │ ├── mod.rs │ │ │ ├── model.rs │ │ │ ├── round_parameters.rs │ │ │ ├── seed_dict.rs │ │ │ └── sum_dict.rs │ │ ├── messages/ │ │ │ ├── decryptor.rs │ │ │ ├── error.rs │ │ │ ├── message_parser.rs │ │ │ ├── mod.rs │ │ │ ├── multipart/ │ │ │ │ ├── buffer.rs │ │ │ │ ├── mod.rs │ │ │ │ └── service.rs │ │ │ ├── state_machine.rs │ │ │ └── task_validator.rs │ │ ├── mod.rs │ │ └── tests/ │ │ ├── fetchers.rs │ │ ├── mod.rs │ │ └── utils.rs │ ├── settings/ │ │ ├── mod.rs │ │ └── s3.rs │ ├── state_machine/ │ │ ├── coordinator.rs │ │ ├── events.rs │ │ ├── initializer.rs │ │ ├── mod.rs │ │ ├── phases/ │ │ │ ├── failure.rs │ │ │ ├── handler.rs │ │ │ ├── idle.rs │ │ │ ├── mod.rs │ │ │ ├── phase.rs │ │ │ ├── shutdown.rs │ │ │ ├── sum.rs │ │ │ ├── sum2.rs │ │ │ ├── unmask.rs │ │ │ └── update.rs │ │ ├── requests.rs │ │ └── tests/ │ │ ├── coordinator_state.rs │ │ ├── event_bus.rs │ │ ├── impls.rs │ │ ├── initializer.rs │ │ ├── mod.rs │ │ └── utils.rs │ └── storage/ │ ├── coordinator_storage/ │ │ ├── mod.rs │ │ └── redis/ │ │ ├── impls.rs │ │ └── mod.rs │ ├── mod.rs │ ├── model_storage/ │ │ ├── mod.rs │ │ ├── noop.rs │ │ └── s3.rs │ ├── store.rs │ ├── tests/ │ │ ├── mod.rs │ │ └── utils.rs │ ├── traits.rs │ └── trust_anchor/ │ ├── mod.rs │ └── noop.rs └── scripts/ └── bump_version.sh ================================================ FILE CONTENTS ================================================ ================================================ FILE: .dockerignore ================================================ **/.ignore **/shell.nix **/.envrc .git .github assets bindings configs docker k8s rust/target scripts ================================================ FILE: .github/codecov.yml ================================================ coverage: status: patch: off ================================================ FILE: .github/dependabot.yml ================================================ version: 2 updates: - package-ecosystem: cargo directory: "/rust" schedule: interval: daily time: "09:00" timezone: "Europe/Berlin" - package-ecosystem: cargo directory: "/bindings/python" schedule: interval: weekly day: "monday" - package-ecosystem: pip directory: "/bindings/python/examples/keras_house_prices" schedule: interval: weekly day: "monday" - package-ecosystem: "github-actions" directory: "/" schedule: interval: "weekly" day: "monday" ================================================ FILE: .github/workflows/dockercompose-validation.yml ================================================ name: docker-compose validation on: push: paths: - 'docker/docker-compose*yml' jobs: check-docker-compose: name: docker-compose validation runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v3 - name: Verify docker-compose working-directory: ./docker run: docker-compose -f docker-compose.yml config -q ================================================ FILE: .github/workflows/dockerfile-validation.yml ================================================ name: Dockerfiles linting on: push: paths: - 'docker/Dockerfile**' jobs: lint: name: Dockerfiles linting runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v3 - name: Lint file run: docker run -v $GITHUB_WORKSPACE/docker/Dockerfile:/Dockerfile replicated/dockerfilelint /Dockerfile ================================================ FILE: .github/workflows/dockerhub-cleanup.yml ================================================ name: DockerHub Scheduled Cleanup on: schedule: - cron: '00 00 * * sun' workflow_dispatch: jobs: dockerhub-cleanup-inactive: name: Cleanup inactive xaynet tags on Dockerhub runs-on: ubuntu-latest steps: - name: Setup hub-tool env: DHUSER: ${{ secrets.DOCKER_USERNAME }} DHTOKEN: ${{ secrets.DOCKER_PASSWORD }} run: | export DEBIAN_FRONTEND="noninteractive" sudo apt update sudo apt install -y jq LATEST=$(curl -s "https://api.github.com/repos/docker/hub-tool/releases/latest" | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/') wget https://github.com/docker/hub-tool/releases/download/${LATEST}/hub-tool-linux-amd64.tar.gz -O /tmp/hub-tool-linux-amd64.tar.gz tar xzvf /tmp/hub-tool-linux-amd64.tar.gz --strip-components 1 -C /tmp hub-tool/hub-tool mkdir -pv -m 700 ~/.docker chmod -v 600 ~/.docker/config.json echo -ne "ewogICJ1c2VybmFtZSI6ICJESFVTRVIiLAogICJwYXNzd29yZCI6ICJESFRPS0VOIgp9Cg==" | base64 -d > /tmp/auth.json echo -ne "ewogICJhdXRocyI6IHsKICAgICJodWItdG9vbCI6IHsKICAgICAgImF1dGgiOiAiREhVU0VSVE9LRU4iCiAgICB9LAogICAgImh1Yi10b29sLXJlZnJlc2gtdG9rZW4iOiB7CiAgICAgICJhdXRoIjogIkRIVVNFUiIKICAgIH0sCiAgICAiaHViLXRvb2wtdG9rZW4iOiB7CiAgICAgICJhdXRoIjogIkRIVVNFUiIsCiAgICAgICJpZGVudGl0eXRva2VuIjogIkpXVFRPS0VOIgogICAgfQogIH0KfQoK" | base64 -d > ~/.docker/config.json RUSERTOKEN=$(echo -ne "${DHUSER}:${DHTOKEN}" | base64 -w0) RUSER=$(echo -ne "${DHUSER}:" | base64 -w0) RTOKEN=$(echo -ne "${DHTOKEN}" | base64 -w0) sed -i -e "s,DHUSERTOKEN,${RUSERTOKEN},g" -e "s,DHUSER,${RUSER},g" -e "s,DHTOKEN,${RTOKEN},g" /tmp/auth.json ~/.docker/config.json JWT=$(curl -s -XPOST "https://hub.docker.com/v2/users/login" -H "Content-Type:application/json" -d "@/tmp/auth.json" | jq -r .token) sed -i -e "s,JWTTOKEN,${JWT},g" ~/.docker/config.json - name: Delete target tags run: | echo -e "Inactive tags:" /tmp/hub-tool tag ls xaynetwork/xaynet | grep -e STATUS -e inactive TAGS=$(/tmp/hub-tool tag ls xaynetwork/xaynet | grep inactive | grep -v -e "v[0-9]\+\.[0-9]\+\.[0-9]\+" | awk '{ print $1 }') if [[ ! -z ${TAGS} ]] then echo -e "\n\n" for tag in ${TAGS} do /tmp/hub-tool tag rm -f ${tag} done fi ================================================ FILE: .github/workflows/dockerhub-master.yml ================================================ name: DockerHub (master) on: push: branches: - master jobs: build-tag-push-master: name: build-tag-push-master runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Login to DockerHub uses: docker/login-action@v2 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - name: Set up Docker Buildx uses: docker/setup-buildx-action@v2 - name: build-tag-push uses: docker/build-push-action@v3 id: docker with: context: . file: docker/Dockerfile tags: xaynetwork/xaynet:development push: true build-args: COORDINATOR_FEATURES=metrics - name: Notify on Slack uses: 8398a7/action-slack@v3 if: always() with: status: custom fields: workflow,job,repo,ref custom_payload: | { username: 'GitHub Actions', icon_emoji: ':octocat:', attachments: [{ color: '${{ steps.docker.outcome }}' === 'success' ? 'good' : '${{ steps.docker.outcome }}' === 'failure' ? 'danger' : 'warning', text: `${process.env.AS_WORKFLOW}\nRepository: :xaynet: ${process.env.AS_REPO}\nRef: ${process.env.AS_REF}\nTags: development`, }] } env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} ================================================ FILE: .github/workflows/dockerhub-pr-with-parameters.yml ================================================ name: DockerHub (PR) with parameters on: issue_comment: types: [created] jobs: check_comments: name: Check comments for /deploy runs-on: ubuntu-latest steps: - name: Check for Command id: command uses: xt0rted/slash-command-action@v1 with: repo-token: ${{ secrets.GITHUB_TOKEN }} command: deploy reaction: "true" reaction-type: "eyes" allow-edits: "false" permission-level: write - uses: jungwinter/split@v2 id: split with: msg: '${{ steps.command.outputs.command-arguments }}' maxsplit: 1 - uses: xt0rted/pull-request-comment-branch@v1 id: comment-branch with: repo_token: ${{ secrets.GITHUB_TOKEN }} - uses: actions/checkout@v3 if: success() with: ref: ${{ steps.comment-branch.outputs.head_ref }} - name: Find and Replace uses: jacobtomlinson/gha-find-replace@master with: find: "newTag: development" replace: "newTag: ${{ steps.comment-branch.outputs.head_ref }}" include: "kustomization.yaml" - name: Login to DockerHub uses: docker/login-action@v2 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - name: Set up Docker Buildx uses: docker/setup-buildx-action@v2 - name: build-tag-push uses: docker/build-push-action@v3 id: docker with: context: . file: docker/Dockerfile tags: xaynetwork/xaynet:${{ steps.comment-branch.outputs.head_ref }} push: true build-args: | ${{ steps.split.outputs._0 }} ${{ steps.split.outputs._1 }} - name: Notify on Slack uses: 8398a7/action-slack@v3 if: ${{ success() }} with: status: custom fields: workflow,job,repo,ref custom_payload: | { username: 'GitHub Actions', icon_emoji: ':octocat:', attachments: [{ color: '${{ steps.docker.outcome }}' === 'success' ? 'good' : '${{ steps.docker.outcome }}' === 'failure' ? 'danger' : 'warning', text: `${process.env.AS_WORKFLOW}\nRepository: :xaynet: ${process.env.AS_REPO}\nRef: ${process.env.AS_REF}\nTags: ${{ steps.comment-branch.outputs.head_ref }}`, }] } env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} ================================================ FILE: .github/workflows/dockerhub-release.yml ================================================ name: DockerHub (Release) on: push: tags: - v[0-9]+.[0-9]+.[0-9]+ jobs: build-tag-push-release: name: build-tag-push-release runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Login to DockerHub uses: docker/login-action@v2 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - name: Set up Docker Buildx uses: docker/setup-buildx-action@v2 - name: build-tag-push uses: docker/build-push-action@v3 id: docker with: context: . file: docker/Dockerfile tags: xaynetwork/xaynet:latest push: true build-args: RELEASE_BUILD=1 - name: Notify on Slack uses: 8398a7/action-slack@v3 if: always() with: status: custom fields: workflow,job,repo,ref custom_payload: | { username: 'GitHub Actions', icon_emoji: ':octocat:', attachments: [{ color: '${{ steps.docker.outcome }}' === 'success' ? 'good' : '${{ steps.docker.outcome }}' === 'failure' ? 'danger' : 'warning', text: `${process.env.AS_WORKFLOW}\nRepository: :xaynet: ${process.env.AS_REPO}\nRef: ${process.env.AS_REF} :heavy_check_mark:`, }] } env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} ================================================ FILE: .github/workflows/kubernetes-manifests.yml ================================================ name: Kubernetes manifests validation on: push: paths: - 'k8s/**' jobs: k8s-kustomize-validation: name: Kubernetes manifests validation runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v3 - name: Verify Kubernetes manifests run: kubectl kustomize $GITHUB_WORKSPACE/k8s/coordinator/development > /dev/null # Print only errors, if any ================================================ FILE: .github/workflows/rust-audit-cron.yml ================================================ name: Rust Audit for Security Vulnerabilities (master) on: schedule: - cron: '00 08 * * mon-fri' jobs: audit: name: Rust Audit runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v3 with: ref: master - name: Run rust-audit id: rust-audit run: | cargo audit --deny-warnings -f rust/Cargo.lock - name: Notify on Slack uses: 8398a7/action-slack@v3 if: ${{ failure() }} with: status: custom fields: workflow,job,repo custom_payload: | { username: 'GitHub Actions', icon_emoji: ':octocat:', attachments: [{ color: '${{ steps.rust-audit.outcome }}' === 'success' ? 'good' : '${{ steps.rust-audit.outcome }}' === 'failure' ? 'danger' : 'warning', text: `${process.env.AS_WORKFLOW}\nRepository: ${process.env.AS_REPO}\nRef: master :warning:`, }] } env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} ================================================ FILE: .github/workflows/rust-next.yml ================================================ name: Rust-CI Next on: schedule: - cron: '00 04 10,20 * *' jobs: registry-cache: name: cargo-fetch timeout-minutes: 5 runs-on: ubuntu-latest outputs: cache-key: ${{ steps.cache-key.outputs.key }} cache-date: ${{ steps.get-date.outputs.date }} steps: - name: Checkout repository uses: actions/checkout@v3 - name: Install stable toolchain id: rust-toolchain uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: stable default: true # We want to create a new cache after a week. Otherwise, the cache will # take up too much space by caching old dependencies - name: Year + ISO week number id: get-date run: echo "::set-output name=date::$(/bin/date -u "+%Y-%V")" shell: bash # We can use the registry cache of the normal rust ci - name: Cache key id: cache-key run: echo "::set-output name=key::$(echo ${{ runner.os }}-cargo-registry-${{ steps.get-date.outputs.date }}-${{ hashFiles('**/Cargo.lock') }})" shell: bash - name: Cache cargo registry uses: actions/cache@v3.0.8 with: path: | ~/.cargo/registry ~/.cargo/git key: ${{ steps.cache-key.outputs.key }} restore-keys: ${{ runner.os }}-cargo-registry-${{ steps.get-date.outputs.date }}- - name: cargo fetch working-directory: ./rust run: cargo fetch format: name: cargo-fmt needs: registry-cache timeout-minutes: 10 runs-on: ubuntu-latest strategy: fail-fast: false matrix: cargo_manifest: [rust, bindings/python] steps: - name: Checkout repository uses: actions/checkout@v3 - name: Install nightly toolchain id: rust-toolchain uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: nightly components: rustfmt default: true - name: Use cached cargo registry uses: actions/cache@v3.0.8 with: path: | ~/.cargo/registry ~/.cargo/git key: ${{ needs.registry-cache.outputs.cache-key }} # cargo fmt does not create any artifacts, therefore we don't need to cache the target folder - name: cargo fmt working-directory: ${{ matrix.cargo_manifest }} run: cargo fmt --all -- --check check: name: cargo-check needs: registry-cache timeout-minutes: 20 runs-on: ubuntu-latest strategy: fail-fast: false matrix: rust_version: [stable, beta] cargo_manifest: [rust, bindings/python] steps: - name: Checkout repository uses: actions/checkout@v3 - name: Install toolchain id: rust-toolchain uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: ${{ matrix.rust_version }} default: true - name: Use cached cargo registry uses: actions/cache@v3.0.8 with: path: | ~/.cargo/registry ~/.cargo/git key: ${{ needs.registry-cache.outputs.cache-key }} - name: cargo check working-directory: ${{ matrix.cargo_manifest }} env: RUSTFLAGS: "-D warnings" run: | cargo check --all-targets cargo check --all-targets --all-features clippy: name: cargo-clippy needs: [registry-cache, check] timeout-minutes: 20 runs-on: ubuntu-latest strategy: fail-fast: false matrix: rust_version: [stable, beta] steps: - name: Checkout repository uses: actions/checkout@v3 - name: Install toolchain id: rust-toolchain uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: ${{ matrix.rust_version }} default: true components: clippy - name: Use cached cargo registry uses: actions/cache@v3.0.8 with: path: | ~/.cargo/registry ~/.cargo/git key: ${{ needs.registry-cache.outputs.cache-key }} - name: cargo clippy working-directory: rust run: | cargo clippy --all-targets -- --deny warnings --deny clippy::cargo cargo clippy --all-targets --all-features -- --deny warnings --deny clippy::cargo docs: name: cargo-doc needs: [registry-cache, check] timeout-minutes: 20 runs-on: ubuntu-latest strategy: fail-fast: false matrix: rust_version: [stable, beta] steps: - name: Checkout repository uses: actions/checkout@v3 - name: Install toolchain id: rust-toolchain uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: ${{ matrix.rust_version }} default: true - name: Use cached cargo registry uses: actions/cache@v3.0.8 with: path: | ~/.cargo/registry ~/.cargo/git key: ${{ needs.registry-cache.outputs.cache-key }} - name: Check the building of docs working-directory: ./rust run: cargo doc --all-features --document-private-items --no-deps --color always notify: name: notify if: failure() needs: [format, check, clippy, docs] timeout-minutes: 20 runs-on: ubuntu-latest steps: - name: Notify on Slack uses: 8398a7/action-slack@v3 with: status: custom fields: workflow,repo custom_payload: | { username: 'GitHub Actions', icon_emoji: ':octocat:', attachments: [{ color: 'danger', text: `${process.env.AS_WORKFLOW} :warning:\nRepository: ${process.env.AS_REPO}`, }] } env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} ================================================ FILE: .github/workflows/rust.yml ================================================ name: Rust-CI on: push: paths: - 'rust/**' - 'bindings/python/**' - '.github/workflows/rust.yml' - 'README.md' - 'README.tpl' env: RUST_STABLE: 1.55.0 RUST_NIGHTLY: nightly-2021-09-09 jobs: registry-cache: name: cargo-fetch timeout-minutes: 5 runs-on: ubuntu-latest outputs: cache-key: ${{ steps.cache-key.outputs.key }} cache-date: ${{ steps.get-date.outputs.date }} steps: - name: Checkout repository uses: actions/checkout@v3 - name: Install stable toolchain id: rust-toolchain uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: ${{ env.RUST_STABLE }} default: true # We want to create a new cache after a week. Otherwise, the cache will # take up too much space by caching old dependencies - name: Year + ISO week number id: get-date run: echo "::set-output name=date::$(/bin/date -u "+%Y-%V")" shell: bash - name: Cache key id: cache-key run: echo "::set-output name=key::$(echo ${{ runner.os }}-cargo-registry-${{ steps.get-date.outputs.date }}-${{ hashFiles('**/Cargo.lock') }})" shell: bash - name: Cache cargo registry uses: actions/cache@v3.0.8 with: path: | ~/.cargo/registry ~/.cargo/git key: ${{ steps.cache-key.outputs.key }} restore-keys: ${{ runner.os }}-cargo-registry-${{ steps.get-date.outputs.date }}- - name: cargo fetch working-directory: ./rust run: cargo fetch format: name: cargo-fmt needs: registry-cache timeout-minutes: 10 runs-on: ubuntu-latest strategy: fail-fast: false matrix: cargo_manifest: [rust, bindings/python] steps: - name: Checkout repository uses: actions/checkout@v3 - name: Install nightly toolchain id: rust-toolchain uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: ${{ env.RUST_NIGHTLY }} components: rustfmt default: true - name: Use cached cargo registry uses: actions/cache@v3.0.8 with: path: | ~/.cargo/registry ~/.cargo/git key: ${{ needs.registry-cache.outputs.cache-key }} # cargo fmt does not create any artifacts, therefore we don't need to cache the target folder - name: cargo fmt working-directory: ${{ matrix.cargo_manifest }} run: cargo fmt --all -- --check check: name: cargo-check needs: registry-cache timeout-minutes: 20 runs-on: ubuntu-latest strategy: fail-fast: false matrix: cargo_manifest: [rust, bindings/python] steps: - name: Checkout repository uses: actions/checkout@v3 - name: Install stable toolchain id: rust-toolchain uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: ${{ env.RUST_STABLE }} default: true - name: Use cached cargo registry uses: actions/cache@v3.0.8 with: path: | ~/.cargo/registry ~/.cargo/git key: ${{ needs.registry-cache.outputs.cache-key }} - name: Cache build artifacts uses: actions/cache@v3.0.8 with: path: ${{ matrix.cargo_manifest }}/target key: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-check-${{ matrix.cargo_manifest }}-${{ needs.registry-cache.outputs.cache-date }}-${{ hashFiles('**/Cargo.lock') }} restore-keys: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-${{ matrix.cargo_manifest }}-check-${{ needs.registry-cache.outputs.cache-date }}- - name: cargo check working-directory: ${{ matrix.cargo_manifest }} env: RUSTFLAGS: "-D warnings" run: | cargo check --all-targets cargo check --all-targets --all-features clippy: name: cargo-clippy needs: [registry-cache, check] timeout-minutes: 20 runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v3 - name: Install stable toolchain id: rust-toolchain uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: ${{ env.RUST_STABLE }} default: true components: clippy - name: Use cached cargo registry uses: actions/cache@v3.0.8 with: path: | ~/.cargo/registry ~/.cargo/git key: ${{ needs.registry-cache.outputs.cache-key }} - name: Cache build artifacts uses: actions/cache@v3.0.8 with: path: ${{ github.workspace }}/rust/target key: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-clippy-${{ needs.registry-cache.outputs.cache-date }}-${{ hashFiles('**/Cargo.lock') }} restore-keys: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-clippy-${{ needs.registry-cache.outputs.cache-date }}- - name: cargo clippy working-directory: rust run: | cargo clippy --all-targets -- --deny warnings --deny clippy::cargo cargo clippy --all-targets --all-features -- --deny warnings --deny clippy::cargo test: name: cargo-test needs: [registry-cache, check] timeout-minutes: 20 runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v3 - name: Install stable toolchain id: rust-toolchain uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: ${{ env.RUST_STABLE }} default: true - name: Use cached cargo registry uses: actions/cache@v3.0.8 with: path: | ~/.cargo/registry ~/.cargo/git key: ${{ needs.registry-cache.outputs.cache-key }} - name: Cache build artifacts uses: actions/cache@v3.0.8 with: path: ${{ github.workspace }}/rust/target key: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-tests-${{ needs.registry-cache.outputs.cache-date }}-${{ hashFiles('**/Cargo.lock') }} restore-keys: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-tests-${{ needs.registry-cache.outputs.cache-date }}- - name: Start docker-compose working-directory: ./docker run: docker-compose up -d influxdb minio redis - name: Run tests (unit & integration & doc) working-directory: ./rust env: RUSTFLAGS: "-D warnings" run: | cargo test --lib --bins --examples --tests -- -Z unstable-options --include-ignored cargo test --lib --bins --examples --tests --all-features -- -Z unstable-options --include-ignored cargo test --doc --all-features - name: Stop docker-compose working-directory: ./docker run: docker-compose down bench: name: cargo-bench needs: [registry-cache, check] timeout-minutes: 20 runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v3 - name: Install stable toolchain id: rust-toolchain uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: ${{ env.RUST_STABLE }} default: true - name: Use cached cargo registry uses: actions/cache@v3.0.8 with: path: | ~/.cargo/registry ~/.cargo/git key: ${{ needs.registry-cache.outputs.cache-key }} - name: Cache build artifacts uses: actions/cache@v3.0.8 with: path: ${{ github.workspace }}/rust/target key: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-bench-${{ needs.registry-cache.outputs.cache-date }}-${{ hashFiles('**/Cargo.lock') }} restore-keys: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-bench-${{ needs.registry-cache.outputs.cache-date }}- - name: Run Bench working-directory: ./rust/benches run: cargo bench - name: Upload bench artifacts uses: actions/upload-artifact@v3 with: name: bench_${{ github.sha }} path: ${{ github.workspace }}/rust/benches/target/criterion docs: name: cargo-doc needs: [registry-cache, check] timeout-minutes: 20 runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v3 - name: Install stable toolchain id: rust-toolchain uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: ${{ env.RUST_STABLE }} default: true - name: Use cached cargo registry uses: actions/cache@v3.0.8 with: path: | ~/.cargo/registry ~/.cargo/git key: ${{ needs.registry-cache.outputs.cache-key }} - name: Cache build artifacts uses: actions/cache@v3.0.8 with: path: ${{ github.workspace }}/rust/target key: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-doc-${{ needs.registry-cache.outputs.cache-date }}-${{ hashFiles('**/Cargo.lock') }} restore-keys: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-doc-${{ needs.registry-cache.outputs.cache-date }}- - name: Check the building of docs working-directory: ./rust run: cargo doc --all-features --document-private-items --no-deps --color always coverage: name: cargo-tarpaulin needs: [registry-cache, check] timeout-minutes: 20 runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v3 - name: Install stable toolchain id: rust-toolchain uses: actions-rs/toolchain@v1 with: toolchain: ${{ env.RUST_STABLE }} default: true profile: minimal - name: Use cached cargo registry uses: actions/cache@v3.0.8 with: path: | ~/.cargo/registry ~/.cargo/git key: ${{ needs.registry-cache.outputs.cache-key }} - name: Cache build artifacts uses: actions/cache@v3.0.8 with: path: ${{ github.workspace }}/rust/target key: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-tarpaulin-${{ needs.registry-cache.outputs.cache-date }}-${{ hashFiles('**/Cargo.lock') }} restore-keys: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-tarpaulin-${{ needs.registry-cache.outputs.cache-date }}- - name: Start docker-compose working-directory: ./docker run: docker-compose up -d influxdb minio redis - name: Run cargo-tarpaulin uses: actions-rs/tarpaulin@v0.1 with: version: '0.16.0' args: '--manifest-path rust/Cargo.toml --all-features --force-clean --lib --ignore-tests --ignored --workspace --exclude xaynet-analytics' - name: Stop docker-compose working-directory: ./docker run: docker-compose down - name: Upload to codecov.io uses: codecov/codecov-action@v3.1.0 with: token: ${{ secrets.CODECOV_TOKEN }} python_sdk: name: python sdk needs: [registry-cache, format, check] timeout-minutes: 20 runs-on: ubuntu-latest env: working-directory: ./bindings/python steps: - name: Checkout repository uses: actions/checkout@v3 - name: Install Rust id: rust-toolchain uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: ${{ env.RUST_STABLE }} default: true - name: Cache cargo registry uses: actions/cache@v3.0.8 with: path: | ~/.cargo/registry ~/.cargo/git key: ${{ needs.registry-cache.outputs.cache-key }} - name: Cache cargo target uses: actions/cache@v3.0.8 with: path: ${{ env.working-directory }}/target key: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-python-bindings-${{ needs.registry-cache.outputs.cache-date }}-${{ hashFiles('**/Cargo.lock') }} restore-keys: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-python-bindings-${{ needs.registry-cache.outputs.cache-date }}- - name: Setup Python 3.6 uses: actions/setup-python@v4 with: python-version: 3.6 architecture: "x64" - name: Get pip cache dir id: pip-cache run: echo "::set-output name=dir::$(pip cache dir)" - name: Cache pip packages uses: actions/cache@v3.0.8 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-${{ hashFiles('./bindings/python/setup.py') }} - name: Install dependencies and build sdk run: | pip install --upgrade pip pip install --upgrade setuptools pip install maturin==0.9.1 black==20.8b1 isort==5.7.0 maturin build working-directory: ${{ env.working-directory }} - name: black working-directory: ${{ env.working-directory }} run: black --check . - name: isort working-directory: ${{ env.working-directory }} run: isort --check-only --diff . readme: name: cargo-readme timeout-minutes: 20 runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v3 - name: Install stable toolchain id: rust-toolchain uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: ${{ env.RUST_STABLE }} default: true - name: Cache cargo readme uses: actions/cache@v3.0.8 with: path: | ~/.cargo/registry ~/.cargo/git ~/.cargo/bin/cargo-readme key: ${{ runner.os }}-cargo-readme-bin - name: Install cargo readme run: cargo install cargo-readme || true - name: Check that readme matches docs working-directory: ./ run: | cargo readme --project-root rust/xaynet/ --template ../../README.tpl --output ../../CARGO_README.md git diff --exit-code --no-index README.md CARGO_README.md ================================================ FILE: .gitignore ================================================ **/.ignore/ # https://github.com/github/gitignore/blob/master/Global/macOS.gitignore # General **.DS_Store **.swp # vscode workspace settings .vscode/* CARGO_README.md ================================================ FILE: CHANGELOG.md ================================================ # Changelog All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to the [Semantic Versioning](http://semver.org/spec/v2.0.0.html). ## [unreleased] ### Changed #### `xaynet-sdk` - Update to `tokio` `v1.x` - Update to `reqwest` `v0.11.x` - Update to `bytes` `v1.x` #### `xaynet-mobile` - Update to `tokio` `v1.x` - Update to `reqwest` `v0.11.x` #### `examples` - Update to `tokio` `v1.x` - Update to `reqwest` `v0.11.x` #### `xaynet-server` - Update to `tokio` `v1.x` - Update to `warp` `v0.3.x` - Update to `bytes` `v1.x` - Update to `rusoto_core` `v0.46.x` - Update to `rusoto_s3` `v0.46.x` - Update to `tower` `v0.4.x` - Update to `redis` `v0.19.x` - Enable optional server side client authentication via tls - Environment variable prefixes respect the `__` separator now, i.e. all envs have changed from `XAYNET_*` to `XAYNET__*`. ## [0.11.0] - 2021-01-18 ### Added #### Rust SDK `xaynet-sdk` `xaynet-sdk` contains the basic building blocks required to run the _Privacy-Enhancing Technology_ (PET) Protocol. It consists of a state machine and two I/O interfaces with which specific Xaynet participants can be developed that are adapted to the respective environments/requirements. If you are interested in building your own Xaynet participant, you can take a look at `xaynet-sdk`, our [Rust participant](https://github.com/xaynetwork/xaynet/blob/master/rust/examples/test-drive/participant.rs) which we use primarily for testing or at [`xaynet-mobile`](https://github.com/xaynetwork/xaynet/blob/master/rust/xaynet-mobile/src/participant.rs) our mobile friendly participant. #### A Mobile friendly Xaynet participant `xaynet-mobile` `xaynet-mobile` provides a mobile friendly implementation of a Xaynet participant. It gives the user a lot of control on how to drive the participant execution. You can regularly pause the execution of the participant, save it, and later restore it and continue the execution. When running on a device that is low on battery or does not have access to Wi-Fi for instance, it can be useful to be able to pause the participant. **C API** Furthermore, `xaynet-mobile` offers `C` bindings that allow `xaynet-mobile` to be used in other programming languages ​​such as `Dart`. #### Python participant SDK `xaynet-sdk-python` We are happy to announce that we finally released `xaynet-sdk-python` a Python SDK that consists of two experimental Xaynet participants (`ParticipantABC` and `AsyncParticipant`). The `ParticipantABC` API is similar to the old one which we introduced in `v0.8.0`. Aside from some changes to the method signature, the biggest change is that the participant now runs in its own thread. To migrate from `v0.8.0` to `v0.11.0` please follow the [migration guide](https://github.com/xaynetwork/xaynet/blob/master/bindings/python/migration_guide.md). However, we noticed that our Participant API may be difficult to integrate with existing applications, considering the code for the training has to be moved into the `train_round` method, which can lead to significant changes to the existing code. Therefore, we offer a second API (`AsyncParticipant`) in which the training of the model is no longer part of the participant. A more in-depth explanation of the differences between the Participant APIs and examples of how to use them can be found [here](https://github.com/xaynetwork/xaynet/blob/master/bindings/python/README.md). #### Multi-part messages Participant messages can get large, possibly too large to be sent successfully in one go. On mobile devices in particular, the internet connection may not be as reliable. In order to make the transmission of messages more robust, we implemented multi-part messages to break a large message into parts and send them sequentially to the coordinator. If the transmission of part of a message fails, only that part will be resent and not the entire message. #### Coordinator state managed in Redis In order to be able to restore the state of the coordinator after a failure or shutdown, the state is managed in Redis and no longer in memory. The Redis client can be configured via the `[redis]` setting: ```toml [redis] url = "redis://127.0.0.1/" ``` #### Support for storing global models in S3/Minio The coordinator is able to save a global model in S3/Minio after a successful round. The S3 client can be configured via the `[s3]` setting: ```toml [s3] access_key = "minio" secret_access_key = "minio123" region = ["minio", "http://localhost:9000"] [s3.buckets] global_models = "global-models" ``` `xaynet-server` must be compiled with the feature flag `model-persistence` in order to enable this feature. #### Restore coordinator state The state of the coordinator can be restored after a failure or shutdown. Restoring the coordinator be configured via the `[restore]` setting: ```toml [restore] enable = true ``` `xaynet-server` must be compiled with the feature flag `model-persistence` in order to enable this feature. #### Improved collection of state machine metrics In `v0.10.0` we introduced the collection of metrics that are emitted in the state machine of `xaynet-server` and sent to an InfluxDB instance. In `v0.11.0` we have revised the implementation and improved it further. Metrics are now sent much faster and adding metrics to the code has become much easier. ### Removed - `xaynet_client` (was split into `xaynet_sdk` and `xaynet_mobile`) - `xaynet_ffi` (is now part of `xaynet_mobile`) - `xaynet_macro` ## [0.10.0] - 2020-09-22 ### Added - Preparation for redis support: prepare for `xaynet_server` to store PET data in redis [#416](https://github.com/xaynetwork/xaynet/pull/416), [#515](https://github.com/xaynetwork/xaynet/pull/515) - Add support for multipart messages in the message structure [#508](https://github.com/xaynetwork/xaynet/pull/508), [#513](https://github.com/xaynetwork/xaynet/pull/513), [#514](https://github.com/xaynetwork/xaynet/pull/514) - Generalised scalar extension [#496](https://github.com/xaynetwork/xaynet/pull/496), [#507](https://github.com/xaynetwork/xaynet/pull/507) - Add server metrics [#487](https://github.com/xaynetwork/xaynet/pull/487), [#488](https://github.com/xaynetwork/xaynet/pull/488), [#489](https://github.com/xaynetwork/xaynet/pull/489), [#493](https://github.com/xaynetwork/xaynet/pull/493) - Refactor the client into a state machine, and add a client tailored for mobile devices [#471](https://github.com/xaynetwork/xaynet/pull/471), [#497](https://github.com/xaynetwork/xaynet/pull/497), [#506](https://github.com/xaynetwork/xaynet/pull/506) ### Changed - Split the xaynet crate into several sub-crates: - `xaynet_core` (0.1.0 released), re-exported as `xaynet::core` - `xaynet_client` (0.1.0 released), re-exported as `xaynet::client` when compiled with `--features client` - `xaynet_server` (0.1.0 released), re-exported as `xaynet::server` when compiled with `--features server` - `xaynet_macro` (0.1.0 released) - `xaynet_ffi` (not released) ## [0.9.0] - 2020-07-24 `xain/xain-fl` repository was renamed to `xaynetwork/xaynet`. The new crate will be published as `xaynet` under `v0.9.0`. ### Added This release introduces the integration of the [PET protocol](https://uploads-ssl.webflow.com/5f0c5c0bb18a279f0a62919e/5f157004da6585f299fa542b_XayNet%20Whitepaper%202.1.pdf) into the platform. **Note:** The integration of the PET protocol required a complete rewrite of the codebase and is therefore not compatible with the previous release. ## [0.8.0] - 2020-04-08 ### Added - New tutorial for the Python SDK [#355](https://github.com/xaynetwork/xaynet/pull/355) - Swagger description of the REST API [#345](https://github.com/xaynetwork/xaynet/pull/345), and is published at https://xain-fl.readthedocs.io/en/latest/ [#358](https://github.com/xaynetwork/xaynet/pull/358) - The Python examples now accepts additional parameters (model size, heartbeat period, verbosity, etc.) [#351](https://github.com/xaynetwork/xaynet/pull/351) - Publish docker images to dockerhub ### Security - Stop using `pickle` for messages serialization [#355](https://github.com/xaynetwork/xaynet/pull/355). `pickle` is insecure and can lead to remote code execution. Instead, the default aggregator uses `numpy.save()`. ### Fixed - The documentation has been updated at https://xain-fl.readthedocs.io/en/latest/ [#358](https://github.com/xaynetwork/xaynet/pull/358) - Document aggregator error on Darwin platform [#365](https://github.com/xaynetwork/xaynet/pull/365/files) ### Changed - Simplified the Python SDK API [#355](https://github.com/xaynetwork/xaynet/pull/355) - Added unit tests for the coordinator and aggregator [#353](https://github.com/xaynetwork/xaynet/pull/353), [#352](https://github.com/xaynetwork/xaynet/pull/352) - Refactor the metrics store [#340](https://github.com/xaynetwork/xaynet/pull/340) - Speed up the docker builds [#348](https://github.com/xaynetwork/xaynet/pull/348) ## [0.7.0] - 2020-03-25 On this release we archived the Python code under the `legacy` folder and shifted the development to Rust. This release has many breaking changes from the previous versions. More details will be made available through the updated README.md of the repository. ## [0.6.0] - 2020-02-26 - HOTFIX add disclaimer (#309) [janpetschexain] - PB-314: document the new weight exchange mechanism (#308) [Corentin Henry] - PB-407 add more debug level logging (#303) [janpetschexain] - PB-44 add heartbeat time and timeout to config (#305) [Robert Steiner] - PB-423 lock round access (#304) [kwok] - PB-439 Make thread pool workers configurable (#302) [Robert Steiner] - PB-159: update xain-{proto,sdk} dependencies to the right branch (#301) [Corentin Henry] - PB-159: remove weights from gRPC messages (#298) [Corentin Henry] - PB-431 send participant state to influxdb (#300) [Robert Steiner] - PB-434 separate metrics (#296) [Robert Steiner] - PB-406 :snowflake: Configure mypy (#297) [Anastasiia Tymoshchuk] - PB-428 send coordinator states (#292) [Robert Steiner] - PB-425 split weight init from training (#295) [janpetschexain] - PB-398 Round resumption in Coordinator (#285) [kwok] - Merge pull request #294 from xainag/master. [Daniel Kravetz] - Hotfix: PB-432 :pencil: :books: Update test badge and CI to reflect changes. [Daniel Kravetz] - PB-417 Start new development cycle (#291) [Anastasiia Tymoshchuk, kwok] ## [0.5.0] - 2020-02-12 Fix minor issues, update documentation. - PB-402 Add more logs (#281) [Robert Steiner] - DO-76 :whale: non alpine image (#287) [Daniel Kravetz] - PB-401 Add console renderer (#280) [Robert Steiner] - DO-80 :ambulance: Update dev Dockerfile to build gRPC (#286) [Daniel Kravetz] - DO-78 :sparkles: add grafana (#284) [Daniel Kravetz] - DO-66 :sparkles: Add keycloak (#283) [Daniel Kravetz] - PB-400 increment epoch base (#282) [janpetschexain] - PB-397 Simplify write metrics function (#279) [Robert Steiner] - PB-385 Fix xain-sdk test (#278) [Robert Steiner] - PB-352 Add sdk config (#272) [Robert Steiner] - Merge pull request #277 from xainag/master. [Daniel Kravetz] - Hotfix: update ci. [Daniel Kravetz] - DO-72 :art: Make CI name and feature consistent with other repos. [Daniel Kravetz] - DO-47 :newspaper: Build test package on release branch. [Daniel Kravetz] - PB-269: enable reading participants weights from S3 (#254) [Corentin Henry] - PB-363 Start new development cycle (#271) [Anastasiia Tymoshchuk] - PB-119 enable isort diff (#262) [janpetschexain] - PB-363 :gem: Release v0.4.0. [Daniel Kravetz] - DO-73 :green_heart: Disable continue_on_failure for CI jobs. Fix mypy. [Daniel Kravetz] ## [0.4.0] - 2020-02-04 Flatten model weights instead of using lists. Fix minor issues, update documentation. - PB-116: pin docutils version (#259) [Corentin Henry] - PB-119 update isort config and calls (#260) [janpetschexain] - PB-351 Store participant metrics (#244) [Robert Steiner] - Adjust isort config (#258) [Robert Steiner] - PB-366 flatten weights (#253) [janpetschexain] - PB-379 Update black setup (#255) [Anastasiia Tymoshchuk] - PB-387 simplify serve module (#251) [Corentin Henry] - PB-104: make the tests fast again (#252) [Corentin Henry] - PB-122: handle sigint properly (#250) [Corentin Henry] - PB-383 write aggregated weights after each round (#246) [Corentin Henry] - PB-104: Fix exception in monitor_hearbeats() (#248) [Corentin Henry] - DO-57 Update docker-compose files for provisioning InfluxDB (#249) [Ricardo Saffi Marques] - DO-59 Provision Redis 5.x for persisting states for the Coordinator (#247) [Ricardo Saffi Marques] - PB-381: make the log level configurable (#243) [Corentin Henry] - PB-382: cleanup storage (#245) [Corentin Henry] - PB-380: split get_logger() (#242) [Corentin Henry] - XP-332: grpc resource exhausted (#238) [Robert Steiner] - XP-456: fix coordinator command (#241) [Corentin Henry] - XP-485 Document revised state machine (#240) [kwok] - XP-456: replace CLI argument with a config file (#221) [Corentin Henry] - DO-48 :snowflake: :rocket: Build stable package on git tag with SemVer (#234) [Daniel Kravetz] - XP-407 update documentation (#239) [janpetschexain] - XP-406 remove numpy file cli (#237) [janpetschexain] - XP-544 fix aggregate module (#235) [janpetschexain] - DO-58: cache xain-fl dependencies in Docker (#232) [Corentin Henry] - XP-479 Start training rounds from 0 (#226) [kwok] ## [0.3.0] - 2020-01-21 - XP-505 cleanup docstrings in xain_fl.coordinator (#228) - XP-498 more generic shebangs (#229) - XP-510 allow for zero epochs on cli (#227) - XP-508 Replace circleci badge (#225) - XP-505 docstrings cleanup (#224) - XP-333 Replace numproto with xain-proto (#220) - XP-499 Remove conftest, exclude tests folder (#223) - XP-480 revise message names (#222) - XP-436 Reinstate FINISHED heartbeat from Coordinator (#219) - XP-308 store aggregated weights in S3 buckets (#215) - XP-308 store aggregated weights in S3 buckets (#215) - XP-422 ai metrics (#216) - XP-119 Fix gRPC testing setup so that it can run on macOS (#217) - XP-433 Fix docker headings (#218) - Xp 373 add sdk as dependency in fl (#214) - DO-49 Create initial buckets (#213) - XP-424 Remove unused packages (#212) - XP-271 fix pylint issues (#210) - XP-374 Clean up docs (#211) - DO-43 docker compose minio (#208) - XP-384 remove unused files (#209) - XP-357 make controller parametrisable (#201) - XP 273 scripts cleanup (#206) - XP-385 Fix docs badge (#204) - XP-354 Remove proto files (#200) - DO-17 Add Dockerfiles, dockerignore and docs (#202) - XP-241 remove legacy participant and sdk dir (#199) - XP-168 update setup.py (#191) - XP-261 move tests to own dir (#197) - XP-257 cleanup cproto dir (#198) - XP-265 move benchmarks to separate repo (#193) - XP-255 update codeowners and authors in setup (#195) - XP-255 update codeowners and authors in setup (#195) - XP-229 Update Readme.md (#189) - XP-337 Clean up docs before generation (#188) - XP-264 put coordinator as own package (#183) - XP-272 Archive rust code (#186) - Xp 238 add participant selection (#179) - XP-229 Update readme (#185) - XP-334 Add make docs into docs make file (#184) - XP-291 harmonize docs styles (#181) - XP-300 Update docs makefile (#180) - XP-228 Update readme (#178) - XP-248 use structlog (#173) - XP-207 model framework agnostic (#166) - XAIN-284 rename package name (#176) - XP-251 Add ability to pass params per cmd args to coordinator (#174) - XP-167 Add gitter badge (#171) - Hotfix badge versions and style (#170) - Integrate docs with readthedocs (#169) - add pull request template (#168) ## [0.2.0] - 2019-12-02 ### Changed - Renamed package from xain to xain-fl ## [0.1.0] - 2019-09-25 The first public release of **XAIN** ### Added - FedML implementation on well known [benchmarks](https://github.com/xaynetwork/xaynet/tree/v0.1.0/xain/benchmark) using a realistic deep learning model structure. [Unreleased]: https://github.com/xaynetwork/xaynet/compare/v0.11.0...HEAD [0.11.0]: https://github.com/xaynetwork/xaynet/compare/v0.10.0...v0.11.0 [0.10.0]: https://github.com/xaynetwork/xaynet/compare/v0.9.0...v0.10.0 [0.9.0]: https://github.com/xaynetwork/xaynet/compare/v0.8.0...v0.9.0 [0.8.0]: https://github.com/xaynetwork/xaynet/compare/v0.7.0...v0.8.0 [0.7.0]: https://github.com/xaynetwork/xaynet/compare/v0.6.0...v0.7.0 [0.6.0]: https://github.com/xaynetwork/xaynet/compare/v0.5.0...v0.6.0 [0.5.0]: https://github.com/xaynetwork/xaynet/compare/v0.4.0...v0.5.0 [0.4.0]: https://github.com/xaynetwork/xaynet/compare/v0.3.0...v0.4.0 [0.3.0]: https://github.com/xaynetwork/xaynet/compare/v0.2.0...v0.3.0 [0.2.1]: https://github.com/xaynetwork/xaynet/compare/v0.2.0...v0.2.1 [0.2.0]: https://github.com/xaynetwork/xaynet/compare/v0.1.0...v0.2.0 [0.1.0]: https://github.com/xaynetwork/xaynet/tree/v0.1.0 ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ [![crates.io badge](https://img.shields.io/crates/v/xaynet.svg)](https://crates.io/crates/xaynet) [![docs.rs badge](https://docs.rs/xaynet/badge.svg)](https://docs.rs/xaynet) [![rustc badge](https://img.shields.io/badge/rustc-1.51.0+-lightgray.svg)](https://www.rust-lang.org/learn/get-started) [![Coverage Status](https://codecov.io/gh/xaynetwork/xaynet/branch/master/graph/badge.svg)](https://codecov.io/gh/xaynetwork/xaynet) ![Maintenance](https://img.shields.io/badge/maintenance-activly--developed-brightgreen.svg) [![roadmap badge](https://img.shields.io/badge/Roadmap-2021-blue)](./ROADMAP.md) ![Xaynet banner](./assets/xaynet_banner.png) # xaynet ## Xaynet: Train on the Edge with Federated Learning Want a framework that supports federated learning on the edge, in desktop browsers, integrates well with mobile apps, is performant, and preserves privacy? Welcome to XayNet, written entirely in Rust! ### Making federated learning easy for developers Frameworks for machine learning - including those expressly for federated learning - exist already. These frameworks typically facilitate federated learning of cross-silo use cases - for example in collaborative learning across a limited number of hospitals or for instance across multiple banks working on a common use case without the need to share valuable and sensitive data. This repository focusses on masked cross-device federated learning to enable the orchestration of machine learning in millions of low-power edge devices, such as smartphones or even cars. By doing this, we hope to also increase the pace and scope of adoption of federated learning in practice and especially allow the protection of end user data. All data remains in private local premises, whereby only encrypted AI models get automatically and asynchronously aggregated. Thus, we provide a solution to the AI privacy dilemma and bridge the often-existing gap between privacy and convenience. Imagine, for example, a voice assistant to learn new words directly on device level and sharing this knowledge with all other instances, without recording and collecting your voice input centrally. Or, think about search engine that learns to personalise search results without collecting your often sensitive search queries centrally… There are thousands of such use cases that right today still trade privacy for convenience. We think this shouldn’t be the case and we want to provide an alternative to overcome this dilemma. Concretely, we provide developers with: - **App dev tools**: An SDK to integrate federated learning into apps written in Dart or other languages of choice for mobile development, as well as frameworks like Flutter. - **Privacy via cross-device federated learning**: Train your AI models locally on edge devices such as mobile phones, browsers, or even in cars. Federated learning automatically aggregates the local models into a global model. Thus, all insights inherent in the local models are captured, while the user data stays private on end devices. - **Security Privacy via homomorphic encryption**: Aggregate models with the highest security and trust. Xayn’s masking protocol encrypts all models homomorphically. This enables you to aggregate encrypted local models into a global one – without having to decrypt local models at all. This protects private and even the most sensitive data. ### The case for writing this framework in Rust Our framework for federated learning is not only a framework for machine learning as such. Rather, it supports the federation of machine learning that takes place on possibly heterogeneous devices and where use cases involve many such devices. The programming language in which this framework is written should therefore give us strong support for the following: - **Runs "everywhere"**: the language should not require its own runtime and code should compile on a wide range of devices. - **Memory and concurrency safety**: code that compiles should be both memory safe and free of data races. - **Secure communication**: state of the art cryptography should be available in vetted implementations. - **Asynchronous communication**: abstractions for asynchronous communication should exist that make federated learning scale. - **Fast and functional**: the language should offer functional abstractions but also compile code into fast executables. Rust is one of the very few choices of modern programming languages that meets these requirements: - its concepts of Ownership and Borrowing make it both memory and thread-safe (hence avoiding many common concurrency issues). - it has a strong and static type discipline and traits, which describe shareable functionality of a type. - it is a modern systems programming language, with some functional style features such as pattern matching, closures and iterators. - its idiomatic code compares favourably to idiomatic C in performance. - it compiles to WASM and can therefore be applied natively in browser settings. - it is widely deployable and doesn't necessarily depend on a runtime, unlike languages such as Java and their need for a virtual machine to run its code. Foreign Function Interfaces support calls from other languages/frameworks, including Dart, Python and Flutter. - it compiles into LLVM, and so it can draw from the abundant tool suites for LLVM. --- # Getting Started ## Minimum supported rust version rustc 1.51.0 ## Running the platform There are a few different ways to run the backend: via docker, or by deploying it to a Kubernetes cluster or by compiling the code and running the binary manually. 1. Everything described below assumes your shell's working directory to be the root of the repository. 2. The following instructions assume you have pre-existing knowledge on some of the referenced software (like `docker` and `docker-compose`) and/or a working setup (if you decide to compile the Rust code and run the binary manually). 3. In case you need help with setting up your system accordingly, we recommend you refer to the official documentation of each tool, as supporting them here would be beyond the scope of this project: * [Rust](https://www.rust-lang.org/tools/install) * [Docker](https://docs.docker.com/) and [Docker Compose](https://docs.docker.com/compose/) * [Kubernetes](https://kubernetes.io/docs/home/) **Note:** With Xaynet `v0.11` the coordinator needs a connection to a redis instance in order to save its state. **Don't connect the coordinator to a Redis instance that is used in production!** We recommend connecting the coordinator to its own Redis instance. We have invested a lot of time to make sure that the coordinator only deletes its own data but in the current state of development, we cannot guarantee that this will always be the case. ### Using Docker The convenience of using the docker setup is that there's no need to setup a working Rust environment on your system, as everything is done inside the container. #### Run an image from Docker Hub Docker images of the latest releases are provided on [Docker Hub](https://hub.docker.com/r/xaynetwork/xaynet). You can try them out with the default `configs/docker-dev.toml` by running: **Xaynet below v0.11** ```bash docker run -v ${PWD}/configs/docker-dev.toml:/app/config.toml -p 8081:8081 xaynetwork/xaynet:v0.10.0 /app/coordinator -c /app/config.toml ``` **Xaynet v0.11+** ```bash # don't forget to adjust the Redis url in configs/docker-dev.toml docker run -v ${PWD}/configs/docker-dev.toml:/app/config.toml -p 8081:8081 xaynetwork/xaynet:v0.11.0 ``` The docker image contains a release build of the coordinator without optional features. #### Run a coordinator with additional infrastructure Start the coordinator by pointing to the `docker/docker-compose.yml` file. It spins up all infrastructure that is essential to run the coordinator with default or optional features. Keep in mind that this file is used for development only. ```bash docker-compose -f docker/docker-compose.yml up --build ``` #### Create a release build If you would like, you can create an optimized release build of the coordinator, but keep in mind that the compilation will be slower. ```bash docker build --build-arg RELEASE_BUILD=1 -f ./docker/Dockerfile . ``` #### Build a coordinator with optional features Optional features can be specified via the build argument `COORDINATOR_FEATURES`. ```bash docker build --build-arg COORDINATOR_FEATURES=tls,metrics -f ./docker/Dockerfile . ``` ### Using Kubernetes To deploy an instance of the coordinator to your Kubernetes cluster, use the manifests that are located inside the `k8s/coordinator` folder. The manifests rely on `kustomize` to be generated (`kustomize` is officially supported by `kubectl` since v1.14). We recommend you thoroughly go through the manifests and adjust them according to your own setup (namespace, ingress, etc.). Remember to also check (and adjust if necessary) the default configuration for the coordinator, available at `k8s/coordinator/development/config.toml`. Please adjust the domain used in the `k8s/coordinator/development/ingress.yaml` file so it matches your needs (you can also skip `ingress` altogether, just make sure you remove its reference from `k8s/coordinator/development/kustomization.yaml`). Keep in mind that the `ingress` configuration that is shown on `k8s/coordinator/development/ingress.yaml` relies on resources that aren't available in this repository, due to their sensitive nature (TLS key and certificate, for instance). To verify the generated manifests, run: ```bash kubectl kustomize k8s/coordinator/development ``` To apply them: ```bash kubectl apply -k k8s/coordinator/development ``` In case you are not exposing your coordinator via `ingress`, you can still reach it using a port-forward. The example below creates a port-forward at port `8081` assuming the coordinator pod is still using the `app=coordinator` label: ```bash kubectl port-forward $(kubectl get pods -l "app=coordinator" -o jsonpath="{.items[0].metadata.name}") 8081 ``` ### Building the project manually The coordinator without optional features can be built and started with: ```bash cd rust cargo run --bin coordinator -- -c ../configs/config.toml ``` ## Running the example The example can be found under [rust/examples/](./rust/examples/). It uses a dummy model but is network-capable, so it's a good starting point for checking connectivity with the coordinator. ### `test-drive` Make sure you have a running instance of the coordinator and that the clients you will spawn with the command below are able to reach it through the network. Here is an example on how to start `20` participants that will connect to a coordinator running on `127.0.0.1:8081`: ```bash cd rust RUST_LOG=info cargo run --example test-drive -- -n 20 -u http://127.0.0.1:8081 ``` For more in-depth details on how to run examples, see the accompanying Getting Started guide under [rust/xaynet-server/src/examples.rs](./rust/xaynet-server/src/examples.rs). ## Troubleshooting If you have any difficulties running the project, please reach out to us by [opening an issue](https://github.com/xaynetwork/xaynet/issues/new) and describing your setup and the problems you're facing. ================================================ FILE: README.tpl ================================================ [![crates.io badge](https://img.shields.io/crates/v/xaynet.svg)](https://crates.io/crates/xaynet) [![docs.rs badge](https://docs.rs/xaynet/badge.svg)](https://docs.rs/xaynet) [![rustc badge](https://img.shields.io/badge/rustc-1.51.0+-lightgray.svg)](https://www.rust-lang.org/learn/get-started) {{badges}} [![roadmap badge](https://img.shields.io/badge/Roadmap-2021-blue)](./ROADMAP.md) ![Xaynet banner](./assets/xaynet_banner.png) # {{crate}} {{readme}} --- # Getting Started ## Minimum supported rust version rustc 1.51.0 ## Running the platform There are a few different ways to run the backend: via docker, or by deploying it to a Kubernetes cluster or by compiling the code and running the binary manually. 1. Everything described below assumes your shell's working directory to be the root of the repository. 2. The following instructions assume you have pre-existing knowledge on some of the referenced software (like `docker` and `docker-compose`) and/or a working setup (if you decide to compile the Rust code and run the binary manually). 3. In case you need help with setting up your system accordingly, we recommend you refer to the official documentation of each tool, as supporting them here would be beyond the scope of this project: * [Rust](https://www.rust-lang.org/tools/install) * [Docker](https://docs.docker.com/) and [Docker Compose](https://docs.docker.com/compose/) * [Kubernetes](https://kubernetes.io/docs/home/) **Note:** With Xaynet `v0.11` the coordinator needs a connection to a redis instance in order to save its state. **Don't connect the coordinator to a Redis instance that is used in production!** We recommend connecting the coordinator to its own Redis instance. We have invested a lot of time to make sure that the coordinator only deletes its own data but in the current state of development, we cannot guarantee that this will always be the case. ### Using Docker The convenience of using the docker setup is that there's no need to setup a working Rust environment on your system, as everything is done inside the container. #### Run an image from Docker Hub Docker images of the latest releases are provided on [Docker Hub](https://hub.docker.com/r/xaynetwork/xaynet). You can try them out with the default `configs/docker-dev.toml` by running: **Xaynet below v0.11** ```bash docker run -v ${PWD}/configs/docker-dev.toml:/app/config.toml -p 8081:8081 xaynetwork/xaynet:v0.10.0 /app/coordinator -c /app/config.toml ``` **Xaynet v0.11+** ```bash # don't forget to adjust the Redis url in configs/docker-dev.toml docker run -v ${PWD}/configs/docker-dev.toml:/app/config.toml -p 8081:8081 xaynetwork/xaynet:v0.11.0 ``` The docker image contains a release build of the coordinator without optional features. #### Run a coordinator with additional infrastructure Start the coordinator by pointing to the `docker/docker-compose.yml` file. It spins up all infrastructure that is essential to run the coordinator with default or optional features. Keep in mind that this file is used for development only. ```bash docker-compose -f docker/docker-compose.yml up --build ``` #### Create a release build If you would like, you can create an optimized release build of the coordinator, but keep in mind that the compilation will be slower. ```bash docker build --build-arg RELEASE_BUILD=1 -f ./docker/Dockerfile . ``` #### Build a coordinator with optional features Optional features can be specified via the build argument `COORDINATOR_FEATURES`. ```bash docker build --build-arg COORDINATOR_FEATURES=tls,metrics -f ./docker/Dockerfile . ``` ### Using Kubernetes To deploy an instance of the coordinator to your Kubernetes cluster, use the manifests that are located inside the `k8s/coordinator` folder. The manifests rely on `kustomize` to be generated (`kustomize` is officially supported by `kubectl` since v1.14). We recommend you thoroughly go through the manifests and adjust them according to your own setup (namespace, ingress, etc.). Remember to also check (and adjust if necessary) the default configuration for the coordinator, available at `k8s/coordinator/development/config.toml`. Please adjust the domain used in the `k8s/coordinator/development/ingress.yaml` file so it matches your needs (you can also skip `ingress` altogether, just make sure you remove its reference from `k8s/coordinator/development/kustomization.yaml`). Keep in mind that the `ingress` configuration that is shown on `k8s/coordinator/development/ingress.yaml` relies on resources that aren't available in this repository, due to their sensitive nature (TLS key and certificate, for instance). To verify the generated manifests, run: ```bash kubectl kustomize k8s/coordinator/development ``` To apply them: ```bash kubectl apply -k k8s/coordinator/development ``` In case you are not exposing your coordinator via `ingress`, you can still reach it using a port-forward. The example below creates a port-forward at port `8081` assuming the coordinator pod is still using the `app=coordinator` label: ```bash kubectl port-forward $(kubectl get pods -l "app=coordinator" -o jsonpath="{.items[0].metadata.name}") 8081 ``` ### Building the project manually The coordinator without optional features can be built and started with: ```bash cd rust cargo run --bin coordinator -- -c ../configs/config.toml ``` ## Running the example The example can be found under [rust/examples/](./rust/examples/). It uses a dummy model but is network-capable, so it's a good starting point for checking connectivity with the coordinator. ### `test-drive` Make sure you have a running instance of the coordinator and that the clients you will spawn with the command below are able to reach it through the network. Here is an example on how to start `20` participants that will connect to a coordinator running on `127.0.0.1:8081`: ```bash cd rust RUST_LOG=info cargo run --example test-drive -- -n 20 -u http://127.0.0.1:8081 ``` For more in-depth details on how to run examples, see the accompanying Getting Started guide under [rust/xaynet-server/src/examples.rs](./rust/xaynet-server/src/examples.rs). ## Troubleshooting If you have any difficulties running the project, please reach out to us by [opening an issue](https://github.com/xaynetwork/xaynet/issues/new) and describing your setup and the problems you're facing. ================================================ FILE: ROADMAP.md ================================================ # Roadmap 2021 ![Roadmap Q1](./assets/roadmap_q1.png) In Q1 we focus entirely on using XayNet for the [Xayn app] in terms of federated learning and first simple analytics, such as gathering relevant AI performance data like [NDCG metrics] because we want to know how our AI models perform without violating the privacy of our users. As you know, our framework originated with the aim to aggregate machine learning models securely and privately between edge devices. Thereby, the models are transformed into one-dimensional lists so that at the end we only aggregate a list of numbers, so why not also aggregate other numerical analytics data, like AI performance metrics or user behaviour, such as screen times in our app, all of course with the privacy guarantees of XayNet. As such, we focus predominantly on mobile cross-device learning but also extend our framework to cover such use cases. In Q1 we take however mostly care about the internal mobile case and testing so we set the basis to further generalisation to external cases in the community during the rest of the year. ![Roadmap Q2](./assets/roadmap_q2.png) In Q2 we have three main focus points: Extending XayNet to support also web applications, since also our [Xayn app] will be provided as a web version via [WASM]; integrating our product analytics extensions in our [Xayn app] and optimising the client for higher performance, which is one the major bottlenecks. ![Roadmap Q3](./assets/roadmap_q3.png) In Q3, we can imagine to opening up the analytics layer also to more general use cases outside of Xayn itself. Until then our core focus is predominantly internally, yet, of course we hope to get community and external feature suggestions and reviews. Also we want to make the coordinator more observable as a foundation for further optimisations. [Xayn app]: https://www.xayn.com/ [NDCG metrics]: https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG [WASM]: https://webassembly.org/ ================================================ FILE: bindings/python/.gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ global_model.bin state.bin ================================================ FILE: bindings/python/.isort.cfg ================================================ [settings] combine_as_imports=True force_grid_wrap=0 force_sort_within_sections=True include_trailing_comma=True indent=4 line_length=88 multi_line_output=3 use_parentheses=True ================================================ FILE: bindings/python/.pylintrc ================================================ [MASTER] # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. extension-pkg-whitelist= # Add files or directories to the blacklist. They should be base names, not # paths. ignore=grpc # Add files or directories matching the regex patterns to the blacklist. The # regex matches against base names, not paths. # ignore-patterns= # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). #init-hook= # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the # number of processors available to use. # Use only 1 because of https://github.com/PyCQA/pylint/issues/374 jobs=1 # Control the amount of potential inferred values when inferring a single # object. This can help the performance when dealing with large functions or # complex, nested conditions. limit-inference-results=100 # Pickle collected data for later comparisons. persistent=yes # Specify a configuration file. #rcfile= # When enabled, pylint would attempt to guess common misconfiguration and emit # user-friendly hints instead of false-positive error messages. suggestion-mode=yes # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no [MESSAGES CONTROL] # Only show warnings with the listed confidence levels. Leave empty to show # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. confidence= # Disable the message, report, category or checker with the given id(s). You # can either give multiple identifiers separated by comma (,) or put this # option multiple times (only on the command line, not in the configuration # file where it should appear only once). You can also use "--disable=all" to # disable everything first and then reenable specific checks. For example, if # you want to run only the similarities checker, you can use "--disable=all # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use "--disable=all --enable=classes # --disable=W". disable=print-statement, old-raise-syntax, backtick, long-suffix, old-ne-operator, old-octal-literal, import-star-module-level, non-ascii-bytes-literal, raw-checker-failed, bad-inline-option, locally-disabled, file-ignored, suppressed-message, useless-suppression, deprecated-pragma, apply-builtin, basestring-builtin, buffer-builtin, cmp-builtin, coerce-builtin, execfile-builtin, file-builtin, long-builtin, raw_input-builtin, reduce-builtin, standarderror-builtin, unicode-builtin, xrange-builtin, coerce-method, delslice-method, getslice-method, setslice-method, no-absolute-import, old-division, dict-iter-method, dict-view-method, next-method-called, metaclass-assignment, indexing-exception, raising-string, reload-builtin, oct-method, hex-method, nonzero-method, cmp-method, input-builtin, round-builtin, intern-builtin, unichr-builtin, map-builtin-not-iterating, zip-builtin-not-iterating, range-builtin-not-iterating, filter-builtin-not-iterating, using-cmp-argument, eq-without-hash, div-method, idiv-method, rdiv-method, exception-message-attribute, invalid-str-codec, sys-max-int, bad-python3-import, deprecated-string-function, deprecated-str-translate-call, deprecated-itertools-function, deprecated-types-field, next-method-defined, dict-items-not-iterating, dict-keys-not-iterating, dict-values-not-iterating, deprecated-operator-function, deprecated-urllib-function, xreadlines-attribute, deprecated-sys-function, exception-escape, comprehension-escape, c-extension-no-member, duplicate-code, bad-continuation, fixme, redefined-builtin, missing-docstring, # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where # it should appear only once). See also the "--disable" option for examples. enable= [REFACTORING] # Maximum number of nested blocks for function / method body max-nested-blocks=5 # Complete name of functions that never returns. When checking for # inconsistent-return-statements if a never returning function is called then # it will be considered as an explicit return statement and no message will be # printed. never-returning-functions=sys.exit [LOGGING] # Format style used to check logging format string. `old` means using % # formatting, while `new` is for `{}` formatting. logging-format-style=old # Logging modules to check that the string format arguments are in logging # function parameter format. logging-modules=logging [SPELLING] # Limits count of emitted suggestions for spelling mistakes. max-spelling-suggestions=4 # Spelling dictionary name. Available dictionaries: none. To make it working # install python-enchant package.. spelling-dict= # List of comma separated words that should not be checked. spelling-ignore-words= # A path to a file that contains private dictionary; one word per line. spelling-private-dict-file= # Tells whether to store unknown words to indicated private dictionary in # --spelling-private-dict-file option instead of raising a message. spelling-store-unknown-words=no [MISCELLANEOUS] # List of note tags to take in consideration, separated by a comma. notes=FIXME, XXX, TODO [TYPECHECK] # List of decorators that produce context managers, such as # contextlib.contextmanager. Add to this list to register other decorators that # produce valid context managers. contextmanager-decorators=contextlib.contextmanager # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. # regular expressions currently don't work https://github.com/PyCQA/pylint/issues/2498. generated-members= # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). ignore-mixin-members=yes # Tells whether to warn about missing members when the owner of the attribute # is inferred to be None. ignore-none=yes # This flag controls whether pylint should warn about no-member and similar # checks whenever an opaque object is returned when inferring. The inference # can return multiple potential results while evaluating a Python object, but # some branches might not be evaluated, which results in partial inference. In # that case, it might be useful to still emit no-member and other checks for # the rest of the inferred objects. ignore-on-opaque-inference=yes # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of # qualified names. ignored-classes=optparse.Values,thread._local,_thread._local,SQLAlchemy # List of module names for which member attributes should not be checked # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis. It # supports qualified module names, as well as Unix pattern matching. ignored-modules= # Show a hint with possible names when a member name was not found. The aspect # of finding the hint is based on edit distance. missing-member-hint=yes # The minimum edit distance a name should have in order to be considered a # similar match for a missing member name. missing-member-hint-distance=1 # The total number of similar names that should be taken in consideration when # showing a hint for a missing member. missing-member-max-choices=1 [VARIABLES] # List of additional names supposed to be defined in builtins. Remember that # you should avoid defining new builtins when possible. additional-builtins= # Tells whether unused global variables should be treated as a violation. allow-global-unused-variables=yes # List of strings which can identify a callback function by name. A callback # name must start or end with one of those strings. callbacks=cb_, _cb # A regular expression matching the name of dummy variables (i.e. expected to # not be used). dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ # Argument names that match this expression will be ignored. Default to name # with leading underscore. ignored-argument-names=_.*|^ignored_|^unused_ # Tells whether we should check for unused import in __init__ files. init-import=no # List of qualified module names which can have objects that can redefine # builtins. redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io [FORMAT] # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. expected-line-ending-format= # Regexp for a line that is allowed to be longer than the limit. ignore-long-lines=^\s*(# )??$ # Number of spaces of indent required inside a hanging or continued line. indent-after-paren=4 # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 # tab). indent-string=' ' # Maximum number of characters on a single line. max-line-length=100 # Maximum number of lines in a module. max-module-lines=2000 # List of optional constructs for which whitespace checking is disabled. `dict- # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. # `trailing-comma` allows a space between comma and closing bracket: (a, ). # `empty-line` allows space-only lines. no-space-check=trailing-comma, dict-separator # Allow the body of a class to be on the same line as the declaration if body # contains single statement. single-line-class-stmt=no # Allow the body of an if to be on the same line as the test if there is no # else. single-line-if-stmt=no [SIMILARITIES] # Ignore comments when computing similarities. ignore-comments=yes # Ignore docstrings when computing similarities. ignore-docstrings=yes # Ignore imports when computing similarities. ignore-imports=no # Minimum lines number of a similarity. min-similarity-lines=10 [BASIC] # Naming style matching correct argument names. argument-naming-style=snake_case # Regular expression matching correct argument names. Overrides argument- # naming-style. argument-rgx=[a-z_][a-z0-9_]{2,30}$ # Naming style matching correct attribute names. attr-naming-style=snake_case # Regular expression matching correct attribute names. Overrides attr-naming- # style. attr-rgx=[a-z_][a-z0-9_]{2,}$ # Bad variable names which should always be refused, separated by a comma. bad-names=foo, bar, baz, toto, tutu, tata # Naming style matching correct class attribute names. class-attribute-naming-style=any # Regular expression matching correct class attribute names. Overrides class- # attribute-naming-style. class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ # Naming style matching correct class names. class-naming-style=PascalCase # Regular expression matching correct class names. Overrides class-naming- # style. class-rgx=[A-Z_][a-zA-Z0-9]+$ # Naming style matching correct constant names. const-naming-style=UPPER_CASE # Regular expression matching correct constant names. Overrides const-naming- # style. const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ # Minimum line length for functions/classes that require docstrings, shorter # ones are exempt. docstring-min-length=1 # Naming style matching correct function names. function-naming-style=snake_case # Regular expression matching correct function names. Overrides function- # naming-style. function-rgx=[a-z_][a-z0-9_]{2,}$ # Good variable names which should always be accepted, separated by a comma. good-names=i, j, k, ex, Run, _, logger, _y # Include a hint for the correct naming format with invalid-name. include-naming-hint=no # Naming style matching correct inline iteration names. inlinevar-naming-style=any # Regular expression matching correct inline iteration names. Overrides # inlinevar-naming-style. inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ # Naming style matching correct method names. method-naming-style=snake_case # Regular expression matching correct method names. Overrides method-naming- # style. method-rgx=[a-z_][a-z0-9_]{2,}$ # Naming style matching correct module names. module-naming-style=snake_case # Regular expression matching correct module names. Overrides module-naming- # style. module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. name-group= # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=^_ # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. # These decorators are taken in consideration only for invalid-name. property-classes=abc.abstractproperty # Naming style matching correct variable names. variable-naming-style=snake_case # Regular expression matching correct variable names. Overrides variable- # naming-style. variable-rgx=[a-z_][a-z0-9_]{2,30}$ [STRING] # This flag controls whether the implicit-str-concat-in-sequence should # generate a warning on implicit string concatenation in sequences defined over # several lines. check-str-concat-over-line-jumps=no [STRING_QUOTES] # The quote character for triple-quoted docstrings. docstring-quote=double # The quote character for string literals. string-quote=double-avoid-escape # The quote character for triple-quoted strings (non-docstring). triple-quote=double [IMPORTS] # Allow wildcard imports from modules that define __all__. allow-wildcard-with-all=no # Analyse import fallback blocks. This can be used to support both Python 2 and # 3 compatible code, which means that the block might have code that exists # only in one or another interpreter, leading to false positives when analysed. analyse-fallback-blocks=no # Deprecated modules which should not be used, separated by a comma. deprecated-modules=optparse,tkinter.tix # Create a graph of external dependencies in the given file (report RP0402 must # not be disabled). ext-import-graph= # Create a graph of every (i.e. internal and external) dependencies in the # given file (report RP0402 must not be disabled). import-graph= # Create a graph of internal dependencies in the given file (report RP0402 must # not be disabled). int-import-graph= # Force import order to recognize a module as part of the standard # compatibility libraries. known-standard-library= # Force import order to recognize a module as part of a third party library. known-third-party=enchant [CLASSES] # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods=__init__, __new__, setUp # List of member names, which should be excluded from the protected access # warning. exclude-protected=_asdict, _fields, _replace, _source, _make # List of valid names for the first argument in a class method. valid-classmethod-first-arg=cls # List of valid names for the first argument in a metaclass class method. valid-metaclass-classmethod-first-arg=mcs [DESIGN] # Maximum number of arguments for function / method. max-args=10 # Maximum number of attributes for a class (see R0902). max-attributes=11 # Maximum number of boolean expressions in an if statement. max-bool-expr=5 # Maximum number of branch for function / method body. max-branches=26 # Maximum number of locals for function / method body. max-locals=25 # Maximum number of parents for a class (see R0901). max-parents=7 # Maximum number of public methods for a class (see R0904). max-public-methods=25 # Maximum number of return / yield for function / method body. max-returns=6 # Maximum number of statements in function / method body. max-statements=100 # Minimum number of public methods for a class (see R0903). min-public-methods=0 [EXCEPTIONS] # Exceptions that will emit a warning when being caught. Defaults to # "BaseException, Exception". overgeneral-exceptions=Exception ================================================ FILE: bindings/python/Cargo.toml ================================================ [package] name = "xaynet-sdk-python" version = "0.1.0" authors = ["Xayn Engineering "] edition = "2018" description = "The Xayn Network project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant." readme = "README.md" homepage = "https://xaynet.dev/" repository = "https://github.com/xaynetwork/xaynet/" license = "Apache-2.0" keywords = ["federated-learning", "fl", "ai", "machine-learning"] categories = ["science", "cryptography"] [package.metadata.maturin] classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "Intended Audience :: Information Technology", "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Software Development", "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Application Frameworks", "Topic :: Software Development :: Libraries :: Python Modules", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Operating System :: MacOS :: MacOS X", "Operating System :: POSIX :: Linux", ] requires-python = ">=3.6" requires-dist = [ "justbackoff (==0.6.0)", ] [package.metadata] # minimum supported rust version msrv = "1.51.0" [dependencies] sodiumoxide = "0.2.7" tracing = "0.1.36" tracing-subscriber = { version = "0.3.15", features = ["env-filter"] } pyo3 = {version = "=0.13.2", features = ["abi3-py36", "extension-module"]} xaynet-core = { path = "../../rust/xaynet-core", version = "0.2.0"} xaynet-mobile = { path = "../../rust/xaynet-mobile", version = "0.1.0"} xaynet-sdk = { path = "../../rust/xaynet-sdk", version = "0.1.0"} [lib] name = "xaynet_sdk" crate-type = ["cdylib"] ================================================ FILE: bindings/python/README.md ================================================ ![Xaynet banner](../../assets/xaynet_banner.png) ## Installation **Prerequisites** - Python 3.6 or higher **1. Install it via `pip`** ```bash # create and activate a virtual environment e.g. pyenv virtualenv xaynet pyenv activate xaynet pip install xaynet-sdk-python ``` **2. Build it from source** ```bash # first install rust via https://rustup.rs/ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh # clone the xaynet repository git clone https://github.com/xaynetwork/xaynet.git cd xaynet/bindings/python # create and activate a virtual environment e.g. pyenv virtualenv xaynet pyenv activate xaynet # install maturin pip install maturin==0.9.1 pip install justbackoff # install xaynet-sdk maturin develop ``` ## Participant API(s) The Python SDK that consists of two experimental Xaynet participants `ParticipantABC` and `AsyncParticipant`. The word `Async` does not refer to either `asyncio` or asynchronous federated learning. It refers to the property when a local model can be set. In `ParticipantABC` the local model can only be set if the participant was selected an update participant while in `AsyncParticipant` the model can be set at any time. ### `ParticipantABC` The `ParticipantABC` API is similar to the old one which we introduced in [`v0.8.0`](https://github.com/xaynetwork/xaynet/blob/v0.8.0/python/sdk/xain_sdk/participant.py#L24). Aside from some changes to the method signature, the biggest change is that the participant now runs in its own thread. To migrate from `v0.8.0` to `v0.11.0` please follow the [migration guide](./migration_guide.md). ![ParticipantABC](../../assets/python_participant.svg) **Public API of `ParticipantABC` and `InternalParticipant`** ```python def spawn_participant( coordinator_url: str, participant: ParticipantABC, args: Tuple = (), kwargs: dict = {}, state: Optional[List[int]] = None, scalar: float = 1.0, ): """ Spawns a `InternalParticipant` in a separate thread and returns a participant handle. If a `state` is passed, this state is restored, otherwise a new `InternalParticipant` is created. Args: coordinator_url: The url of the coordinator. participant: A class that implements `ParticipantABC`. args: The args that get passed to the constructor of the `participant` class. kwargs: The kwargs that get passed to the constructor of the `participant` class. state: A serialized participant state. Defaults to `None`. scalar: The scalar used for masking. Defaults to `1.0`. Note: The `scalar` is used later when the models are aggregated in order to scale their weights. It can be used when you want to weight the participants updates differently. For example: If not all participant updates should be weighted equally but proportionally to their training samples, the scalar would be set to `scalar = 1 / number_of_samples`. Returns: The `InternalParticipant`. Raises: CryptoInit: If the initialization of the underling crypto library has failed. ParticipantInit: If the participant cannot be initialized. This is most likely caused by an invalid `coordinator_url`. ParticipantRestore: If the participant cannot be restored due to invalid serialized state. This exception can never be thrown if the `state` is `None`. Exception: Any exception that can be thrown during the instantiation of `participant`. """ class ParticipantABC(ABC): def train_round(self, training_input: Optional[TrainingInput]) -> TrainingResult: """ Trains a model. `training_input` is the deserialized global model (see `deserialize_training_input`). If no global model exists (usually in the first round), `training_input` will be `None`. In this case the weights of the model should be initialized and returned. Args: self: The participant. training_input: The deserialized global model (weights of the global model) or None. Returns: The updated model weights (the local model). """ def serialize_training_result(self, training_result: TrainingResult) -> list: """ Serializes the `training_result` into a `list`. The data type of the elements must match the data type defined in the coordinator configuration. Args: self: The participant. training_result: The `TrainingResult` of `train_round`. Returns: The `training_result` as a `list`. """ def deserialize_training_input(self, global_model: list) -> TrainingInput: """ Deserializes the `global_model` from a `list` to the type of `TrainingInput`. The data type of the elements matches the data type defined in the coordinator configuration. If no global model exists (usually in the first round), the method will not be called by the `InternalParticipant`. Args: self: The participant. global_model: The global model. Returns: The `TrainingInput` for `train_round`. """ def participate_in_update_task(self) -> bool: """ A callback used by the `InternalParticipant` to determine whether the `train_round` method should be called. This callback is only called if the participant is selected as an update participant. If `participate_in_update_task` returns `False`, `train_round` will not be called by the `InternalParticipant`. If the method is not overridden, it returns `True` by default. Returns: Whether the `train_round` method should be called when the participant is an update participant. """ def on_new_global_model(self, global_model: Optional[TrainingInput]) -> None: """ A callback that is called by the `InternalParticipant` once a new global model is available. If no global model exists (usually in the first round), `global_model` will be `None`. If a global model exists, `global_model` is already the deserialized global model. (See `deserialize_training_input`) If the method is not overridden, it does nothing by default. Args: self: The participant. global_model: The deserialized global model or `None`. """ def on_stop(self) -> None: """ A callback that is called by the `InternalParticipant` before the `InternalParticipant` thread is stopped. This callback can be used, for example, to show performance values ​​that have been collected in the participant over the course of the training rounds. If the method is not overridden, it does nothing by default. Args: self: The participant. """ class InternalParticipant: def stop(self) -> List[int]: """ Stops the execution of the participant and returns its serialized state. The serialized state can be passed to the `spawn_participant` function to restore a participant. After calling `stop`, the participant is consumed. Every further method call on the handle of `InternalParticipant` leads to an `UninitializedParticipant` exception. Note: The serialized state contains unencrypted **private key(s)**. If used in production, it is important that the serialized state is securely saved. Returns: The serialized state of the participant. """ ``` ### `AsyncParticipant` We noticed that the API of `ParticipantABC`/`InternalParticipant` reduces a fair amount of code on the user side, however, it may not be flexible enough to cover some of the following use cases: 1. The user wants to use the global/local model in a different thread. It is possible to provide methods for this on the `InternalParticipant` but they are not straight forward to implement. To make them thread-safe, it is probably necessary to use synchronization primitives but this would make the `InternalParticipant` more complicated. In addition, questions arise such as: Would the user want to be able to get the current local model at any time or would they like to be notified as soon as a new local model is available. 2. Train a model without the participant Since the training of the model is embedded in the `ParticipantABC`, this will probably lead to code duplication if the user wants to perform the training without the participant. Furthermore, the embedding of the training in the `ParticipantABC` can also be a problem once the participant is integrated into an existing application, considering the code for the training has to be moved into the `train_round` method, which can lead to significant changes to the existing code. 3. Custom exception handling Last but not least, the question arises how we can inform the user that an exception has been thrown. We do not want the participant to be terminated with every exception but we want to give the user the opportunity to respond appropriately. The main issue we saw is that the participant is responsible for training the model and to run the PET protocol. Therefore, we offer a second API in which the training of the model is no longer part of the participant. This results in a simpler and more flexible API, but it comes with the tradeoff that the user needs to perform the de/serialization of the global/local on their side. ![AsyncParticipant](../../assets/python_async_participant.svg) **Public API of `AsyncParticipant`** ```python def spawn_async_participant(coordinator_url: str, state: Optional[List[int]] = None, scalar: float = 1.0) -> (AsyncParticipant, threading.Event): """ Spawns a `AsyncParticipant` in a separate thread and returns a participant handle together with a global model notifier. If a `state` is passed, this state is restored, otherwise a new participant is created. The global model notifier sets the flag once a new global model is available. The flag is also set when the global model is `None` (usually in the first round). The flag is reset once the method `get_global_model` has been called but it is also possible to reset the flag manually by calling [`clear()`](https://docs.python.org/3/library/threading.html#threading.Event.clear). Args: coordinator_url: The url of the coordinator. state: A serialized participant state. Defaults to `None`. scalar: The scalar used for masking. Defaults to `1.0`. Note: The `scalar` is used later when the models are aggregated in order to scale their weights. It can be used when you want to weight the participants updates differently. For example: If not all participant updates should be weighted equally but proportionally to their training samples, the scalar would be set to `scalar = 1 / number_of_samples`. Returns: A tuple which consists of an `AsyncParticipant` and a global model notifier. Raises: CryptoInit: If the initialization of the underling crypto library has failed. ParticipantInit: If the participant cannot be initialized. This is most likely caused by an invalid `coordinator_url`. ParticipantRestore: If the participant cannot be restored due to invalid serialized state. This exception can never be thrown if the `state` is `None`. """ class AsyncParticipant: def get_global_model(self) -> Optional[list]: """ Fetches the current global model. This method can be called at any time. If no global model exists (usually in the first round), the method returns `None`. Returns: The current global model or `None`. The data type of the elements matches the data type defined in the coordinator configuration. Raises: GlobalModelUnavailable: If the participant cannot connect to the coordinator to get the global model. GlobalModelDataTypeMisMatch: If the data type of the global model does not match the data type defined in the coordinator configuration. """ def set_local_model(self, local_model: list): """ Sets a local model. This method can be called at any time. Internally the participant first caches the local model. As soon as the participant is selected as an update participant, the currently cached local model is used. This means that the cache is empty after this operation. If a local model is already in the cache and `set_local_model` is called with a new local model, the current cached local model will be replaced by the new one. If the participant is an update participant and there is no local model in the cache, the participant waits until a local model is set or until a new round has been started. Args: local_model: The local model. The data type of the elements must match the data type defined in the coordinator configuration. Raises: LocalModelLengthMisMatch: If the length of the local model does not match the length defined in the coordinator configuration. LocalModelDataTypeMisMatch: If the data type of the local model does not match the data type defined in the coordinator configuration. """ def stop(self) -> List[int]: """ Stops the execution of the participant and returns its serialized state. The serialized state can be passed to the `spawn_async_participant` function to restore a participant. After calling `stop`, the participant is consumed. Every further method call on the handle of `AsyncParticipant` leads to an `UninitializedParticipant` exception. Note: The serialized state contains unencrypted **private key(s)**. If used in production, it is important that the serialized state is securely saved. Returns: The serialized state of the participant. """ ``` ## Enable logging of `xaynet-mobile` If you are interested in what `xaynet-mobile` is doing under the hood, you can turn on the logging via the environment variable `XAYNET__CLIENT`. For example: `XAYNET__CLIENT=info python examples/participate_in_update.py` ## How can I ... ? We have created a few [examples](./examples/README.md) that show the basic methods in action. But if something is missing, not very clear or not working properly, please let us know by opening an issue. We are happy to help and open to ideas or feedback :) ================================================ FILE: bindings/python/examples/README.md ================================================ # Examples Some examples that show how the `ParticipantABC` or `AsyncParticipant` can be used. ## Getting Started All examples in this section work without changing the coordinator [config.toml](../../../configs/config.toml) or [docker-dev.toml](../../../configs/docker-dev.toml). - [`hello_world.py`](./hello_world.py) A basic `ParticipantABC` example - [`hello_world_async.py`](./hello_world_async.py) A basic `AsyncParticipant` example - [`download_global_model.py`](./download_global_model.py) A `ParticipantABC` that only downloads the latest global model - [`download_global_model_async.py`](./download_global_model_async.py) An `AsyncParticipant` that only downloads the latest global model - [`multiple_participants.py`](./download_global_model_async.py) Spawn multiple `ParticipantABC`s in a single process - [`participate_in_update.py`](./participate_in_update.py) Only train a model when there is enough battery left - [`restore.py`](./restore.py) Save and restore the state of an `AsyncParticipant` ## Keras House Prices - [`keras_house_prices`](./keras_house_prices/) A full machine learning example ================================================ FILE: bindings/python/examples/download_global_model.py ================================================ """A `ParticipantABC` that only downloads the latest global model""" import json import logging from typing import Optional import xaynet_sdk LOG = logging.getLogger(__name__) class Participant(xaynet_sdk.ParticipantABC): def __init__(self, model: list) -> None: self.model = model super().__init__() def deserialize_training_input(self, global_model: list) -> list: return global_model def train_round(self, training_input: Optional[list]) -> list: pass def serialize_training_result(self, training_result: list) -> list: pass def participate_in_update_task(self) -> bool: return False def on_new_global_model(self, global_model: Optional[list]) -> None: LOG.info("new global model") if global_model is not None: with open("global_model.bin", "w") as filehandle: filehandle.write(json.dumps(global_model)) def main() -> None: logging.basicConfig( format="%(asctime)s.%(msecs)03d %(levelname)8s %(message)s", level=logging.DEBUG, datefmt="%b %d %H:%M:%S", ) participant = xaynet_sdk.spawn_participant( "http://127.0.0.1:8081", Participant, args=([0.1, 0.2, 0.345, 0.3],) ) try: participant.join() except KeyboardInterrupt: participant.stop() if __name__ == "__main__": main() ================================================ FILE: bindings/python/examples/download_global_model_async.py ================================================ """An `AsyncParticipant` that only downloads the latest global model""" import json import logging import xaynet_sdk LOG = logging.getLogger(__name__) def main() -> None: logging.basicConfig( format="%(asctime)s.%(msecs)03d %(levelname)8s %(message)s", level=logging.DEBUG, datefmt="%b %d %H:%M:%S", ) (participant, global_model_notifier) = xaynet_sdk.spawn_async_participant( "http://127.0.0.1:8081" ) try: while global_model_notifier.wait(): LOG.info("a new global model") global_model = participant.get_global_model() if global_model is not None: with open("global_model.bin", "w") as filehandle: filehandle.write(json.dumps(global_model)) except KeyboardInterrupt: participant.stop() if __name__ == "__main__": main() ================================================ FILE: bindings/python/examples/hello_world.py ================================================ """A basic `ParticipantABC` example""" import json import logging import time from typing import Optional import xaynet_sdk LOG = logging.getLogger(__name__) class Participant(xaynet_sdk.ParticipantABC): def __init__(self, model: list) -> None: self.model = model super().__init__() def deserialize_training_input(self, global_model: list) -> list: return global_model def train_round(self, training_input: Optional[list]) -> list: LOG.info("training") time.sleep(3.0) LOG.info("training done") return self.model def serialize_training_result(self, training_result: list) -> list: return training_result def participate_in_update_task(self) -> bool: return True def on_new_global_model(self, global_model: Optional[list]) -> None: if global_model is not None: with open("global_model.bin", "w") as filehandle: filehandle.write(json.dumps(global_model)) def main() -> None: logging.basicConfig( format="%(asctime)s.%(msecs)03d %(levelname)8s %(message)s", level=logging.DEBUG, datefmt="%b %d %H:%M:%S", ) participant = xaynet_sdk.spawn_participant( "http://127.0.0.1:8081", Participant, args=([0.1, 0.2, 0.345, 0.3],) ) try: participant.join() except KeyboardInterrupt: participant.stop() if __name__ == "__main__": main() ================================================ FILE: bindings/python/examples/hello_world_async.py ================================================ """A basic `AsyncParticipant` example""" import logging import time import xaynet_sdk LOG = logging.getLogger(__name__) def training(): LOG.info("training") time.sleep(10.0) LOG.info("training done") def main() -> None: logging.basicConfig( format="%(asctime)s.%(msecs)03d %(levelname)8s %(message)s", level=logging.DEBUG, datefmt="%b %d %H:%M:%S", ) (participant, global_model_notifier) = xaynet_sdk.spawn_async_participant( "http://127.0.0.1:8081" ) try: while global_model_notifier.wait(): LOG.info("a new global model") participant.get_global_model() training() participant.set_local_model([0.1, 0.2, 0.345, 0.3]) except KeyboardInterrupt: participant.stop() if __name__ == "__main__": main() ================================================ FILE: bindings/python/examples/keras_house_prices/.gitignore ================================================ data/ ================================================ FILE: bindings/python/examples/keras_house_prices/README.md ================================================ # `keras_house_prices` Example **Prerequisites** - Python >=3.7.1 <=3.8 1. Adjust the coordinator settings Change the model length to `55117` and the `bound_type` to `B2` in [`docker-dev.toml`](../../../../configs/docker-dev.toml). ```toml [model] length = 55117 [mask] bound_type = "B2" ``` Curious what the `bond_type` is? You can find an explanation [here](https://docs.rs/xaynet-core/0.2.0/xaynet_core/mask/index.html#bound-type). 2. Start the coordinator ```shell # in the root of the repository docker-compose -f docker/docker-compose.yml up --build ``` **All the commands in this section are run from the `bindings/python/examples/keras_house_prices` directory.** 3. Install the SDK: Follow the installation steps described in [bindings/python/README.md](../../README.md). 4. Install the example: ```shell pip install -e . ``` 5. Download the dataset from Kaggle: https://www.kaggle.com/c/house-prices-advanced-regression-techniques/data 6. Extract the data (into `python/examples/keras_house_prices/data/` here, but the location doesn't matter): ```shell (cd ./data ; unzip house-prices-advanced-regression-techniques.zip) ``` 7. Prepare the data: ```shell split-data --data-directory data --number-of-participants 10 ``` 8. Run one participant: ```shell XAYNET__CLIENT=info run-participant --data-directory data --coordinator-url http://127.0.0.1:8081 ``` 9. Repeat the previous step to run more participants ================================================ FILE: bindings/python/examples/keras_house_prices/keras_house_prices/__init__.py ================================================ ================================================ FILE: bindings/python/examples/keras_house_prices/keras_house_prices/data_handlers/__init__.py ================================================ ================================================ FILE: bindings/python/examples/keras_house_prices/keras_house_prices/data_handlers/data_handler.py ================================================ """DataHandler base class to read, preprocess and split data for each example.""" from abc import ABC, abstractmethod import logging import os from typing import Dict, List, Optional import numpy as np import pandas as pd LOG = logging.getLogger(__name__) class DataHandler(ABC): # pylint: disable=too-many-instance-attributes """Base class to handle data preparation Args: data_directory: path to the directory where the data is stored homogeneity: The level of homogeneity in the assignment of training samples to each participants. It can take three values: - `iid`: meaning samples are randomly assigned to participants. - `intermediate`: half of the samples are randomly assigned to participants, half of the samples follow the 'total_split' logic. - `total_split`: if there are more participants than labels, samples are split among participants so that each participant has samples from only one class. if there are more classes than participants, samples are split so that no class is repeated between participants. n_participants: The number of participants into which the dataset will be split. NOTE: the random seed is set in the initialisation and will make the results reproducible. """ TEST_RATIO: float = 0.1 MINIMUM_PARTICIPANT_N_SAMPLES: int = 20 def __init__( self, data_directory: str, homogeneity: str = "iid", n_participants: int = 10, ) -> None: self.homogeneity: str = homogeneity self.n_participants: int = n_participants self.participant_ids: List[str] = [str(p) for p in range(self.n_participants)] self.data_dir: str = data_directory self.parts_dir: str = os.path.join(self.data_dir, "split_data") if not os.path.exists(self.parts_dir): os.mkdir(self.parts_dir) LOG.info("created %s dir", self.parts_dir) self.train_file_path: str = os.path.join(self.data_dir, "train.csv") self.test_file_path: str = os.path.join(self.data_dir, "test.csv") self.train_df: pd.DataFrame = pd.DataFrame() self.test_df: pd.DataFrame = pd.DataFrame() self.labels: List[str] = [] # set the seed that will be used by numpy to make the results reproducible. np.random.seed(42) def read_data(self) -> None: """Find the train_set CSV file and load it into a dataframe""" self.train_df = pd.read_csv(self.train_file_path, index_col=None) @abstractmethod def preprocess_data(self) -> None: """Abstract method to be implemented by the testcase data handling subclass, to preprocess the data. """ raise NotImplementedError() def create_testset(self) -> None: """Create testset by sampling and removing a TEST_RATIO percentage of samples from self.train_df. Save the data locally. """ n_test_samples: int = int(len(self.train_df) * self.TEST_RATIO) test_indexes: np.ndarray = np.random.choice( self.train_df.index, n_test_samples, replace=False ) self.test_df = self.train_df.loc[test_indexes, :] self.train_df = self.train_df.drop(test_indexes) self.test_df.to_csv(self.test_file_path) def make_discrete_y(self) -> pd.Series: """Split a continuous Y variable into discrete bins, one per participant. Returns: discrete_y: The discrete dependent variable. """ discrete_y: pd.Series = pd.cut( self.train_df["Y"], bins=self.n_participants, labels=range(self.n_participants), ) self.labels = list(set(discrete_y)) return discrete_y def make_iid_split( self, input_df: pd.DataFrame, target_length: int, assigned_samples: Optional[List[str]] = None, ) -> np.ndarray: """Randomly select samples so that each participant has a similar amount of samples. Args: input_df: DataFrame containing the samples to be selected. target_length: Length of the full dataset considered for IID split. assigned_samples: List of sample IDs already assigned to previous participants. Returns: The selected sample indexes. """ if assigned_samples is not None: input_df = input_df.drop(assigned_samples) samples_ids_per_participant: int = int(target_length / self.n_participants) selected_sample_ids: np.ndarray = np.random.choice( input_df.index, samples_ids_per_participant, replace=False ) return selected_sample_ids @staticmethod def split_lists( longer_list: List[str], shorter_list: List[str] ) -> Dict[str, List[str]]: """Split the lists of labels and participant IDs. We use longer and shorter list to make sure that the elements of the longer list are distributed to the elements of the shorter. For example: - If there are more participants than labels, the samples of each label will be distributed to different participants, and each participant will have samples from only one label. - If there are more labels than participants, each participant will have samples from more than one label, but samples from a single label will belong to only one participant. Args: longer_list: List of either labels or participant IDS, whichever is longer. shorter_list: List of either labels or participant IDS, whichever is shorter. Returns: Dictionary whose keys are the elements of the shorted list, and its values are a sample without replacement of the elements of the longer list. """ ratio: int = len(longer_list) // len(shorter_list) splits: List[List[str]] = [ longer_list[i : i + ratio] for i in range(0, len(longer_list), ratio) ] splits_by_shorter_element: Dict[str, List[str]] = { item: splits[i] for i, item in enumerate(shorter_list) } return splits_by_shorter_element def make_total_split( self, discrete_y: pd.Series, participant_id: str, participant_ids: List[str] ) -> np.ndarray: """Select labels for one participant. If there are more labels than participants, it will select a list of labels not assigned to any other participant. If there are more participants than labels, it will select one label only for this participant (the label may re-occur for other participants). Args: discrete_y: The discrete dependent variable. participant_id: The ID of the participant for which we are currently selecting the samples for its dataset. participant_ids: List of all participant IDs. Returns: List of selected samples for the current participant. """ labels_by_participant_id: Dict[str, List[str]] selected_labels: List[str] if len(self.labels) >= self.n_participants: labels_by_participant_id = self.split_lists( list(self.labels), participant_ids ) selected_labels = labels_by_participant_id[participant_id] else: participant_ids_by_label = self.split_lists(participant_ids, self.labels) selected_labels = [ label for label, ids in participant_ids_by_label.items() if participant_id in ids ] selected_samples: np.ndarray = np.array( [i for i, label in discrete_y.items() if label in selected_labels] ) return selected_samples def make_intermediate_split( self, assigned_samples: List[str], participant_id: str, discrete_y: pd.Series ) -> np.ndarray: """Handles an intermediate split, 50% IID and 50% total_split. Args: assigned_samples: Samples that have already been assigned to a participant. participant_id: The ID of the participant that will have samples assigned to. discrete_y: The discrete dependent variable. Raises: AssertionError: If the selected samples are not unique. Typically if there was replacement, or the random seed had not been set. Returns: The IDs of the selected samples for this participant. """ remaining_samples_df: pd.DataFrame = self.train_df.drop(assigned_samples) first_half_df: pd.DataFrame = remaining_samples_df.sample(frac=0.5) second_half_df: pd.DataFrame = remaining_samples_df.drop(first_half_df.index) target_length: int = len(self.train_df) // 2 iid_samples: np.ndarray = self.make_iid_split(first_half_df, target_length) second_half_y: pd.Series = discrete_y.loc[second_half_df.index] total_split_samples: np.ndarray = self.make_total_split( second_half_y, participant_id, self.participant_ids ) selected_samples: np.ndarray = np.concatenate( (iid_samples, total_split_samples) ) if len(set(selected_samples)) != len(selected_samples): raise AssertionError return selected_samples def split_data(self) -> None: """Split the data. Continuous variables (for regression) are made discrete only for the purpose of splitting the data (not for analysis). For each participant ID, it performs the data split according to the level of homogeneity selected. Saves the dataframe for each participant locally. """ discrete_y: pd.Series = self.make_discrete_y() np.random.shuffle(self.labels) np.random.shuffle(self.participant_ids) assigned_samples: List[str] = [] selected_samples: np.ndarray for participant_id in self.participant_ids: if self.homogeneity == "iid": selected_samples = self.make_iid_split( self.train_df, len(self.train_df), assigned_samples ) elif self.homogeneity == "total_split": selected_samples = self.make_total_split( discrete_y, participant_id, self.participant_ids ) else: selected_samples = self.make_intermediate_split( assigned_samples, participant_id, discrete_y ) participant_df: pd.DataFrame = self.train_df.loc[selected_samples, :] LOG.info( "participant %s df has shape %s", participant_id, participant_df.shape ) if len(participant_df) < self.MINIMUM_PARTICIPANT_N_SAMPLES: LOG.info( "participant %s has only %d samples.", participant_id, len(participant_df), ) LOG.info("consider decreasing the number of participants") # TODO: edge case: non-IID splits (especially 'total_split') with # too many participants may lead to an empty df. Pandas will save # the CSV anyway, but we may have problems reading the files later. # Solve this with: https://xainag.atlassian.net/browse/AP-154 output_filepath: str = os.path.join( self.parts_dir, f"data_part_{participant_id}.csv" ) participant_df.to_csv(output_filepath, index=False) LOG.info("participant df saved to %s", output_filepath) assigned_samples.extend(participant_df.index) def run(self) -> None: """One function to run them all.""" self.read_data() self.preprocess_data() self.create_testset() self.split_data() ================================================ FILE: bindings/python/examples/keras_house_prices/keras_house_prices/data_handlers/regression_data.py ================================================ """Implementation of the RegressionData subclass, to handle the data of regression examples.""" import argparse import logging from keras_house_prices.data_handlers.data_handler import DataHandler import numpy as np import pandas as pd from sklearn.preprocessing import MinMaxScaler LOG = logging.getLogger(__name__) class RegressionData(DataHandler): """Data processing logic that is specific to the house prices dataset.""" def __init__( self, data_directory: str, homogeneity: str, n_participants: int ) -> None: super().__init__( data_directory, homogeneity=homogeneity, n_participants=n_participants ) def fill_nan(self) -> None: """Filling missing data in the dataframe.""" self.train_df["PoolQC"] = self.train_df["PoolQC"].fillna("None") self.train_df["MiscFeature"] = self.train_df["MiscFeature"].fillna("None") self.train_df["Alley"] = self.train_df["Alley"].fillna("None") self.train_df["Fence"] = self.train_df["Fence"].fillna("None") self.train_df["FireplaceQu"] = self.train_df["FireplaceQu"].fillna("None") self.train_df["LotFrontage"] = self.train_df.groupby("Neighborhood")[ "LotFrontage" ].transform(lambda x: x.fillna(x.median())) for col in ("GarageType", "GarageFinish", "GarageQual", "GarageCond"): self.train_df[col] = self.train_df[col].fillna("None") for col in ("GarageYrBlt", "GarageArea", "GarageCars"): self.train_df[col] = self.train_df[col].fillna(0) for col in ( "BsmtFinSF1", "BsmtFinSF2", "BsmtUnfSF", "TotalBsmtSF", "BsmtFullBath", "BsmtHalfBath", ): self.train_df[col] = self.train_df[col].fillna(0) for col in ( "BsmtQual", "BsmtCond", "BsmtExposure", "BsmtFinType1", "BsmtFinType2", ): self.train_df[col] = self.train_df[col].fillna("None") self.train_df["MSZoning"] = self.train_df["MSZoning"].fillna( self.train_df["MSZoning"].mode()[0] ) self.train_df["MasVnrType"] = self.train_df["MasVnrType"].fillna("None") self.train_df["MasVnrArea"] = self.train_df["MasVnrArea"].fillna(0) self.train_df = self.train_df.drop(["Utilities"], axis=1) self.train_df["Functional"] = self.train_df["Functional"].fillna("Typ") self.train_df["Electrical"] = self.train_df["Electrical"].fillna( self.train_df["Electrical"].mode()[0] ) self.train_df["KitchenQual"] = self.train_df["KitchenQual"].fillna( self.train_df["KitchenQual"].mode()[0] ) self.train_df["Exterior1st"] = self.train_df["Exterior1st"].fillna( self.train_df["Exterior1st"].mode()[0] ) self.train_df["Exterior2nd"] = self.train_df["Exterior2nd"].fillna( self.train_df["Exterior2nd"].mode()[0] ) self.train_df["SaleType"] = self.train_df["SaleType"].fillna( self.train_df["SaleType"].mode()[0] ) self.train_df["MSSubClass"] = self.train_df["MSSubClass"].fillna("None") no_nulls_in_dataset = not self.train_df.isnull().values.any() if no_nulls_in_dataset: LOG.info("No missing values") LOG.info("data shape is %s", self.train_df.shape) def hot_encoding(self) -> None: """Hot encoding of the categorical features.""" self.train_df: pd.DataFrame = pd.get_dummies( self.train_df, dummy_na=True, drop_first=True ) LOG.info("data shape is %s", self.train_df.shape) def scaling(self) -> None: """Scales the features in minmax way and the process in log(1+x).""" self.train_df = self.train_df.rename(columns={"SalePrice": "Y"}) self.train_df["Y"] = np.log1p(self.train_df["Y"]) scaler = MinMaxScaler() cols = self.train_df.drop("Y", axis=1).columns train = pd.DataFrame( scaler.fit_transform(self.train_df.drop("Y", axis=1)), columns=cols ) self.train_df[cols] = train def preprocess_data(self) -> None: """Call methods that execute the preprocessing.""" self.train_df.drop("Id", axis=1, inplace=True) self.fill_nan() self.hot_encoding() self.scaling() def main() -> None: """Initialise and run the regression data preparation.""" logging.basicConfig(level=logging.DEBUG) parser = argparse.ArgumentParser(description="Prepare data for regression") parser.add_argument( "--data-directory", type=str, help="path to the directory that contains the raw data", ) parser.add_argument( "--number-of-participants", type=int, help="number of participants into which the dataset will be split", ) args = parser.parse_args() regression_data = RegressionData( args.data_directory, "total_split", args.number_of_participants, ) regression_data.run() ================================================ FILE: bindings/python/examples/keras_house_prices/keras_house_prices/participant.py ================================================ """Tensorflow Keras regression test case""" import argparse import logging import os import random from typing import List, Optional, Tuple from keras_house_prices.regressor import Regressor import numpy as np import pandas as pd from tabulate import tabulate from xaynet_sdk import ParticipantABC, spawn_participant LOG = logging.getLogger(__name__) class Participant( # pylint: disable=too-few-public-methods,too-many-instance-attributes ParticipantABC ): """An example of a Keras implementation of a participant for federated learning. The attributes for the model and the datasets are only for convenience, they might as well be loaded elsewhere. Attributes: regressor: The model to be trained. trainset_x: A dataset for training. trainset_y: Labels for training. testset_x: A dataset for test. testset_y: Labels for test. number_samples: The number of samples in the training dataset. performance_metrics: metrics collected after each round of training """ def __init__(self, dataset_dir: str) -> None: """Initialize a custom participant.""" super().__init__() self.load_random_dataset(dataset_dir) self.regressor = Regressor(len(self.trainset_x.columns)) self.performance_metrics: List[Tuple[float, float]] = [] def load_random_dataset(self, dataset_dir: str) -> None: """Load a random dataset from the data directory""" i = random.randrange(0, 10, 1) LOG.info("Train on sample number %d", i) trainset_file_path = os.path.join( dataset_dir, "split_data", f"data_part_{i}.csv" ) trainset = pd.read_csv(trainset_file_path, index_col=None) self.trainset_x = trainset.drop("Y", axis=1) self.trainset_y = trainset["Y"] self.number_of_samples = len(trainset) testset_file_path = os.path.join(dataset_dir, "test.csv") testset = pd.read_csv(testset_file_path, index_col=None) testset_x = testset.drop("Y", axis=1) self.testset_x: pd.DataFrame = testset_x.drop(testset_x.columns[0], axis=1) self.testset_y = testset["Y"] def train_round(self, training_input: Optional[np.ndarray]) -> np.ndarray: """Train a model in a federated learning round. A model is given in terms of its weights and the model is trained on the participant's dataset for a number of epochs. The weights of the updated model are returned. Args: weights: The weights of the model to be trained. Returns: The updated model weights . """ if training_input is None: # This is the first round: the coordinator doesn't have a # global model yet, so we need to initialize the weights self.regressor = Regressor(len(self.trainset_x.columns)) return self.regressor.get_weights() weights = training_input epochs = 10 self.regressor.set_weights(weights) self.regressor.train_n_epochs(epochs, self.trainset_x, self.trainset_y) loss: float r_squared: float loss, r_squared = self.regressor.evaluate_on_test( self.testset_x, self.testset_y ) LOG.info("loss = %f, R² = %f", loss, r_squared) self.performance_metrics.append((loss, r_squared)) return self.regressor.get_weights() def deserialize_training_input(self, global_model: list) -> np.ndarray: return np.array(global_model) def serialize_training_result(self, training_result: np.ndarray) -> list: return training_result.tolist() def on_stop(self) -> None: table = tabulate(self.performance_metrics, headers=["Loss", "R²"]) print(table) def main() -> None: """Entry point to start a participant.""" parser = argparse.ArgumentParser(description="Prepare data for regression") parser.add_argument( "--data-directory", type=str, help="path to the directory that contains the data", ) parser.add_argument( "--coordinator-url", type=str, required=True, help="URL of the coordinator", ) args = parser.parse_args() # pylint: disable=invalid-name logging.basicConfig( format="%(asctime)s.%(msecs)03d %(levelname)8s %(message)s", level=logging.DEBUG, datefmt="%b %d %H:%M:%S", ) participant = spawn_participant( args.coordinator_url, Participant, args=(args.data_directory,) ) try: participant.join() except KeyboardInterrupt: participant.stop() if __name__ == "__main__": main() ================================================ FILE: bindings/python/examples/keras_house_prices/keras_house_prices/regressor.py ================================================ """Wrapper for tensorflow regression neural network.""" from typing import List, Tuple import numpy as np import pandas as pd from sklearn.metrics import r2_score from tensorflow.keras import Sequential # pylint: disable=import-error from tensorflow.keras.layers import Dense # pylint: disable=import-error class Regressor: """Neural network class for the Boston pricing house problem. Attributes: model: Keras Sequential model """ def __init__(self, dim: int): self.model = Sequential() self.model.add(Dense(144, input_dim=dim, activation="relu")) self.model.add(Dense(72, activation="relu")) self.model.add(Dense(18, activation="relu")) self.model.add(Dense(1, activation="linear")) self.model.compile(optimizer="adam", loss="mean_squared_error") def train_n_epochs( self, n_epochs: int, x_train: pd.DataFrame, y_train: pd.DataFrame ) -> None: """Training function for the built in model. Args: n_epochs (int): Number of epochs to be trained. x_train (~pd.dataframe): Features dataset for training. y_train(~pd.dataframe): Labels for training. """ self.model.fit(x_train, y_train, epochs=n_epochs, verbose=0) def evaluate_on_test( self, x_test: pd.DataFrame, y_test: pd.DataFrame ) -> Tuple[float, float]: """Evaluating on testset. Args: x_test (dataframe): Feature set for evaluation. y_test (dataframe): Dependent variable for evaluation. Returns: test_loss: Value of the testing loss. r_squared: Value of R-squared, to be shown as 'accuracy' metric to the Coordinator """ y_pred: np.ndarray = self.model.predict(x_test) r_squared: float = r2_score(y_test, y_pred) test_loss: float = self.model.evaluate(x_test, y_test) return test_loss, r_squared def get_shapes(self) -> List[Tuple[int, ...]]: return [weight.shape for weight in self.model.get_weights()] def get_weights(self) -> np.ndarray: return np.concatenate(self.model.get_weights(), axis=None) def set_weights(self, weights: np.ndarray) -> None: shapes = self.get_shapes() # expand the flat weights indices: np.ndarray = np.cumsum([np.prod(shape) for shape in shapes]) tensorflow_weights: List[np.ndarray] = np.split( weights, indices_or_sections=indices ) tensorflow_weights = [ np.reshape(weight, newshape=shape) for weight, shape in zip(tensorflow_weights, shapes) ] # apply the weights to the tensorflow model self.model.set_weights(tensorflow_weights) ================================================ FILE: bindings/python/examples/keras_house_prices/setup.py ================================================ # pylint: disable=invalid-name from setuptools import find_packages, setup setup( name="keras_house_prices", version="0.1", author=["Xayn Engineering"], author_email="engineering@xaynet.dev", license="Apache License Version 2.0", python_requires=">=3.7.1, <=3.8", packages=find_packages(), install_requires=[ "pandas==1.4.3", "scikit-learn==1.1.2", "tensorflow==2.9.1", "numpy>=1.19.2,<1.24.0", "tabulate~=0.8.7", ], entry_points={ "console_scripts": [ "run-participant=keras_house_prices.participant:main", "split-data=keras_house_prices.data_handlers.regression_data:main", ] }, ) ================================================ FILE: bindings/python/examples/multiple_participants.py ================================================ """Spawn multiple `ParticipantABC`s in a single process""" import json import logging import time from typing import Optional import xaynet_sdk LOG = logging.getLogger(__name__) class Participant(xaynet_sdk.ParticipantABC): def __init__(self, p_id: int, model: list) -> None: self.p_id = p_id self.model = model super().__init__() def deserialize_training_input(self, global_model: list) -> list: return global_model def train_round(self, training_input: Optional[list]) -> list: LOG.info("participant %s: start training", self.p_id) time.sleep(5.0) LOG.info("participant %s: training done", self.p_id) return self.model def serialize_training_result(self, training_result: list) -> list: return training_result def participate_in_update_task(self) -> bool: return True def on_new_global_model(self, global_model: Optional[list]) -> None: if global_model is not None: with open("global_model.bin", "w") as filehandle: filehandle.write(json.dumps(global_model)) def main() -> None: logging.basicConfig( format="%(asctime)s.%(msecs)03d %(levelname)8s %(message)s", level=logging.DEBUG, datefmt="%b %d %H:%M:%S", ) participant = xaynet_sdk.spawn_participant( "http://127.0.0.1:8081", Participant, args=( 1, [0.1, 0.2, 0.345, 0.3], ), ) participant_2 = xaynet_sdk.spawn_participant( "http://127.0.0.1:8081", Participant, args=( 2, [0.3, 0.4, 0.45, 0.1], ), ) participant_3 = xaynet_sdk.spawn_participant( "http://127.0.0.1:8081", Participant, args=( 3, [0.123, 0.1567, 0.123, 0.46], ), ) try: participant.join() participant_2.join() participant_3.join() except KeyboardInterrupt: participant.stop() participant_2.stop() participant_3.stop() if __name__ == "__main__": main() ================================================ FILE: bindings/python/examples/participate_in_update.py ================================================ """Only train a model when there is enough battery left""" import json import logging from random import randint import time from typing import Optional import xaynet_sdk LOG = logging.getLogger(__name__) def get_battery_level(): return randint(1, 100) class Participant(xaynet_sdk.ParticipantABC): def __init__(self, model: list) -> None: self.model = model super().__init__() def deserialize_training_input(self, global_model: list) -> list: return global_model def train_round(self, training_input: Optional[list]) -> list: LOG.info("training") time.sleep(3.0) LOG.info("training done") return self.model def serialize_training_result(self, training_result: list) -> list: return training_result def participate_in_update_task(self) -> bool: if get_battery_level() < 20: LOG.info("low battery, skip training") return False LOG.info("enough battery, participate in update task") return True def on_new_global_model(self, global_model: Optional[list]) -> None: if global_model is not None: with open("global_model.bin", "w") as filehandle: filehandle.write(json.dumps(global_model)) def main() -> None: logging.basicConfig( format="%(asctime)s.%(msecs)03d %(levelname)8s %(message)s", level=logging.DEBUG, datefmt="%b %d %H:%M:%S", ) participant = xaynet_sdk.spawn_participant( "http://127.0.0.1:8081", Participant, args=([0.1, 0.2, 0.345, 0.3],) ) try: participant.join() except KeyboardInterrupt: participant.stop() if __name__ == "__main__": main() ================================================ FILE: bindings/python/examples/restore.py ================================================ """Save and restore the state of an `AsyncParticipant`""" import json import logging import xaynet_sdk LOG = logging.getLogger(__name__) def main() -> None: logging.basicConfig( format="%(asctime)s.%(msecs)03d %(levelname)8s %(message)s", level=logging.DEBUG, datefmt="%b %d %H:%M:%S", ) try: with open("state.bin", "r") as filehandle: restored_state = json.loads(filehandle.read()) except IOError: LOG.info("no saved state available, initialize new participant") restored_state = None (participant, _) = xaynet_sdk.spawn_async_participant( "http://127.0.0.1:8081", restored_state ) state = participant.stop() with open("state.bin", "w") as filehandle: filehandle.write(json.dumps(state)) if __name__ == "__main__": main() ================================================ FILE: bindings/python/migration_guide.md ================================================ # Migration from `v0.8.0` to `v.0.11.0` To demonstrate the API changes from `v0.8.0` to `v.0.11.0`, we will use the keras example which is available in both versions. For reasons of clarity, some parts of the code have been removed. ## [`v0.8.0`](https://github.com/xaynetwork/xaynet/blob/v0.8.0/python/sdk/xain_sdk/participant.py#L24) ```bash pip install xain-sdk ``` ```python from xain_sdk import ParticipantABC, configure_logging, run_participant class Participant(ParticipantABC): def train_round( self, training_input: Optional[np.ndarray] ) -> Tuple[np.ndarray, int]: if training_input is None: self.regressor = Regressor(len(self.trainset_x.columns)) return (self.regressor.get_weights(), 0) return (self.regressor.get_weights(), self.number_of_samples) def deserialize_training_input(self, data: bytes) -> Optional[np.ndarray]: if not data: return None reader = BytesIO(data) return np.load(reader, allow_pickle=False) def serialize_training_result( self, training_result: Tuple[np.ndarray, int] ) -> bytes: (weights, number_of_samples) = training_result writer = BytesIO() writer.write(number_of_samples.to_bytes(4, byteorder="big")) np.save(writer, weights, allow_pickle=False) return writer.getbuffer()[:] def main() -> None: participant = Participant(args.data_directory) run_participant( participant, args.coordinator_url, heartbeat_period=args.heartbeat_period ) ``` ## [`v0.11.0`](https://github.com/xaynetwork/xaynet/blob/v0.11.0/bindings/python/xaynet_sdk/participant.py) ```bash pip install xaynet-sdk-python ``` ```python # - renamed `run_participant` to `spawn_participant` # - removed `configure_logging` from xaynet_sdk import ParticipantABC, spawn_participant class Participant(ParticipantABC): # Returns: # - returns a `np.ndarray` instead of `Tuple[np.ndarray, int]` # The scalar has been moved to the `spawn_participant` function. # This change is only temporary. In a future version it will again # be possible to set the scalar in the `train_round` method. def train_round(self, training_input: Optional[np.ndarray]) -> np.ndarray: if training_input is None: self.regressor = Regressor(len(self.trainset_x.columns)) return self.regressor.get_weights() return self.regressor.get_weights() # Args: # - renamed `data` to `global_model` # - provides a `list` instead of `Optional[bytes]` # - `deserialize_training_input` is not called if `global_model` is `None` # therefore the `None` case no longer needs to be handled. # # Returns: # - returns a `np.ndarray` instead of `Optional[np.ndarray]` def deserialize_training_input(self, global_model: list) -> np.ndarray: return np.array(global_model) # Args: # - provides a `np.ndarray` instead of `Tuple[np.ndarray, int]` # # Returns: # - returns a `list` instead of `bytes` def serialize_training_result(self, training_result: np.ndarray) -> list: return training_result.tolist() def main() -> None: # - `spawn_participant` spawns the participant in a separate thread instead of the main thread. # # Args: # - removed `heartbeat_period` # - `Participant` is instantiated in the participant thread instead of the main thread. # This ensures that both the participant as well as the model of `Participant` live on # the same thread. If they don't live on the same thread, it can cause problems with some # of the ml frameworks. participant = spawn_participant( args.coordinator_url, Participant, args=(args.data_directory,) scalar = 1 / number_of_samples ) try: participant.join() except KeyboardInterrupt: participant.stop() ``` ================================================ FILE: bindings/python/src/lib.rs ================================================ pub mod python_ffi; ================================================ FILE: bindings/python/src/python_ffi.rs ================================================ use pyo3::create_exception; use pyo3::exceptions::PyException; use pyo3::types::PyList; use pyo3::{prelude::*, wrap_pyfunction}; use tracing::debug; use tracing_subscriber::{EnvFilter, FmtSubscriber}; use xaynet_core::mask::IntoPrimitives; use xaynet_core::mask::{DataType, FromPrimitives, Model}; use xaynet_sdk::settings::MaxMessageSize; use crate::from_primitives; use crate::into_primitives; create_exception!(xaynet_sdk, CryptoInit, PyException); create_exception!(xaynet_sdk, ParticipantInit, PyException); create_exception!(xaynet_sdk, ParticipantRestore, PyException); create_exception!(xaynet_sdk, UninitializedParticipant, PyException); create_exception!(xaynet_sdk, LocalModelLengthMisMatch, PyException); create_exception!(xaynet_sdk, LocalModelDataTypeMisMatch, PyException); create_exception!(xaynet_sdk, GlobalModelUnavailable, PyException); create_exception!(xaynet_sdk, GlobalModelDataTypeMisMatch, PyException); #[pymodule] fn xaynet_sdk(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_function(wrap_pyfunction!(init_logging, m)?)?; m.add("CryptoInit", py.get_type::())?; m.add("ParticipantInit", py.get_type::())?; m.add("ParticipantRestore", py.get_type::())?; m.add( "UninitializedParticipant", py.get_type::(), )?; m.add( "LocalModelLengthMisMatch", py.get_type::(), )?; m.add( "LocalModelDataTypeMisMatch", py.get_type::(), )?; m.add( "GlobalModelUnavailable", py.get_type::(), )?; m.add( "GlobalModelDataTypeMisMatch", py.get_type::(), )?; Ok(()) } #[pyclass] #[text_signature = "(url, scalar, /)"] struct Participant { inner: Option, } #[pymethods] impl Participant { #[new] pub fn new(url: String, scalar: f64, state: Option>) -> PyResult { sodiumoxide::init() .map_err(|_| CryptoInit::new_err("failed to initialize crypto library"))?; let inner = if let Some(state) = state { debug!("restore participant"); xaynet_mobile::Participant::restore(&state, &url).map_err(|err| { ParticipantRestore::new_err(format!("failed to restore participant: {}", err)) })? } else { debug!("initialize participant"); let mut settings = xaynet_mobile::Settings::new(); settings.set_url(url); settings.set_keys(xaynet_core::crypto::SigningKeyPair::generate()); settings.set_scalar(scalar); settings.set_max_message_size(MaxMessageSize::unlimited()); xaynet_mobile::Participant::new(settings).map_err(|err| { ParticipantInit::new_err(format!("failed to initialize participant: {}", err)) })? }; Ok(Self { inner: Some(inner) }) } #[text_signature = "($self)"] pub fn tick(&mut self) -> PyResult<()> { let inner = match self.inner { Some(ref mut inner) => inner, None => { return Err(UninitializedParticipant::new_err( "called 'tick' on an uninitialized participant. this is a bug.", )) } }; inner.tick(); Ok(()) } #[text_signature = "($self, local_model)"] pub fn set_model(&mut self, local_model: &PyList) -> PyResult<()> { let inner = match self.inner { Some(ref mut inner) => inner, None => { return Err(UninitializedParticipant::new_err( "called 'set_model' on an uninitialized participant. this is a bug.", )) } }; let local_model_config = inner.local_model_config(); if local_model.len() != local_model_config.len { return Err(LocalModelLengthMisMatch::new_err(format!( "the local model length is incompatible with the model length of the current model configuration {} != {}", local_model.len(), local_model_config.len ))); } debug!( "convert local model to {:?} datatype", local_model_config.data_type ); match local_model_config.data_type { DataType::F32 => from_primitives!(inner, local_model, f32), DataType::F64 => from_primitives!(inner, local_model, f64), DataType::I32 => from_primitives!(inner, local_model, i32), DataType::I64 => from_primitives!(inner, local_model, i64), } } /// Check whether the participant internal state machine made progress while /// executing the PET protocol. If so, the participant state likely changed. #[text_signature = "($self)"] pub fn made_progress(&self) -> PyResult { let inner = match self.inner { Some(ref inner) => inner, None => { return Err(UninitializedParticipant::new_err( "called 'made_progress' on an uninitialized participant. this is a bug.", )) } }; Ok(inner.made_progress()) } /// Check whether the participant internal state machine is waiting for the /// participant to load its model into the store. If this method returns `true`, the /// caller should make sure to call [`Participant::set_model()`] at some point. #[text_signature = "($self)"] pub fn should_set_model(&self) -> PyResult { let inner = match self.inner { Some(ref inner) => inner, None => { return Err(UninitializedParticipant::new_err( "called 'should_set_model' on an uninitialized participant. this is a bug.", )) } }; Ok(inner.should_set_model()) } #[text_signature = "($self)"] pub fn task(&self) -> PyResult { let inner = match self.inner { Some(ref inner) => inner, None => { return Err(UninitializedParticipant::new_err( "called 'task' on an uninitialized participant. this is a bug.", )) } }; // FIXME: // Returning an enum is currently not supported: https://github.com/PyO3/pyo3/pull/1045 let task_as_u8 = match inner.task() { xaynet_mobile::Task::None => 0, xaynet_mobile::Task::Sum => 1, xaynet_mobile::Task::Update => 2, }; Ok(task_as_u8) } #[text_signature = "($self)"] pub fn new_global_model(&self) -> PyResult { let inner = match self.inner { Some(ref inner) => inner, None => { return Err(UninitializedParticipant::new_err( "called 'new_global_model' on an uninitialized participant. this is a bug.", )) } }; Ok(inner.new_global_model()) } #[text_signature = "($self)"] pub fn global_model(&mut self, py: Python) -> PyResult>> { let inner = match self.inner { Some(ref mut inner) => inner, None => { return Err(UninitializedParticipant::new_err( "called 'global_model' on an uninitialized participant. this is a bug.", )) } }; let global_model = inner .global_model() .map_err(|_| GlobalModelUnavailable::new_err("failed to fetch global model"))?; let global_model = match global_model { Some(global_model) => global_model, None => return Ok(None), }; match inner.local_model_config().data_type { DataType::F32 => into_primitives!(py, global_model, f32), DataType::F64 => into_primitives!(py, global_model, f64), DataType::I32 => into_primitives!(py, global_model, i32), DataType::I64 => into_primitives!(py, global_model, i64), } } #[text_signature = "($self)"] pub fn save(&mut self) -> PyResult> { let inner = match self.inner.take() { Some(inner) => inner, None => { return Err(UninitializedParticipant::new_err( "called 'save' on an uninitialized participant. this is a bug.", )) } }; Ok(inner.save()) } } #[macro_export] macro_rules! into_primitives { ($py:expr, $global_model:expr, $data_type:ty) => { if let Ok(global_model) = $global_model .into_primitives() .collect::, _>>() { let py_list = PyList::new($py, global_model.into_iter()); Ok(Some(py_list.into())) } else { Err(GlobalModelDataTypeMisMatch::new_err( "the global model data type is incompatible with the data type of the current model configuration", )) } }; } #[macro_export] macro_rules! from_primitives { ($participant:expr, $local_model:expr, $data_type:ty) => {{ let model: Vec<$data_type> = $local_model.extract() .map_err(|err| LocalModelDataTypeMisMatch::new_err(format!("{}", err)))?; let converted_model = Model::from_primitives(model.into_iter()); if let Ok(converted_model) = converted_model { $participant.set_model(converted_model); Ok(()) } else { Err(LocalModelDataTypeMisMatch::new_err( "the local model data type is incompatible with the data type of the current model configuration" )) }} }; } #[pyfunction] fn init_logging() { let env_filter = EnvFilter::try_from_env("XAYNET__CLIENT"); if let Ok(filter) = env_filter { let _fmt_subscriber = FmtSubscriber::builder() .with_env_filter(filter) .with_ansi(true) .try_init(); } } ================================================ FILE: bindings/python/xaynet_sdk/__init__.py ================================================ import threading from typing import List, Optional, Tuple from .async_participant import * from .participant import * def spawn_participant( coordinator_url: str, participant: ParticipantABC, args: Tuple = (), kwargs: dict = {}, state: Optional[List[int]] = None, scalar: float = 1.0, ): """ Spawns a `InternalParticipant` in a separate thread and returns a participant handle. If a `state` is passed, this state is restored, otherwise a new `InternalParticipant` is created. Args: coordinator_url: The url of the coordinator. participant: A class that implements `ParticipantABC`. args: The args that get passed to the constructor of the `participant` class. kwargs: The kwargs that get passed to the constructor of the `participant` class. state: A serialized participant state. Defaults to `None`. scalar: The scalar used for masking. Defaults to `1.0`. Note: The `scalar` is used later when the models are aggregated in order to scale their weights. It can be used when you want to weight the participants updates differently. For example: If not all participant updates should be weighted equally but proportionally to their training samples, the scalar would be set to `scalar = 1 / number_of_samples`. Returns: The `InternalParticipant`. Raises: CryptoInit: If the initialization of the underling crypto library has failed. ParticipantInit: If the participant cannot be initialized. This is most likely caused by an invalid `coordinator_url`. ParticipantRestore: If the participant cannot be restored due to invalid serialized state. This exception can never be thrown if the `state` is `None`. Exception: Any exception that can be thrown during the instantiation of `participant`. """ internal_participant = InternalParticipant( coordinator_url, participant, args, kwargs, state, scalar ) # spawns the internal participant in a thread. # `start` calls the `run` method of `InternalParticipant` # https://docs.python.org/3.8/library/threading.html#threading.Thread.start # https://docs.python.org/3.8/library/threading.html#threading.Thread.run internal_participant.start() return internal_participant def spawn_async_participant( coordinator_url: str, state: Optional[List[int]] = None, scalar: float = 1.0 ) -> (AsyncParticipant, threading.Event): """ Spawns a `AsyncParticipant` in a separate thread and returns a participant handle together with a global model notifier. If a `state` is passed, this state is restored, otherwise a new participant is created. Args: coordinator_url: The url of the coordinator. state: A serialized participant state. Defaults to `None`. scalar: The scalar used for masking. Defaults to `1.0`. Note: The `scalar` is used later when the models are aggregated in order to scale their weights. It can be used when you want to weight the participants updates differently. For example: If not all participant updates should be weighted equally but proportionally to their training samples, the scalar would be set to `scalar = 1 / number_of_samples`. Returns: A tuple which consists of an `AsyncParticipant` and a global model notifier. Raises: CryptoInit: If the initialization of the underling crypto library has failed. ParticipantInit: If the participant cannot be initialized. This is most likely caused by an invalid `coordinator_url`. ParticipantRestore: If the participant cannot be restored due to invalid serialized state. This exception can never be thrown if the `state` is `None`. """ notifier = threading.Event() async_participant = AsyncParticipant(coordinator_url, notifier, state, scalar) async_participant.start() return (async_participant, notifier) ================================================ FILE: bindings/python/xaynet_sdk/async_participant.py ================================================ import logging import threading from typing import List, Optional from justbackoff import Backoff from xaynet_sdk import xaynet_sdk # rust participant logging xaynet_sdk.init_logging() # python participant logging LOG = logging.getLogger("participant") class AsyncParticipant(threading.Thread): def __init__( self, coordinator_url: str, notifier, state, scalar, ): # xaynet rust participant self._xaynet_participant = xaynet_sdk.Participant( coordinator_url, scalar, state ) self._exit_event = threading.Event() self._poll_period = Backoff(min_ms=100, max_ms=10000, factor=1.2, jitter=False) # new global model notifier self._notifier = notifier # calls to an external lib are thread-safe https://stackoverflow.com/a/42023362 # however, if a user calls `stop` in the middle of the `_tick` call, the # `save` method will be executed (which consumes the participant) and every following call # will fail with a call on an uninitialized participant. Therefore we lock during `tick`. self._tick_lock = threading.Lock() super().__init__(daemon=True) def run(self): try: self._run() except Exception as err: # pylint: disable=broad-except LOG.error("unrecoverable error: %s shut down participant", err) self._exit_event.set() def _notify(self): if self._notifier.is_set() is False: LOG.debug("notify that a new global model is available") self._notifier.set() def _run(self): while not self._exit_event.is_set(): self._tick() def _tick(self): with self._tick_lock: self._xaynet_participant.tick() new_global_model = self._xaynet_participant.new_global_model() made_progress = self._xaynet_participant.made_progress() if new_global_model: self._notify() if made_progress: self._poll_period.reset() self._exit_event.wait(timeout=self._poll_period.duration()) else: self._exit_event.wait(timeout=self._poll_period.duration()) def get_global_model(self) -> Optional[list]: """ Fetches the current global model. This method can be called at any time. If no global model exists (usually in the first round), the method returns `None`. Returns: The current global model in the form of a list or `None`. The data type of the elements match the data type defined in the coordinator configuration. Raises: GlobalModelUnavailable: If the participant cannot connect to the coordinator to get the global model. GlobalModelDataTypeMisMatch: If the data type of the global model does not match the data type defined in the coordinator configuration. """ LOG.debug("get global model") self._notifier.clear() with self._tick_lock: return self._xaynet_participant.global_model() def set_local_model(self, local_model: list): """ Sets a local model. This method can be called at any time. Internally the participant first caches the local model. As soon as the participant is selected as an update participant, the currently cached local model is used. This means that the cache is empty after this operation. If a local model is already in the cache and `set_local_model` is called with a new local model, the current cached local model will be replaced by the new one. If the participant is an update participant and there is no local model in the cache, the participant waits until a local model is set or until a new round has been started. Args: local_model: The local model in the form of a list. The data type of the elements must match the data type defined in the coordinator configuration. Raises: LocalModelLengthMisMatch: If the length of the local model does not match the length defined in the coordinator configuration. LocalModelDataTypeMisMatch: If the data type of the local model does not match the data type defined in the coordinator configuration. """ LOG.debug("set local model in model store") with self._tick_lock: self._xaynet_participant.set_model(local_model) def stop(self) -> List[int]: """ Stops the execution of the participant and returns its serialized state. The serialized state can be passed to the `spawn_async_participant` function to restore a participant. After calling `stop`, the participant is consumed. Every further method call on the handle of `AsyncParticipant` leads to an `UninitializedParticipant` exception. Note: The serialized state contains unencrypted **private key(s)**. If used in production, it is important that the serialized state is securely saved. Returns: The serialized state of the participant. """ LOG.debug("stop participant") self._exit_event.set() self._notifier.clear() with self._tick_lock: return self._xaynet_participant.save() ================================================ FILE: bindings/python/xaynet_sdk/participant.py ================================================ from abc import ABC, abstractmethod import logging import threading from typing import List, Optional, TypeVar from justbackoff import Backoff from xaynet_sdk import xaynet_sdk # rust participant logging xaynet_sdk.init_logging() # python participant logging LOG = logging.getLogger("participant") TrainingResult = TypeVar("TrainingResult") TrainingInput = TypeVar("TrainingInput") class ParticipantABC(ABC): @abstractmethod def train_round(self, training_input: Optional[TrainingInput]) -> TrainingResult: """ Trains a model. `training_input` is the deserialized global model (see `deserialize_training_input`). If no global model exists (usually in the first round), `training_input` will be `None`. In this case the weights of the model should be initialized and returned. Args: self: The participant. training_input: The deserialized global model (weights of the global model) or None. Returns: The updated model weights (the local model). """ raise NotImplementedError() @abstractmethod def serialize_training_result(self, training_result: TrainingResult) -> list: """ Serializes the `training_result` into a `list`. The data type of the elements must match the data type defined in the coordinator configuration. Args: self: The participant. training_result: The `TrainingResult` of `train_round`. Returns: The `training_result` as a `list`. """ raise NotImplementedError() @abstractmethod def deserialize_training_input(self, global_model: list) -> TrainingInput: """ Deserializes the `global_model` from a `list` to the type of `TrainingInput`. The data type of the elements matches the data type defined in the coordinator configuration. If no global model exists (usually in the first round), the method will not be called by the `InternalParticipant`. Args: self: The participant. global_model: The global model. Returns: The `TrainingInput` for `train_round`. """ raise NotImplementedError() def participate_in_update_task(self) -> bool: """ A callback used by the `InternalParticipant` to determine whether the `train_round` method should be called. This callback is only called if the participant is selected as an update participant. If `participate_in_update_task` returns `False`, `train_round` will not be called by the `InternalParticipant`. If the method is not overridden, it returns `True` by default. Returns: Whether the `train_round` method should be called when the participant is an update participant. """ return True def on_new_global_model(self, global_model: Optional[TrainingInput]) -> None: """ A callback that is called by the `InternalParticipant` once a new global model is available. If no global model exists (usually in the first round), `global_model` will be `None`. If a global model exists, `global_model` is already the deserialized global model. (See `deserialize_training_input`) If the method is not overridden, it does nothing by default. Args: self: The participant. global_model: The deserialized global model or `None`. """ def on_stop(self) -> None: """ A callback that is called by the `InternalParticipant` before the `InternalParticipant` thread is stopped. This callback can be used, for example, to show performance values ​​that have been collected in the participant over the course of the training rounds. If the method is not overridden, it does nothing by default. Args: self: The participant. """ class InternalParticipant(threading.Thread): def __init__( self, coordinator_url: str, participant, p_args, p_kwargs, state, scalar, ): # xaynet rust participant self._xaynet_participant = xaynet_sdk.Participant( coordinator_url, scalar, state ) # https://github.com/python/cpython/blob/3.9/Lib/multiprocessing/process.py#L80 # stores the Participant class with its args and kwargs # the participant is created in the `run` method to ensure that the participant/ ml # model is initialized on the participant thread otherwise the participant lives on the main # thread which can created issues with some of the ml frameworks. self._participant = participant self._p_args = tuple(p_args) self._p_kwargs = dict(p_kwargs) self._exit_event = threading.Event() self._poll_period = Backoff(min_ms=100, max_ms=10000, factor=1.2, jitter=False) # global model cache self._global_model = None self._error_on_fetch_global_model = False self._tick_lock = threading.Lock() super().__init__(daemon=True) def run(self): self._participant = self._participant(*self._p_args, *self._p_kwargs) try: self._run() except Exception as err: # pylint: disable=broad-except LOG.error("unrecoverable error: %s shut down participant", err) self._exit_event.set() def _fetch_global_model(self): LOG.debug("fetch global model") try: global_model = self._xaynet_participant.global_model() except ( xaynet_sdk.GlobalModelUnavailable, xaynet_sdk.GlobalModelDataTypeMisMatch, ) as err: LOG.warning("failed to get global model: %s", err) self._error_on_fetch_global_model = True else: if global_model is not None: self._global_model = self._participant.deserialize_training_input( global_model ) else: self._global_model = None self._error_on_fetch_global_model = False def _train(self): LOG.debug("train model") data = self._participant.train_round(self._global_model) local_model = self._participant.serialize_training_result(data) try: self._xaynet_participant.set_model(local_model) except ( xaynet_sdk.LocalModelLengthMisMatch, xaynet_sdk.LocalModelDataTypeMisMatch, ) as err: LOG.warning("failed to set local model: %s", err) def _run(self): while not self._exit_event.is_set(): self._tick() def _tick(self): with self._tick_lock: self._xaynet_participant.tick() if ( self._xaynet_participant.new_global_model() or self._error_on_fetch_global_model ): self._fetch_global_model() if not self._error_on_fetch_global_model: self._participant.on_new_global_model(self._global_model) if ( self._xaynet_participant.should_set_model() and self._participant.participate_in_update_task() and not self._error_on_fetch_global_model ): self._train() made_progress = self._xaynet_participant.made_progress() if made_progress: self._poll_period.reset() self._exit_event.wait(timeout=self._poll_period.duration()) else: self._exit_event.wait(timeout=self._poll_period.duration()) def stop(self) -> List[int]: """ Stops the execution of the participant and returns its serialized state. The serialized state can be passed to the `spawn_participant` function to restore a participant. After calling `stop`, the participant is consumed. Every further method call on the handle of `InternalParticipant` leads to an `UninitializedParticipant` exception. Note: The serialized state contains unencrypted **private key(s)**. If used in production, it is important that the serialized state is securely saved. Returns: The serialized state of the participant. """ LOG.debug("stopping participant") self._exit_event.set() with self._tick_lock: state = self._xaynet_participant.save() LOG.debug("participant stopped") self._participant.on_stop() return state ================================================ FILE: configs/config.toml ================================================ [log] filter = "xaynet=debug,http=warn,info" [api] bind_address = "127.0.0.1:8081" tls_certificate = "/app/ssl/tls.pem" tls_key = "/app/ssl/tls.key" # tls_client_auth = "/app/ssl/trust_anchor.pem" [pet.sum] prob = 0.5 count = { min = 1, max = 100 } time = { min = 5, max = 3600 } [pet.update] prob = 0.9 count = { min = 3, max = 10000 } time = { min = 10, max = 3600 } [pet.sum2] count = { min = 1, max = 100 } time = { min = 5, max = 3600 } [mask] group_type = "Prime" data_type = "F32" bound_type = "B0" model_type = "M3" [model] length = 4 [metrics.influxdb] url = "http://127.0.0.1:8086" db = "metrics" [redis] url = "redis://127.0.0.1/" [s3] access_key = "minio" secret_access_key = "minio123" region = ["minio", "http://localhost:9000"] [restore] enable = true ================================================ FILE: configs/docker-dev.toml ================================================ [log] filter = "xaynet=debug,http=warn,info" [api] bind_address = "0.0.0.0:8081" tls_certificate = "/app/ssl/tls.pem" tls_key = "/app/ssl/tls.key" # tls_client_auth = "/app/ssl/trust_anchor.pem" [pet.sum] prob = 0.01 count = { min = 1, max = 100 } time = { min = 5, max = 3600 } [pet.update] prob = 0.1 count = { min = 3, max = 10000 } time = { min = 10, max = 3600 } [pet.sum2] count = { min = 1, max = 100 } time = { min = 5, max = 3600 } [mask] group_type = "Prime" data_type = "F32" bound_type = "B0" model_type = "M3" [model] length = 4 [metrics.influxdb] url = "http://influxdb:8086" db = "metrics" [redis] url = "redis://redis" [s3] access_key = "minio" secret_access_key = "minio123" region = ["minio", "http://minio:9000"] [restore] enable = true ================================================ FILE: docker/.dev.env ================================================ MINIO_ACCESS_KEY=minio MINIO_SECRET_KEY=minio123 ================================================ FILE: docker/Dockerfile ================================================ FROM buildpack-deps:stable-curl AS builder RUN apt update # Install Rust ENV RUSTUP_HOME=/usr/local/rustup \ CARGO_HOME=/usr/local/cargo \ PATH=/usr/local/cargo/bin:$PATH RUN curl https://sh.rustup.rs -sSf | sh -s -- -y --profile minimal # install build dependencies: libc, openssl RUN apt install -y build-essential libssl-dev pkg-config COPY rust/ /rust/ WORKDIR /rust/xaynet-server # https://github.com/linkerd/linkerd2-proxy/blob/main/Dockerfile#L31 # Controls which profile the coordinator is compiled with. # If set to RELEASE_BUILD=1, the coordinator is compiled using the release profile. # Default is development profile. ARG RELEASE_BUILD=0 # Controls which optional features the coordinator is compiled with. # Syntax: # default features: - # single feature: COORDINATOR_FEATURES=tls # multiple features: COORDINATOR_FEATURES=tls,metrics # all features: COORDINATOR_FEATURES=full ARG COORDINATOR_FEATURES RUN mkdir -p /out && \ echo "RELEASE_BUILD=$RELEASE_BUILD COORDINATOR_FEATURES=$COORDINATOR_FEATURES" && \ if [ "$RELEASE_BUILD" -eq "0" ]; \ then \ cargo build --features="$COORDINATOR_FEATURES" && \ mv /rust/target/debug/coordinator /out/coordinator; \ else \ cargo build --features="$COORDINATOR_FEATURES" --release && \ mv /rust/target/release/coordinator /out/coordinator; \ fi FROM ubuntu:20.04 RUN apt update && apt install -y --no-install-recommends libssl-dev COPY --from=builder /out/coordinator /app/coordinator ENTRYPOINT ["/app/coordinator", "-c", "/app/config.toml"] ================================================ FILE: docker/docker-compose.yml ================================================ version: "3.8" services: coordinator: image: xaynetwork/xaynet:development build: context: .. dockerfile: docker/Dockerfile depends_on: - minio - redis - influxdb volumes: - ${PWD}/configs/docker-dev.toml:/app/config.toml networks: - xaynet ports: - "8081:8081" # temporary fix: # The coordinator crashes if Redis is not ready or busy at startup restart: unless-stopped influxdb: image: influxdb:1.8 hostname: influxdb container_name: influxdb environment: INFLUXDB_DB: metrics INFLUXDB_DATA_QUERY_LOG_ENABLED: 'false' INFLUXDB_HTTP_LOG_ENABLED: 'false' volumes: - influxdb-data:/var/lib/influxdb networks: - xaynet ports: - "8086:8086" minio: image: minio/minio hostname: minio container_name: minio env_file: - .dev.env command: server /data volumes: - minio-data:/data networks: - xaynet ports: - "9000:9000" redis: image: redis:6 hostname: redis container_name: redis entrypoint: /usr/local/bin/redis-server --appendonly yes --appendfsync everysec # using combination of RDB and AOF for persistence: https://redis.io/topics/persistence volumes: - redis-data:/data networks: - xaynet ports: - "6379:6379" volumes: minio-data: redis-data: influxdb-data: networks: xaynet: ================================================ FILE: k8s/coordinator/base/deployment.yaml ================================================ apiVersion: apps/v1 kind: Deployment metadata: name: coordinator-deployment spec: selector: matchLabels: app: coordinator replicas: 1 strategy: type: Recreate template: metadata: labels: app: coordinator spec: containers: - name: coordinator image: coordinator imagePullPolicy: Always ports: - containerPort: 8081 protocol: TCP env: - name: REDIS_AUTH valueFrom: secretKeyRef: name: redis-auth key: redis-password - name: XAYNET__REDIS__URL value: "redis://:$(REDIS_AUTH)@redis-master" - name: XAYNET__S3__ACCESS_KEY valueFrom: secretKeyRef: name: minio-auth key: accesskey - name: XAYNET__S3__SECRET_ACCESS_KEY valueFrom: secretKeyRef: name: minio-auth key: secretkey ================================================ FILE: k8s/coordinator/base/kustomization.yaml ================================================ apiVersion: kustomize.config.k8s.io/v1beta1 kind: Kustomization commonLabels: app.kubernetes.io/component: backend app.kubernetes.io/name: coordinator app.kubernetes.io/part-of: xaynet resources: - deployment.yaml - service.yaml ================================================ FILE: k8s/coordinator/base/service.yaml ================================================ apiVersion: v1 kind: Service metadata: name: coordinator-service spec: type: ClusterIP ports: - port: 8081 targetPort: 8081 name: http-port selector: app: coordinator ================================================ FILE: k8s/coordinator/development/cert-volume-mount.yaml ================================================ apiVersion: apps/v1 kind: Deployment metadata: name: coordinator-deployment spec: template: spec: volumes: - name: tls-certificate secret: secretName: dev-coordinator items: - key: tls.key path: tls.key mode: 0400 - key: tls.crt path: tls.pem mode: 0444 containers: - name: coordinator volumeMounts: - name: tls-certificate mountPath: "/app/ssl" readOnly: true ================================================ FILE: k8s/coordinator/development/config-volume-mount.yaml ================================================ apiVersion: apps/v1 kind: Deployment metadata: name: coordinator-deployment spec: template: spec: volumes: - name: config-volume configMap: name: config-toml items: - key: config.toml path: config.toml containers: - name: coordinator volumeMounts: - name: config-volume mountPath: /app/config.toml subPath: config.toml ================================================ FILE: k8s/coordinator/development/config.toml ================================================ [log] filter = "xaynet=debug,http=warn,info" [api] bind_address = "0.0.0.0:8081" tls_certificate = "/app/ssl/tls.pem" tls_key = "/app/ssl/tls.key" [pet.sum] prob = 0.5 count = { min = 1, max = 100 } time = { min = 5, max = 3600 } [pet.update] prob = 0.9 count = { min = 3, max = 10000 } time = { min = 10, max = 3600 } [pet.sum2] count = { min = 1, max = 100 } time = { min = 5, max = 3600 } [mask] group_type = "Prime" data_type = "F32" bound_type = "B0" model_type = "M3" [model] length = 4 [metrics.influxdb] url = "http://influxdb:8086" db = "metrics" [redis] # The url is configured via the environment variable `XAYNET__REDIS__URL`. # `XAYNET__REDIS__URL` depends on the environment variable `REDIS_AUTH`, # which is defined as a Kubernetes secret and exposed to the coordinator pod. # See: k8s/coordinator/base/deployment.yaml [s3] # The access_key and secret_access_key are configured via the environment variables # `XAYNET__S3__ACCESS_KEY` and `XAYNET__S3__SECRET_ACCESS_KEY`. # See: k8s/coordinator/base/deployment.yaml region = ["minio", "http://minio:9000"] [restore] enable = true ================================================ FILE: k8s/coordinator/development/history-limit.yaml ================================================ apiVersion: apps/v1 kind: Deployment metadata: name: coordinator-deployment spec: revisionHistoryLimit: 0 ================================================ FILE: k8s/coordinator/development/ingress.yaml ================================================ apiVersion: networking.k8s.io/v1 kind: Ingress metadata: name: coordinator-ingress annotations: kubernetes.io/ingress.class: "nginx" cert-manager.io/cluster-issuer: "letsencrypt-production" spec: tls: - hosts: - dev-coordinator.xaynet.dev secretName: dev-coordinator rules: - host: dev-coordinator.xaynet.dev http: paths: - path: / pathType: Prefix backend: service: name: coordinator-service port: number: 8081 ================================================ FILE: k8s/coordinator/development/kustomization.yaml ================================================ apiVersion: kustomize.config.k8s.io/v1beta1 kind: Kustomization namespace: xaynet images: - name: coordinator newName: xaynetwork/xaynet newTag: development configMapGenerator: - name: config-toml files: - config.toml bases: - ../base patchesStrategicMerge: - history-limit.yaml - config-volume-mount.yaml - cert-volume-mount.yaml resources: - ingress.yaml ================================================ FILE: rust/.gitignore ================================================ # https://github.com/github/gitignore/blob/master/Rust.gitignore # Generated by Cargo # will have compiled files and executables /target/ # These are backup files generated by rustfmt **/*.rs.bk /benches/target/ ================================================ FILE: rust/Cargo.toml ================================================ [workspace] members = [ "xaynet", # "xaynet-analytics", "xaynet-core", "xaynet-mobile", "xaynet-server", "xaynet-sdk", # internals "benches", "examples", ] [workspace.metadata] # minimum supported rust version msrv = "1.51.0" ================================================ FILE: rust/benches/Cargo.toml ================================================ [package] name = "benches" version = "0.0.0" authors = ["Xayn Engineering "] edition = "2018" description = "The Xayn Network project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant." readme = "../../README.md" homepage = "https://xaynet.dev/" repository = "https://github.com/xaynetwork/xaynet/" license-file = "../../LICENSE" keywords = ["federated-learning", "fl", "ai", "machine-learning"] categories = ["science", "cryptography"] publish = false [dev-dependencies] criterion = { version = "0.3.6", features = ["html_reports"] } num = "0.4.0" paste = "1.0.8" xaynet-core = { path = "../xaynet-core", features = ["testutils"] } [[bench]] name = "sum_message" path = "messages/sum.rs" harness = false [[bench]] name = "update_message" path = "messages/update.rs" harness = false [[bench]] name = "models_from_primitives" path = "models/from_primitives.rs" harness = false [[bench]] name = "models_to_primitives" path = "models/to_primitives.rs" harness = false ================================================ FILE: rust/benches/messages/sum.rs ================================================ use std::time::Duration; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use xaynet_core::{ crypto::{ByteObject, SecretSigningKey}, message::Message, testutils::messages as helpers, }; // `Message::to_bytes` takes a secret key as argument. It is not // actually used, since the message we generate already contains a // (dummy) signature. fn participant_sk() -> SecretSigningKey { SecretSigningKey::from_slice(vec![2; 64].as_slice()).unwrap() } fn to_bytes(crit: &mut Criterion) { let (sum_message, _) = helpers::message(helpers::sum::payload); let buf_len = sum_message.buffer_length(); let mut pre_allocated_buf = vec![0; buf_len]; // the benchmarks run under 20 ns. The results for such // benchmarks can vary a bit more so we: // - eliminate outliers a bit more aggressively (confidence level) // - increase the noise threshold // // Note: criterion always reports p = 0.0 so lowering the // significance level doesn't change anything let mut crit = crit.benchmark_group("serialize sum message to bytes"); crit.confidence_level(0.9).noise_threshold(0.05); crit.bench_function("compute sum message buffer length", |bench| { bench.iter(|| black_box(&sum_message).buffer_length()) }); crit.bench_function("serialize sum message to bytes", |bench| { bench.iter(|| { sum_message.to_bytes( black_box(&mut pre_allocated_buf), black_box(&participant_sk()), ) }) }); } fn from_bytes(crit: &mut Criterion) { let sum_message = helpers::message(helpers::sum::payload).0; let mut bytes = vec![0; sum_message.buffer_length()]; sum_message.to_bytes(&mut bytes, &participant_sk()); // This benchmark is also quite unstable so make it a bit more // relaxed let mut crit = crit.benchmark_group("deserialize sum message from bytes"); crit.confidence_level(0.9).noise_threshold(0.05); crit.bench_function("deserialize sum message from bytes", |bench| { bench.iter(|| Message::from_byte_slice(&black_box(bytes.as_slice()))) }); } criterion_group!( name = bench_sum_message; // By default criterion collection 100 sample and the // measurement time is 5 seconds, but the results are // quite unstable with this configuration. This // config makes the benchmarks running longer but // provide more reliable results config = Criterion::default().sample_size(1000).measurement_time(Duration::new(10, 0)); targets = to_bytes, from_bytes, ); criterion_main!(bench_sum_message); ================================================ FILE: rust/benches/messages/update.rs ================================================ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use paste::paste; use xaynet_core::{ message::{FromBytes, ToBytes, Update}, testutils::multipart as helpers, }; fn make_update(dict_len: usize, mask_len: usize, total_expected_len: usize) -> (Update, Vec) { let update = helpers::update(dict_len, mask_len); // just check that we made our calculation right // message size = dict_len + mask_len + 64*2 assert_eq!(update.buffer_length(), total_expected_len); let mut bytes = vec![0; update.buffer_length()]; update.to_bytes(&mut bytes); (update, bytes) } macro_rules! fn_from_bytes { ($name: ident, $dict_len: expr, $mask_len: expr, $total_len: expr) => { paste! { #[allow(non_snake_case)] fn [](crit: &mut Criterion) { let (_, bytes) = make_update($dict_len, $mask_len, $total_len); let name = &stringify!($name)[1..]; let mut crit = crit.benchmark_group(format!("deserialize {} update from bytes", name)); crit.bench_function( format!("deserialize {} update from bytes slice", name).as_str(), |bench| { bench.iter(|| Update::from_byte_slice(&black_box(bytes.as_slice()))) }, ); // it's less overhead to clone the iterator of bytes instead of re-creating it // again in every benchmark iteration let iter = bytes.into_iter(); crit.bench_function( format!("deserialize {} update from bytes stream", name).as_str(), |bench| { bench.iter(|| Update::from_byte_stream(black_box(&mut iter.clone()))) }, ); } } }; } // Get an update that corresponds to: // - 1 sum participant (1 entry in the seed dict) // - a 42 bytes serialized masked model fn_from_bytes!(_tiny, 116, 42, 286); // Get an update that corresponds to: // - 1k sum participants (1k entries in the seed dict) // - a 6kB serialized masked model fn_from_bytes!(_100kB, 112_004, 6_018, 118_150); // Get an update that corresponds to: // - 10k sum participants (10k entries in the seed dict) // - a 60kB serialized masked model fn_from_bytes!(_1MB, 1_120_004, 60_018, 1_180_150); // Get an update that corresponds to: // - 10k sum participants (10k entries in the seed dict) // - a ~1MB serialized masked model fn_from_bytes!(_2MB, 1_120_004, 1_000_020, 2_120_152); // Get an update that corresponds to: // - 10k sum participants (10k entries in the seed dict) // - a ~9MB serialized masked model fn_from_bytes!(_10MB, 1_120_004, 9_000_018, 10_120_150); criterion_group!( name = bench_update_message; config = Criterion::default(); targets = from_bytes_tiny, from_bytes_100kB, from_bytes_1MB, from_bytes_2MB, from_bytes_10MB, ); criterion_main!(bench_update_message); ================================================ FILE: rust/benches/models/from_primitives.rs ================================================ use std::time::Duration; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use paste::paste; use xaynet_core::mask::{FromPrimitives, Model}; fn make_vector(bytes_size: usize) -> Vec { // 1 i32 -> 4 bytes assert_eq!(bytes_size % 4, 0); let n_elements = bytes_size / 4; vec![0_i32; n_elements] } macro_rules! fn_from_primitives { ($name: ident, $size: expr) => { paste! { #[allow(non_snake_case)] fn [](crit: &mut Criterion) { let vector = make_vector($size); let name = &stringify!($name)[1..]; let iter = vector.into_iter(); crit.bench_function( format!("convert {} model from primitive vector", name).as_str(), |bench| { bench.iter(|| Model::from_primitives(black_box(iter.clone()))) }, ); } } }; } // 4 bytes fn_from_primitives!(_tiny, 4); // 100kB = 102_400 bytes fn_from_primitives!(_100kB, 102_400); // 1MB = 1_024_000 bytes fn_from_primitives!(_1MB, 1_024_000); criterion_group!( name = bench_model_from_primitives; config = Criterion::default().sample_size(1000).measurement_time(Duration::new(10, 0)); targets = from_primitives_tiny, from_primitives_100kB, from_primitives_1MB, ); criterion_main!(bench_model_from_primitives); ================================================ FILE: rust/benches/models/to_primitives.rs ================================================ use std::{iter, time::Duration}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use num::{bigint::BigInt, rational::Ratio}; use paste::paste; use xaynet_core::mask::{IntoPrimitives, Model}; fn make_model(bytes_size: usize) -> Model { // 1 i32 -> 4 bytes assert_eq!(bytes_size % 4, 0); let n_elements = bytes_size / 4; iter::repeat(Ratio::from(BigInt::from(0))) .take(n_elements) .collect() } macro_rules! fn_to_primitives { ($name: ident, $size: expr) => { paste! { #[allow(non_snake_case)] fn [](crit: &mut Criterion) { let model = make_model($size); let name = &stringify!($name)[1..]; crit.bench_function( format!("convert {} model to primitive vector", name).as_str(), |bench| { bench.iter(|| black_box(&model).to_primitives().collect::, _>>()) } ); } } }; } // 4 bytes fn_to_primitives!(_tiny, 4); // 100kB = 102_400 bytes fn_to_primitives!(_100kB, 102_400); // 1MB = 1_024_000 bytes fn_to_primitives!(_1MB, 1_024_000); criterion_group!( name = bench_model_to_primitives; config = Criterion::default().sample_size(1000).measurement_time(Duration::new(10, 0)); targets = to_primitives_tiny, to_primitives_100kB, to_primitives_1MB, ); criterion_main!(bench_model_to_primitives); ================================================ FILE: rust/examples/Cargo.toml ================================================ [package] name = "examples" version = "0.0.0" authors = ["Xayn Engineering "] edition = "2018" description = "The Xayn Network project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant." readme = "../../README.md" homepage = "https://xaynet.dev/" repository = "https://github.com/xaynetwork/xaynet/" license-file = "../../LICENSE" keywords = ["federated-learning", "fl", "ai", "machine-learning"] categories = ["science", "cryptography"] publish = false # https://github.com/http-rs/tide/issues/225 # https://github.com/dependabot/dependabot-core/issues/1156 autobins = false [dev-dependencies] async-trait = "0.1.57" reqwest = { version = "0.11.10", default-features = false, features = ["rustls-tls"] } structopt = "0.3.26" tokio = { version = "1.20.1", features = ["sync", "time", "macros", "rt-multi-thread", "signal"] } tracing = "0.1.36" tracing-futures = "0.2.5" tracing-subscriber = { version = "0.3.15", features = ["env-filter"] } xaynet-core = { path = "../xaynet-core" } xaynet-sdk = { path = "../xaynet-sdk", features = ["reqwest-client"] } [[example]] name = "test-drive" path = "test-drive/main.rs" ================================================ FILE: rust/examples/test-drive/main.rs ================================================ use std::{fs::File, io::Read, sync::Arc, time::Duration}; use structopt::StructOpt; use tracing::error_span; use tracing_futures::Instrument; use tracing_subscriber::{EnvFilter, FmtSubscriber}; use xaynet_core::{ crypto::SigningKeyPair, mask::{FromPrimitives, Model}, }; use xaynet_sdk::{ client::{Client, ClientError}, settings::PetSettings, }; mod participant; mod settings; #[tokio::main] async fn main() -> Result<(), ClientError> { let _fmt_subscriber = FmtSubscriber::builder() .with_env_filter(EnvFilter::from_default_env()) .with_ansi(true) .init(); let opt = settings::Opt::from_args(); // dummy local model for clients let len = opt.len as usize; let model = Arc::new(Model::from_primitives(vec![0; len].into_iter()).unwrap()); for id in 0..opt.nb_client { spawn_participant(id as u32, &opt, model.clone())?; } tokio::signal::ctrl_c().await.unwrap(); Ok(()) } fn generate_agent_config() -> PetSettings { let keys = SigningKeyPair::generate(); PetSettings::new(keys) } fn build_http_client(settings: &settings::Opt) -> reqwest::Client { let builder = reqwest::ClientBuilder::new(); let builder = if let Some(ref path) = settings.certificate { let mut buf = Vec::new(); File::open(path).unwrap().read_to_end(&mut buf).unwrap(); let root_cert = reqwest::Certificate::from_pem(&buf).unwrap(); builder.use_rustls_tls().add_root_certificate(root_cert) } else { builder }; let builder = if let Some(ref path) = settings.identity { let mut buf = Vec::new(); File::open(path).unwrap().read_to_end(&mut buf).unwrap(); let identity = reqwest::Identity::from_pem(&buf).unwrap(); builder.use_rustls_tls().identity(identity) } else { builder }; builder.build().unwrap() } fn spawn_participant( id: u32, settings: &settings::Opt, model: Arc, ) -> Result<(), ClientError> { let config = generate_agent_config(); let http_client = build_http_client(settings); let client = Client::new(http_client, &settings.url).unwrap(); let (participant, agent) = participant::Participant::new(config, client, model); tokio::spawn(async move { participant .run() .instrument(error_span!("participant", id = id)) .await; }); tokio::spawn(async move { agent .run(Duration::from_secs(1)) .instrument(error_span!("agent", id = id)) .await; }); Ok(()) } ================================================ FILE: rust/examples/test-drive/participant.rs ================================================ use std::{sync::Arc, time::Duration}; use async_trait::async_trait; use tokio::{sync::mpsc, time::sleep}; use tracing::{info, warn}; use xaynet_core::mask::Model; use xaynet_sdk::{ client::Client, settings::PetSettings, ModelStore, Notify, StateMachine, TransitionOutcome, XaynetClient, }; enum Event { Update, Sum, NewRound, Idle, } pub struct Participant { // FIXME: XaynetClient requires the client to be mutable. This may // make it easier to implement clients, but as a result we can't // wrap the client in an Arc, which would allow us to share the // same client with all the participants. Maybe XaynetClient // should have methods that take &self? xaynet_client: Client, notifications: mpsc::Receiver, } pub struct Agent(StateMachine); impl Agent { fn new(settings: PetSettings, xaynet_client: X, model_store: M, notify: N) -> Self where X: XaynetClient + Send + 'static, M: ModelStore + Send + 'static, N: Notify + Send + 'static, { Agent(StateMachine::new( settings, xaynet_client, model_store, notify, )) } pub async fn run(mut self, tick: Duration) { loop { self = match self.0.transition().await { TransitionOutcome::Pending(state_machine) => { sleep(tick).await; Self(state_machine) } TransitionOutcome::Complete(state_machine) => Self(state_machine), }; } } } impl Participant { pub fn new( settings: PetSettings, xaynet_client: Client, model: Arc, ) -> (Self, Agent) { let (tx, rx) = mpsc::channel::(10); let notifier = Notifier(tx); let agent = Agent::new(settings, xaynet_client.clone(), LocalModel(model), notifier); let participant = Self { xaynet_client, notifications: rx, }; (participant, agent) } pub async fn run(mut self) { use Event::*; loop { match self.notifications.recv().await { Some(Sum) => { info!("taking part in the sum task"); } Some(Update) => { info!("taking part to the update task"); } Some(Idle) => { info!("waiting"); } Some(NewRound) => { info!("new round started, downloading latest global model"); if let Err(e) = self.xaynet_client.get_model().await { warn!("failed to download latest model: {}", e); } } None => { warn!("notifications channel closed, terminating"); return; } } } } } struct Notifier(mpsc::Sender); impl Notify for Notifier { fn new_round(&mut self) { if let Err(e) = self.0.try_send(Event::NewRound) { warn!("failed to notify participant: {}", e); } } fn sum(&mut self) { if let Err(e) = self.0.try_send(Event::Sum) { warn!("failed to notify participant: {}", e); } } fn update(&mut self) { if let Err(e) = self.0.try_send(Event::Update) { warn!("failed to notify participant: {}", e); } } fn idle(&mut self) { if let Err(e) = self.0.try_send(Event::Idle) { warn!("failed to notify participant: {}", e); } } } pub struct LocalModel(Arc); #[async_trait] impl ModelStore for LocalModel { type Model = Arc; type Error = std::convert::Infallible; async fn load_model(&mut self) -> Result, Self::Error> { Ok(Some(self.0.clone())) } } ================================================ FILE: rust/examples/test-drive/settings.rs ================================================ use std::path::PathBuf; use structopt::StructOpt; #[derive(Debug, StructOpt)] #[structopt(name = "Test Drive")] pub struct Opt { #[structopt( default_value = "http://127.0.0.1:8081", short, help = "The URL of the coordinator" )] pub url: String, #[structopt(default_value = "4", short, help = "The length of the model")] pub len: u32, #[structopt( default_value = "1", short, help = "The time period at which to poll for service data, in seconds" )] pub period: u64, #[structopt(default_value = "10", short, help = "The number of clients")] pub nb_client: u32, #[structopt( short, long, parse(from_os_str), help = "Trusted DER/PEM encoded TLS server certificate" )] pub certificate: Option, #[structopt( short, long, parse(from_os_str), help = "The PEM encoded TLS client identity" )] pub identity: Option, } ================================================ FILE: rust/rustfmt.toml ================================================ # requires nightly rustfmt until the options are stabilized format_code_in_doc_comments = true imports_granularity = "Crate" imports_layout = "HorizontalVertical" ================================================ FILE: rust/xaynet/Cargo.toml ================================================ [package] name = "xaynet" version = "0.11.0" authors = ["Xayn Engineering "] edition = "2018" description = "The Xayn Network project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant." readme = "../../README.md" homepage = "https://xaynet.dev/" repository = "https://github.com/xaynetwork/xaynet/" license-file = "../../LICENSE" keywords = ["federated-learning", "fl", "ai", "machine-learning"] categories = ["science", "cryptography"] [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] [badges] codecov = { repository = "xaynetwork/xaynet", branch = "master", service = "github" } maintenance = { status = "actively-developed" } [dependencies] xaynet-core = { path = "../xaynet-core", version = "0.2.0" } # feature: mobile xaynet-mobile = { path = "../xaynet-mobile", version = "0.1.0", optional = true } # feature: sdk xaynet-sdk = { path = "../xaynet-sdk", version = "0.1.0", optional = true } # feature: server xaynet-server = { path = "../xaynet-server", version = "0.2.0", optional = true } [features] default = [] full = ["mobile", "sdk", "server"] mobile = ["xaynet-mobile"] sdk = ["xaynet-sdk"] server = ["xaynet-server"] ================================================ FILE: rust/xaynet/src/lib.rs ================================================ #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr( doc, forbid(rustdoc::broken_intra_doc_links, rustdoc::private_intra_doc_links) )] #![doc( html_logo_url = "https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/xaynet_banner.png", html_favicon_url = "https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/favicon.png", issue_tracker_base_url = "https://github.com/xaynetwork/xaynet/issues" )] //! # Xaynet: Train on the Edge with Federated Learning //! //! Want a framework that supports federated learning on the edge, in //! desktop browsers, integrates well with mobile apps, is performant, and //! preserves privacy? Welcome to XayNet, written entirely in Rust! //! //! ## Making federated learning easy for developers //! //! Frameworks for machine learning - including those expressly for //! federated learning - exist already. These frameworks typically //! facilitate federated learning of cross-silo use cases - for example in //! collaborative learning across a limited number of hospitals or for //! instance across multiple banks working on a common use case without //! the need to share valuable and sensitive data. //! //! This repository focusses on masked cross-device federated learning to //! enable the orchestration of machine learning in millions of low-power //! edge devices, such as smartphones or even cars. By doing this, we hope //! to also increase the pace and scope of adoption of federated learning //! in practice and especially allow the protection of end user data. All //! data remains in private local premises, whereby only encrypted AI //! models get automatically and asynchronously aggregated. Thus, we //! provide a solution to the AI privacy dilemma and bridge the //! often-existing gap between privacy and convenience. Imagine, for //! example, a voice assistant to learn new words directly on device level //! and sharing this knowledge with all other instances, without recording //! and collecting your voice input centrally. Or, think about search //! engine that learns to personalise search results without collecting //! your often sensitive search queries centrally… There are thousands of //! such use cases that right today still trade privacy for //! convenience. We think this shouldn’t be the case and we want to //! provide an alternative to overcome this dilemma. //! //! Concretely, we provide developers with: //! //! - **App dev tools**: An SDK to integrate federated learning into //! apps written in Dart or other languages of choice for mobile development, //! as well as frameworks like Flutter. //! - **Privacy via cross-device federated learning**: Train your AI //! models locally on edge devices such as mobile phones, browsers, //! or even in cars. Federated learning automatically aggregates the //! local models into a global model. Thus, all insights inherent in //! the local models are captured, while the user data stays //! private on end devices. //! - **Security Privacy via homomorphic encryption**: Aggregate //! models with the highest security and trust. Xayn’s masking //! protocol encrypts all models homomorphically. This enables you //! to aggregate encrypted local models into a global one – without //! having to decrypt local models at all. This protects private and //! even the most sensitive data. //! //! ## The case for writing this framework in Rust //! //! Our framework for federated learning is not only a framework for //! machine learning as such. Rather, it supports the federation of //! machine learning that takes place on possibly heterogeneous devices //! and where use cases involve many such devices. //! //! The programming language in which this framework is written should //! therefore give us strong support for the following: //! //! - **Runs "everywhere"**: the language should not require its own //! runtime and code should compile on a wide range of devices. //! - **Memory and concurrency safety**: code that compiles should be both //! memory safe and free of data races. //! - **Secure communication**: state of the art cryptography should be //! available in vetted implementations. //! - **Asynchronous communication**: abstractions for asynchronous //! communication should exist that make federated learning scale. //! - **Fast and functional**: the language should offer functional //! abstractions but also compile code into fast executables. //! //! Rust is one of the very few choices of modern programming languages //! that meets these requirements: //! //! - its concepts of Ownership and Borrowing make it both memory and //! thread-safe (hence avoiding many common concurrency issues). //! - it has a strong and static type discipline and traits, which //! describe shareable functionality of a type. //! - it is a modern systems programming language, with some functional //! style features such as pattern matching, closures and iterators. //! - its idiomatic code compares favourably to idiomatic C in performance. //! - it compiles to WASM and can therefore be applied natively in browser //! settings. //! - it is widely deployable and doesn't necessarily depend on a runtime, //! unlike languages such as Java and their need for a virtual machine //! to run its code. Foreign Function Interfaces support calls from //! other languages/frameworks, including Dart, Python and Flutter. //! - it compiles into LLVM, and so it can draw from the abundant tool //! suites for LLVM. pub use xaynet_core as core; #[cfg(feature = "mobile")] #[cfg_attr(docsrs, doc(cfg(feature = "mobile")))] pub use xaynet_mobile as mobile; #[cfg(feature = "sdk")] #[cfg_attr(docsrs, doc(cfg(feature = "sdk")))] pub use xaynet_sdk as sdk; #[cfg(feature = "server")] #[cfg_attr(docsrs, doc(cfg(feature = "server")))] pub use xaynet_server as server; ================================================ FILE: rust/xaynet-analytics/Cargo.toml ================================================ [package] name = "xaynet-analytics" version = "0.1.0" authors = ["Xayn Engineering "] edition = "2018" description = "The Xayn Network project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant." readme = "../../README.md" homepage = "https://xaynet.dev/" repository = "https://github.com/xaynetwork/xaynet/" license-file = "../../LICENSE" keywords = ["federated-learning", "fl", "ai", "machine-learning"] categories = ["science", "cryptography"] publish = false [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] [dependencies] anyhow = "1.0.59" chrono = "0.4.19" isar-core = { git = "https://github.com/isar/isar-core", rev = "59d9008be33343d1fd313c659e50e2835365a19d" } ================================================ FILE: rust/xaynet-analytics/src/controller.rs ================================================ //! In this file the `AnalyticsController` is defined. use anyhow::{anyhow, Error, Result}; use chrono::{DateTime, Datelike, Duration, NaiveDate, Utc}; use crate::{ data_combination::data_combiner::DataCombiner, database::{ analytics_event::{ adapter::AnalyticsEventAdapter, data_model::{AnalyticsEvent, AnalyticsEventType}, }, common::{CollectionNames, Repo, SchemaGenerator}, controller_data::{adapter::ControllerDataAdapter, data_model::ControllerData}, isar::IsarDb, screen_route::{adapter::ScreenRouteAdapter, data_model::ScreenRoute}, }, sender::Sender, }; /// The `AnalyticsController` is the core component of the library. It exposes public functions to the FFI wrapper, and it’s responsible for: /// - Instantiating the other necessary components (`DataCombiner`, `Sender` and `IsarDb`) /// - Receiving incoming data recorded by the mobile framework (via FFI of course) and saving them to the db via `IsarDb`. /// - Checking if the library needs to send data to the XayNet coordinator via `Sender`. /// - Holding some simple state (`self.is_charging`, `self.is_connected_to_wifi`) so that it knows whether it’s appropriate to send data to XayNet. /// /// ## Arguments /// /// * `db` - Singleton instance of `IsarDb`, used to operate with the database. /// * `is_charging` - Boolean flag representing whether the phone is currently charging or not. /// * `is_connected_to_wifi` - Boolean flag representing whether the phone is currently connected to the wifi or not. /// * `last_time_data_sent` - Timestamp representing when analytics data was last sent to the coordinator. If `None`, data was never sent before. /// * `combiner` - `DataCombiner` component responsible for calculating `DataPoints` based on `AnalyticsEvents` and `ScreenRoutes`. /// * `sender` - `Sender` component responsible for preparing the message to be sent to the coordinator for aggregation. /// * `send_frequency_hours` - `Duration` in hours representing periods within which we want to send data to the coordinator only once. struct AnalyticsController { db: IsarDb, is_charging: bool, is_connected_to_wifi: bool, last_time_data_sent: Option>, combiner: DataCombiner, sender: Sender, send_frequency_hours: Duration, } // TODO: remove allow dead code when AnalyticsController is integrated with FFI layer: https://xainag.atlassian.net/browse/XN-1415 #[allow(dead_code)] impl AnalyticsController { const MAX_SEND_FREQUENCY_HOURS: u8 = 24; pub fn init( path: String, is_charging: bool, is_connected_to_wifi: bool, input_send_frequency_hours: Option, ) -> Result { let schemas = vec![ AnalyticsEventAdapter::get_schema(&CollectionNames::ANALYTICS_EVENTS)?, ControllerDataAdapter::get_schema(&CollectionNames::CONTROLLER_DATA)?, ScreenRouteAdapter::get_schema(&CollectionNames::SCREEN_ROUTES)?, ]; let db = IsarDb::new(&path, schemas)?; let last_time_data_sent = Self::get_last_time_data_sent(&db)?; let send_frequency_hours = Self::validate_send_frequency(input_send_frequency_hours)?; Ok(AnalyticsController { db, is_charging, is_connected_to_wifi, last_time_data_sent, combiner: DataCombiner, sender: Sender, send_frequency_hours, }) } pub fn dispose(self) -> Result<(), Error> { self.db.dispose() } pub fn save_analytics_event( &self, name: &str, event_type: AnalyticsEventType, timestamp: DateTime, option_screen_route_name: Option<&str>, ) -> Result<(), Error> { let option_screen_route = option_screen_route_name .map(|screen_route_name| self.add_screen_route_if_new(screen_route_name, timestamp)) .transpose()?; let event = AnalyticsEvent::new(name, event_type, timestamp, option_screen_route); event.save(&self.db, &CollectionNames::ANALYTICS_EVENTS)?; Ok(()) } pub fn change_connectivity_status(&mut self) { self.is_connected_to_wifi = !self.is_connected_to_wifi; } pub fn change_state_of_charge(&mut self) { self.is_charging = !self.is_charging; } pub fn maybe_send_data(&mut self) -> Result<(), Error> { if self.should_send_data() { self.send_data() } else { Ok(()) } } #[cfg(test)] fn db(&self) -> &IsarDb { &self.db } /// Check whether `input_send_frequency_hours` is at most `MAX_SEND_FREQUENCY_HOURS`, otherwise return an `Error`. /// If it's lower, return a `Duration`. /// If it's `None`, assign `Self::MAX_SEND_FREQUENCY_HOURS` and turn it into a `Duration` as well. fn validate_send_frequency(input_send_frequency_hours: Option) -> Result { let send_frequency_hours = input_send_frequency_hours.unwrap_or(Self::MAX_SEND_FREQUENCY_HOURS); if send_frequency_hours > Self::MAX_SEND_FREQUENCY_HOURS { Err(anyhow!( "input_send_frequency_hours must be between 0 and {}", Self::MAX_SEND_FREQUENCY_HOURS )) } else { Ok(Duration::hours(send_frequency_hours as i64)) } } fn should_send_data(&self) -> bool { let can_send_data = self.is_charging && self.is_connected_to_wifi; can_send_data && !self.did_send_already_in_this_period() } /// Check whether the new incoming `screen_route_name` already exists in the `ScreenRoutes` saved to the db. /// If it exists, return the existing `ScreenRoute` object from the db. /// If it doesn't exist, create the new `ScreenRoute` object, save it to db, and return a clone of it. fn add_screen_route_if_new( &self, screen_route_name: &str, timestamp: DateTime, ) -> Result { let existing_screen_routes = ScreenRoute::get_all(&self.db, &CollectionNames::SCREEN_ROUTES)?; if let Some(existing_screen_route) = existing_screen_routes .into_iter() .find(|existing_route| existing_route.name == screen_route_name) { Ok(existing_screen_route) } else { let screen_route = ScreenRoute::new(screen_route_name, timestamp); screen_route .clone() .save(&self.db, &CollectionNames::SCREEN_ROUTES)?; Ok(screen_route) } } fn get_last_time_data_sent(db: &IsarDb) -> Result>, Error> { Ok( ControllerData::get_all(db, &CollectionNames::CONTROLLER_DATA)? .last() .map(|data| data.time_data_sent), ) } /// This method implements a sliding 'time window' of `self.send_frequency_hours` duration, to check whether we have /// already sent data in the current window, or not. /// /// An alternative implementation could be based on simply checking whether: /// `last_time_data_sent > Utc::now() - self.send_frequency_hours` /// /// In the current implementation, it might be easier to then group the aggregated data on the coordinator side, /// to then be displayed in the UI, especially if `self.send_frequency_hours == Duration::hours(24)`. /// /// The more dynamic approach however implies that if, for example, `self.send_frequency_hours == Duration::hours(6)`, /// and the last time we sent the data was at 5AM, we would be able to send again at 7AM, while with the simpler solution /// we wouldn't be able to send again until 11AM. /// /// The correct approach to be chosen should very much depend on the amount of data available for aggregation, /// and it's possible that `MAX_SEND_FREQUENCY_HOURS` should be increased to more than 24. /// In that case, this function below will need to be reworked, because it' coupled with `MAX_SEND_FREQUENCY_HOURS` being 24. /// /// Only once it's more clear how the aggregation will work on the coordinator side, there will be more information /// to decide the approach here. fn did_send_already_in_this_period(&self) -> bool { self.last_time_data_sent .map(|last_time_data_sent| { let now = Utc::now(); let start_of_day: DateTime = DateTime::from_utc( NaiveDate::from_ymd(now.year(), now.month(), now.day()).and_hms(0, 0, 0), Utc, ); let mut end_of_current_period = start_of_day + self.send_frequency_hours; while now > end_of_current_period { end_of_current_period = end_of_current_period + self.send_frequency_hours; } let start_of_current_period = end_of_current_period - self.send_frequency_hours; last_time_data_sent > start_of_current_period }) .unwrap_or(false) } /// Retrive all `AnalyticsEvents` and `ScreenRoutes` from the db and pass them to the `DataCombiner`. /// The `DataCombiner` will init all `DataPoints` and pack them in a `Vec`, which will be the input to the `Sender`. /// After that, save the new time_data_sent inside `ControllerData`, and cache it in `self.last_time_data_sent` fn send_data(&mut self) -> Result<(), Error> { let events = AnalyticsEvent::get_all(&self.db, &CollectionNames::ANALYTICS_EVENTS)?; let screen_routes = ScreenRoute::get_all(&self.db, &CollectionNames::SCREEN_ROUTES)?; let time_data_sent = Utc::now(); self.sender .send(self.combiner.init_data_points(&events, &screen_routes)?) .and_then(|_| { ControllerData::new(time_data_sent) .save(&self.db, &CollectionNames::CONTROLLER_DATA) }) .map(|_| self.last_time_data_sent = Some(time_data_sent)) } } #[cfg(test)] mod tests { use super::*; use std::{env, fs, path::PathBuf}; fn get_path(test_name: &str) -> PathBuf { let temp_dir = env::temp_dir(); temp_dir.join(test_name) } fn get_controller( test_name: &str, input_send_data_frequency: Option, ) -> AnalyticsController { let path_buf = get_path(test_name); let path = path_buf.to_str().unwrap().to_string(); if !path_buf.exists() { fs::create_dir(path.clone()).unwrap(); } AnalyticsController::init(path, true, true, input_send_data_frequency).unwrap() } fn remove_dir(test_name: &str) { let path = get_path(test_name); std::fs::remove_dir_all(path).unwrap(); } fn cleanup(controller: AnalyticsController, test_name: &str) { remove_dir(test_name); controller.dispose().unwrap(); } #[test] fn test_dispose() { let test_name = "test_dispose"; let controller = get_controller(test_name, None); assert!(controller.dispose().is_ok()); remove_dir(test_name); } #[test] fn test_save_analytics_event_no_screen_route() { let test_name = "test_save_analytics_event_no_screen_route"; let controller = get_controller(test_name, None); let name = "test"; let event_type = AnalyticsEventType::AppEvent; let timestamp = DateTime::parse_from_rfc3339("2021-01-01T01:01:00+00:00") .unwrap() .with_timezone(&Utc); let existing_analytics_events = AnalyticsEvent::get_all(controller.db(), CollectionNames::ANALYTICS_EVENTS).unwrap(); assert!(existing_analytics_events.is_empty()); assert!(controller .save_analytics_event(name, event_type, timestamp, None) .is_ok()); let analytics_event = AnalyticsEvent::new(name, event_type, timestamp, None); let all_analytics_events = AnalyticsEvent::get_all(controller.db(), CollectionNames::ANALYTICS_EVENTS).unwrap(); assert_eq!(all_analytics_events.len(), 1); assert_eq!(all_analytics_events.first(), Some(&analytics_event)); cleanup(controller, test_name); } #[test] fn test_save_analytics_event_with_screen_route() { let test_name = "test_save_analytics_event_with_screen_route"; let controller = get_controller(test_name, None); let name = "test"; let event_type = AnalyticsEventType::ScreenEnter; let timestamp = DateTime::parse_from_rfc3339("2021-01-01T01:01:00+00:00") .unwrap() .with_timezone(&Utc); let screen_route_name = "route"; assert!(controller .save_analytics_event(name, event_type, timestamp, Some(screen_route_name)) .is_ok()); let screen_route = ScreenRoute::new(screen_route_name, timestamp); let analytics_event = AnalyticsEvent::new(name, event_type, timestamp, Some(screen_route)); let all_analytics_events = AnalyticsEvent::get_all(controller.db(), CollectionNames::ANALYTICS_EVENTS).unwrap(); assert_eq!(all_analytics_events.len(), 1); assert_eq!(all_analytics_events.first(), Some(&analytics_event)); cleanup(controller, test_name); } #[test] fn test_change_connectivity_status() { let test_name = "test_change_connectivity_status"; let mut controller = get_controller(test_name, None); assert!(controller.is_connected_to_wifi); controller.change_connectivity_status(); assert!(!controller.is_connected_to_wifi); cleanup(controller, test_name); } #[test] fn test_change_state_of_charge() { let test_name = "test_change_state_of_charge"; let mut controller = get_controller(test_name, None); assert!(controller.is_charging); controller.change_state_of_charge(); assert!(!controller.is_charging); cleanup(controller, test_name); } #[test] fn test_validate_send_data_frequency_when_none() { assert_eq!( AnalyticsController::validate_send_frequency(None).unwrap(), Duration::hours(AnalyticsController::MAX_SEND_FREQUENCY_HOURS as i64) ) } #[test] fn test_validate_send_data_frequency_when_more_than_24() { assert!(AnalyticsController::validate_send_frequency(Some(25)).is_err()); } #[test] fn test_validate_send_data_frequency_when_less_than_24() { assert_eq!( AnalyticsController::validate_send_frequency(Some(6)).unwrap(), Duration::hours(6) ) } #[test] fn test_validate_send_data_frequency_when_0() { assert_eq!( AnalyticsController::validate_send_frequency(Some(0)).unwrap(), Duration::hours(0) ) } #[test] fn test_add_screen_route_if_new_with_new_route() { let test_name = "test_add_screen_route_if_new_with_new_route"; let controller = get_controller(test_name, None); let screen_route_name = "route"; let timestamp = DateTime::parse_from_rfc3339("2021-01-01T01:01:00+00:00") .unwrap() .with_timezone(&Utc); let screen_route = ScreenRoute::new(screen_route_name, timestamp); let existing_screen_routes = ScreenRoute::get_all(controller.db(), CollectionNames::SCREEN_ROUTES).unwrap(); assert!(existing_screen_routes.is_empty()); assert_eq!( controller .add_screen_route_if_new(screen_route_name, timestamp) .unwrap(), screen_route ); let retrieved_screen_routes = ScreenRoute::get_all(controller.db(), CollectionNames::SCREEN_ROUTES).unwrap(); assert_eq!(retrieved_screen_routes.len(), 1); assert_eq!(retrieved_screen_routes.first(), Some(&screen_route)); cleanup(controller, test_name); } #[test] fn test_add_screen_route_if_new_without_new_route() { let test_name = "test_add_screen_route_if_new_without_new_route"; let controller = get_controller(test_name, None); let screen_route_name = "route"; let first_timestamp = DateTime::parse_from_rfc3339("2021-01-01T01:01:00+00:00") .unwrap() .with_timezone(&Utc); let first_screen_route = ScreenRoute::new(screen_route_name, first_timestamp); assert!(controller .add_screen_route_if_new(screen_route_name, first_timestamp) .is_ok()); // if we call controller.add_screen_route_if_new() with the same screen_route_name, but a new_timestamp, // we expect to get the first_screen_route back, with the first_timestamp let new_timestamp = DateTime::parse_from_rfc3339("2021-02-02T02:02:00+00:00") .unwrap() .with_timezone(&Utc); assert_eq!( controller .add_screen_route_if_new(screen_route_name, new_timestamp) .unwrap(), first_screen_route ); let retrieved_screen_routes = ScreenRoute::get_all(controller.db(), CollectionNames::SCREEN_ROUTES).unwrap(); assert_eq!(retrieved_screen_routes.len(), 1); assert_eq!(retrieved_screen_routes.first(), Some(&first_screen_route)); cleanup(controller, test_name); } #[test] fn test_get_last_time_data_sent() { let test_name = "test_get_last_time_data_sent_is_none"; let controller = get_controller(test_name, None); let last_time_data_sent = AnalyticsController::get_last_time_data_sent(controller.db()); assert!(last_time_data_sent.is_ok()); assert!(last_time_data_sent.unwrap().is_none()); let timestamp = DateTime::parse_from_rfc3339("2021-03-03T03:03:00+00:00") .unwrap() .with_timezone(&Utc); let controller_data = ControllerData::new(timestamp); let existing_controller_data = ControllerData::get_all(controller.db(), CollectionNames::CONTROLLER_DATA).unwrap(); assert!(existing_controller_data.is_empty()); assert!(controller_data .save(controller.db(), CollectionNames::CONTROLLER_DATA) .is_ok()); let last_time_data_sent = AnalyticsController::get_last_time_data_sent(controller.db()); assert!(last_time_data_sent.is_ok()); assert_eq!(last_time_data_sent.unwrap(), Some(timestamp)); cleanup(controller, test_name); } #[test] fn test_did_send_already_in_this_period_never_sent_before() { let test_name = "test_did_send_already_in_this_period_never_sent_before"; let controller = get_controller(test_name, Some(24)); assert!(!controller.did_send_already_in_this_period()); cleanup(controller, test_name); } #[test] fn test_did_send_already_in_this_period_inside_24h() { let test_name = "test_did_send_already_in_this_period_inside_24h"; let initial_controller = get_controller(test_name, Some(24)); let timestamp = Utc::now(); let controller_data = ControllerData::new(timestamp); assert!(controller_data .save(initial_controller.db(), CollectionNames::CONTROLLER_DATA) .is_ok()); // init controller again, to read self.last_time_data_sent from db assert!(initial_controller.dispose().is_ok()); let controller = get_controller(test_name, Some(24)); assert!(controller.did_send_already_in_this_period()); cleanup(controller, test_name); } #[test] fn test_did_send_already_in_this_period_outside_24h() { let test_name = "test_did_send_already_in_this_period_outside_24h"; let initial_controller = get_controller(test_name, Some(24)); let timestamp = Utc::now() - Duration::hours(25); let controller_data = ControllerData::new(timestamp); assert!(controller_data .save(initial_controller.db(), CollectionNames::CONTROLLER_DATA) .is_ok()); // init controller again, to read self.last_time_data_sent from db assert!(initial_controller.dispose().is_ok()); let controller = get_controller(test_name, Some(24)); assert!(!controller.did_send_already_in_this_period()); cleanup(controller, test_name); } #[test] fn test_did_send_already_in_this_period_inside_12h() { let test_name = "test_did_send_already_in_this_period_inside_12h"; let initial_controller = get_controller(test_name, Some(12)); let timestamp = Utc::now(); let controller_data = ControllerData::new(timestamp); assert!(controller_data .save(initial_controller.db(), CollectionNames::CONTROLLER_DATA) .is_ok()); // init controller again, to read self.last_time_data_sent from db assert!(initial_controller.dispose().is_ok()); let controller = get_controller(test_name, Some(12)); assert!(controller.did_send_already_in_this_period()); cleanup(controller, test_name); } #[test] fn test_did_send_already_in_this_period_outside_12h() { let test_name = "test_did_send_already_in_this_period_outside_12h"; let initial_controller = get_controller(test_name, Some(12)); let timestamp = Utc::now() - Duration::hours(13); let controller_data = ControllerData::new(timestamp); assert!(controller_data .save(initial_controller.db(), CollectionNames::CONTROLLER_DATA) .is_ok()); // init controller again, to read self.last_time_data_sent from db assert!(initial_controller.dispose().is_ok()); let controller = get_controller(test_name, Some(12)); assert!(!controller.did_send_already_in_this_period()); cleanup(controller, test_name); } #[test] fn test_did_send_already_in_this_period_inside_6h() { let test_name = "test_did_send_already_in_this_period_inside_6h"; let initial_controller = get_controller(test_name, Some(6)); let timestamp = Utc::now(); let controller_data = ControllerData::new(timestamp); assert!(controller_data .save(initial_controller.db(), CollectionNames::CONTROLLER_DATA) .is_ok()); // init controller again, to read self.last_time_data_sent from db assert!(initial_controller.dispose().is_ok()); let controller = get_controller(test_name, Some(6)); assert!(controller.did_send_already_in_this_period()); cleanup(controller, test_name); } #[test] fn test_did_send_already_in_this_period_outside_6h() { let test_name = "test_did_send_already_in_this_period_outside_6h"; let initial_controller = get_controller(test_name, Some(6)); let timestamp = Utc::now() - Duration::hours(7); let controller_data = ControllerData::new(timestamp); assert!(controller_data .save(initial_controller.db(), CollectionNames::CONTROLLER_DATA) .is_ok()); // init controller again, to read self.last_time_data_sent from db assert!(initial_controller.dispose().is_ok()); let controller = get_controller(test_name, Some(6)); assert!(!controller.did_send_already_in_this_period()); cleanup(controller, test_name); } #[test] fn test_did_send_already_in_this_period_outside_twice_6h() { let test_name = "test_did_send_already_in_this_period_outside_twice_6h"; let initial_controller = get_controller(test_name, Some(6)); let timestamp = Utc::now() - Duration::hours(13); let controller_data = ControllerData::new(timestamp); assert!(controller_data .save(initial_controller.db(), CollectionNames::CONTROLLER_DATA) .is_ok()); // init controller again, to read self.last_time_data_sent from db assert!(initial_controller.dispose().is_ok()); let controller = get_controller(test_name, Some(6)); assert!(!controller.did_send_already_in_this_period()); cleanup(controller, test_name); } #[test] fn test_did_send_already_in_this_period_outside_thrice_6h() { let test_name = "test_did_send_already_in_this_period_outside_thrice_6h"; let initial_controller = get_controller(test_name, Some(6)); let timestamp = Utc::now() - Duration::hours(19); let controller_data = ControllerData::new(timestamp); assert!(controller_data .save(initial_controller.db(), CollectionNames::CONTROLLER_DATA) .is_ok()); // init controller again, to read self.last_time_data_sent from db assert!(initial_controller.dispose().is_ok()); let controller = get_controller(test_name, Some(6)); assert!(!controller.did_send_already_in_this_period()); cleanup(controller, test_name); } } ================================================ FILE: rust/xaynet-analytics/src/data_combination/data_combiner.rs ================================================ //! Declaration and implementation of `DataCombiner`. use anyhow::{Error, Result}; use chrono::{DateTime, Datelike, Duration, NaiveDate, Utc}; use std::iter::empty; use crate::{ data_combination::data_points::data_point::{ CalcScreenActiveTime, CalcScreenEnterCount, CalcWasActiveEachPastPeriod, CalcWasActivePastNDays, DataPoint, DataPointMetadata, Period, PeriodUnit, }, database::{ analytics_event::data_model::AnalyticsEvent, screen_route::data_model::ScreenRoute, }, }; /// `DataCombiner` is responsible for instantiating the `DataPoint` variants. When it’s time to send the data to XayNet, /// the `AnalyticsEvents` and `ScreenRoutes` are retrieved from the db (by the `AnalyticsController`) and passed to the `DataCombiner`, /// which then instantiates the various `DataPoint` variants and packs them in a `Vec`, which will be utilised by the `Sender`. /// /// Possible improvements include: /// - Move the `DataPointMetadatas` to a sort of config, and pass them to the `DataCombiner`. /// - Turn `DataCombiner` into a trait on each `DataPoint`. /// See: https://xainag.atlassian.net/browse/XN-1651 pub struct DataCombiner; impl<'screen> DataCombiner { pub fn init_data_points( &self, events: &[AnalyticsEvent], screen_routes: &[ScreenRoute], ) -> Result, Error> { let end_period = Utc::now(); let one_day_period_metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period); let was_active_each_period_metadatas = vec![ DataPointMetadata::new(Period::new(PeriodUnit::Days, 7), end_period), DataPointMetadata::new(Period::new(PeriodUnit::Weeks, 6), end_period), DataPointMetadata::new(Period::new(PeriodUnit::Months, 3), end_period), ]; let was_active_past_days_metadatas = vec![ one_day_period_metadata, DataPointMetadata::new(Period::new(PeriodUnit::Days, 7), end_period), DataPointMetadata::new(Period::new(PeriodUnit::Days, 28), end_period), ]; let data_points = empty::() .chain(Self::init_screen_active_time_vars( one_day_period_metadata, events, screen_routes, )) .chain(Self::init_screen_enter_count_vars( one_day_period_metadata, events, screen_routes, )) .chain(Self::init_was_active_each_past_period_vars( was_active_each_period_metadatas, events, )) .chain(Self::init_was_active_past_n_days_vars( was_active_past_days_metadatas, events, )) .collect(); Ok(data_points) } fn init_screen_active_time_vars( metadata: DataPointMetadata, events: &[AnalyticsEvent], screen_routes: &[ScreenRoute], ) -> Vec { let mut screen_active_time_vars: Vec = screen_routes .iter() .map(|route| { let events_this_route = Self::get_events_single_route(route, events); CalcScreenActiveTime::new( metadata, Self::filter_events_in_this_period(metadata, events_this_route.as_slice()), ) }) .map(DataPoint::ScreenActiveTime) .collect(); screen_active_time_vars.push(DataPoint::ScreenActiveTime(CalcScreenActiveTime::new( metadata, Self::filter_events_in_this_period(metadata, events), ))); screen_active_time_vars } fn init_screen_enter_count_vars( metadata: DataPointMetadata, events: &[AnalyticsEvent], screen_routes: &[ScreenRoute], ) -> Vec { screen_routes .iter() .map(|route| { let events_this_route = Self::get_events_single_route(&route, events); CalcScreenEnterCount::new( metadata, Self::filter_events_in_this_period(metadata, events_this_route.as_slice()), ) }) .map(DataPoint::ScreenEnterCount) .collect() } fn init_was_active_each_past_period_vars( metadatas: Vec, events: &[AnalyticsEvent], ) -> Vec { metadatas .iter() .map(|metadata| { let period_thresholds = (0..metadata.period.n) .map(|i| Self::get_start_of_period(*metadata, Some(i))) .collect(); CalcWasActiveEachPastPeriod::new( *metadata, Self::filter_events_in_this_period(*metadata, events), period_thresholds, ) }) .map(DataPoint::WasActiveEachPastPeriod) .collect() } fn init_was_active_past_n_days_vars( metadatas: Vec, events: &[AnalyticsEvent], ) -> Vec { metadatas .iter() .map(|metadata| { CalcWasActivePastNDays::new( *metadata, Self::filter_events_in_this_period(*metadata, events), ) }) .map(DataPoint::WasActivePastNDays) .collect() } // TODO: return an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517 fn filter_events_in_this_period( metadata: DataPointMetadata, events: &[AnalyticsEvent], ) -> Vec { let start_of_period = Self::get_start_of_period(metadata, None); Self::filter_events_before_end_of_period(metadata.end, events) .iter() .filter(|event| event.timestamp > start_of_period) .cloned() .collect() } fn get_start_of_period( metadata: DataPointMetadata, n_periods_override: Option, ) -> DateTime { let n_periods = if let Some(n_periods) = n_periods_override { n_periods } else { metadata.period.n }; let avg_days_per_month = 365.0 / 12.0; let midnight_end_of_period = get_midnight(metadata.end); match metadata.period.unit { PeriodUnit::Days => midnight_end_of_period - Duration::days(n_periods as i64), PeriodUnit::Weeks => midnight_end_of_period - Duration::weeks(n_periods as i64), PeriodUnit::Months => { midnight_end_of_period - Duration::days((n_periods as f64 * avg_days_per_month) as i64) } } } // TODO: return an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517 fn filter_events_before_end_of_period( end_of_period: DateTime, events: &[AnalyticsEvent], ) -> Vec { let midnight_end_of_period = get_midnight(end_of_period); events .iter() .filter(|event| event.timestamp < midnight_end_of_period) .cloned() .collect() } // TODO: return an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517 fn get_events_single_route( route: &ScreenRoute, all_events: &[AnalyticsEvent], ) -> Vec { all_events .iter() .filter(|event| event.screen_route.as_ref() == Some(route)) .cloned() .collect() } } fn get_midnight(timestamp: DateTime) -> DateTime { DateTime::from_utc( NaiveDate::from_ymd(timestamp.year(), timestamp.month(), timestamp.day()).and_hms(0, 0, 0), Utc, ) } #[cfg(test)] mod tests { use chrono::{DateTime, Duration, Utc}; use crate::{ data_combination::{ data_combiner::{get_midnight, DataCombiner}, data_points::data_point::{ CalcScreenActiveTime, CalcScreenEnterCount, CalcWasActiveEachPastPeriod, CalcWasActivePastNDays, DataPoint, DataPointMetadata, Period, PeriodUnit, }, }, database::{ analytics_event::data_model::{AnalyticsEvent, AnalyticsEventType}, screen_route::data_model::ScreenRoute, }, }; #[test] fn test_init_screen_active_time_vars() { let end_period = DateTime::parse_from_rfc3339("2021-01-01T01:01:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period); let screen_route = ScreenRoute::new("home_screen", end_period + Duration::days(1)); let first_event = AnalyticsEvent::new( "test1", AnalyticsEventType::ScreenEnter, end_period - Duration::hours(12), Some(screen_route.clone()), ); let all_events = vec![ first_event.clone(), AnalyticsEvent::new( "test1", AnalyticsEventType::AppEvent, end_period - Duration::hours(13), None, ), ]; let expected_output = vec![ DataPoint::ScreenActiveTime(CalcScreenActiveTime::new(metadata, vec![first_event])), DataPoint::ScreenActiveTime(CalcScreenActiveTime::new(metadata, all_events.clone())), ]; let actual_output = DataCombiner::init_screen_active_time_vars(metadata, &all_events, &[screen_route]); assert_eq!(actual_output, expected_output); } #[test] fn test_init_screen_enter_count_vars() { let end_period = DateTime::parse_from_rfc3339("2021-02-02T02:02:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period); let screen_route = ScreenRoute::new("home_screen", end_period + Duration::days(1)); let events = vec![AnalyticsEvent::new( "test1", AnalyticsEventType::ScreenEnter, end_period - Duration::hours(12), Some(screen_route.clone()), )]; let expected_output = vec![DataPoint::ScreenEnterCount(CalcScreenEnterCount::new( metadata, events.clone(), ))]; let actual_output = DataCombiner::init_screen_enter_count_vars(metadata, &events, &[screen_route]); assert_eq!(actual_output, expected_output); } #[test] fn test_init_was_active_each_past_period_vars() { let end_period = DateTime::parse_from_rfc3339("2021-03-03T03:03:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period); let events = vec![AnalyticsEvent::new( "test1", AnalyticsEventType::AppEvent, end_period - Duration::hours(12), None, )]; let period_thresholds = vec![get_midnight(end_period)]; let expected_output = vec![DataPoint::WasActiveEachPastPeriod( CalcWasActiveEachPastPeriod::new(metadata, events.clone(), period_thresholds), )]; let actual_output = DataCombiner::init_was_active_each_past_period_vars(vec![metadata], &events); assert_eq!(actual_output, expected_output); } #[test] fn test_init_was_active_past_n_days_vars() { let end_period = DateTime::parse_from_rfc3339("2021-04-04T04:04:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period); let events = vec![AnalyticsEvent::new( "test1", AnalyticsEventType::AppEvent, end_period - Duration::hours(12), None, )]; let expected_output = vec![DataPoint::WasActivePastNDays(CalcWasActivePastNDays::new( metadata, events.clone(), ))]; let actual_output = DataCombiner::init_was_active_past_n_days_vars(vec![metadata], &events); assert_eq!(actual_output, expected_output); } #[test] fn test_filter_events_in_this_period() { let end_period = DateTime::parse_from_rfc3339("2021-05-05T05:05:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 3), end_period); let event_before = AnalyticsEvent::new( "test1", AnalyticsEventType::AppEvent, end_period - Duration::days(5), None, ); let event_during = AnalyticsEvent::new( "test2", AnalyticsEventType::AppEvent, end_period - Duration::days(1), None, ); let event_after = AnalyticsEvent::new( "test3", AnalyticsEventType::AppEvent, end_period + Duration::days(2), None, ); let events = vec![event_before, event_during.clone(), event_after]; let expected_output = vec![event_during]; let actual_output = DataCombiner::filter_events_in_this_period(metadata, &events); assert_eq!(actual_output, expected_output); } #[test] fn test_get_start_of_period_one_day() { let end_period = DateTime::parse_from_rfc3339("2021-01-01T00:00:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period); let expected_output = end_period - Duration::days(1); let actual_output = DataCombiner::get_start_of_period(metadata, None); assert_eq!(actual_output, expected_output); } #[test] fn test_get_start_of_period_one_day_with_override() { let end_period = DateTime::parse_from_rfc3339("2021-02-02T00:00:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 2), end_period); let expected_output = end_period - Duration::days(1); let actual_output = DataCombiner::get_start_of_period(metadata, Some(1)); assert_eq!(actual_output, expected_output); } #[test] fn test_get_start_of_period_one_week() { let end_period = DateTime::parse_from_rfc3339("2021-03-03T00:00:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Weeks, 1), end_period); let expected_output = end_period - Duration::weeks(1); let actual_output = DataCombiner::get_start_of_period(metadata, None); assert_eq!(actual_output, expected_output); } #[test] fn test_get_start_of_period_one_month() { let end_period = DateTime::parse_from_rfc3339("2021-04-04T00:00:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Months, 1), end_period); let expected_output = end_period - Duration::days(30); let actual_output = DataCombiner::get_start_of_period(metadata, None); assert_eq!(actual_output, expected_output); } #[test] fn text_filter_events_before_end_of_period() { let end_of_period = Utc::now(); let event_before = AnalyticsEvent::new( "test1", AnalyticsEventType::AppEvent, end_of_period - Duration::days(1), None, ); let event_after = AnalyticsEvent::new( "test2", AnalyticsEventType::AppEvent, end_of_period + Duration::days(1), None, ); let events = vec![event_before.clone(), event_after]; let expected_output = vec![event_before]; let actual_output = DataCombiner::filter_events_before_end_of_period(end_of_period, &events); assert_eq!(actual_output, expected_output); } #[test] fn test_get_events_single_route() { let timestamp = Utc::now(); let home_route = ScreenRoute::new("home_screen", timestamp + Duration::days(1)); let other_route = ScreenRoute::new("other_screen", timestamp + Duration::days(2)); let home_route_event = AnalyticsEvent::new( "test1", AnalyticsEventType::AppEvent, timestamp, Some(home_route.clone()), ); let other_route_event = AnalyticsEvent::new( "test2", AnalyticsEventType::ScreenEnter, timestamp, Some(other_route), ); let all_events = [home_route_event.clone(), other_route_event]; let expected_output = vec![home_route_event]; let actual_output = DataCombiner::get_events_single_route(&home_route, &all_events); assert_eq!(actual_output, expected_output); } #[test] fn test_get_midnight() { let timestamp = DateTime::parse_from_rfc3339("2021-01-01T21:21:21-02:00") .unwrap() .with_timezone(&Utc); let expected_output = DateTime::parse_from_rfc3339("2021-01-01T00:00:00-00:00") .unwrap() .with_timezone(&Utc); let actual_output = get_midnight(timestamp); assert_eq!(actual_output, expected_output); } } ================================================ FILE: rust/xaynet-analytics/src/data_combination/data_points/data_point.rs ================================================ //! File containing various structs used to define `DataPoints`. use chrono::{DateTime, Utc}; use crate::database::analytics_event::data_model::AnalyticsEvent; #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum PeriodUnit { Days, Weeks, Months, } /// Period combines information about the unit of this period, and the number of periods. /// For example a `Period` of three weeks can be represented with `Period::new(unit: PeriodUnit::Weeks, n: 3)` #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct Period { pub unit: PeriodUnit, pub n: u32, } impl Period { pub fn new(unit: PeriodUnit, n: u32) -> Self { Self { unit, n } } } /// `DataPointMetadata` contains information about `Period` and when the period ends. It is used to /// define which `AnalyticsEvents` fall inside a `Period` and must therefore be included in the calculation /// of a specific `DataPoint`. #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct DataPointMetadata { pub period: Period, pub end: DateTime, } impl DataPointMetadata { pub fn new(period: Period, end: DateTime) -> Self { Self { period, end } } } pub trait CalculateDataPoints { fn metadata(&self) -> DataPointMetadata; fn calculate(&self) -> Vec; } /// `DataPoint` is an enum whose variants represent data points that will need to be aggregated and shown to the user. /// They are the actual analytics information that is valuable to the user. Each `DataPoint` refers to a specific `Period`. /// ## Variants: /// * `ScreenActiveTime`: How much time was spent on a specific screen. /// * `ScreenEnterCount`: How many times the user entered a specific screen. /// * `WasActiveEachPastPeriod`: Whether the user was active or not in each specified period (in general, not by screen). /// * `WasActivePastNDays`: Whether the user was active or not in the past N days (in general, not by screen). /// /// There are still more variants to be implemented: https://xainag.atlassian.net/browse/XN-1687 #[derive(Debug, PartialEq, Eq)] pub enum DataPoint { ScreenActiveTime(CalcScreenActiveTime), ScreenEnterCount(CalcScreenEnterCount), WasActiveEachPastPeriod(CalcWasActiveEachPastPeriod), WasActivePastNDays(CalcWasActivePastNDays), } #[allow(dead_code)] // TODO: will be called when preparing the data to be sent to the coordinator impl DataPoint { fn calculate(&self) -> Vec { match self { DataPoint::ScreenActiveTime(data) => data.calculate(), DataPoint::ScreenEnterCount(data) => data.calculate(), DataPoint::WasActiveEachPastPeriod(data) => data.calculate(), DataPoint::WasActivePastNDays(data) => data.calculate(), } } } #[derive(Debug, PartialEq, Eq)] // TODO: accept an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517 pub struct CalcScreenActiveTime { pub metadata: DataPointMetadata, pub events: Vec, } #[derive(Debug, PartialEq, Eq)] // TODO: accept an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517 pub struct CalcScreenEnterCount { pub metadata: DataPointMetadata, pub events: Vec, } #[derive(Debug, PartialEq, Eq)] // TODO: accept an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517 pub struct CalcWasActiveEachPastPeriod { pub metadata: DataPointMetadata, pub events: Vec, pub period_thresholds: Vec>, } #[derive(Debug, PartialEq, Eq)] // TODO: accept an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517 pub struct CalcWasActivePastNDays { pub metadata: DataPointMetadata, pub events: Vec, } ================================================ FILE: rust/xaynet-analytics/src/data_combination/data_points/mod.rs ================================================ pub mod data_point; pub mod screen_active_time; pub mod screen_enter_count; pub mod was_active_each_past_period; pub mod was_active_past_n_days; ================================================ FILE: rust/xaynet-analytics/src/data_combination/data_points/screen_active_time.rs ================================================ use chrono::Duration; use crate::{ data_combination::data_points::data_point::{ CalcScreenActiveTime, CalculateDataPoints, DataPointMetadata, }, database::analytics_event::data_model::{AnalyticsEvent, AnalyticsEventType}, }; impl CalcScreenActiveTime { pub fn new(metadata: DataPointMetadata, events: Vec) -> Self { Self { metadata, events } } // TODO: return an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517 fn get_screen_and_app_events(&self) -> Vec { self.events .iter() .filter(|event| { matches!( event.event_type, AnalyticsEventType::ScreenEnter | AnalyticsEventType::AppEvent ) }) .cloned() .collect() } } impl CalculateDataPoints for CalcScreenActiveTime { fn metadata(&self) -> DataPointMetadata { self.metadata } fn calculate(&self) -> Vec { let screen_and_app_events = self.get_screen_and_app_events(); let value = if screen_and_app_events.is_empty() { 0 } else { screen_and_app_events .iter() .scan( screen_and_app_events.first().unwrap().timestamp, |last_timestamp, event| { let duration = if event.event_type == AnalyticsEventType::ScreenEnter { last_timestamp.signed_duration_since(event.timestamp) } else { Duration::zero() }; *last_timestamp = event.timestamp; Some(duration) }, ) .map(|duration| duration.num_milliseconds() as u32) .sum() }; vec![value] } } #[cfg(test)] mod tests { use chrono::{DateTime, Duration, Utc}; use super::*; use crate::{ data_combination::data_points::data_point::{Period, PeriodUnit}, database::screen_route::data_model::ScreenRoute, }; #[test] fn test_get_screen_and_app_events() { let end_period = DateTime::parse_from_rfc3339("2021-01-01T01:01:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period); let screen_route = ScreenRoute::new("home_screen", end_period + Duration::days(1)); let screen_enter_event = AnalyticsEvent::new( "test1", AnalyticsEventType::ScreenEnter, end_period - Duration::hours(10), Some(screen_route), ); let app_event = AnalyticsEvent::new( "test1", AnalyticsEventType::AppEvent, end_period - Duration::hours(12), None, ); let events = vec![ screen_enter_event.clone(), AnalyticsEvent::new( "test1", AnalyticsEventType::AppError, end_period - Duration::hours(11), None, ), app_event.clone(), AnalyticsEvent::new( "test1", AnalyticsEventType::UserAction, end_period - Duration::hours(13), None, ), ]; let screen_active_time = CalcScreenActiveTime::new(metadata, events); let expected_output = vec![screen_enter_event, app_event]; let actual_output = screen_active_time.get_screen_and_app_events(); assert_eq!(actual_output, expected_output); } #[test] fn test_calculate_when_no_events() { let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), Utc::now()); let screen_active_time = CalcScreenActiveTime::new(metadata, Vec::new()); assert_eq!(screen_active_time.calculate(), vec![0]); } #[test] fn test_calculate_when_one_screen_enter_event() { let end_period = DateTime::parse_from_rfc3339("2021-03-03T03:03:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period); let screen_route = ScreenRoute::new("home_screen", end_period + Duration::days(1)); let events = vec![AnalyticsEvent::new( "test1", AnalyticsEventType::ScreenEnter, end_period - Duration::hours(12), Some(screen_route), )]; let screen_active_time = CalcScreenActiveTime::new(metadata, events); assert_eq!(screen_active_time.calculate(), vec![0]); } #[test] fn test_calculate_when_two_screen_enter_events() { let end_period = DateTime::parse_from_rfc3339("2021-03-03T03:03:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period); let screen_route = ScreenRoute::new("home_screen", end_period + Duration::days(1)); let events = vec![ AnalyticsEvent::new( "test1", AnalyticsEventType::ScreenEnter, end_period - Duration::hours(12), Some(screen_route.clone()), ), AnalyticsEvent::new( "test2", AnalyticsEventType::ScreenEnter, end_period - Duration::hours(15), Some(screen_route), ), ]; let time_between_events = events.first().unwrap().timestamp - events.last().unwrap().timestamp; let screen_active_time = CalcScreenActiveTime::new(metadata, events); assert_eq!( screen_active_time.calculate(), vec![time_between_events.num_milliseconds() as u32] ); } #[test] fn test_calculate_when_mixed_type_events() { let end_period = DateTime::parse_from_rfc3339("2021-04-04T04:04:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period); let screen_route = ScreenRoute::new("home_screen", end_period + Duration::days(1)); let first = AnalyticsEvent::new( "test1", AnalyticsEventType::ScreenEnter, end_period - Duration::hours(12), Some(screen_route.clone()), ); let second = AnalyticsEvent::new( "test1", AnalyticsEventType::AppEvent, end_period - Duration::hours(13), None, ); let third = AnalyticsEvent::new( "test2", AnalyticsEventType::ScreenEnter, end_period - Duration::hours(14), Some(screen_route.clone()), ); let fourth = AnalyticsEvent::new( "test2", AnalyticsEventType::ScreenEnter, end_period - Duration::hours(14), Some(screen_route), ); let events = vec![first.clone(), second.clone(), third.clone(), fourth.clone()]; let time_between_events = first.timestamp - second.timestamp + (third.timestamp - fourth.timestamp); let screen_active_time = CalcScreenActiveTime::new(metadata, events); assert_eq!( screen_active_time.calculate(), vec![time_between_events.num_milliseconds() as u32] ); } } ================================================ FILE: rust/xaynet-analytics/src/data_combination/data_points/screen_enter_count.rs ================================================ use crate::{ data_combination::data_points::data_point::{ CalcScreenEnterCount, CalculateDataPoints, DataPointMetadata, }, database::analytics_event::data_model::{AnalyticsEvent, AnalyticsEventType}, }; impl CalcScreenEnterCount { pub fn new(metadata: DataPointMetadata, events: Vec) -> Self { Self { metadata, events } } } impl CalculateDataPoints for CalcScreenEnterCount { fn metadata(&self) -> DataPointMetadata { self.metadata } fn calculate(&self) -> Vec { let value = self .events .iter() .filter(|event| event.event_type == AnalyticsEventType::ScreenEnter) .count() as u32; vec![value] } } #[cfg(test)] mod tests { use chrono::{DateTime, Duration, Utc}; use super::*; use crate::{ data_combination::data_points::data_point::{Period, PeriodUnit}, database::screen_route::data_model::ScreenRoute, }; #[test] fn test_calculate_when_no_events() { let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), Utc::now()); let screen_enter_count = CalcScreenEnterCount::new(metadata, Vec::new()); assert_eq!(screen_enter_count.calculate(), vec![0]); } #[test] fn test_calculate_when_one_event() { let end_period = DateTime::parse_from_rfc3339("2021-01-01T01:01:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period); let screen_route = ScreenRoute::new("home_screen", end_period + Duration::days(1)); let events = vec![AnalyticsEvent::new( "test1", AnalyticsEventType::ScreenEnter, end_period - Duration::hours(12), Some(screen_route), )]; let screen_enter_count = CalcScreenEnterCount::new(metadata, events); assert_eq!(screen_enter_count.calculate(), vec![1]); } #[test] fn test_calculate_when_two_events() { let end_period = DateTime::parse_from_rfc3339("2021-02-02T02:02:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period); let screen_route = ScreenRoute::new("home_screen", end_period + Duration::days(1)); let events = vec![ AnalyticsEvent::new( "test1", AnalyticsEventType::ScreenEnter, end_period - Duration::hours(9), Some(screen_route.clone()), ), AnalyticsEvent::new( "test2", AnalyticsEventType::ScreenEnter, end_period - Duration::hours(18), Some(screen_route), ), ]; let screen_enter_count = CalcScreenEnterCount::new(metadata, events); assert_eq!(screen_enter_count.calculate(), vec![2]); } } ================================================ FILE: rust/xaynet-analytics/src/data_combination/data_points/was_active_each_past_period.rs ================================================ use chrono::{DateTime, Utc}; use std::collections::BTreeMap; use crate::{ data_combination::data_points::data_point::{ CalcWasActiveEachPastPeriod, CalculateDataPoints, DataPointMetadata, }, database::analytics_event::data_model::AnalyticsEvent, }; impl CalcWasActiveEachPastPeriod { pub fn new( metadata: DataPointMetadata, events: Vec, period_thresholds: Vec>, ) -> Self { Self { metadata, events, period_thresholds, } } // TODO: this could possibly later be optimised by sorting events by timestamp and caching the last timestamp added to the HashMap fn group_timestamps_by_period_threshold(&self) -> BTreeMap, Vec>> { let mut timestamps_by_period_threshold = BTreeMap::new(); for these_thresholds in self.period_thresholds.windows(2) { // safe unwrap: `windows` guarantees that there are at least two elements. // If `period_thresholds` contains less than two elements, this code block is not executed let newer_threshold = these_thresholds.first().unwrap(); let older_threshold = these_thresholds.last().unwrap(); let timestamps: Vec> = self .events .iter() .filter(|event| { event.timestamp < *newer_threshold && event.timestamp > *older_threshold }) .map(|event| event.timestamp) .collect(); timestamps_by_period_threshold.insert(*newer_threshold, timestamps); } timestamps_by_period_threshold } } impl CalculateDataPoints for CalcWasActiveEachPastPeriod { fn metadata(&self) -> DataPointMetadata { self.metadata } fn calculate(&self) -> Vec { let timestamps_by_period_threshold = self.group_timestamps_by_period_threshold(); // since we are travelling 'back in time' we need to reverse the order of the values of the BTreeMap timestamps_by_period_threshold .values() .rev() .map(|timestamps| !timestamps.is_empty() as u32) .collect::>() } } #[cfg(test)] mod tests { use chrono::{DateTime, Duration, Utc}; use super::*; use crate::{ data_combination::data_points::data_point::{Period, PeriodUnit}, database::analytics_event::data_model::AnalyticsEventType, }; #[test] fn test_calculate_no_events_in_a_period() { let end_period = DateTime::parse_from_rfc3339("2021-02-02T00:00:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period); let period_thresholds = vec![end_period, end_period - Duration::days(1)]; let was_active_each_past_period = CalcWasActiveEachPastPeriod::new(metadata, Vec::new(), period_thresholds); assert_eq!(was_active_each_past_period.calculate(), vec![0]); } #[test] fn test_calculate_one_event_in_a_period() { let end_period = DateTime::parse_from_rfc3339("2021-03-03T00:00:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period); let events = vec![AnalyticsEvent::new( "test1", AnalyticsEventType::UserAction, end_period - Duration::hours(12), None, )]; let period_thresholds = vec![end_period, end_period - Duration::days(1)]; let was_active_each_past_period = CalcWasActiveEachPastPeriod::new(metadata, events, period_thresholds); assert_eq!(was_active_each_past_period.calculate(), vec![1]); } #[test] fn test_calculate_no_events_in_two_periods() { let end_period = DateTime::parse_from_rfc3339("2021-04-04T00:00:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 2), end_period); let period_thresholds = vec![ end_period, end_period - Duration::days(1), end_period - Duration::days(2), ]; let was_active_each_past_period = CalcWasActiveEachPastPeriod::new(metadata, Vec::new(), period_thresholds); assert_eq!(was_active_each_past_period.calculate(), vec![0, 0]); } #[test] fn test_calculate_one_event_in_one_period_zero_in_another() { let end_period = DateTime::parse_from_rfc3339("2021-05-05T00:00:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 2), end_period); let events = vec![AnalyticsEvent::new( "test1", AnalyticsEventType::UserAction, end_period - Duration::hours(12), None, )]; let period_thresholds = vec![ end_period, end_period - Duration::days(1), end_period - Duration::days(2), ]; let was_active_each_past_period = CalcWasActiveEachPastPeriod::new(metadata, events, period_thresholds); assert_eq!(was_active_each_past_period.calculate(), vec![1, 0]); } #[test] fn test_calculate_two_events_in_one_period_zero_in_another() { let end_period = DateTime::parse_from_rfc3339("2021-06-06T00:00:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 2), end_period); let events = vec![ AnalyticsEvent::new( "test1", AnalyticsEventType::UserAction, end_period - Duration::hours(12), None, ), AnalyticsEvent::new( "test2", AnalyticsEventType::AppError, end_period - Duration::hours(15), None, ), ]; let period_thresholds = vec![ end_period, end_period - Duration::days(1), end_period - Duration::days(2), ]; let was_active_each_past_period = CalcWasActiveEachPastPeriod::new(metadata, events, period_thresholds); assert_eq!(was_active_each_past_period.calculate(), vec![1, 0]); } #[test] fn test_calculate_two_periods_with_one_event_each() { let end_period = DateTime::parse_from_rfc3339("2021-07-07T00:00:00-00:00") .unwrap() .with_timezone(&Utc); let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 2), end_period); let events = vec![ AnalyticsEvent::new( "test1", AnalyticsEventType::UserAction, end_period - Duration::hours(12), None, ), AnalyticsEvent::new( "test2", AnalyticsEventType::AppError, end_period - Duration::hours(36), None, ), ]; let period_thresholds = vec![ end_period, end_period - Duration::days(1), end_period - Duration::days(2), ]; let was_active_each_past_period = CalcWasActiveEachPastPeriod::new(metadata, events, period_thresholds); assert_eq!(was_active_each_past_period.calculate(), vec![1, 1]); } } ================================================ FILE: rust/xaynet-analytics/src/data_combination/data_points/was_active_past_n_days.rs ================================================ use crate::{ data_combination::data_points::data_point::{ CalcWasActivePastNDays, CalculateDataPoints, DataPointMetadata, }, database::analytics_event::data_model::AnalyticsEvent, }; impl CalcWasActivePastNDays { pub fn new(metadata: DataPointMetadata, events: Vec) -> Self { Self { metadata, events } } } impl CalculateDataPoints for CalcWasActivePastNDays { fn metadata(&self) -> DataPointMetadata { self.metadata } fn calculate(&self) -> Vec { vec![!self.events.is_empty() as u32] } } #[cfg(test)] mod tests { use chrono::{Duration, Utc}; use super::*; use crate::{ data_combination::data_points::data_point::{Period, PeriodUnit}, database::analytics_event::data_model::AnalyticsEventType, }; #[test] fn test_calculate_without_events() { let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), Utc::now()); let was_active_past_n_days = CalcWasActivePastNDays::new(metadata, Vec::new()); assert_eq!(was_active_past_n_days.calculate(), vec![0]); } #[test] fn test_calculate_with_events() { let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), Utc::now()); let events = vec![AnalyticsEvent::new( "test1", AnalyticsEventType::AppEvent, metadata.end - Duration::hours(12), None, )]; let was_active_past_n_days = CalcWasActivePastNDays::new(metadata, events); assert_eq!(was_active_past_n_days.calculate(), vec![1]); } } ================================================ FILE: rust/xaynet-analytics/src/data_combination/mod.rs ================================================ pub mod data_combiner; pub mod data_points; ================================================ FILE: rust/xaynet-analytics/src/database/analytics_event/adapter.rs ================================================ //! This file contains struct and impls for `AnalyticsEventAdapter` and `AnalyticsEventRelationalAdapter`, //! as well as the implementation of `IsarAdapter` for `AnalyticsEventAdapter`. use anyhow::{anyhow, Error, Result}; use isar_core::object::{ data_type::DataType, isar_object::{IsarObject, Property}, object_builder::ObjectBuilder, }; use std::{convert::TryFrom, vec::IntoIter}; use crate::database::{ common::{FieldProperty, IsarAdapter, RelationalField, Repo, SchemaGenerator}, isar::IsarDb, screen_route::data_model::ScreenRoute, }; /// `AnalyticsEventAdapter` allows to convert an `IsarObject` from the db to an `AnalyticsEvent`. It is an intermediate /// representation. #[derive(Debug, PartialEq, Eq, Clone)] pub struct AnalyticsEventAdapter { pub name: String, pub event_type: i32, pub timestamp: String, pub screen_route_field: Option, } impl AnalyticsEventAdapter { pub fn new>( name: N, event_type: i32, timestamp: String, screen_route_field: Option, ) -> Self { Self { name: name.into(), event_type, timestamp, screen_route_field: screen_route_field.map(|field| field.into()), } } } impl<'event> IsarAdapter<'event> for AnalyticsEventAdapter { fn get_oid(&self) -> String { format!("{}-{}", self.name, self.timestamp) } fn into_field_properties() -> IntoIter { vec![ FieldProperty::new("oid", DataType::String, true), FieldProperty::new("name", DataType::String, false), FieldProperty::new("event_type", DataType::Int, false), FieldProperty::new("timestamp", DataType::String, false), FieldProperty::new("screen_route_field", DataType::String, false), ] .into_iter() } fn write_with_object_builder(&self, object_builder: &mut ObjectBuilder) { object_builder.write_string(Some(&self.get_oid())); object_builder.write_string(Some(&self.name)); object_builder.write_int(self.event_type); object_builder.write_string(Some(&self.timestamp)); object_builder.write_string(self.screen_route_field.as_deref()); } fn read( isar_object: &'event IsarObject, isar_properties: &'event [(String, Property)], ) -> Result { let name_property = Self::find_property_by_name("name", isar_properties)?; let event_type_property = Self::find_property_by_name("event_type", isar_properties)?; let timestamp_property = Self::find_property_by_name("timestamp", isar_properties)?; let screen_route_field_property = Self::find_property_by_name("screen_route_field", isar_properties)?; let name_field = isar_object .read_string(name_property) .ok_or_else(|| anyhow!("unable to read name"))?; let event_type_field = isar_object.read_int(event_type_property); let timestamp_field = isar_object .read_string(timestamp_property) .ok_or_else(|| anyhow!("unable to read timestamp"))? .to_string(); let screen_route_field = isar_object .read_string(screen_route_field_property) .map(RelationalField::try_from) .transpose()?; Ok(AnalyticsEventAdapter::new( name_field, event_type_field, timestamp_field, screen_route_field, )) } } impl<'event> SchemaGenerator<'event, AnalyticsEventAdapter> for AnalyticsEventAdapter {} /// `AnalyticsEventRelationalAdapter` is needed as an intermediate step when saving/retrieving events /// from the db because `AnalyticsEvent` contains an `Option`, which, if `Some`, needs to be retrieved /// from a different collection in Isar. pub struct AnalyticsEventRelationalAdapter { pub name: String, pub event_type: i32, pub timestamp: String, pub screen_route: Option, } impl AnalyticsEventRelationalAdapter { pub fn new(adapter: AnalyticsEventAdapter, db: &IsarDb) -> Result { let screen_route = adapter .screen_route_field .map(|screen_route_field| { let relational_field = RelationalField::try_from(screen_route_field.as_str())?; ScreenRoute::get( &relational_field.value, db, &relational_field.collection_name, ) }) .transpose()?; Ok(Self { name: adapter.name, event_type: adapter.event_type, timestamp: adapter.timestamp, screen_route, }) } } ================================================ FILE: rust/xaynet-analytics/src/database/analytics_event/data_model.rs ================================================ //! In this file `AnalyticsEvent` and `AnalyticsEventType` are declared, together with some conversion methods to/from adapters. use anyhow::{anyhow, Result}; use chrono::{DateTime, Utc}; use std::convert::{From, Into, TryFrom, TryInto}; use crate::database::{ analytics_event::adapter::{AnalyticsEventAdapter, AnalyticsEventRelationalAdapter}, common::RelationalField, screen_route::data_model::ScreenRoute, }; /// The type of `AnalyticsEvent` recorded on the framework side. /// ## Variants: /// * `AppEvent`: It refes to Flutter's `AppLifeCyclesEvents` (of the equivalent in other frameworks): /// https://flutter.dev/docs/get-started/flutter-for/android-devs#how-do-i-listen-to-android-activity-lifecycle-events /// * `AppError`: A known error logged by the developers /// * `ScreenEnter`: Registers when the user enters a specific screen /// * `UserAction`: A custom event logged by the developer (eg: clicked on a specific button) #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum AnalyticsEventType { AppEvent = 0, AppError = 1, ScreenEnter = 2, UserAction = 3, } impl TryFrom for AnalyticsEventType { type Error = anyhow::Error; fn try_from(v: i32) -> Result { match v { x if x == AnalyticsEventType::AppEvent as i32 => Ok(AnalyticsEventType::AppEvent), x if x == AnalyticsEventType::AppError as i32 => Ok(AnalyticsEventType::AppError), x if x == AnalyticsEventType::ScreenEnter as i32 => Ok(AnalyticsEventType::ScreenEnter), x if x == AnalyticsEventType::UserAction as i32 => Ok(AnalyticsEventType::UserAction), _ => Err(anyhow!( "i32 value {:?} is not mapped to an AnalyticsEventType variant", v )), } } } /// The core data model of the library. It represents an event recorded on the mobile framework side. /// It can be logged manually by the developers, or automatically detected by Flutter/the mobile framework side. /// ## Fields: /// * `name`: The name of the event. /// * `event_type`: The type of event. /// * `timestamp`: When the event was created. /// * `screen_route`: Optional field representing the screen on which the event was recorded. #[derive(Debug, PartialEq, Eq, Clone)] pub struct AnalyticsEvent { pub name: String, pub event_type: AnalyticsEventType, pub timestamp: DateTime, pub screen_route: Option, } impl AnalyticsEvent { pub fn new>( name: N, event_type: AnalyticsEventType, timestamp: DateTime, screen_route: Option, ) -> Self { Self { name: name.into(), event_type, timestamp, screen_route, } } } impl TryFrom for AnalyticsEvent { type Error = anyhow::Error; fn try_from(adapter: AnalyticsEventRelationalAdapter) -> Result { let event = AnalyticsEvent::new( adapter.name, adapter .event_type .try_into() .map_err(|_| anyhow!("unable to convert event_type into enum"))?, DateTime::parse_from_rfc3339(&adapter.timestamp)?.with_timezone(&Utc), adapter.screen_route, ); Ok(event) } } impl From for AnalyticsEventAdapter { fn from(ae: AnalyticsEvent) -> Self { AnalyticsEventAdapter::new( ae.name, ae.event_type as i32, ae.timestamp.to_rfc3339(), ae.screen_route.map(RelationalField::from), ) } } #[cfg(test)] mod tests { use super::*; use crate::database::common::CollectionNames; #[test] fn test_analytics_event_type_try_from_valid_i32() { assert_eq!( AnalyticsEventType::try_from(0).unwrap(), AnalyticsEventType::AppEvent ); assert_eq!( AnalyticsEventType::try_from(1).unwrap(), AnalyticsEventType::AppError ); assert_eq!( AnalyticsEventType::try_from(2).unwrap(), AnalyticsEventType::ScreenEnter ); assert_eq!( AnalyticsEventType::try_from(3).unwrap(), AnalyticsEventType::UserAction ); } #[test] fn test_analytics_event_type_invalid_i32() { assert!(AnalyticsEventType::try_from(42).is_err()); } #[test] fn test_analytics_event_try_from_relational_adapter_without_screen_route() { let timestamp = "2021-01-01T01:01:00+00:00"; let relational_adapter = AnalyticsEventRelationalAdapter { name: "test".to_string(), event_type: 0, timestamp: timestamp.to_string(), screen_route: None, }; let analytics_event = AnalyticsEvent::new( "test", AnalyticsEventType::AppEvent, DateTime::parse_from_rfc3339(timestamp) .unwrap() .with_timezone(&Utc), None, ); assert_eq!( AnalyticsEvent::try_from(relational_adapter).unwrap(), analytics_event ); } #[test] fn test_analytics_event_try_from_relational_adapter_with_screen_route() { let timestamp_str = "2021-01-01T01:01:00+00:00"; let timestamp_parsed = DateTime::parse_from_rfc3339(timestamp_str) .unwrap() .with_timezone(&Utc); let screen_route = ScreenRoute::new("route", timestamp_parsed); let relational_adapter = AnalyticsEventRelationalAdapter { name: "test".to_string(), event_type: 2, timestamp: timestamp_str.to_string(), screen_route: Some(screen_route.clone()), }; let analytics_event = AnalyticsEvent::new( "test", AnalyticsEventType::ScreenEnter, timestamp_parsed, Some(screen_route), ); assert_eq!( AnalyticsEvent::try_from(relational_adapter).unwrap(), analytics_event ); } #[test] fn test_analytics_event_try_into_adapter_without_screen_route() { let timestamp_str = "2021-01-01T01:01:00+00:00"; let timestamp_parsed = DateTime::parse_from_rfc3339(timestamp_str) .unwrap() .with_timezone(&Utc); let analytics_event = AnalyticsEvent::new("test", AnalyticsEventType::AppError, timestamp_parsed, None); let actual_analytics_event_adapter: AnalyticsEventAdapter = analytics_event.try_into().unwrap(); let expected_analytics_event_adapter = AnalyticsEventAdapter::new("test", 1, timestamp_str.to_string(), None); assert_eq!( actual_analytics_event_adapter, expected_analytics_event_adapter ); } #[test] fn test_analytics_event_try_into_adapter_with_screen_route() { let timestamp_str = "2021-01-01T01:01:00+00:00"; let timestamp_parsed = DateTime::parse_from_rfc3339(timestamp_str) .unwrap() .with_timezone(&Utc); let screen_route = ScreenRoute::new("route", timestamp_parsed); let relationa_field = RelationalField { value: "route".to_string(), collection_name: CollectionNames::SCREEN_ROUTES.to_string(), }; let analytics_event = AnalyticsEvent::new( "test", AnalyticsEventType::UserAction, timestamp_parsed, Some(screen_route), ); let actual_analytics_event_adapter: AnalyticsEventAdapter = analytics_event.try_into().unwrap(); let expected_analytics_event_adapter = AnalyticsEventAdapter::new("test", 3, timestamp_str.to_string(), Some(relationa_field)); assert_eq!( actual_analytics_event_adapter, expected_analytics_event_adapter ); } } ================================================ FILE: rust/xaynet-analytics/src/database/analytics_event/mod.rs ================================================ pub mod adapter; pub mod data_model; pub mod repo; ================================================ FILE: rust/xaynet-analytics/src/database/analytics_event/repo.rs ================================================ //! Implementations of the methods needed to save and get `AnalyticsEvents` to/from Isar. use anyhow::{anyhow, Error, Result}; use std::convert::{Into, TryFrom}; use crate::database::{ analytics_event::{ adapter::{AnalyticsEventAdapter, AnalyticsEventRelationalAdapter}, data_model::AnalyticsEvent, }, common::{IsarAdapter, Repo}, isar::IsarDb, }; /// Inside `get()` and `get_all()` there is an intermediate conversion from `Adapter` to `RelationalAdapter`, /// and then to data model (`AnalyticsEvent`), which is different than other data models where /// they can be converted directly from Adapter to data model. impl<'db> Repo<'db, AnalyticsEvent> for AnalyticsEvent { fn save(self, db: &'db IsarDb, collection_name: &str) -> Result<(), Error> { let mut object_builder = db.get_object_builder(collection_name)?; let event_adapter: AnalyticsEventAdapter = self.into(); event_adapter.write_with_object_builder(&mut object_builder); db.put(collection_name, object_builder.finish().as_bytes()) } // TODO: return an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517 fn get_all(db: &'db IsarDb, collection_name: &str) -> Result, Error> { let isar_properties = db.get_collection_properties(collection_name)?; db.get_all_isar_objects(collection_name)? .into_iter() .map(|(_, isar_object)| AnalyticsEventAdapter::read(&isar_object, isar_properties)) .map(|adapter| AnalyticsEventRelationalAdapter::new(adapter?, &db)) .map(|relational_adapter| AnalyticsEvent::try_from(relational_adapter?)) .collect() } fn get(oid: &str, db: &'db IsarDb, collection_name: &str) -> Result { let isar_properties = db.get_collection_properties(collection_name)?; let object_id = db.get_object_id_from_str(collection_name, oid)?; let mut transaction = db.get_read_transaction()?; let isar_object = db.get_isar_object_by_id(&object_id, collection_name, &mut transaction)?; if let Some(isar_object) = isar_object { let adapter = AnalyticsEventAdapter::read(&isar_object, isar_properties)?; let relational_adapter = AnalyticsEventRelationalAdapter::new(adapter, &db)?; AnalyticsEvent::try_from(relational_adapter) } else { Err(anyhow!("unable to get {:?} object", object_id)) } } } ================================================ FILE: rust/xaynet-analytics/src/database/common.rs ================================================ //! This file contains traits and structs that are common to other components involved with the database. //! It could be split up in smaller files, especially if more traits receive a default implementation. //! See: https://xainag.atlassian.net/browse/XN-1692 use anyhow::{anyhow, Error, Result}; use isar_core::{ index::IndexType, object::{ data_type::DataType, isar_object::{IsarObject, Property}, object_builder::ObjectBuilder, }, schema::collection_schema::{ CollectionSchema, IndexPropertySchema, IndexSchema, PropertySchema, }, }; use std::{convert::TryFrom, vec::IntoIter}; use crate::database::isar::IsarDb; /// `IsarAdapter` trait needs to be implemented for each data model adapters. /// This is needed to be able to tell Isar how to write/read objects to/from a collection. /// /// The implementations of these methods could actually be automated by a macro, since they are always the same. /// See: https://xainag.atlassian.net/browse/XN-1689 pub trait IsarAdapter<'object>: Sized { fn get_oid(&self) -> String; fn into_field_properties() -> IntoIter; fn write_with_object_builder(&self, object_builder: &mut ObjectBuilder); fn read( isar_object: &'object IsarObject, isar_properties: &'object [(String, Property)], ) -> Result; fn find_property_by_name( name: &str, isar_properties: &[(String, Property)], ) -> Result { isar_properties .iter() .find(|(isar_property_name, _)| isar_property_name == name) .map(|(_, property)| *property) .ok_or_else(|| anyhow!("failed to retrieve property {:?}", name)) } } /// This trait is implemented directly for each data model to have a high level API for `AnalyticsController` to /// save/get objects from the db. /// /// Consider using default implementations here, to reduce boiler plate code in repo.rs files. /// See: https://xainag.atlassian.net/browse/XN-1688 pub trait Repo<'db, M> where M: Sized, { fn save(self, db: &'db IsarDb, collection_name: &str) -> Result<(), Error>; fn get_all(db: &'db IsarDb, collection_name: &str) -> Result, Error>; fn get(object_id: &str, db: &'db IsarDb, collection_name: &str) -> Result; } /// `FieldProperty` is a simple struct that holds data used to register properties and indexes for Isar schemas. pub struct FieldProperty { pub name: String, pub data_type: DataType, pub is_oid: bool, pub index_type: IndexType, pub is_case_sensitive: bool, pub is_unique: bool, } impl FieldProperty { pub fn new>(name: N, data_type: DataType, is_oid: bool) -> Self { Self { name: name.into(), data_type, is_oid, index_type: IndexType::Value, is_case_sensitive: data_type == DataType::String, is_unique: true, } } } /// `SchemaGenerator` is needed to register the `PropertySchema` and `IndexSchema` for each `FieldProperty`. /// `PropertySchema` and `IndexSchema` are imported from Isar, while `FieldProperty` is an internal struct to /// make it convenient to iterate through each property (see the fold below). /// /// When `Ok` it returns a `CollectionSchema` that is needed by Isar to manage a collection. pub trait SchemaGenerator<'object, A> where A: IsarAdapter<'object>, { fn get_schema(name: &str) -> Result { let (properties, indexes) = A::into_field_properties().fold( (Vec::new(), Vec::new()), |(mut properties, mut indexes), prop| { let property_schema = PropertySchema::new(&prop.name, prop.data_type, prop.is_oid); let is_index_case_sensitive = Some(true).filter(|_| prop.data_type == DataType::String); let index_property_schema = vec![IndexPropertySchema::new( &prop.name, prop.index_type, is_index_case_sensitive, )]; let index_schema = IndexSchema::new(index_property_schema, prop.is_unique); properties.push(property_schema); indexes.push(index_schema); (properties, indexes) }, ); Ok(CollectionSchema::new(name, properties, indexes)) } } /// `RelationalField` is the struct that allows to save data model instances inside other data models. /// /// ## Arguments /// * `value` - is a `String` representing an id with which the data model can be identified /// * `collection_name` - is the name of the collection where the object is saved #[derive(Debug, PartialEq, Eq, Clone)] pub struct RelationalField { pub value: String, pub collection_name: String, } // NOTE: when split_once gets to stable, it would be a much better solution for this // https://doc.rust-lang.org/std/string/struct.String.html#method.split_once impl TryFrom<&str> for RelationalField { type Error = anyhow::Error; fn try_from(data: &str) -> Result { let data_split: Vec<&str> = data.split('=').collect(); if data_split.len() != 2 { return Err(anyhow!( "data {:?} is not a str made of two elements separated by '='", data )); } Ok(Self { value: data_split[0].to_string(), collection_name: data_split[1].to_string(), }) } } impl From for String { fn from(rf: RelationalField) -> String { [rf.value, rf.collection_name].join("=") } } /// Stores the name of each collection. Whenever you need to make an operation on an `IsarCollection`, /// these `str`s are needed. pub struct CollectionNames; impl CollectionNames { pub const ANALYTICS_EVENTS: &'static str = "analytics_events"; pub const CONTROLLER_DATA: &'static str = "controller_data"; pub const SCREEN_ROUTES: &'static str = "screen_routes"; } ================================================ FILE: rust/xaynet-analytics/src/database/controller_data/adapter.rs ================================================ //! This file contains struct and impl for `ControllerDataAdapter` the implementation of `IsarAdapter` //! for `ControllerDataAdapter`. use anyhow::{anyhow, Error, Result}; use isar_core::object::{ data_type::DataType, isar_object::{IsarObject, Property}, object_builder::ObjectBuilder, }; use std::vec::IntoIter; use crate::database::common::{FieldProperty, IsarAdapter, SchemaGenerator}; /// Allows to convert an IsarObject from the db to a `ControllerData`. #[derive(Debug, PartialEq, Eq, Clone)] pub struct ControllerDataAdapter { pub time_data_sent: String, } impl ControllerDataAdapter { pub fn new>(time_data_sent: T) -> Self { Self { time_data_sent: time_data_sent.into(), } } } impl<'ctrl> IsarAdapter<'ctrl> for ControllerDataAdapter { fn get_oid(&self) -> String { self.time_data_sent.clone() } fn into_field_properties() -> IntoIter { vec![ FieldProperty::new("oid", DataType::String, true), FieldProperty::new("time_data_sent", DataType::String, false), ] .into_iter() } fn write_with_object_builder(&self, object_builder: &mut ObjectBuilder) { object_builder.write_string(Some(&self.get_oid())); object_builder.write_string(Some(&self.time_data_sent)); } fn read( isar_object: &'ctrl IsarObject, isar_properties: &'ctrl [(String, Property)], ) -> Result { let time_data_sent_property = Self::find_property_by_name("time_data_sent", isar_properties)?; let time_data_sent_data = isar_object .read_string(time_data_sent_property) .ok_or_else(|| anyhow!("unable to read time_data_sent"))?; Ok(ControllerDataAdapter::new(time_data_sent_data)) } } impl<'ctrl> SchemaGenerator<'ctrl, ControllerDataAdapter> for ControllerDataAdapter {} ================================================ FILE: rust/xaynet-analytics/src/database/controller_data/data_model.rs ================================================ //! In this file `ControllerData` is declared, together with some conversion methods to/from adapters. use anyhow::Result; use chrono::{DateTime, Utc}; use std::convert::TryFrom; use crate::database::controller_data::adapter::ControllerDataAdapter; /// Holds some metadata useful for the `AnalyticsController`. For now it only contains `time_data_sent`, /// which is the time when analytics data was last sent to the coordinator for aggregation. #[derive(Debug, PartialEq, Eq, Clone)] pub struct ControllerData { pub time_data_sent: DateTime, } impl ControllerData { pub fn new(time_data_sent: DateTime) -> Self { Self { time_data_sent } } } impl TryFrom for ControllerData { type Error = anyhow::Error; fn try_from(adapter: ControllerDataAdapter) -> Result { Ok(ControllerData::new( DateTime::parse_from_rfc3339(&adapter.time_data_sent)?.with_timezone(&Utc), )) } } impl From for ControllerDataAdapter { fn from(cd: ControllerData) -> Self { ControllerDataAdapter::new(cd.time_data_sent.to_rfc3339()) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_controller_data_try_from_adapter() { let timestamp_str = "2021-01-01T01:01:00+00:00"; let timestamp_parsed = DateTime::parse_from_rfc3339(timestamp_str) .unwrap() .with_timezone(&Utc); let controller_data = ControllerData::new(timestamp_parsed); let adapter = ControllerDataAdapter::new(timestamp_str); assert_eq!(ControllerData::try_from(adapter).unwrap(), controller_data); } #[test] fn test_adapter_into_controller_data() { let timestamp_str = "2021-01-01T01:01:00+00:00"; let timestamp_parsed = DateTime::parse_from_rfc3339(timestamp_str) .unwrap() .with_timezone(&Utc); let controller_data = ControllerData::new(timestamp_parsed); let actual_adapter: ControllerDataAdapter = controller_data.into(); let expected_adapter = ControllerDataAdapter::new(timestamp_str); assert_eq!(actual_adapter, expected_adapter); } } ================================================ FILE: rust/xaynet-analytics/src/database/controller_data/mod.rs ================================================ pub mod adapter; pub mod data_model; pub mod repo; ================================================ FILE: rust/xaynet-analytics/src/database/controller_data/repo.rs ================================================ //! Implementations of the methods needed to save and get `ControllerData` to/from Isar. use anyhow::{anyhow, Error, Result}; use std::convert::{Into, TryFrom}; use crate::database::{ common::{IsarAdapter, Repo}, controller_data::{adapter::ControllerDataAdapter, data_model::ControllerData}, isar::IsarDb, }; impl<'db> Repo<'db, ControllerData> for ControllerData { fn save(self, db: &'db IsarDb, collection_name: &str) -> Result<(), Error> { let mut object_builder = db.get_object_builder(collection_name)?; let data_adapter: ControllerDataAdapter = self.into(); data_adapter.write_with_object_builder(&mut object_builder); db.put(collection_name, object_builder.finish().as_bytes()) } // TODO: return an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517 fn get_all(db: &'db IsarDb, collection_name: &str) -> Result, Error> { let isar_properties = db.get_collection_properties(collection_name)?; db.get_all_isar_objects(collection_name)? .into_iter() .map(|(_, isar_object)| ControllerDataAdapter::read(&isar_object, isar_properties)) .map(|data_adapter| ControllerData::try_from(data_adapter?)) .collect() } fn get(oid: &str, db: &'db IsarDb, collection_name: &str) -> Result { let isar_properties = db.get_collection_properties(collection_name)?; let object_id = db.get_object_id_from_str(collection_name, oid)?; let mut transaction = db.get_read_transaction()?; let isar_object = db.get_isar_object_by_id(&object_id, collection_name, &mut transaction)?; if let Some(isar_object) = isar_object { let data_adapter = ControllerDataAdapter::read(&isar_object, isar_properties)?; ControllerData::try_from(data_adapter) } else { Err(anyhow!("unable to get {:?} object", object_id)) } } } ================================================ FILE: rust/xaynet-analytics/src/database/isar.rs ================================================ //! `IsarDb` is an internal abstraction on top of Isar that wraps `IsarInstance`, the main singleton from Isar. use anyhow::{anyhow, Error, Result}; use isar_core::{ collection::IsarCollection, instance::IsarInstance, object::{ isar_object::{IsarObject, Property}, object_builder::ObjectBuilder, object_id::ObjectId, }, schema::{collection_schema::CollectionSchema, Schema}, txn::IsarTxn, }; use std::sync::Arc; /// `IsarDb` is the internal singleton wrapping the `IsarInstance`, which is the singleton coming from Isar. /// `IsarDb` exposes public methods for the `AnalyticsController` to save/get models via the `Repo` impls and the adapters. pub struct IsarDb { instance: Arc, } impl IsarDb { const MAX_SIZE: usize = 10000000; /// `IsarInstance` is the singleton from Isar that coordinates the whole database. /// /// `Vec` is required by Isar to register each data model `IsarCollection`. /// A `IsarCollection` organises data for a single data model (eg: `AnalyticsEvents`). pub fn new(path: &str, collection_schemas: Vec) -> Result { IsarInstance::open( path, IsarDb::MAX_SIZE, IsarDb::get_schema(collection_schemas)?, ) .map_err(|error| anyhow!("failed to create IsarInstance: {:?}", error)) .map(|instance| IsarDb { instance }) } pub fn get_all_isar_objects( &self, collection_name: &str, ) -> Result, Error> { self.get_collection(collection_name)? .new_query_builder() .build() .find_all_vec(&mut self.begin_txn(false)?) .map_err(|error| { anyhow!( "failed to find all objects from collection {}: {:?}", collection_name, error, ) }) } /// Transactions are needed to write and read from Isar. /// This method is public because it's called inside `Repo::read()`, before passing it to `get_isar_object_by_id()`, /// so that the transaction is in scope when called, and the lifetimes are valid. pub fn get_read_transaction(&self) -> Result { self.begin_txn(false) } pub fn get_isar_object_by_id<'txn>( &self, object_id: &ObjectId, collection_name: &str, transaction: &'txn mut IsarTxn, ) -> Result>, Error> { self.get_collection(collection_name)? .get(transaction, object_id) .map_err(|error| anyhow!("unable to get {:?} object ({:?})", object_id, error)) } pub fn put(&self, collection_name: &str, object: &[u8]) -> Result<(), Error> { let mut transaction = self.begin_txn(true)?; self.get_collection(collection_name)? .put(&mut transaction, IsarObject::new(object)) .and_then(|_| transaction.commit()) .map_err(|error| { anyhow!( "failed to add object {:?} to collection: {} | {:?}", object, collection_name, error, ) }) } pub fn get_object_builder(&self, collection_name: &str) -> Result { Ok(self .get_collection(collection_name)? .new_object_builder(None)) } /// When `Ok`, this method returns a valid `ObjectId` that can be used to retrieve a single object from a collection. pub fn get_object_id_from_str( &self, collection_name: &str, oid: &str, ) -> Result { self.get_collection(collection_name)? .new_string_oid(oid) .map_err(|error| anyhow!("could not get the object id from {:?}: {:?}", oid, error)) } /// Returns the properties from a collection that were registered via the `CollectionSchema`, and are needed to /// read/write objects to/from the collection. pub fn get_collection_properties( &self, collection_name: &str, ) -> Result<&[(String, Property)], Error> { Ok(self.get_collection(collection_name)?.get_properties()) } pub fn dispose(self) -> Result<(), Error> { match self.instance.close() { Some(_) => Err(anyhow!("could not close the IsarInstance")), None => Ok(()), } } /// The `Schema` is needed to open the `IsarInstance` and is automatically produced by Isar /// based on the `Vec` provided when calling `IsarDb::new()`. fn get_schema(collection_schemas: Vec) -> Result { Schema::new(collection_schemas).map_err(|error| { anyhow!( "failed to add collection schemas to instance schema: {:?}", error ) }) } fn get_collection(&self, collection_name: &str) -> Result<&IsarCollection, Error> { self.instance .get_collection_by_name(collection_name) .ok_or_else(|| anyhow!("wrong collection name: {}", collection_name)) } /// Transactions are needed to read/write objects from Isar. Write transactions should stay private. fn begin_txn(&self, is_write: bool) -> Result { self.instance .begin_txn(is_write) .map_err(|error| anyhow!("failed to begin transaction: {:?}", error)) } } ================================================ FILE: rust/xaynet-analytics/src/database/mod.rs ================================================ pub mod analytics_event; pub mod common; pub mod controller_data; pub mod isar; pub mod screen_route; ================================================ FILE: rust/xaynet-analytics/src/database/screen_route/adapter.rs ================================================ //! This file contains struct and impl for `ScreenRouteAdapter` the implementation of `IsarAdapter` //! for `ScreenRouteAdapter`. use anyhow::{anyhow, Error, Result}; use isar_core::object::{ data_type::DataType, isar_object::{IsarObject, Property}, object_builder::ObjectBuilder, }; use std::vec::IntoIter; use crate::database::common::{FieldProperty, IsarAdapter, SchemaGenerator}; /// Allows to convert an `IsarObject` from the db to a `ScreenRoute`. #[derive(Debug, PartialEq, Eq, Clone)] pub struct ScreenRouteAdapter { pub name: String, pub created_at: String, } impl ScreenRouteAdapter { pub fn new>(name: S, created_at: S) -> Self { Self { name: name.into(), created_at: created_at.into(), } } } impl<'screen> IsarAdapter<'screen> for ScreenRouteAdapter { fn get_oid(&self) -> String { self.name.clone() } fn into_field_properties() -> IntoIter { vec![ FieldProperty::new("oid", DataType::String, true), FieldProperty::new("name", DataType::String, false), FieldProperty::new("created_at", DataType::String, false), ] .into_iter() } fn write_with_object_builder(&self, object_builder: &mut ObjectBuilder) { object_builder.write_string(Some(&self.get_oid())); object_builder.write_string(Some(&self.name)); object_builder.write_string(Some(&self.created_at)); } fn read( isar_object: &'screen IsarObject, isar_properties: &'screen [(String, Property)], ) -> Result { let name_property = Self::find_property_by_name("name", isar_properties)?; let created_at_property = Self::find_property_by_name("created_at", isar_properties)?; let name_data = isar_object .read_string(name_property) .ok_or_else(|| anyhow!("unable to read name"))?; let created_at_data = isar_object .read_string(created_at_property) .ok_or_else(|| anyhow!("unable to read created_at"))?; Ok(ScreenRouteAdapter::new( name_data.to_string(), created_at_data.to_string(), )) } } impl<'screen> SchemaGenerator<'screen, ScreenRouteAdapter> for ScreenRouteAdapter {} ================================================ FILE: rust/xaynet-analytics/src/database/screen_route/data_model.rs ================================================ //! In this file `ScreenRoute` is declared, together with some conversion methods to/from adapters. use anyhow::Result; use chrono::{DateTime, Utc}; use std::convert::{Into, TryFrom}; use crate::database::{ common::{CollectionNames, RelationalField}, screen_route::adapter::ScreenRouteAdapter, }; /// A `ScreenRoute` is the internal representation of a screen in the app. #[derive(Debug, PartialEq, Eq, Clone)] pub struct ScreenRoute { pub name: String, pub created_at: DateTime, } impl ScreenRoute { pub fn new>(name: N, created_at: DateTime) -> Self { Self { name: name.into(), created_at, } } } impl TryFrom for ScreenRoute { type Error = anyhow::Error; fn try_from(adapter: ScreenRouteAdapter) -> Result { Ok(ScreenRoute::new( adapter.name, DateTime::parse_from_rfc3339(&adapter.created_at)?.with_timezone(&Utc), )) } } impl From for ScreenRouteAdapter { fn from(sr: ScreenRoute) -> Self { ScreenRouteAdapter::new(sr.name, sr.created_at.to_rfc3339()) } } impl From for RelationalField { fn from(screen_route: ScreenRoute) -> Self { Self { value: screen_route.name, collection_name: CollectionNames::SCREEN_ROUTES.to_string(), } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_screen_route_try_from_adapter() { let timestamp_str = "2021-01-01T01:01:00+00:00"; let timestamp_parsed = DateTime::parse_from_rfc3339(timestamp_str) .unwrap() .with_timezone(&Utc); let screen_route = ScreenRoute::new("route", timestamp_parsed); let adapter = ScreenRouteAdapter::new("route", timestamp_str); assert_eq!(ScreenRoute::try_from(adapter).unwrap(), screen_route); } #[test] fn test_adapter_into_screen_route() { let timestamp_str = "2021-01-01T01:01:00+00:00"; let timestamp_parsed = DateTime::parse_from_rfc3339(timestamp_str) .unwrap() .with_timezone(&Utc); let screen_route = ScreenRoute::new("route", timestamp_parsed); let actual_adapter: ScreenRouteAdapter = screen_route.into(); let expected_adapter = ScreenRouteAdapter::new("route", timestamp_str); assert_eq!(actual_adapter, expected_adapter); } #[test] fn test_screen_route_from_relational_field() { let timestamp_str = "2021-01-01T01:01:00+00:00"; let timestamp_parsed = DateTime::parse_from_rfc3339(timestamp_str) .unwrap() .with_timezone(&Utc); let screen_route = ScreenRoute::new("route", timestamp_parsed); let relational_field = RelationalField { value: "route".to_string(), collection_name: CollectionNames::SCREEN_ROUTES.to_string(), }; assert_eq!(RelationalField::from(screen_route), relational_field); } } ================================================ FILE: rust/xaynet-analytics/src/database/screen_route/mod.rs ================================================ pub mod adapter; pub mod data_model; pub mod repo; ================================================ FILE: rust/xaynet-analytics/src/database/screen_route/repo.rs ================================================ //! Implementations of the methods needed to save and get ScreenRoute to/from Isar. use anyhow::{anyhow, Error, Result}; use std::convert::{Into, TryFrom}; use crate::database::{ common::{IsarAdapter, Repo}, isar::IsarDb, screen_route::{adapter::ScreenRouteAdapter, data_model::ScreenRoute}, }; impl<'db> Repo<'db, ScreenRoute> for ScreenRoute { fn save(self, db: &'db IsarDb, collection_name: &str) -> Result<(), Error> { let mut object_builder = db.get_object_builder(collection_name)?; let route_adapter: ScreenRouteAdapter = self.into(); route_adapter.write_with_object_builder(&mut object_builder); db.put(collection_name, object_builder.finish().as_bytes()) } // TODO: return an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517 fn get_all(db: &'db IsarDb, collection_name: &str) -> Result, Error> { let isar_properties = db.get_collection_properties(collection_name)?; db.get_all_isar_objects(collection_name)? .into_iter() .map(|(_, isar_object)| ScreenRouteAdapter::read(&isar_object, isar_properties)) .map(|screen_route_adapter| ScreenRoute::try_from(screen_route_adapter?)) .collect() } fn get(oid: &str, db: &'db IsarDb, collection_name: &str) -> Result { let isar_properties = db.get_collection_properties(collection_name)?; let object_id = db.get_object_id_from_str(collection_name, oid)?; let mut transaction = db.get_read_transaction()?; let isar_object = db.get_isar_object_by_id(&object_id, collection_name, &mut transaction)?; if let Some(isar_object) = isar_object { let screen_route_adapter = ScreenRouteAdapter::read(&isar_object, isar_properties)?; ScreenRoute::try_from(screen_route_adapter) } else { Err(anyhow!("unable to get {:?} object", object_id)) } } } ================================================ FILE: rust/xaynet-analytics/src/lib.rs ================================================ #![cfg_attr(doc, forbid(broken_intra_doc_links, private_intra_doc_links))] #![doc( html_logo_url = "https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/xaynet_banner.png", html_favicon_url = "https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/favicon.png", issue_tracker_base_url = "https://github.com/xaynetwork/xaynet/issues" )] //! This crate containes the Rust component of Federated Analytics, //! a framework that allows mobile applications to collect and aggregate //! analytics data via the _Privacy-Enhancing Technology_ (PET) protocol. #[cfg(not(tarpaulin))] pub mod controller; #[cfg(not(tarpaulin))] pub mod data_combination; #[cfg(not(tarpaulin))] pub mod database; #[cfg(not(tarpaulin))] pub mod sender; ================================================ FILE: rust/xaynet-analytics/src/sender.rs ================================================ //! In this file `Sender` is just stubbed and will need to be implemented. use anyhow::{Error, Result}; use crate::data_combination::data_points::data_point::DataPoint; /// `Sender` receives a `Vec` from the `DataCombiner`. /// /// It will need to call the exposed `calculate()` method on each `DataPoint` variant and compose the messages /// that will then need to reach the XayNet coordinator. /// /// These messages should contain not only the actual data that is the output of calling `calculate()` on the variant, /// but also some extra data so that the coordinator knows how to aggregate each `DataPoint` variant. /// This is in line with the research done on the “global spec” idea. pub struct Sender; impl Sender { pub fn send(&self, _data_points: Vec) -> Result<(), Error> { // TODO: https://xainag.atlassian.net/browse/XN-1647 todo!() } } ================================================ FILE: rust/xaynet-core/Cargo.toml ================================================ [package] name = "xaynet-core" version = "0.2.0" authors = ["Xayn Engineering "] edition = "2018" description = "The Xayn Network project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant." readme = "../../README.md" homepage = "https://xaynet.dev/" repository = "https://github.com/xaynetwork/xaynet/" license-file = "../../LICENSE" keywords = ["federated-learning", "fl", "ai", "machine-learning"] categories = ["science", "cryptography"] [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] [dependencies] anyhow = "1.0.62" bitflags = "1.3.2" derive_more = { version = "0.99.17", default-features = false, features = [ "as_ref", "as_mut", "display", "from", "index", "index_mut", "into", ] } num = { version = "0.4.0", features = ["serde"] } rand = "0.8.5" rand_chacha = "0.3.1" serde = { version = "1.0.144", features = ["derive"] } sodiumoxide = "0.2.7" thiserror = "1.0.32" [features] testutils = [] [dev-dependencies] paste = "1.0.8" ================================================ FILE: rust/xaynet-core/src/common.rs ================================================ use serde::{Deserialize, Serialize}; use sodiumoxide::{self, crypto::box_}; use crate::{crypto::ByteObject, mask::MaskConfigPair, CoordinatorPublicKey}; /// The round parameters. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct RoundParameters { /// The public key of the coordinator used for encryption. pub pk: CoordinatorPublicKey, /// Fraction of participants to be selected for the sum task. pub sum: f64, /// Fraction of participants to be selected for the update task. pub update: f64, /// The random round seed. pub seed: RoundSeed, /// The masking configuration pub mask_config: MaskConfigPair, /// The length of the model. pub model_length: usize, } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] /// A seed for a round. pub struct RoundSeed(box_::Seed); impl ByteObject for RoundSeed { const LENGTH: usize = box_::SEEDBYTES; /// Creates a round seed from a slice of bytes. /// /// # Errors /// Fails if the length of the input is invalid. fn from_slice(bytes: &[u8]) -> Option { box_::Seed::from_slice(bytes).map(Self) } /// Creates a round seed initialized to zero. fn zeroed() -> Self { Self(box_::Seed([0_u8; Self::LENGTH])) } /// Gets the round seed as a slice. fn as_slice(&self) -> &[u8] { self.0.as_ref() } } ================================================ FILE: rust/xaynet-core/src/crypto/encrypt.rs ================================================ //! Wrappers around some of the [sodiumoxide] encryption primitives. //! //! See the [crypto module] documentation since this is a private module anyways. //! //! [sodiumoxide]: https://docs.rs/sodiumoxide/ //! [crypto module]: crate::crypto use derive_more::{AsMut, AsRef, From}; use serde::{Deserialize, Serialize}; use sodiumoxide::crypto::{box_, sealedbox}; use super::ByteObject; /// Number of additional bytes in a ciphertext compared to the corresponding plaintext. pub const SEALBYTES: usize = sealedbox::SEALBYTES; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] /// A `C25519` key pair for asymmetric authenticated encryption. pub struct EncryptKeyPair { /// The `C25519` public key. pub public: PublicEncryptKey, /// The `C25519` secret key. pub secret: SecretEncryptKey, } impl EncryptKeyPair { /// Generates a new random `C25519` key pair for encryption. pub fn generate() -> Self { let (pk, sk) = box_::gen_keypair(); Self { public: PublicEncryptKey(pk), secret: SecretEncryptKey(sk), } } /// Deterministically derives a new `C25519` key pair for encryption from a seed. pub fn derive_from_seed(seed: &EncryptKeySeed) -> Self { let (pk, sk) = seed.derive_encrypt_key_pair(); Self { public: pk, secret: sk, } } } #[derive( AsRef, AsMut, From, Serialize, Deserialize, Hash, Eq, Ord, PartialEq, Copy, Clone, PartialOrd, Debug, )] /// A `C25519` public key for asymmetric authenticated encryption. pub struct PublicEncryptKey(box_::PublicKey); impl ByteObject for PublicEncryptKey { const LENGTH: usize = box_::PUBLICKEYBYTES; fn zeroed() -> Self { Self(box_::PublicKey([0_u8; box_::PUBLICKEYBYTES])) } fn as_slice(&self) -> &[u8] { self.0.as_ref() } fn from_slice(bytes: &[u8]) -> Option { box_::PublicKey::from_slice(bytes).map(Self) } } impl PublicEncryptKey { /// Encrypts a message `m` with this public key. /// /// The resulting ciphertext length is [`SEALBYTES`]` + m.len()`. /// /// The function creates a new ephemeral key pair for the message and attaches the ephemeral /// public key to the ciphertext. The ephemeral secret key is zeroed out and is not accessible /// after this function returns. pub fn encrypt(&self, m: &[u8]) -> Vec { sealedbox::seal(m, self.as_ref()) } } #[derive(thiserror::Error, Debug)] #[error("decryption of a message failed")] /// An error related to the decryption of a message. pub struct DecryptionError; #[derive(AsRef, AsMut, From, Serialize, Deserialize, Eq, PartialEq, Clone, Debug)] /// A `C25519` secret key for asymmetric authenticated encryption. /// /// When this goes out of scope, its contents will be zeroed out. pub struct SecretEncryptKey(box_::SecretKey); impl SecretEncryptKey { /// Decrypts the ciphertext `c` using this secret key and the associated public key, and returns /// the decrypted message. /// /// # Errors /// Returns `Err(DecryptionError)` if decryption fails. pub fn decrypt(&self, c: &[u8], pk: &PublicEncryptKey) -> Result, DecryptionError> { sealedbox::open(c, pk.as_ref(), self.as_ref()).map_err(|_| DecryptionError) } /// Computes the corresponding public key for this secret key. pub fn public_key(&self) -> PublicEncryptKey { PublicEncryptKey(self.0.public_key()) } } impl ByteObject for SecretEncryptKey { const LENGTH: usize = box_::SECRETKEYBYTES; fn zeroed() -> Self { Self(box_::SecretKey([0_u8; box_::SECRETKEYBYTES])) } fn as_slice(&self) -> &[u8] { self.0.as_ref() } fn from_slice(bytes: &[u8]) -> Option { box_::SecretKey::from_slice(bytes).map(Self) } } #[derive(AsRef, AsMut, From, Serialize, Deserialize, Eq, PartialEq, Clone)] /// A seed that can be used for `C25519` encryption key pair generation. /// /// When this goes out of scope, its contents will be zeroed out. pub struct EncryptKeySeed(box_::Seed); impl EncryptKeySeed { /// Deterministically derives a new key pair from this seed. pub fn derive_encrypt_key_pair(&self) -> (PublicEncryptKey, SecretEncryptKey) { let (pk, sk) = box_::keypair_from_seed(self.as_ref()); (PublicEncryptKey(pk), SecretEncryptKey(sk)) } } impl ByteObject for EncryptKeySeed { const LENGTH: usize = box_::SEEDBYTES; fn from_slice(bytes: &[u8]) -> Option { box_::Seed::from_slice(bytes).map(Self) } fn zeroed() -> Self { Self(box_::Seed([0; box_::SEEDBYTES])) } fn as_slice(&self) -> &[u8] { self.0.as_ref() } } ================================================ FILE: rust/xaynet-core/src/crypto/hash.rs ================================================ //! Wrappers around some of the [sodiumoxide] hashing primitives. //! //! See the [crypto module] documentation since this is a private module anyways. //! //! [sodiumoxide]: https://docs.rs/sodiumoxide/ //! [crypto module]: crate::crypto use derive_more::{AsMut, AsRef, From}; use serde::{Deserialize, Serialize}; use sodiumoxide::crypto::hash::sha256; use super::ByteObject; #[derive( AsRef, AsMut, From, Serialize, Deserialize, Hash, Eq, Ord, PartialEq, Copy, Clone, PartialOrd, Debug, )] /// A digest of the `SHA256` hash function. pub struct Sha256(sha256::Digest); impl ByteObject for Sha256 { const LENGTH: usize = sha256::DIGESTBYTES; fn zeroed() -> Self { Self(sha256::Digest([0_u8; sha256::DIGESTBYTES])) } fn as_slice(&self) -> &[u8] { self.0.as_ref() } fn from_slice(bytes: &[u8]) -> Option { sha256::Digest::from_slice(bytes).map(Self) } } impl Sha256 { /// Computes the digest of the message `m`. pub fn hash(m: &[u8]) -> Self { Self(sha256::hash(m)) } } ================================================ FILE: rust/xaynet-core/src/crypto/mod.rs ================================================ //! Wrappers around some of the [sodiumoxide] crypto primitives. //! //! The wrappers provide methods defined on structs instead of the sodiumoxide functions. This is //! done for the `C25519` encryption and `Ed25519` signature key pairs and their corresponding seeds //! as well as the `SHA256` hash function. Additionally, some methods for slicing and signature //! eligibility are available. //! //! # Examples //! ## Encryption of messages //! ``` //! # use xaynet_core::crypto::EncryptKeyPair; //! let keys = EncryptKeyPair::generate(); //! let message = b"Hello world!".to_vec(); //! let cipher = keys.public.encrypt(&message); //! assert_eq!(message, keys.secret.decrypt(&cipher, &keys.public).unwrap()); //! ``` //! //! ## Signing of messages //! ``` //! # use xaynet_core::crypto::SigningKeyPair; //! let keys = SigningKeyPair::generate(); //! let message = b"Hello world!".to_vec(); //! let signature = keys.secret.sign_detached(&message); //! assert!(keys.public.verify_detached(&signature, &message)); //! ``` //! //! [sodiumoxide]: https://docs.rs/sodiumoxide/ pub(crate) mod encrypt; pub(crate) mod hash; pub(crate) mod prng; pub(crate) mod sign; use sodiumoxide::randombytes::randombytes; pub use self::{ encrypt::{EncryptKeyPair, EncryptKeySeed, PublicEncryptKey, SecretEncryptKey, SEALBYTES}, hash::Sha256, prng::generate_integer, sign::{PublicSigningKey, SecretSigningKey, Signature, SigningKeyPair, SigningKeySeed}, }; /// An interface for slicing into cryptographic byte objects. pub trait ByteObject: Sized { /// Length in bytes of this object const LENGTH: usize; /// Creates a new object with all the bytes initialized to `0`. fn zeroed() -> Self; /// Gets the object byte representation. fn as_slice(&self) -> &[u8]; /// Creates an object from the given buffer. /// /// # Errors /// Returns `None` if the length of the byte-slice isn't equal to the length of the object. fn from_slice(bytes: &[u8]) -> Option; /// Creates an object from the given buffer. /// /// # Panics /// Panics if the length of the byte-slice isn't equal to the length of the object. fn from_slice_unchecked(bytes: &[u8]) -> Self { Self::from_slice(bytes).unwrap() } /// Generates an object with random bytes fn generate() -> Self { // safe unwrap: length of slice is guaranteed by constants Self::from_slice_unchecked(randombytes(Self::LENGTH).as_slice()) } /// A helper for instantiating an object filled with the given value fn fill_with(value: u8) -> Self { Self::from_slice_unchecked(&vec![value; Self::LENGTH]) } } ================================================ FILE: rust/xaynet-core/src/crypto/prng.rs ================================================ //! PRNG utilities for the crypto primitives. //! //! See the [crypto module] documentation since this is a private module anyways. //! //! [sodiumoxide]: https://docs.rs/sodiumoxide/ //! [crypto module]: crate::crypto use num::{bigint::BigUint, traits::identities::Zero}; use rand::RngCore; use rand_chacha::ChaCha20Rng; /// Generates a secure pseudo-random integer. /// /// Draws from a uniform distribution over the integers between zero (included) and /// `max_int` (excluded). Employs the `ChaCha20` stream cipher as a PRNG. pub fn generate_integer(prng: &mut ChaCha20Rng, max_int: &BigUint) -> BigUint { if max_int.is_zero() { return BigUint::zero(); } let mut bytes = max_int.to_bytes_le(); let mut rand_int = max_int.clone(); while &rand_int >= max_int { prng.fill_bytes(&mut bytes); rand_int = BigUint::from_bytes_le(&bytes); } rand_int } #[cfg(test)] mod tests { use num::traits::{pow::Pow, Num}; use rand::SeedableRng; use super::*; #[test] fn test_generate_integer() { let mut prng = ChaCha20Rng::from_seed([0_u8; 32]); let max_int = BigUint::from(u128::max_value()).pow(2_usize); assert_eq!( generate_integer(&mut prng, &max_int), BigUint::from_str_radix( "90034050956742099321159087842304570510687605373623064829879336909608119744630", 10 ) .unwrap() ); assert_eq!( generate_integer(&mut prng, &max_int), BigUint::from_str_radix( "60790020689334235010238064028215988394112077193561636249125918224917556969946", 10 ) .unwrap() ); assert_eq!( generate_integer(&mut prng, &max_int), BigUint::from_str_radix( "107415344426328791036720294006773438815099086866510488084511304829720271980447", 10 ) .unwrap() ); assert_eq!( generate_integer(&mut prng, &max_int), BigUint::from_str_radix( "50343610553303623842889112417183549658912134525854625844144939347139411162921", 10 ) .unwrap() ); assert_eq!( generate_integer(&mut prng, &max_int), BigUint::from_str_radix( "42382469383990928111449714288937630103705168010724718767641573929365517895981", 10 ) .unwrap() ); } } ================================================ FILE: rust/xaynet-core/src/crypto/sign.rs ================================================ //! Wrappers around some of the [sodiumoxide] signing primitives. //! //! See the [crypto module] documentation since this is a private module anyways. //! //! [sodiumoxide]: https://docs.rs/sodiumoxide/ //! [crypto module]: crate::crypto use std::convert::TryInto; use derive_more::{AsMut, AsRef, From}; use num::{ bigint::{BigUint, ToBigInt}, rational::Ratio, }; use serde::{Deserialize, Serialize}; use sodiumoxide::crypto::{hash::sha256, sign}; use super::ByteObject; #[derive(Debug, Clone, Serialize, Deserialize)] /// A `Ed25519` key pair for signatures. pub struct SigningKeyPair { /// The `Ed25519` public key. pub public: PublicSigningKey, /// The `Ed25519` secret key. pub secret: SecretSigningKey, } impl SigningKeyPair { /// Generates a new random `Ed25519` key pair for signing. pub fn generate() -> Self { let (pk, sk) = sign::gen_keypair(); Self { public: PublicSigningKey(pk), secret: SecretSigningKey(sk), } } pub fn derive_from_seed(seed: &SigningKeySeed) -> Self { let (pk, sk) = seed.derive_signing_key_pair(); Self { public: pk, secret: sk, } } } #[derive( AsRef, AsMut, From, Serialize, Deserialize, Hash, Eq, Ord, PartialEq, Copy, Clone, PartialOrd, Debug, )] /// An `Ed25519` public key for signatures. pub struct PublicSigningKey(sign::PublicKey); impl PublicSigningKey { /// Verifies the signature `s` against the message `m` and this public key. /// /// Returns `true` if the signature is valid and `false` otherwise. pub fn verify_detached(&self, s: &Signature, m: &[u8]) -> bool { sign::verify_detached(s.as_ref(), m, self.as_ref()) } } impl ByteObject for PublicSigningKey { const LENGTH: usize = sign::PUBLICKEYBYTES; fn zeroed() -> Self { Self(sign::PublicKey([0_u8; sign::PUBLICKEYBYTES])) } fn as_slice(&self) -> &[u8] { self.0.as_ref() } fn from_slice(bytes: &[u8]) -> Option { sign::PublicKey::from_slice(bytes).map(Self) } } #[derive(AsRef, AsMut, From, Serialize, Deserialize, Eq, PartialEq, Clone, Debug)] /// An `Ed25519` secret key for signatures. /// /// When this goes out of scope, its contents will be zeroed out. pub struct SecretSigningKey(sign::SecretKey); impl SecretSigningKey { /// Signs a message `m` with this secret key. pub fn sign_detached(&self, m: &[u8]) -> Signature { sign::sign_detached(m, self.as_ref()).into() } /// Computes the corresponding public key for this secret key. pub fn public_key(&self) -> PublicSigningKey { PublicSigningKey(self.0.public_key()) } } impl ByteObject for SecretSigningKey { const LENGTH: usize = sign::SECRETKEYBYTES; fn zeroed() -> Self { Self(sign::SecretKey([0_u8; Self::LENGTH])) } fn as_slice(&self) -> &[u8] { self.0.as_ref() } fn from_slice(bytes: &[u8]) -> Option { sign::SecretKey::from_slice(bytes).map(Self) } } #[derive(AsRef, AsMut, From, Eq, PartialEq, Copy, Clone, Debug)] /// An `Ed25519` signature detached from its message. pub struct Signature(sign::Signature); mod manually_derive_serde_for_signature { //! TODO: //! remove this if sodiumoxide decides to reintroduce serialization of signatures //! use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer}; use crate::crypto::{sign::Signature, ByteObject}; impl Serialize for Signature { fn serialize(&self, serializer: S) -> Result where S: Serializer, { self.as_slice().serialize(serializer) } } impl<'de> Deserialize<'de> for Signature { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { let bytes = <&[u8] as Deserialize>::deserialize(deserializer)?; Self::from_slice(bytes).ok_or_else(|| { D::Error::custom(format!( "invalid length {}, expected {}", bytes.len(), Self::LENGTH, )) }) } } } impl ByteObject for Signature { const LENGTH: usize = sign::SIGNATUREBYTES; fn zeroed() -> Self { Self(sign::Signature::new([0_u8; Self::LENGTH])) } fn as_slice(&self) -> &[u8] { self.0.as_ref() } fn from_slice(bytes: &[u8]) -> Option { bytes.try_into().ok().map(Self) } } impl Signature { /// Computes the floating point representation of the hashed signature and ensures that it is /// below the given threshold: /// ```no_rust /// int(hash(signature)) / (2**hashbits - 1) <= threshold. /// ``` pub fn is_eligible(&self, threshold: f64) -> bool { if threshold < 0_f64 { return false; } else if threshold > 1_f64 { return true; } // safe unwraps: `to_bigint` never fails for `BigUint`s let numer = BigUint::from_bytes_le(sha256::hash(self.as_slice()).as_ref()) .to_bigint() .unwrap(); let denom = BigUint::from_bytes_le([u8::MAX; sha256::DIGESTBYTES].as_ref()) .to_bigint() .unwrap(); // safe unwrap: `threshold` is guaranteed to be finite Ratio::new(numer, denom) <= Ratio::from_float(threshold).unwrap() } } #[derive(AsRef, AsMut, From, Serialize, Deserialize, Eq, PartialEq, Clone)] /// A seed that can be used for `Ed25519` signing key pair generation. /// /// When this goes out of scope, its contents will be zeroed out. pub struct SigningKeySeed(sign::Seed); impl SigningKeySeed { /// Deterministically derives a new signing key pair from this seed. pub fn derive_signing_key_pair(&self) -> (PublicSigningKey, SecretSigningKey) { let (pk, sk) = sign::keypair_from_seed(&self.0); (PublicSigningKey(pk), SecretSigningKey(sk)) } } impl ByteObject for SigningKeySeed { const LENGTH: usize = sign::SEEDBYTES; fn from_slice(bytes: &[u8]) -> Option { sign::Seed::from_slice(bytes).map(Self) } fn zeroed() -> Self { Self(sign::Seed([0; sign::PUBLICKEYBYTES])) } fn as_slice(&self) -> &[u8] { self.0.as_ref() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_signature_is_eligible() { // eligible signature let sig = Signature::from_slice_unchecked(&[ 172, 29, 85, 219, 118, 44, 107, 32, 219, 253, 25, 242, 53, 45, 111, 62, 102, 130, 24, 8, 222, 199, 34, 120, 166, 163, 223, 229, 100, 50, 252, 244, 250, 88, 196, 151, 136, 48, 39, 198, 166, 86, 29, 151, 13, 81, 69, 198, 40, 148, 134, 126, 7, 202, 1, 56, 174, 43, 89, 28, 242, 194, 4, 0, ]); assert!(sig.is_eligible(0.5_f64)); // ineligible signature let sig = Signature::from_slice_unchecked(&[ 119, 2, 197, 174, 52, 165, 229, 22, 218, 210, 240, 188, 220, 232, 149, 129, 211, 13, 61, 217, 186, 79, 102, 15, 109, 237, 83, 193, 12, 117, 210, 66, 99, 230, 30, 131, 63, 108, 28, 222, 48, 92, 153, 71, 159, 220, 115, 181, 183, 155, 146, 182, 205, 89, 140, 234, 100, 40, 199, 248, 23, 147, 172, 0, ]); assert!(!sig.is_eligible(0.5_f64)); } } ================================================ FILE: rust/xaynet-core/src/lib.rs ================================================ #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr( doc, forbid(rustdoc::broken_intra_doc_links, rustdoc::private_intra_doc_links) )] #![doc( html_logo_url = "https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/xaynet_banner.png", html_favicon_url = "https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/favicon.png", issue_tracker_base_url = "https://github.com/xaynetwork/xaynet/issues" )] //! `xaynet_core` provides basic building blocks for implementing the //! _Privacy-Enhancing Technology_ (PET), a privacy preserving //! protocol for federated machine learning. Download the [whitepaper] //! for an introduction. //! //! [whitepaper]: https://uploads-ssl.webflow.com/5f0c5c0bb18a279f0a62919e/5f157004da6585f299fa542b_XayNet%20Whitepaper%202.1.pdf pub mod common; pub mod crypto; pub mod mask; pub mod message; #[cfg(any(feature = "testutils", test))] #[cfg_attr(docsrs, doc(cfg(feature = "testutils")))] pub mod testutils; use std::collections::HashMap; use thiserror::Error; use self::crypto::{ encrypt::{PublicEncryptKey, SecretEncryptKey}, sign::{PublicSigningKey, SecretSigningKey, Signature}, }; #[derive(Error, Debug)] #[error("initialization failed: insufficient system entropy to generate secrets")] /// An error related to insufficient system entropy for secrets at program startup. pub struct InitError; /// A public encryption key that identifies a coordinator. pub type CoordinatorPublicKey = PublicEncryptKey; /// A secret encryption key that belongs to the public key of a /// coordinator. pub type CoordinatorSecretKey = SecretEncryptKey; /// A public signature key that identifies a participant. pub type ParticipantPublicKey = PublicSigningKey; /// A secret signature key that belongs to the public key of a /// participant. pub type ParticipantSecretKey = SecretSigningKey; /// A public signature key that identifies a sum participant. pub type SumParticipantPublicKey = ParticipantPublicKey; /// A secret signature key that belongs to the public key of a sum /// participant. pub type SumParticipantSecretKey = ParticipantSecretKey; /// A public encryption key generated by a sum participant. It is used /// by the update participants to encrypt their masking seed for each /// sum participant. pub type SumParticipantEphemeralPublicKey = PublicEncryptKey; /// The secret counterpart of [`SumParticipantEphemeralPublicKey`] pub type SumParticipantEphemeralSecretKey = SecretEncryptKey; /// A public signature key that identifies an update participant. pub type UpdateParticipantPublicKey = ParticipantPublicKey; /// A secret signature key that belongs to the public key of an update /// participant. pub type UpdateParticipantSecretKey = ParticipantSecretKey; /// A signature to prove a participant's eligibility for a task. pub type ParticipantTaskSignature = Signature; /// A dictionary created during the sum phase of the protocol. It maps the public key of every sum /// participant to the ephemeral public key generated by that sum participant. pub type SumDict = HashMap; /// Local seed dictionaries are sent by update participants. They contain the participant's masking /// seed, encrypted with the ephemeral public key of each sum participant. pub type LocalSeedDict = HashMap; /// A dictionary created during the update phase of the protocol. The global seed dictionary is /// built from the local seed dictionaries sent by the update participants. It maps each sum /// participant to the encrypted masking seeds of all the update participants. pub type SeedDict = HashMap; /// Values of [`SeedDict`]. Sent to sum participants. pub type UpdateSeedDict = HashMap; ================================================ FILE: rust/xaynet-core/src/mask/config/mod.rs ================================================ //! Masking configuration parameters. //! //! See the [mask module] documentation since this is a private module anyways. //! //! [mask module]: crate::mask pub(crate) mod serialization; use std::convert::TryFrom; use num::{ bigint::{BigInt, BigUint}, rational::Ratio, traits::{pow::Pow, Num}, }; use serde::{Deserialize, Serialize}; use thiserror::Error; // target dependent maximum bytes per mask object element #[cfg(target_pointer_width = "16")] const MAX_BPN: u64 = u16::MAX as u64; #[cfg(target_pointer_width = "32")] const MAX_BPN: u64 = u32::MAX as u64; #[derive(Debug, Error)] /// Errors related to invalid masking configurations. pub enum InvalidMaskConfigError { #[error("invalid group type")] GroupType, #[error("invalid data type")] DataType, #[error("invalid bound type")] BoundType, #[error("invalid model type")] ModelType, } #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] #[repr(u8)] /// The order of the finite group. pub enum GroupType { /// A finite group of exact integer order. Integer = 0, /// A finite group of prime order. Prime = 1, /// A finite group of power-of-two order. Power2 = 2, } impl TryFrom for GroupType { type Error = InvalidMaskConfigError; fn try_from(byte: u8) -> Result { match byte { 0 => Ok(GroupType::Integer), 1 => Ok(GroupType::Prime), 2 => Ok(GroupType::Power2), _ => Err(InvalidMaskConfigError::GroupType), } } } #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] #[repr(u8)] /// The original primitive data type of the numerical values to be masked. pub enum DataType { /// Numbers of type f32. F32 = 0, /// Numbers of type f64. F64 = 1, /// Numbers of type i32. I32 = 2, /// Numbers of type i64. I64 = 3, } impl TryFrom for DataType { type Error = InvalidMaskConfigError; fn try_from(byte: u8) -> Result { match byte { 0 => Ok(DataType::F32), 1 => Ok(DataType::F64), 2 => Ok(DataType::I32), 3 => Ok(DataType::I64), _ => Err(InvalidMaskConfigError::DataType), } } } #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] #[repr(u8)] /// The bounds of the numerical values. /// /// For a value `v` to be absolutely bounded by another value `b`, it has to hold that /// `-b <= v <= b` or equivalently `|v| <= b`. pub enum BoundType { /// Numerical values absolutely bounded by 1. B0 = 0, /// Numerical values absolutely bounded by 100. B2 = 2, /// Numerical values absolutely bounded by 10_000. B4 = 4, /// Numerical values absolutely bounded by 1_000_000. B6 = 6, /// Numerical values absolutely bounded by their original primitive data type's maximum absolute /// value. Bmax = 255, } impl TryFrom for BoundType { type Error = InvalidMaskConfigError; fn try_from(byte: u8) -> Result { match byte { 0 => Ok(BoundType::B0), 2 => Ok(BoundType::B2), 4 => Ok(BoundType::B4), 6 => Ok(BoundType::B6), 255 => Ok(BoundType::Bmax), _ => Err(InvalidMaskConfigError::ModelType), } } } #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] #[repr(u8)] /// The maximum number of models to be aggregated. pub enum ModelType { /// At most 1_000 models to be aggregated. M3 = 3, /// At most 1_000_000 models to be aggregated. M6 = 6, /// At most 1_000_000_000 models to be aggregated. M9 = 9, /// At most 1_000_000_000_000 models to be aggregated. M12 = 12, } impl ModelType { /// Gets the maximum number of models that can be aggregated for this model type. pub fn max_nb_models(&self) -> usize { 10_usize.pow(*self as u8 as u32) } } impl TryFrom for ModelType { type Error = InvalidMaskConfigError; fn try_from(byte: u8) -> Result { match byte { 3 => Ok(ModelType::M3), 6 => Ok(ModelType::M6), 9 => Ok(ModelType::M9), 12 => Ok(ModelType::M12), _ => Err(InvalidMaskConfigError::ModelType), } } } #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] /// A masking configuration. /// /// This configuration is applied for masking, aggregation and unmasking of models. pub struct MaskConfig { /// The order of the finite group. pub group_type: GroupType, /// The original primitive data type of the numerical values to be masked. pub data_type: DataType, /// The bounds of the numerical values. pub bound_type: BoundType, /// The maximum number of models to be aggregated. pub model_type: ModelType, } impl MaskConfig { /// Returns the number of bytes needed for an element of a mask object. /// /// # Panics /// Panics if the bytes per number can't be represented as usize. pub(crate) fn bytes_per_number(&self) -> usize { let max_number = self.order() - BigUint::from(1_u8); let bpn = (max_number.bits() + 7) / 8; // the largest bpn from the masking configuration catalogue is currently 173, hence this is // almost impossible on 32 bits targets and smaller targets are currently not of interest #[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))] if bpn > MAX_BPN { panic!("the employed masking config is not supported on the target") } bpn as usize } /// Gets the additional shift value for masking/unmasking. pub fn add_shift(&self) -> Ratio { use BoundType::{Bmax, B0, B2, B4, B6}; use DataType::{F32, F64, I32, I64}; match self.bound_type { B0 => Ratio::from_integer(BigInt::from(1)), B2 => Ratio::from_integer(BigInt::from(100)), B4 => Ratio::from_integer(BigInt::from(10_000)), B6 => Ratio::from_integer(BigInt::from(1_000_000)), Bmax => match self.data_type { // safe unwraps: all numbers are finite F32 => Ratio::from_float(f32::MAX).unwrap(), F64 => Ratio::from_float(f64::MAX).unwrap(), I32 => Ratio::from_integer(-BigInt::from(i32::MIN)), I64 => Ratio::from_integer(-BigInt::from(i64::MIN)), }, } } /// Gets the exponential shift value for masking/unmasking. pub fn exp_shift(&self) -> BigInt { use BoundType::{Bmax, B0, B2, B4, B6}; use DataType::{F32, F64, I32, I64}; match self.data_type { F32 => match self.bound_type { B0 | B2 | B4 | B6 => BigInt::from(10).pow(10_u8), Bmax => BigInt::from(10).pow(45_u8), }, F64 => match self.bound_type { B0 | B2 | B4 | B6 => BigInt::from(10).pow(20_u8), Bmax => BigInt::from(10).pow(324_u16), }, I32 | I64 => BigInt::from(10).pow(10_u8), } } /// Gets the finite group order value for masking/unmasking. pub fn order(&self) -> BigUint { use BoundType::{Bmax, B0, B2, B4, B6}; use DataType::{F32, F64, I32, I64}; use GroupType::{Integer, Power2, Prime}; use ModelType::{M12, M3, M6, M9}; let order_str = match self.group_type { Integer => match self.data_type { F32 => match self.bound_type { B0 => match self.model_type { M3 => "20_000_000_000_001", M6 => "20_000_000_000_000_001", M9 => "20_000_000_000_000_000_001", M12 => "20_000_000_000_000_000_000_001", } B2 => match self.model_type { M3 => "2_000_000_000_000_001", M6 => "2_000_000_000_000_000_001", M9 => "2_000_000_000_000_000_000_001", M12 => "2_000_000_000_000_000_000_000_001", } B4 => match self.model_type { M3 => "200_000_000_000_000_001", M6 => "200_000_000_000_000_000_001", M9 => "200_000_000_000_000_000_000_001", M12 => "200_000_000_000_000_000_000_000_001", } B6 => match self.model_type { M3 => "20_000_000_000_000_000_001", M6 => "20_000_000_000_000_000_000_001", M9 => "20_000_000_000_000_000_000_000_001", M12 => "20_000_000_000_000_000_000_000_000_001", } Bmax => match self.model_type { M3 => "680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_001", M6 => "680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_001", M9 => "680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_001", M12 => "680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_001", } } F64 => match self.bound_type { B0 => match self.model_type { M3 => "200_000_000_000_000_000_000_001", M6 => "200_000_000_000_000_000_000_000_001", M9 => "200_000_000_000_000_000_000_000_000_001", M12 => "200_000_000_000_000_000_000_000_000_000_001", } B2 => match self.model_type { M3 => "20_000_000_000_000_000_000_000_001", M6 => "20_000_000_000_000_000_000_000_000_001", M9 => "20_000_000_000_000_000_000_000_000_000_001", M12 => "20_000_000_000_000_000_000_000_000_000_000_001", } B4 => match self.model_type { M3 => "2_000_000_000_000_000_000_000_000_001", M6 => "2_000_000_000_000_000_000_000_000_000_001", M9 => "2_000_000_000_000_000_000_000_000_000_000_001", M12 => "2_000_000_000_000_000_000_000_000_000_000_000_001", } B6 => match self.model_type { M3 => "200_000_000_000_000_000_000_000_000_001", M6 => "200_000_000_000_000_000_000_000_000_000_001", M9 => "200_000_000_000_000_000_000_000_000_000_000_001", M12 => "200_000_000_000_000_000_000_000_000_000_000_000_001", } Bmax => match self.model_type { M3 => "359_538_626_972_463_100_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_001", M6 => "359_538_626_972_463_100_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_001", M9 => "359_538_626_972_463_100_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_001", M12 => "359_538_626_972_463_100_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_001", } } I32 => match self.bound_type { B0 => match self.model_type { M3 => "20_000_000_000_001", M6 => "20_000_000_000_000_001", M9 => "20_000_000_000_000_000_001", M12 => "20_000_000_000_000_000_000_001", } B2 => match self.model_type { M3 => "2_000_000_000_000_001", M6 => "2_000_000_000_000_000_001", M9 => "2_000_000_000_000_000_000_001", M12 => "2_000_000_000_000_000_000_000_001", } B4 => match self.model_type { M3 => "200_000_000_000_000_001", M6 => "200_000_000_000_000_000_001", M9 => "200_000_000_000_000_000_000_001", M12 => "200_000_000_000_000_000_000_000_001", } B6 => match self.model_type { M3 => "20_000_000_000_000_000_001", M6 => "20_000_000_000_000_000_000_001", M9 => "20_000_000_000_000_000_000_000_001", M12 => "20_000_000_000_000_000_000_000_000_001", } Bmax => match self.model_type { M3 => "42_949_672_950_000_000_000_001", M6 => "42_949_672_950_000_000_000_000_001", M9 => "42_949_672_950_000_000_000_000_000_001", M12 => "42_949_672_950_000_000_000_000_000_000_001", } } I64 => match self.bound_type { B0 => match self.model_type { M3 => "20_000_000_000_001", M6 => "20_000_000_000_000_001", M9 => "20_000_000_000_000_000_001", M12 => "20_000_000_000_000_000_000_001", } B2 => match self.model_type { M3 => "2_000_000_000_000_001", M6 => "2_000_000_000_000_000_001", M9 => "2_000_000_000_000_000_000_001", M12 => "2_000_000_000_000_000_000_000_001", } B4 => match self.model_type { M3 => "200_000_000_000_000_001", M6 => "200_000_000_000_000_000_001", M9 => "200_000_000_000_000_000_000_001", M12 => "200_000_000_000_000_000_000_000_001", } B6 => match self.model_type { M3 => "20_000_000_000_000_000_001", M6 => "20_000_000_000_000_000_000_001", M9 => "20_000_000_000_000_000_000_000_001", M12 => "20_000_000_000_000_000_000_000_000_001", } Bmax => match self.model_type { M3 => "184_467_440_737_095_516_150_000_000_000_001", M6 => "184_467_440_737_095_516_150_000_000_000_000_001", M9 => "184_467_440_737_095_516_150_000_000_000_000_000_001", M12 => "184_467_440_737_095_516_150_000_000_000_000_000_000_001", } } } Prime => match self.data_type { F32 => match self.bound_type { B0 => match self.model_type { M3 => "20_000_000_000_021", M6 => "20_000_000_000_000_003", M9 => "20_000_000_000_000_000_011", M12 => "20_000_000_000_000_000_000_003", } B2 => match self.model_type { M3 => "2_000_000_000_000_021", M6 => "2_000_000_000_000_000_057", M9 => "2_000_000_000_000_000_000_069", M12 => "2_000_000_000_000_000_000_000_003", } B4 => match self.model_type { M3 => "200_000_000_000_000_003", M6 => "200_000_000_000_000_000_089", M9 => "200_000_000_000_000_000_000_069", M12 => "200_000_000_000_000_000_000_000_027", } B6 => match self.model_type { M3 => "20_000_000_000_000_000_011", M6 => "20_000_000_000_000_000_000_003", M9 => "20_000_000_000_000_000_000_000_009", M12 => "20_000_000_000_000_000_000_000_000_131", } Bmax => match self.model_type { M3 => "680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_281", M6 => "680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_323", M9 => "680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_191", M12 => "680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_083", } } F64 => match self.bound_type { B0 => match self.model_type { M3 => "200_000_000_000_000_000_000_069", M6 => "200_000_000_000_000_000_000_000_027", M9 => "200_000_000_000_000_000_000_000_000_017", M12 => "200_000_000_000_000_000_000_000_000_000_159", } B2 => match self.model_type { M3 => "20_000_000_000_000_000_000_000_009", M6 => "20_000_000_000_000_000_000_000_000_131", M9 => "20_000_000_000_000_000_000_000_000_000_047", M12 => "20_000_000_000_000_000_000_000_000_000_000_203", } B4 => match self.model_type { M3 => "2_000_000_000_000_000_000_000_000_039", M6 => "2_000_000_000_000_000_000_000_000_000_071", M9 => "2_000_000_000_000_000_000_000_000_000_000_017", M12 => "2_000_000_000_000_000_000_000_000_000_000_000_041", } B6 => match self.model_type { M3 => "200_000_000_000_000_000_000_000_000_017", M6 => "200_000_000_000_000_000_000_000_000_000_159", M9 => "200_000_000_000_000_000_000_000_000_000_000_003", M12 => "200_000_000_000_000_000_000_000_000_000_000_000_023", } Bmax => match self.model_type { M3 => "359_538_626_972_463_140_000_000_000_000_000_000_000_593_874_019_667_231_666_067_439_096_529_924_969_333_439_983_391_110_599_943_465_644_007_133_099_721_551_828_263_813_044_710_323_667_390_405_279_670_626_898_022_875_314_671_948_577_301_533_414_396_469_719_048_504_306_012_596_386_638_859_340_084_030_210_314_832_025_518_258_115_226_051_894_034_477_843_584_650_149_420_090_374_373_134_876_775_786_923_748_346_298_936_467_612_015_276_401_624_887_654_050_299_443_392_510_555_689_981_501_608_709_494_004_423_956_258_647_440_955_320_257_123_787_935_493_476_104_132_776_728_548_437_783_283_112_428_445_450_269_488_453_346_610_914_359_272_368_862_786_051_728_965_455_746_393_095_846_720_860_347_644_662_201_994_241_194_193_316_457_656_284_847_050_135_299_403_149_697_261_199_957_835_824_000_531_233_031_619_352_921_347_101_423_914_861_961_738_035_659_301", M6 => "359_538_626_972_463_139_999_999_999_999_999_999_999_903_622_106_309_601_840_402_558_296_261_360_055_843_460_163_714_984_640_183_652_353_129_826_112_739_444_431_322_400_938_984_152_600_575_421_591_212_739_537_896_016_542_591_595_727_264_024_538_428_559_469_178_136_611_680_881_710_150_818_089_794_351_154_869_285_409_959_876_691_068_635_451_827_253_162_844_058_791_343_487_286_852_635_234_799_336_668_682_655_217_329_655_102_622_197_942_194_212_857_658_834_043_465_713_831_143_523_811_067_060_369_640_438_677_832_007_091_511_212_788_398_470_391_285_320_720_769_417_737_628_120_102_221_909_739_846_753_580_817_462_645_602_854_496_103_866_327_474_145_187_363_329_320_852_679_912_679_009_543_036_760_757_409_720_574_191_338_832_841_104_183_169_976_025_577_743_061_881_721_861_634_977_765_641_182_996_194_573_448_626_763_720_938_201_976_656_541_039_724_303", M9 => "359_538_626_972_463_139_999_999_999_999_999_999_999_904_930_781_891_526_077_660_862_016_966_437_766_478_934_820_885_791_914_528_679_207_262_530_042_483_798_832_910_003_057_874_958_310_694_484_517_139_841_166_977_272_287_522_418_122_134_527_125_053_808_273_636_647_181_903_383_717_418_169_782_215_585_647_900_802_728_035_567_327_931_187_710_919_458_230_957_036_511_507_150_288_137_858_111_024_099_126_399_746_768_695_036_546_643_813_753_385_062_385_762_652_380_150_346_615_796_407_577_297_605_069_883_839_431_646_689_072_072_214_687_584_099_356_273_959_025_519_093_953_786_032_481_175_596_842_406_101_871_239_892_163_505_527_137_519_569_046_747_947_203_065_300_865_116_331_411_924_515_285_552_096_042_635_874_474_960_733_445_241_451_746_509_870_642_272_026_256_695_499_704_624_475_309_137_281_644_358_183_373_160_068_523_639_023_207_643_484_888_657_559_597", M12 => "359_538_626_972_463_139_999_999_999_999_999_999_999_904_931_540_467_867_407_238_817_633_447_114_203_759_664_620_787_471_913_925_990_313_859_370_016_783_101_785_327_523_046_787_247_090_978_931_042_236_128_228_564_142_680_745_383_377_953_776_024_143_512_065_781_667_978_525_748_300_241_659_425_164_472_387_573_470_260_831_720_974_578_793_447_369_507_661_739_490_218_806_790_001_765_109_117_055_431_552_295_585_457_639_803_896_262_637_528_011_897_242_316_426_079_400_392_728_240_523_639_775_219_294_589_603_009_325_941_759_217_573_340_626_063_716_838_671_315_192_395_974_939_441_284_468_885_927_433_422_082_497_928_190_254_190_935_717_337_452_741_850_223_510_814_859_331_413_287_559_285_438_144_477_756_395_583_878_761_313_295_130_567_342_888_620_541_025_745_968_373_350_261_259_032_809_052_052_475_301_496_416_128_372_300_050_762_773_363_722_300_553_930_211_649", } } I32 => match self.bound_type { B0 => match self.model_type { M3 => "20_000_000_000_021", M6 => "20_000_000_000_000_003", M9 => "20_000_000_000_000_000_011", M12 => "20_000_000_000_000_000_000_003", } B2 => match self.model_type { M3 => "2_000_000_000_000_021", M6 => "2_000_000_000_000_000_057", M9 => "2_000_000_000_000_000_000_069", M12 => "2_000_000_000_000_000_000_000_003", } B4 => match self.model_type { M3 => "200_000_000_000_000_003", M6 => "200_000_000_000_000_000_089", M9 => "200_000_000_000_000_000_000_069", M12 => "200_000_000_000_000_000_000_000_027", } B6 => match self.model_type { M3 => "20_000_000_000_000_000_011", M6 => "20_000_000_000_000_000_000_003", M9 => "20_000_000_000_000_000_000_000_009", M12 => "20_000_000_000_000_000_000_000_000_131", } Bmax => match self.model_type { M3 => "42_949_672_950_000_000_000_029", M6 => "42_949_672_950_000_000_000_000_049", M9 => "42_949_672_950_000_000_000_000_000_043", M12 => "42_949_672_950_000_000_000_000_000_000_109", } } I64 => match self.bound_type { B0 => match self.model_type { M3 => "20_000_000_000_021", M6 => "20_000_000_000_000_003", M9 => "20_000_000_000_000_000_011", M12 => "20_000_000_000_000_000_000_003", } B2 => match self.model_type { M3 => "2_000_000_000_000_021", M6 => "2_000_000_000_000_000_057", M9 => "2_000_000_000_000_000_000_069", M12 => "2_000_000_000_000_000_000_000_003", } B4 => match self.model_type { M3 => "200_000_000_000_000_003", M6 => "200_000_000_000_000_000_089", M9 => "200_000_000_000_000_000_000_069", M12 => "200_000_000_000_000_000_000_000_027", } B6 => match self.model_type { M3 => "20_000_000_000_000_000_011", M6 => "20_000_000_000_000_000_000_003", M9 => "20_000_000_000_000_000_000_000_009", M12 => "20_000_000_000_000_000_000_000_000_131", } Bmax => match self.model_type { M3 => "184_467_440_737_095_516_150_000_000_000_073", M6 => "184_467_440_737_095_516_150_000_000_000_000_013", M9 => "184_467_440_737_095_516_150_000_000_000_000_000_167", M12 => "184_467_440_737_095_516_150_000_000_000_000_000_000_089", } } }, Power2 => match self.data_type { F32 => match self.bound_type { B0 => match self.model_type { M3 => "35_184_372_088_832", M6 => "36_028_797_018_963_968", M9 => "36_893_488_147_419_103_232", M12 => "37_778_931_862_957_161_709_568", } B2 => match self.model_type { M3 => "2_251_799_813_685_248", M6 => "2_305_843_009_213_693_952", M9 => "2_361_183_241_434_822_606_848", M12 => "2_417_851_639_229_258_349_412_352", } B4 => match self.model_type { M3 => "288_230_376_151_711_744", M6 => "295_147_905_179_352_825_856", M9 => "302_231_454_903_657_293_676_544", M12 => "309_485_009_821_345_068_724_781_056", } B6 => match self.model_type { M3 => "36_893_488_147_419_103_232", M6 => "37_778_931_862_957_161_709_568", M9 => "38_685_626_227_668_133_590_597_632", M12 => "39_614_081_257_132_168_796_771_975_168", } Bmax => match self.model_type { M3 => "994_646_472_819_573_284_310_764_496_293_641_680_200_912_301_594_695_434_880_927_953_786_318_994_025_066_751_066_112", M6 => "1_018_517_988_167_243_043_134_222_844_204_689_080_525_734_196_832_968_125_318_070_224_677_190_649_881_668_353_091_698_688", M9 => "1_042_962_419_883_256_876_169_444_192_465_601_618_458_351_817_556_959_360_325_703_910_069_443_225_478_828_393_565_899_456_512", M12 => "1_067_993_517_960_455_041_197_510_853_084_776_057_301_352_261_178_326_384_973_520_803_911_109_862_890_320_275_011_481_043_468_288", } } F64 => match self.bound_type { B0 => match self.model_type { M3 => "302_231_454_903_657_293_676_544", M6 => "309_485_009_821_345_068_724_781_056", M9 => "316_912_650_057_057_350_374_175_801_344", M12 => "324_518_553_658_426_726_783_156_020_576_256", } B2 => match self.model_type { M3 => "38_685_626_227_668_133_590_597_632", M6 => "39_614_081_257_132_168_796_771_975_168", M9 => "20_282_409_603_651_670_423_947_251_286_016", M12 => "20_769_187_434_139_310_514_121_985_316_880_384", } B4 => match self.model_type { M3 => "2_475_880_078_570_760_549_798_248_448", M6 => "2_535_301_200_456_458_802_993_406_410_752", M9 => "2_596_148_429_267_413_814_265_248_164_610_048", M12 => "2_658_455_991_569_831_745_807_614_120_560_689_152", } B6 => match self.model_type { M3 => "316_912_650_057_057_350_374_175_801_344", M6 => "324_518_553_658_426_726_783_156_020_576_256", M9 => "332_306_998_946_228_968_225_951_765_070_086_144", M12 => "340_282_366_920_938_463_463_374_607_431_768_211_456", } Bmax => match self.model_type { M3 => "596_143_540_225_991_923_146_302_416_688_458_341_289_203_474_674_553_062_792_993_127_033_853_365_765_018_588_197_722_567_551_977_295_508_215_323_031_793_155_057_153_946_025_631_943_349_443_566_464_703_583_960_364_782_216_884_718_655_637_955_371_883_889_285_523_680_681_542_682_622_992_485_998_454_422_254_346_205_188_269_982_058_330_848_165_814_218_528_432_304_958_458_516_472_675_321_199_923_576_436_128_746_194_040_030_388_187_813_654_706_961_312_852_788_047_760_914_640_519_973_439_182_188_222_756_017_424_664_821_230_981_616_162_111_762_973_371_192_278_908_910_941_031_147_045_555_738_506_834_254_728_517_124_812_756_790_583_181_174_762_115_337_827_697_771_072_593_076_558_961_853_936_203_969_690_859_453_400_618_497_370_766_001_868_317_217_344_149_071_638_768_630_396_860_838_478_405_181_466_899_321_747_678_290_733_613_480_879_657_473_540_096", M6 => "610_450_985_191_415_729_301_813_674_688_981_341_480_144_358_066_742_336_300_024_962_082_665_846_543_379_034_314_467_909_173_224_750_600_412_490_784_556_190_778_525_640_730_247_109_989_830_212_059_856_469_975_413_536_990_089_951_903_373_266_300_809_102_628_376_249_017_899_707_005_944_305_662_417_328_388_450_514_112_788_461_627_730_788_521_793_759_773_114_680_277_461_520_868_019_528_908_721_742_270_595_836_102_696_991_117_504_321_182_419_928_384_361_254_960_907_176_591_892_452_801_722_560_740_102_161_842_856_776_940_525_174_950_002_445_284_732_100_893_602_724_803_615_894_574_649_076_230_998_276_842_001_535_808_262_953_557_177_522_956_406_105_935_562_517_578_335_310_396_376_938_430_672_864_963_440_080_282_233_341_307_664_385_913_156_830_560_408_649_358_099_077_526_385_498_601_886_905_822_104_905_469_622_569_711_220_204_420_769_252_905_058_304", M9 => "625_101_808_836_009_706_805_057_202_881_516_893_675_667_822_660_344_152_371_225_561_172_649_826_860_420_131_138_015_138_993_382_144_614_822_390_563_385_539_357_210_256_107_773_040_629_586_137_149_293_025_254_823_461_877_852_110_749_054_224_692_028_521_091_457_278_994_329_299_974_086_968_998_315_344_269_773_326_451_495_384_706_796_327_446_316_810_007_669_432_604_120_597_368_851_997_602_531_064_085_090_136_169_161_718_904_324_424_890_798_006_665_585_925_079_968_948_830_097_871_668_963_902_197_864_613_727_085_339_587_097_779_148_802_503_971_565_671_315_049_190_198_902_676_044_440_654_060_542_235_486_209_572_667_661_264_442_549_783_507_359_852_478_016_018_000_215_357_845_889_984_953_009_013_722_562_642_209_006_941_499_048_331_175_072_594_493_858_456_942_693_455_387_018_750_568_332_191_561_835_423_200_893_511_384_289_489_326_867_714_974_779_703_296", M12 => "640_104_252_248_073_939_768_378_575_750_673_299_123_883_850_404_192_412_028_134_974_640_793_422_705_070_214_285_327_502_329_223_316_085_578_127_936_906_792_301_783_302_254_359_593_604_696_204_440_876_057_860_939_224_962_920_561_407_031_526_084_637_205_597_652_253_690_193_203_173_465_056_254_274_912_532_247_886_286_331_273_939_759_439_305_028_413_447_853_498_986_619_491_705_704_445_544_991_809_623_132_299_437_221_600_158_028_211_088_177_158_825_559_987_281_888_203_602_020_220_589_019_035_850_613_364_456_535_387_737_188_125_848_373_764_066_883_247_426_610_370_763_676_340_269_507_229_757_995_249_137_878_602_411_685_134_789_170_978_311_536_488_937_488_402_432_220_526_434_191_344_591_881_230_051_904_145_622_023_108_095_025_491_123_274_336_761_711_059_909_318_098_316_307_200_581_972_164_159_319_473_357_714_955_657_512_437_070_712_540_134_174_416_175_104", } } I32 => match self.bound_type { B0 => match self.model_type { M3 => "35_184_372_088_832", M6 => "36_028_797_018_963_968", M9 => "36_893_488_147_419_103_232", M12 => "37_778_931_862_957_161_709_568", } B2 => match self.model_type { M3 => "2_251_799_813_685_248", M6 => "2_305_843_009_213_693_952", M9 => "2_361_183_241_434_822_606_848", M12 => "2_417_851_639_229_258_349_412_352", } B4 => match self.model_type { M3 => "288_230_376_151_711_744", M6 => "295_147_905_179_352_825_856", M9 => "302_231_454_903_657_293_676_544", M12 => "309_485_009_821_345_068_724_781_056", } B6 => match self.model_type { M3 => "36_893_488_147_419_103_232", M6 => "37_778_931_862_957_161_709_568", M9 => "38_685_626_227_668_133_590_597_632", M12 => "39_614_081_257_132_168_796_771_975_168", } Bmax => match self.model_type { M3 => "75_557_863_725_914_323_419_136", M6 => "77_371_252_455_336_267_181_195_264", M9 => "79_228_162_514_264_337_593_543_950_336", M12 => "81_129_638_414_606_681_695_789_005_144_064", } } I64 => match self.bound_type { B0 => match self.model_type { M3 => "35_184_372_088_832", M6 => "36_028_797_018_963_968", M9 => "36_893_488_147_419_103_232", M12 => "37_778_931_862_957_161_709_568", } B2 => match self.model_type { M3 => "2_251_799_813_685_248", M6 => "2_305_843_009_213_693_952", M9 => "2_361_183_241_434_822_606_848", M12 => "2_417_851_639_229_258_349_412_352", } B4 => match self.model_type { M3 => "288_230_376_151_711_744", M6 => "295_147_905_179_352_825_856", M9 => "302_231_454_903_657_293_676_544", M12 => "309_485_009_821_345_068_724_781_056", } B6 => match self.model_type { M3 => "36_893_488_147_419_103_232", M6 => "37_778_931_862_957_161_709_568", M9 => "38_685_626_227_668_133_590_597_632", M12 => "39_614_081_257_132_168_796_771_975_168", } Bmax => match self.model_type { M3 => "324_518_553_658_426_726_783_156_020_576_256", M6 => "332_306_998_946_228_968_225_951_765_070_086_144", M9 => "340_282_366_920_938_463_463_374_607_431_768_211_456", M12 => "348_449_143_727_040_986_586_495_598_010_130_648_530_944", } } } }; // safe unwrap: string and radix are valid BigUint::from_str_radix(order_str, 10).unwrap() } } #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] /// Convenience struct for a pair of masking configurations. /// /// One configuration is intended for (un)masking a vector of values, the other /// for a unit value. pub struct MaskConfigPair { pub vect: MaskConfig, pub unit: MaskConfig, } impl From for MaskConfigPair { /// Creates two copies of the given masking configuration as a pair. fn from(config: MaskConfig) -> Self { Self { vect: config, unit: config, } } } ================================================ FILE: rust/xaynet-core/src/mask/config/serialization.rs ================================================ //! Serialization of masking configurations. //! //! See the [mask module] documentation since this is a private module anyways. //! //! [mask module]: crate::mask use std::convert::TryInto; use anyhow::{anyhow, Context}; use crate::{ mask::config::MaskConfig, message::{ traits::{FromBytes, ToBytes}, DecodeError, }, }; const GROUP_TYPE_FIELD: usize = 0; const DATA_TYPE_FIELD: usize = 1; const BOUND_TYPE_FIELD: usize = 2; const MODEL_TYPE_FIELD: usize = 3; pub(crate) const MASK_CONFIG_BUFFER_LEN: usize = 4; /// A buffer for serialized masking configurations. pub struct MaskConfigBuffer { inner: T, } impl> MaskConfigBuffer { /// Creates a new buffer from `bytes`. /// /// # Errors /// Fails if the `bytes` don't conform to the required buffer length for masking configurations. pub fn new(bytes: T) -> Result { let buffer = Self { inner: bytes }; buffer .check_buffer_length() .context("not a valid MaskConfigBuffer")?; Ok(buffer) } /// Creates a new buffer from `bytes`. pub fn new_unchecked(bytes: T) -> Self { Self { inner: bytes } } /// Checks if this buffer conforms to the required buffer length for masking configurations. /// /// # Errors /// Fails if the buffer is too small. pub fn check_buffer_length(&self) -> Result<(), DecodeError> { let len = self.inner.as_ref().len(); if len < MASK_CONFIG_BUFFER_LEN { return Err(anyhow!( "invalid buffer length: {} < {}", len, MASK_CONFIG_BUFFER_LEN )); } Ok(()) } /// Gets the serialized group type of the masking configuration. /// /// # Panics /// May panic if this buffer is unchecked. pub fn group_type(&self) -> u8 { self.inner.as_ref()[GROUP_TYPE_FIELD] } /// Gets the serialized data type of the masking configuration. /// /// # Panics /// May panic if this buffer is unchecked. pub fn data_type(&self) -> u8 { self.inner.as_ref()[DATA_TYPE_FIELD] } /// Gets the serialized bound type of the masking configuration. /// /// # Panics /// May panic if this buffer is unchecked. pub fn bound_type(&self) -> u8 { self.inner.as_ref()[BOUND_TYPE_FIELD] } /// Gets the serialized model type of the masking configuration. /// /// # Panics /// May panic if this buffer is unchecked. pub fn model_type(&self) -> u8 { self.inner.as_ref()[MODEL_TYPE_FIELD] } } impl> MaskConfigBuffer { /// Sets the serialized group type of the masking configuration. /// /// # Panics /// May panic if this buffer is unchecked. pub fn set_group_type(&mut self, value: u8) { self.inner.as_mut()[GROUP_TYPE_FIELD] = value; } /// Sets the serialized data type of the masking configuration. /// /// # Panics /// May panic if this buffer is unchecked. pub fn set_data_type(&mut self, value: u8) { self.inner.as_mut()[DATA_TYPE_FIELD] = value; } /// Sets the serialized bound type of the masking configuration. /// /// # Panics /// May panic if this buffer is unchecked. pub fn set_bound_type(&mut self, value: u8) { self.inner.as_mut()[BOUND_TYPE_FIELD] = value; } /// Sets the serialized model type of the masking configuration. /// /// # Panics /// May panic if this buffer is unchecked. pub fn set_model_type(&mut self, value: u8) { self.inner.as_mut()[MODEL_TYPE_FIELD] = value; } } impl ToBytes for MaskConfig { fn buffer_length(&self) -> usize { MASK_CONFIG_BUFFER_LEN } fn to_bytes>(&self, buffer: &mut T) { let mut writer = MaskConfigBuffer::new_unchecked(buffer.as_mut()); writer.set_group_type(self.group_type as u8); writer.set_data_type(self.data_type as u8); writer.set_bound_type(self.bound_type as u8); writer.set_model_type(self.model_type as u8); } } impl FromBytes for MaskConfig { fn from_byte_slice>(buffer: &T) -> Result { let reader = MaskConfigBuffer::new(buffer.as_ref())?; Ok(Self { group_type: reader .group_type() .try_into() .context("invalid masking config")?, data_type: reader .data_type() .try_into() .context("invalid masking config")?, bound_type: reader .bound_type() .try_into() .context("invalid masking config")?, model_type: reader .model_type() .try_into() .context("invalid masking config")?, }) } fn from_byte_stream + ExactSizeIterator>( iter: &mut I, ) -> Result { let buf: Vec = iter.take(MASK_CONFIG_BUFFER_LEN).collect(); Self::from_byte_slice(&buf) } } #[cfg(test)] mod tests { use super::*; use crate::mask::config::{BoundType, DataType, GroupType, MaskConfig, ModelType}; #[test] fn serialize() { let config = MaskConfig { group_type: GroupType::Prime, data_type: DataType::F64, bound_type: BoundType::Bmax, model_type: ModelType::M9, }; let mut buf = vec![0xff; 4]; config.to_bytes(&mut buf); assert_eq!(buf, vec![1, 1, 255, 9]); } #[test] fn deserialize() { let bytes = vec![1, 1, 255, 9]; let config = MaskConfig::from_byte_slice(&bytes).unwrap(); assert_eq!( config, MaskConfig { group_type: GroupType::Prime, data_type: DataType::F64, bound_type: BoundType::Bmax, model_type: ModelType::M9, } ); } #[test] fn stream_deserialize() { let mut bytes = vec![1, 1, 255, 9].into_iter(); let config = MaskConfig::from_byte_stream(&mut bytes).unwrap(); assert_eq!( config, MaskConfig { group_type: GroupType::Prime, data_type: DataType::F64, bound_type: BoundType::Bmax, model_type: ModelType::M9, } ); } } ================================================ FILE: rust/xaynet-core/src/mask/masking.rs ================================================ //! Masking, aggregation and unmasking of models. //! //! See the [mask module] documentation since this is a private module anyways. //! //! [mask module]: crate::mask use std::iter::{self, Iterator}; use num::{ bigint::{BigInt, BigUint, ToBigInt}, clamp, rational::Ratio, traits::clamp_max, }; use rand::SeedableRng; use rand_chacha::ChaCha20Rng; use thiserror::Error; use crate::{ crypto::{prng::generate_integer, ByteObject}, mask::{ config::MaskConfigPair, model::Model, object::{MaskObject, MaskUnit, MaskVect}, scalar::Scalar, seed::MaskSeed, }, }; #[derive(Debug, Error, Eq, PartialEq)] /// Errors related to the unmasking of models. pub enum UnmaskingError { #[error("there is no model to unmask")] NoModel, #[error("too many models were aggregated for the current unmasking configuration")] TooManyModels, #[error("too many scalars were aggregated for the current unmasking configuration")] TooManyScalars, #[error("the masked model is incompatible with the mask used for unmasking")] MaskManyMismatch, #[error("the masked scalar is incompatible with the mask used for unmasking")] MaskOneMismatch, #[error("the mask is invalid")] InvalidMask, } #[derive(Debug, Error)] /// Errors related to the aggregation of masks and models. pub enum AggregationError { // TODO rename Model -> Vector; or use MaskMany/One terminology #[error("the object to aggregate is invalid")] InvalidObject, #[error("too many models were aggregated for the current unmasking configuration")] TooManyModels, #[error("too many scalars were aggregated for the current unmasking configuration")] TooManyScalars, #[error("the model to aggregate is incompatible with the current aggregated scalar")] ModelMismatch, #[error("the scalar to aggregate is incompatible with the current aggregated scalar")] ScalarMismatch, } #[derive(Debug, Clone)] /// An aggregator for masks and masked models. pub struct Aggregation { nb_models: usize, object: MaskObject, object_size: usize, } impl From for Aggregation { fn from(object: MaskObject) -> Self { Self { nb_models: 1, object_size: object.vect.data.len(), object, } } } impl From for MaskObject { fn from(aggr: Aggregation) -> Self { aggr.object } } #[allow(clippy::len_without_is_empty)] impl Aggregation { /// Creates a new, empty aggregator for masks or masked models. pub fn new(config: MaskConfigPair, object_size: usize) -> Self { Self { nb_models: 0, object: MaskObject::empty(config, object_size), object_size, } } /// Gets the length of the aggregated mask object. pub fn len(&self) -> usize { self.object_size } /// Gets the masking configurations of the aggregator. pub fn config(&self) -> MaskConfigPair { MaskConfigPair { vect: self.object.vect.config, unit: self.object.unit.config, } } /// Validates if unmasking of the aggregated masked model with the given `mask` may be /// safely performed. /// /// This should be checked before calling [`unmask()`], since unmasking may return garbage /// values otherwise. /// /// # Errors /// Fails in one of the following cases: /// - The aggregator has not yet aggregated any models. /// - The number of aggregated masked models is larger than the chosen masking configuration /// allows. /// - The masking configuration of the aggregator and of the `mask` don't coincide. /// - The length of the aggregated masked model and the `mask` don't coincide. /// - The `mask` itself is invalid. /// /// Even though it does not produce any meaningful values, it is safe and technically possible /// due to the [`MaskObject`] type to validate, that: /// - a mask may unmask another mask /// - a masked model may unmask a mask /// - a masked model may unmask another masked model /// /// [`unmask()`]: Aggregation::unmask pub fn validate_unmasking(&self, mask: &MaskObject) -> Result<(), UnmaskingError> { // We cannot perform unmasking without at least one real model if self.nb_models == 0 { return Err(UnmaskingError::NoModel); } if self.nb_models > self.object.vect.config.model_type.max_nb_models() { return Err(UnmaskingError::TooManyModels); } if self.nb_models > self.object.unit.config.model_type.max_nb_models() { return Err(UnmaskingError::TooManyScalars); } if self.object.vect.config != mask.vect.config || self.object_size != mask.vect.data.len() { return Err(UnmaskingError::MaskManyMismatch); } if self.object.unit.config != mask.unit.config { return Err(UnmaskingError::MaskOneMismatch); } if !mask.is_valid() { return Err(UnmaskingError::InvalidMask); } Ok(()) } /// Unmasks the aggregated masked model with the given `mask`. /// /// It should be checked that [`validate_unmasking()`] succeeds before calling this, since /// unmasking may return garbage values otherwise. The unmasking is performed in opposite order /// as described for [`mask()`]. /// /// # Panics /// This may only panic if [`validate_unmasking()`] fails. /// /// Even though it does not produce any meaningful values, it is safe and technically possible /// due to the [`MaskObject`] type to unmask: /// - a mask with another mask /// - a mask with a masked model /// - a masked model with another masked model /// /// if [`validate_unmasking()`] returns `true`. /// /// [`validate_unmasking()`]: Aggregation::validate_unmasking /// [`mask()`]: Masker::mask pub fn unmask(self, mask_obj: MaskObject) -> Model { let MaskObject { vect, unit } = self.object; let (masked_n, config_n) = (vect.data, vect.config); let (masked_1, config_1) = (unit.data, unit.config); let mask_n = mask_obj.vect.data; let mask_1 = mask_obj.unit.data; // unmask scalar sum let scaled_add_shift_1 = config_1.add_shift() * BigInt::from(self.nb_models); let exp_shift_1 = config_1.exp_shift(); let order_1 = config_1.order(); let n = (masked_1 + &order_1 - mask_1) % &order_1; let ratio = Ratio::::from(n.to_bigint().unwrap()); let scalar_sum = ratio / &exp_shift_1 - &scaled_add_shift_1; // unmask global model let scaled_add_shift_n = config_n.add_shift() * BigInt::from(self.nb_models); let exp_shift_n = config_n.exp_shift(); let order_n = config_n.order(); masked_n .into_iter() .zip(mask_n) .map(|(masked, mask)| { // PANIC_SAFE: The substraction panics if it // underflows, which can only happen if: // // mask > order_n // // If the mask is valid, we are guaranteed that this // cannot happen. Thus this method may panic only if // given an invalid mask. let n = (masked + &order_n - mask) % &order_n; // UNWRAP_SAFE: to_bigint never fails for BigUint let ratio = Ratio::::from(n.to_bigint().unwrap()); let unmasked = ratio / &exp_shift_n - &scaled_add_shift_n; // scaling correction unmasked / &scalar_sum }) .collect() } /// Validates if aggregation of the aggregated mask object with the given `object` may be safely /// performed. /// /// This should be checked before calling [`aggregate()`], since aggregation may return garbage /// values otherwise. /// /// # Errors /// Fails in one of the following cases: /// - The masking configuration of the aggregator and of the `object` don't coincide. /// - The length of the aggregated masks or masked model and the `object` don't coincide. If the /// aggregator is empty, then an `object` of any length may be aggregated. /// - The new number of aggregated masks or masked models would exceed the number that the /// chosen masking configuration allows. /// - The `object` itself is invalid. /// /// Even though it does not produce any meaningful values, it is safe and technically possible /// due to the [`MaskObject`] type to validate, that a mask may be aggregated with a masked /// model. /// /// [`aggregate()`]: Aggregation::aggregate pub fn validate_aggregation(&self, object: &MaskObject) -> Result<(), AggregationError> { if self.object.vect.config != object.vect.config { return Err(AggregationError::ModelMismatch); } if self.object.unit.config != object.unit.config { return Err(AggregationError::ScalarMismatch); } if self.object_size != object.vect.data.len() { return Err(AggregationError::ModelMismatch); } if self.nb_models >= self.object.vect.config.model_type.max_nb_models() { return Err(AggregationError::TooManyModels); } if self.nb_models >= self.object.unit.config.model_type.max_nb_models() { return Err(AggregationError::TooManyScalars); } if !object.is_valid() { return Err(AggregationError::InvalidObject); } Ok(()) } /// Aggregates the aggregated mask object with the given `object`. /// /// It should be checked that [`validate_aggregation()`] succeeds before calling this, since /// aggregation may return garbage values otherwise. /// /// # Errors /// Even though it does not produce any meaningful values, it is safe and technically possible /// due to the [`MaskObject`] type to aggregate a mask with a masked model if /// [`validate_aggregation()`] returns `true`. /// /// [`validate_aggregation()`]: Aggregation::validate_aggregation pub fn aggregate(&mut self, object: MaskObject) { if self.nb_models == 0 { self.object = object; self.nb_models = 1; return; } let order_n = self.object.vect.config.order(); for (i, j) in self .object .vect .data .iter_mut() .zip(object.vect.data.into_iter()) { *i = (&*i + j) % &order_n } let order_1 = self.object.unit.config.order(); let a = &mut self.object.unit.data; let b = object.unit.data; *a = (&*a + b) % &order_1; self.nb_models += 1; } } /// A masker for models. pub struct Masker { config: MaskConfigPair, seed: MaskSeed, } impl Masker { /// Creates a new masker with the given masking `config`uration with a randomly generated seed. pub fn new(config: MaskConfigPair) -> Self { Self { config, seed: MaskSeed::generate(), } } /// Creates a new masker with the given masking `config`uration and `seed`. pub fn with_seed(config: MaskConfigPair, seed: MaskSeed) -> Self { Self { config, seed } } } impl Masker { /// Masks the given `model` wrt the masking configuration. Enforces bounds on the scalar and /// weights. /// /// The masking proceeds in the following steps: /// - Clamp the scalar and the weights according to the masking configuration. /// - Scale the weights by the scalar. /// - Shift the weights into the non-negative reals. /// - Shift the weights into the non-negative integers. /// - Shift the weights into the finite group. /// - Mask the weights with random elements from the finite group. /// /// The `scalar` is also masked, following a similar process. /// /// The random elements are derived from a seeded PRNG. Unmasking as performed in [`unmask()`] /// proceeds in reverse order. /// /// [`unmask()`]: Aggregation::unmask pub fn mask(self, scalar: Scalar, model: &Model) -> (MaskSeed, MaskObject) { let (random_int, mut random_ints) = self.random_ints(); let Self { config, seed } = self; let MaskConfigPair { vect: config_n, unit: config_1, } = config; // clamp the scalar let add_shift_1 = config_1.add_shift(); let scalar_ratio = scalar.into(); let scalar_clamped = clamp_max(&scalar_ratio, &add_shift_1); let exp_shift_n = config_n.exp_shift(); let add_shift_n = config_n.add_shift(); let order_n = config_n.order(); let higher_bound = &add_shift_n; let lower_bound = -&add_shift_n; // mask the (scaled) weights let masked_weights = model .iter() .zip(&mut random_ints) .map(|(weight, rand_int)| { let scaled = scalar_clamped * weight; let scaled_clamped = clamp(&scaled, &lower_bound, higher_bound); // PANIC_SAFE: shifted weight is guaranteed to be non-negative let shifted = ((scaled_clamped + &add_shift_n) * &exp_shift_n) .to_integer() .to_biguint() .unwrap(); (shifted + rand_int) % &order_n }) .collect(); let masked_model = MaskVect::new_unchecked(config_n, masked_weights); // mask the scalar // PANIC_SAFE: shifted scalar is guaranteed to be non-negative let shifted = ((scalar_clamped + &add_shift_1) * config_1.exp_shift()) .to_integer() .to_biguint() .unwrap(); let masked = (shifted + random_int) % config_1.order(); let masked_scalar = MaskUnit::new_unchecked(config_1, masked); (seed, MaskObject::new_unchecked(masked_model, masked_scalar)) } /// Randomly generates integers wrt the masking configurations. /// /// The first is generated wrt the scalar configuration, while the rest are /// wrt the vector configuration and returned as an iterator. fn random_ints(&self) -> (BigUint, impl Iterator) { let order_n = self.config.vect.order(); let order_1 = self.config.unit.order(); let mut prng = ChaCha20Rng::from_seed(self.seed.as_array()); let int = generate_integer(&mut prng, &order_1); let ints = iter::from_fn(move || Some(generate_integer(&mut prng, &order_n))); (int, ints) } } #[cfg(test)] mod tests { use std::iter; use num::traits::Signed; use rand::{ distributions::{Distribution, Uniform}, SeedableRng, }; use rand_chacha::ChaCha20Rng; use super::*; use crate::mask::{ config::{ BoundType::{Bmax, B0, B2, B4, B6}, DataType::{F32, F64, I32, I64}, GroupType::{Integer, Power2, Prime}, MaskConfig, ModelType::M3, }, model::FromPrimitives, scalar::FromPrimitive, }; /// Generate tests for masking and unmasking of a single model: /// - generate random weights from a uniform distribution with a seeded PRNG /// - create a model from the weights and mask it /// - check that all masked weights belong to the chosen finite group /// - unmask the masked model /// - check that all unmasked weights are equal to the original weights (up to a tolerance /// determined by the masking configuration) /// /// The arguments to the macro are: /// - a suffix for the test name /// - the group type of the model (variants of `GroupType`) /// - the data type of the model (either primitives or variants of `DataType`) /// - an absolute bound for the weights (optional, choices: 1, 100, 10_000, 1_000_000) /// - the number of weights macro_rules! test_masking { ($suffix:ident, $group:ty, $data:ty, $bound:expr, $len:expr $(,)?) => { paste::item! { #[test] fn []() { // Step 1: Build the masking config let config = MaskConfig { group_type: $group, data_type: paste::expr! { [<$data:upper>] }, bound_type: match $bound { 1 => B0, 100 => B2, 10_000 => B4, 1_000_000 => B6, _ => Bmax, }, model_type: M3, }; let vect_len = $len as usize; // Step 2: Generate a random model let bound = if $bound == 0 { paste::expr! { [<$data:lower>]::MAX / (2.1 as [<$data:lower>]) } } else { paste::expr! { $bound as [<$data:lower>] } }; let mut prng = ChaCha20Rng::from_seed(MaskSeed::generate().as_array()); let random_weights = Uniform::new_inclusive(-bound, bound) .sample_iter(&mut prng) .take(vect_len); let model = Model::from_primitives(random_weights).unwrap(); assert_eq!(model.len(), vect_len); // Step 3 (actual test): // a. mask the model // b. derive the mask corresponding to the seed used // c. unmask the model and check it against the original one. let (mask_seed, masked_model) = Masker::new(config.into()).mask(Scalar::unit(), &model); assert_eq!(masked_model.vect.data.len(), vect_len); assert!(masked_model.is_valid()); let mask = mask_seed.derive_mask(vect_len, config.into()); let aggregation = Aggregation::from(masked_model); let unmasked_model = aggregation.unmask(mask); let tolerance = Ratio::from_integer(config.exp_shift()).recip(); assert!( model.iter() .zip(unmasked_model.iter()) .all(|(weight, unmasked_weight)| { (weight - unmasked_weight).abs() <= tolerance }) ); } } }; ($suffix:ident, $group:ty, $data:ty, $len:expr $(,)?) => { test_masking!($suffix, $group, $data, 0, $len); }; } test_masking!(int_f32_b0, Integer, f32, 1, 10); test_masking!(int_f32_b2, Integer, f32, 100, 10); test_masking!(int_f32_b4, Integer, f32, 10_000, 10); test_masking!(int_f32_b6, Integer, f32, 1_000_000, 10); test_masking!(int_f32_bmax, Integer, f32, 10); test_masking!(prime_f32_b0, Prime, f32, 1, 10); test_masking!(prime_f32_b2, Prime, f32, 100, 10); test_masking!(prime_f32_b4, Prime, f32, 10_000, 10); test_masking!(prime_f32_b6, Prime, f32, 1_000_000, 10); test_masking!(prime_f32_bmax, Prime, f32, 10); test_masking!(pow_f32_b0, Power2, f32, 1, 10); test_masking!(pow_f32_b2, Power2, f32, 100, 10); test_masking!(pow_f32_b4, Power2, f32, 10_000, 10); test_masking!(pow_f32_b6, Power2, f32, 1_000_000, 10); test_masking!(pow_f32_bmax, Power2, f32, 10); test_masking!(int_f64_b0, Integer, f64, 1, 10); test_masking!(int_f64_b2, Integer, f64, 100, 10); test_masking!(int_f64_b4, Integer, f64, 10_000, 10); test_masking!(int_f64_b6, Integer, f64, 1_000_000, 10); test_masking!(int_f64_bmax, Integer, f64, 10); test_masking!(prime_f64_b0, Prime, f64, 1, 10); test_masking!(prime_f64_b2, Prime, f64, 100, 10); test_masking!(prime_f64_b4, Prime, f64, 10_000, 10); test_masking!(prime_f64_b6, Prime, f64, 1_000_000, 10); test_masking!(prime_f64_bmax, Prime, f64, 10); test_masking!(pow_f64_b0, Power2, f64, 1, 10); test_masking!(pow_f64_b2, Power2, f64, 100, 10); test_masking!(pow_f64_b4, Power2, f64, 10_000, 10); test_masking!(pow_f64_b6, Power2, f64, 1_000_000, 10); test_masking!(pow_f64_bmax, Power2, f64, 10); test_masking!(int_i32_b0, Integer, i32, 1, 10); test_masking!(int_i32_b2, Integer, i32, 100, 10); test_masking!(int_i32_b4, Integer, i32, 10_000, 10); test_masking!(int_i32_b6, Integer, i32, 1_000_000, 10); test_masking!(int_i32_bmax, Integer, i32, 10); test_masking!(prime_i32_b0, Prime, i32, 1, 10); test_masking!(prime_i32_b2, Prime, i32, 100, 10); test_masking!(prime_i32_b4, Prime, i32, 10_000, 10); test_masking!(prime_i32_b6, Prime, i32, 1_000_000, 10); test_masking!(prime_i32_bmax, Prime, i32, 10); test_masking!(pow_i32_b0, Power2, i32, 1, 10); test_masking!(pow_i32_b2, Power2, i32, 100, 10); test_masking!(pow_i32_b4, Power2, i32, 10_000, 10); test_masking!(pow_i32_b6, Power2, i32, 1_000_000, 10); test_masking!(pow_i32_bmax, Power2, i32, 10); test_masking!(int_i64_b0, Integer, i64, 1, 10); test_masking!(int_i64_b2, Integer, i64, 100, 10); test_masking!(int_i64_b4, Integer, i64, 10_000, 10); test_masking!(int_i64_b6, Integer, i64, 1_000_000, 10); test_masking!(int_i64_bmax, Integer, i64, 10); test_masking!(prime_i64_b0, Prime, i64, 1, 10); test_masking!(prime_i64_b2, Prime, i64, 100, 10); test_masking!(prime_i64_b4, Prime, i64, 10_000, 10); test_masking!(prime_i64_b6, Prime, i64, 1_000_000, 10); test_masking!(prime_i64_bmax, Prime, i64, 10); test_masking!(pow_i64_b0, Power2, i64, 1, 10); test_masking!(pow_i64_b2, Power2, i64, 100, 10); test_masking!(pow_i64_b4, Power2, i64, 10_000, 10); test_masking!(pow_i64_b6, Power2, i64, 1_000_000, 10); test_masking!(pow_i64_bmax, Power2, i64, 10); /// Generate tests for masking and unmasking of a single model: /// - generate random scalar from a uniform distribution with a seeded PRNG /// - scale a model of unit weights and mask it /// - check that all masked weights belong to the chosen finite group /// - unmask the masked model /// - check that all unmasked weights are equal to the original weights (up to a tolerance /// determined by the masking configuration) /// /// The arguments to the macro are: /// - a suffix for the test name /// - the group type of the model and scalar (variants of `GroupType`) /// - the data type of the model and scalar (either float primitives or float variants of /// `DataType`) /// - an absolute bound for the scalar (optional, choices: 1, 100, 10_000, 1_000_000) /// - the number of weights macro_rules! test_masking_scalar { ($suffix:ident, $group:ty, $data:ty, $bound:expr, $len:expr $(,)?) => { paste::item! { #[test] fn []() { // Step 1: Build the masking config let config = MaskConfig { group_type: $group, data_type: paste::expr! { [<$data:upper>] }, bound_type: match $bound { 1 => B0, 100 => B2, 10_000 => B4, 1_000_000 => B6, _ => Bmax, }, model_type: M3, }; let vect_len = $len as usize; // Step 2: Generate a random scalar from (0, bound] // take vector [1, ..., 1] as the model to scale let bound = if $bound == 0 { paste::expr! { [<$data:lower>]::MAX / (2.1 as [<$data:lower>]) } } else { paste::expr! { $bound as [<$data:lower>] } }; let eps = [<$data:lower>]::EPSILON; let mut prng = ChaCha20Rng::from_seed(MaskSeed::generate().as_array()); let random_weight = Uniform::new_inclusive(eps, bound).sample(&mut prng); let scalar = Scalar::from_primitive(random_weight).unwrap(); let model = Model::from_primitives(iter::repeat(1).take(vect_len)).unwrap(); assert_eq!(model.len(), vect_len); // Step 3 (actual test): // a. mask the model // b. derive the mask corresponding to the seed used // c. unmask the model and check it against the expected [1, ..., 1] let (mask_seed, masked_model) = Masker::new(config.into()).mask(scalar, &model); assert_eq!(masked_model.vect.data.len(), vect_len); assert!(masked_model.is_valid()); let mask = mask_seed.derive_mask(vect_len, config.into()); let unmasked_model = Aggregation::from(masked_model).unmask(mask); let tolerance = Ratio::from_integer(config.exp_shift()).recip(); let expected_weight = Ratio::from_integer(BigInt::from(1)); assert!( unmasked_model .iter() .all(|unmasked_weight| { (unmasked_weight - &expected_weight).abs() <= tolerance }) ); } } }; ($suffix:ident, $group:ty, $data:ty, $len:expr $(,)?) => { test_masking_scalar!($suffix, $group, $data, 0, $len); }; } test_masking_scalar!(int_f32_b0, Integer, f32, 1, 10); test_masking_scalar!(int_f32_b2, Integer, f32, 100, 10); test_masking_scalar!(int_f32_b4, Integer, f32, 10_000, 10); test_masking_scalar!(int_f32_b6, Integer, f32, 1_000_000, 10); test_masking_scalar!(int_f32_bmax, Integer, f32, 10); test_masking_scalar!(prime_f32_b0, Prime, f32, 1, 10); test_masking_scalar!(prime_f32_b2, Prime, f32, 100, 10); test_masking_scalar!(prime_f32_b4, Prime, f32, 10_000, 10); test_masking_scalar!(prime_f32_b6, Prime, f32, 1_000_000, 10); test_masking_scalar!(prime_f32_bmax, Prime, f32, 10); test_masking_scalar!(pow_f32_b0, Power2, f32, 1, 10); test_masking_scalar!(pow_f32_b2, Power2, f32, 100, 10); test_masking_scalar!(pow_f32_b4, Power2, f32, 10_000, 10); test_masking_scalar!(pow_f32_b6, Power2, f32, 1_000_000, 10); test_masking_scalar!(pow_f32_bmax, Power2, f32, 10); test_masking_scalar!(int_f64_b0, Integer, f64, 1, 10); test_masking_scalar!(int_f64_b2, Integer, f64, 100, 10); test_masking_scalar!(int_f64_b4, Integer, f64, 10_000, 10); test_masking_scalar!(int_f64_b6, Integer, f64, 1_000_000, 10); test_masking_scalar!(int_f64_bmax, Integer, f64, 10); test_masking_scalar!(prime_f64_b0, Prime, f64, 1, 10); test_masking_scalar!(prime_f64_b2, Prime, f64, 100, 10); test_masking_scalar!(prime_f64_b4, Prime, f64, 10_000, 10); test_masking_scalar!(prime_f64_b6, Prime, f64, 1_000_000, 10); test_masking_scalar!(prime_f64_bmax, Prime, f64, 10); test_masking_scalar!(pow_f64_b0, Power2, f64, 1, 10); test_masking_scalar!(pow_f64_b2, Power2, f64, 100, 10); test_masking_scalar!(pow_f64_b4, Power2, f64, 10_000, 10); test_masking_scalar!(pow_f64_b6, Power2, f64, 1_000_000, 10); test_masking_scalar!(pow_f64_bmax, Power2, f64, 10); /// Generate tests for aggregation of multiple masked models: /// - generate random integers from a uniform distribution with a seeded PRNG /// - create a masked model from the integers and aggregate it to the aggregated masked models /// - check that all integers belong to the chosen finite group /// /// The arguments to the macro are: /// - a suffix for the test name /// - the group type of the model (variants of `GroupType`) /// - the data type of the model (variants of `DataType`) /// - the bound type of the model (variants of `BoundType`) /// - the number of integers per masked model /// - the number of masked models macro_rules! test_aggregation { ($suffix:ident, $group:ty, $data:ty, $bound:expr, $len:expr, $count:expr $(,)?) => { paste::item! { #[test] fn []() { // Step 1: Build the masking config let config = MaskConfig { group_type: $group, data_type: $data, bound_type: $bound, model_type: M3, }; let vect_len = $len as usize; // Step 2: generate random masked models let mut prng = ChaCha20Rng::from_seed(MaskSeed::generate().as_array()); let mut masked_models = iter::repeat_with(move || { let order = config.order(); let integer = generate_integer(&mut prng, &order); let integers = iter::repeat_with(|| generate_integer(&mut prng, &order)) .take(vect_len) .collect::>(); MaskObject::new(config.into(), integers, integer).unwrap() }); // Step 3 (actual test): // a. aggregate the masked models // b. check the aggregated masked model let mut aggregated_masked_model = Aggregation::new(config.into(), vect_len); for nb in 1..$count as usize + 1 { let masked_model = masked_models.next().unwrap(); assert!( aggregated_masked_model.validate_aggregation(&masked_model).is_ok() ); aggregated_masked_model.aggregate(masked_model); assert_eq!(aggregated_masked_model.nb_models, nb); assert_eq!(aggregated_masked_model.object.vect.data.len(), vect_len); assert_eq!(aggregated_masked_model.object.vect.config, config); assert_eq!(aggregated_masked_model.object.unit.config, config); assert!(aggregated_masked_model.object.is_valid()); } } } }; } test_aggregation!(int_f32_b0, Integer, F32, B0, 10, 5); test_aggregation!(int_f32_b2, Integer, F32, B2, 10, 5); test_aggregation!(int_f32_b4, Integer, F32, B4, 10, 5); test_aggregation!(int_f32_b6, Integer, F32, B6, 10, 5); test_aggregation!(int_f32_bmax, Integer, F32, Bmax, 10, 5); test_aggregation!(prime_f32_b0, Prime, F32, B0, 10, 5); test_aggregation!(prime_f32_b2, Prime, F32, B2, 10, 5); test_aggregation!(prime_f32_b4, Prime, F32, B4, 10, 5); test_aggregation!(prime_f32_b6, Prime, F32, B6, 10, 5); test_aggregation!(prime_f32_bmax, Prime, F32, Bmax, 10, 5); test_aggregation!(pow_f32_b0, Power2, F32, B0, 10, 5); test_aggregation!(pow_f32_b2, Power2, F32, B2, 10, 5); test_aggregation!(pow_f32_b4, Power2, F32, B4, 10, 5); test_aggregation!(pow_f32_b6, Power2, F32, B6, 10, 5); test_aggregation!(pow_f32_bmax, Power2, F32, Bmax, 10, 5); test_aggregation!(int_f64_b0, Integer, F64, B0, 10, 5); test_aggregation!(int_f64_b2, Integer, F64, B2, 10, 5); test_aggregation!(int_f64_b4, Integer, F64, B4, 10, 5); test_aggregation!(int_f64_b6, Integer, F64, B6, 10, 5); test_aggregation!(int_f64_bmax, Integer, F64, Bmax, 10, 5); test_aggregation!(prime_f64_b0, Prime, F64, B0, 10, 5); test_aggregation!(prime_f64_b2, Prime, F64, B2, 10, 5); test_aggregation!(prime_f64_b4, Prime, F64, B4, 10, 5); test_aggregation!(prime_f64_b6, Prime, F64, B6, 10, 5); test_aggregation!(prime_f64_bmax, Prime, F64, Bmax, 10, 5); test_aggregation!(pow_f64_b0, Power2, F64, B0, 10, 5); test_aggregation!(pow_f64_b2, Power2, F64, B2, 10, 5); test_aggregation!(pow_f64_b4, Power2, F64, B4, 10, 5); test_aggregation!(pow_f64_b6, Power2, F64, B6, 10, 5); test_aggregation!(pow_f64_bmax, Power2, F64, Bmax, 10, 5); test_aggregation!(int_i32_b0, Integer, I32, B0, 10, 5); test_aggregation!(int_i32_b2, Integer, I32, B2, 10, 5); test_aggregation!(int_i32_b4, Integer, I32, B4, 10, 5); test_aggregation!(int_i32_b6, Integer, I32, B6, 10, 5); test_aggregation!(int_i32_bmax, Integer, I32, Bmax, 10, 5); test_aggregation!(prime_i32_b0, Prime, I32, B0, 10, 5); test_aggregation!(prime_i32_b2, Prime, I32, B2, 10, 5); test_aggregation!(prime_i32_b4, Prime, I32, B4, 10, 5); test_aggregation!(prime_i32_b6, Prime, I32, B6, 10, 5); test_aggregation!(prime_i32_bmax, Prime, I32, Bmax, 10, 5); test_aggregation!(pow_i32_b0, Power2, I32, B0, 10, 5); test_aggregation!(pow_i32_b2, Power2, I32, B2, 10, 5); test_aggregation!(pow_i32_b4, Power2, I32, B4, 10, 5); test_aggregation!(pow_i32_b6, Power2, I32, B6, 10, 5); test_aggregation!(pow_i32_bmax, Power2, I32, Bmax, 10, 5); test_aggregation!(int_i64_b0, Integer, I64, B0, 10, 5); test_aggregation!(int_i64_b2, Integer, I64, B2, 10, 5); test_aggregation!(int_i64_b4, Integer, I64, B4, 10, 5); test_aggregation!(int_i64_b6, Integer, I64, B6, 10, 5); test_aggregation!(int_i64_bmax, Integer, I64, Bmax, 10, 5); test_aggregation!(prime_i64_b0, Prime, I64, B0, 10, 5); test_aggregation!(prime_i64_b2, Prime, I64, B2, 10, 5); test_aggregation!(prime_i64_b4, Prime, I64, B4, 10, 5); test_aggregation!(prime_i64_b6, Prime, I64, B6, 10, 5); test_aggregation!(prime_i64_bmax, Prime, I64, Bmax, 10, 5); test_aggregation!(pow_i64_b0, Power2, I64, B0, 10, 5); test_aggregation!(pow_i64_b2, Power2, I64, B2, 10, 5); test_aggregation!(pow_i64_b4, Power2, I64, B4, 10, 5); test_aggregation!(pow_i64_b6, Power2, I64, B6, 10, 5); test_aggregation!(pow_i64_bmax, Power2, I64, Bmax, 10, 5); /// Generate tests for masking, aggregation and unmasking of multiple models: /// - generate random weights from a uniform distribution with a seeded PRNG /// - create a model from the weights, mask and aggregate it to the aggregated masked models /// - derive a mask from the mask seed and aggregate it to the aggregated masks /// - unmask the aggregated masked model /// - check that all aggregated unmasked weights are equal to the averaged original weights (up /// to a tolerance determined by the masking configuration) /// /// The arguments to the macro are: /// - a suffix for the test name /// - the group type of the model (variants of `GroupType`) /// - the data type of the model (either primitives or variants of `DataType`) /// - an absolute bound for the weights (optional, choices: 1, 100, 10_000, 1_000_000) /// - the number of weights per model /// - the number of models macro_rules! test_masking_and_aggregation { ($suffix:ident, $group:ty, $data:ty, $bound:expr, $len:expr, $count:expr $(,)?) => { paste::item! { #[test] fn []() { // Step 1: Build the masking config let config = MaskConfig { group_type: $group, data_type: paste::expr! { [<$data:upper>] }, bound_type: match $bound { 1 => B0, 100 => B2, 10_000 => B4, 1_000_000 => B6, _ => Bmax, }, model_type: M3, }; let vect_len = $len as usize; let model_count = $count as usize; // Step 2: Generate random models let bound = if $bound == 0 { paste::expr! { [<$data:lower>]::MAX / (2.1 as [<$data:lower>]) } } else { paste::expr! { $bound as [<$data:lower>] } }; let mut prng = ChaCha20Rng::from_seed(MaskSeed::generate().as_array()); let mut models = iter::repeat_with(move || { Model::from_primitives( Uniform::new_inclusive(-bound, bound) .sample_iter(&mut prng) .take(vect_len) ) .unwrap() }); // Step 3 (actual test): // a. average the model weights for later checks // b. mask the model // c. derive the mask corresponding to the seed used // d. aggregate the masked model resp. mask // e. repeat a-d, then unmask the model and check it against the averaged one let mut averaged_model = Model::from_primitives( iter::repeat(paste::expr! { 0 as [<$data:lower>] }).take(vect_len) ) .unwrap(); let mut aggregated_masked_model = Aggregation::new(config.into(), vect_len); let mut aggregated_mask = Aggregation::new(config.into(), vect_len); let scalar = Scalar::new(1, model_count); let scalar_ratio = &scalar.to_ratio(); for _ in 0..model_count { let model = models.next().unwrap(); averaged_model .iter_mut() .zip(model.iter()) .for_each(|(averaged_weight, weight)| { *averaged_weight += scalar_ratio * weight; }); let (mask_seed, masked_model) = Masker::new(config.into()).mask(scalar.clone(), &model); let mask = mask_seed.derive_mask(vect_len, config.into()); assert!( aggregated_masked_model.validate_aggregation(&masked_model).is_ok() ); aggregated_masked_model.aggregate(masked_model); assert!(aggregated_mask.validate_aggregation(&mask).is_ok()); aggregated_mask.aggregate(mask); } let mask = aggregated_mask.into(); assert!(aggregated_masked_model.validate_unmasking(&mask).is_ok()); let unmasked_model = aggregated_masked_model.unmask(mask); let tolerance = Ratio::from_integer(BigInt::from(model_count)) / Ratio::from_integer(config.exp_shift()); assert!( averaged_model.iter() .zip(unmasked_model.iter()) .all(|(averaged_weight, unmasked_weight)| { (averaged_weight - unmasked_weight).abs() <= tolerance }) ); } } }; ($suffix:ident, $group:ty, $data:ty, $len:expr, $count:expr $(,)?) => { test_masking_and_aggregation!($suffix, $group, $data, 0, $len, $count); }; } test_masking_and_aggregation!(int_f32_b0, Integer, f32, 1, 10, 5); test_masking_and_aggregation!(int_f32_b2, Integer, f32, 100, 10, 5); test_masking_and_aggregation!(int_f32_b4, Integer, f32, 10_000, 10, 5); test_masking_and_aggregation!(int_f32_b6, Integer, f32, 1_000_000, 10, 5); test_masking_and_aggregation!(int_f32_bmax, Integer, f32, 10, 5); test_masking_and_aggregation!(prime_f32_b0, Prime, f32, 1, 10, 5); test_masking_and_aggregation!(prime_f32_b2, Prime, f32, 100, 10, 5); test_masking_and_aggregation!(prime_f32_b4, Prime, f32, 10_000, 10, 5); test_masking_and_aggregation!(prime_f32_b6, Prime, f32, 1_000_000, 10, 5); test_masking_and_aggregation!(prime_f32_bmax, Prime, f32, 10, 5); test_masking_and_aggregation!(pow_f32_b0, Power2, f32, 1, 10, 5); test_masking_and_aggregation!(pow_f32_b2, Power2, f32, 100, 10, 5); test_masking_and_aggregation!(pow_f32_b4, Power2, f32, 10_000, 10, 5); test_masking_and_aggregation!(pow_f32_b6, Power2, f32, 1_000_000, 10, 5); test_masking_and_aggregation!(pow_f32_bmax, Power2, f32, 10, 5); test_masking_and_aggregation!(int_f64_b0, Integer, f64, 1, 10, 5); test_masking_and_aggregation!(int_f64_b2, Integer, f64, 100, 10, 5); test_masking_and_aggregation!(int_f64_b4, Integer, f64, 10_000, 10, 5); test_masking_and_aggregation!(int_f64_b6, Integer, f64, 1_000_000, 10, 5); test_masking_and_aggregation!(int_f64_bmax, Integer, f64, 10, 5); test_masking_and_aggregation!(prime_f64_b0, Prime, f64, 1, 10, 5); test_masking_and_aggregation!(prime_f64_b2, Prime, f64, 100, 10, 5); test_masking_and_aggregation!(prime_f64_b4, Prime, f64, 10_000, 10, 5); test_masking_and_aggregation!(prime_f64_b6, Prime, f64, 1_000_000, 10, 5); test_masking_and_aggregation!(prime_f64_bmax, Prime, f64, 10, 5); test_masking_and_aggregation!(pow_f64_b0, Power2, f64, 1, 10, 5); test_masking_and_aggregation!(pow_f64_b2, Power2, f64, 100, 10, 5); test_masking_and_aggregation!(pow_f64_b4, Power2, f64, 10_000, 10, 5); test_masking_and_aggregation!(pow_f64_b6, Power2, f64, 1_000_000, 10, 5); test_masking_and_aggregation!(pow_f64_bmax, Power2, f64, 10, 5); test_masking_and_aggregation!(int_i32_b0, Integer, i32, 1, 10, 5); test_masking_and_aggregation!(int_i32_b2, Integer, i32, 100, 10, 5); test_masking_and_aggregation!(int_i32_b4, Integer, i32, 10_000, 10, 5); test_masking_and_aggregation!(int_i32_b6, Integer, i32, 1_000_000, 10, 5); test_masking_and_aggregation!(int_i32_bmax, Integer, i32, 10, 5); test_masking_and_aggregation!(prime_i32_b0, Prime, i32, 1, 10, 5); test_masking_and_aggregation!(prime_i32_b2, Prime, i32, 100, 10, 5); test_masking_and_aggregation!(prime_i32_b4, Prime, i32, 10_000, 10, 5); test_masking_and_aggregation!(prime_i32_b6, Prime, i32, 1_000_000, 10, 5); test_masking_and_aggregation!(prime_i32_bmax, Prime, i32, 10, 5); test_masking_and_aggregation!(pow_i32_b0, Power2, i32, 1, 10, 5); test_masking_and_aggregation!(pow_i32_b2, Power2, i32, 100, 10, 5); test_masking_and_aggregation!(pow_i32_b4, Power2, i32, 10_000, 10, 5); test_masking_and_aggregation!(pow_i32_b6, Power2, i32, 1_000_000, 10, 5); test_masking_and_aggregation!(pow_i32_bmax, Power2, i32, 10, 5); test_masking_and_aggregation!(int_i64_b0, Integer, i64, 1, 10, 5); test_masking_and_aggregation!(int_i64_b2, Integer, i64, 100, 10, 5); test_masking_and_aggregation!(int_i64_b4, Integer, i64, 10_000, 10, 5); test_masking_and_aggregation!(int_i64_b6, Integer, i64, 1_000_000, 10, 5); test_masking_and_aggregation!(int_i64_bmax, Integer, i64, 10, 5); test_masking_and_aggregation!(prime_i64_b0, Prime, i64, 1, 10, 5); test_masking_and_aggregation!(prime_i64_b2, Prime, i64, 100, 10, 5); test_masking_and_aggregation!(prime_i64_b4, Prime, i64, 10_000, 10, 5); test_masking_and_aggregation!(prime_i64_b6, Prime, i64, 1_000_000, 10, 5); test_masking_and_aggregation!(prime_i64_bmax, Prime, i64, 10, 5); test_masking_and_aggregation!(pow_i64_b0, Power2, i64, 1, 10, 5); test_masking_and_aggregation!(pow_i64_b2, Power2, i64, 100, 10, 5); test_masking_and_aggregation!(pow_i64_b4, Power2, i64, 10_000, 10, 5); test_masking_and_aggregation!(pow_i64_b6, Power2, i64, 1_000_000, 10, 5); test_masking_and_aggregation!(pow_i64_bmax, Power2, i64, 10, 5); /// Generate tests for masking, aggregation and unmasking of multiple models: /// - generate random scalars from a uniform distribution with a seeded PRNG /// - scale a model of unit weights, mask and aggregate it to the aggregated masked models /// - derive a mask from the mask seed and aggregate it to the aggregated masks /// - unmask the aggregated masked model /// - check that all aggregated unmasked weights are equal to the original unit weights (up /// to a tolerance determined by the masking configuration) /// /// The arguments to the macro are: /// - a suffix for the test name /// - the group type of the model and scalar (variants of `GroupType`) /// - the data type of the model and scalar (either float primitives or float variants of /// `DataType`) /// - an absolute bound for the scalar (optional, choices: 1, 100, 10_000, 1_000_000) /// - the number of weights per model /// - the number of models macro_rules! test_masking_and_aggregation_scalar { ($suffix:ident, $group:ty, $data:ty, $bound:expr, $len:expr, $count:expr $(,)?) => { paste::item! { #[test] fn []() { // Step 1: Build the masking config let config = MaskConfig { group_type: $group, data_type: paste::expr! { [<$data:upper>] }, bound_type: match $bound { 1 => B0, 100 => B2, 10_000 => B4, 1_000_000 => B6, _ => Bmax, }, model_type: M3, }; let vect_len = $len as usize; let model_count = $count as usize; // Step 2: Generate random scalars // take vectors [1, ..., 1] as models to scale let bound = if $bound == 0 { paste::expr! { [<$data:lower>]::MAX / (2 as [<$data:lower>]) } } else { paste::expr! { $bound as [<$data:lower>] } }; let eps = [<$data:lower>]::EPSILON; let mut prng = ChaCha20Rng::from_seed(MaskSeed::generate().as_array()); let mut scalars = iter::repeat_with(move || { let random_weight = Uniform::new_inclusive(eps, bound).sample(&mut prng); Scalar::from_primitive(random_weight).unwrap() }); let mut models = iter::repeat(Model::from_primitives(iter::repeat(1).take(vect_len)).unwrap()); // Step 3 (actual test): // a. mask the model // b. derive the mask corresponding to the seed used // c. aggregate the masked model resp. mask // d. repeat a-c, unmask the model and check it against the expected [1, ..., 1] let mut aggregated_masked_model = Aggregation::new(config.into(), vect_len); let mut aggregated_mask = Aggregation::new(config.into(), vect_len); for _ in 0..model_count { let model = models.next().unwrap(); let scalar = scalars.next().unwrap(); let (mask_seed, masked_model) = Masker::new(config.into()).mask(scalar, &model); let mask = mask_seed.derive_mask(vect_len, config.into()); assert!( aggregated_masked_model.validate_aggregation(&masked_model).is_ok() ); aggregated_masked_model.aggregate(masked_model); assert!(aggregated_mask.validate_aggregation(&mask).is_ok()); aggregated_mask.aggregate(mask); } let mask = aggregated_mask.into(); assert!(aggregated_masked_model.validate_unmasking(&mask).is_ok()); let unmasked_model = aggregated_masked_model.unmask(mask); let tolerance = Ratio::from_integer(BigInt::from(model_count)) / Ratio::from_integer(config.exp_shift()); let expected_weight = Ratio::from_integer(BigInt::from(1)); assert!( unmasked_model .iter() .all(|unmasked_weight| { (unmasked_weight - &expected_weight).abs() <= tolerance }) ); } } }; ($suffix:ident, $group:ty, $data:ty, $len:expr, $count:expr $(,)?) => { test_masking_and_aggregation_scalar!($suffix, $group, $data, 0, $len, $count); }; } test_masking_and_aggregation_scalar!(int_f32_b0, Integer, f32, 1, 10, 5); test_masking_and_aggregation_scalar!(int_f32_b2, Integer, f32, 100, 10, 5); test_masking_and_aggregation_scalar!(int_f32_b4, Integer, f32, 10_000, 10, 5); test_masking_and_aggregation_scalar!(int_f32_b6, Integer, f32, 1_000_000, 10, 5); test_masking_and_aggregation_scalar!(int_f32_bmax, Integer, f32, 10, 2); test_masking_and_aggregation_scalar!(prime_f32_b0, Prime, f32, 1, 10, 5); test_masking_and_aggregation_scalar!(prime_f32_b2, Prime, f32, 100, 10, 5); test_masking_and_aggregation_scalar!(prime_f32_b4, Prime, f32, 10_000, 10, 5); test_masking_and_aggregation_scalar!(prime_f32_b6, Prime, f32, 1_000_000, 10, 5); test_masking_and_aggregation_scalar!(prime_f32_bmax, Prime, f32, 10, 2); test_masking_and_aggregation_scalar!(pow_f32_b0, Power2, f32, 1, 10, 5); test_masking_and_aggregation_scalar!(pow_f32_b2, Power2, f32, 100, 10, 5); test_masking_and_aggregation_scalar!(pow_f32_b4, Power2, f32, 10_000, 10, 5); test_masking_and_aggregation_scalar!(pow_f32_b6, Power2, f32, 1_000_000, 10, 5); test_masking_and_aggregation_scalar!(pow_f32_bmax, Power2, f32, 10, 2); test_masking_and_aggregation_scalar!(int_f64_b0, Integer, f64, 1, 10, 2); test_masking_and_aggregation_scalar!(int_f64_b2, Integer, f64, 100, 10, 2); test_masking_and_aggregation_scalar!(int_f64_b4, Integer, f64, 10_000, 10, 2); test_masking_and_aggregation_scalar!(int_f64_b6, Integer, f64, 1_000_000, 10, 2); test_masking_and_aggregation_scalar!(int_f64_bmax, Integer, f64, 10, 2); test_masking_and_aggregation_scalar!(prime_f64_b0, Prime, f64, 1, 10, 2); test_masking_and_aggregation_scalar!(prime_f64_b2, Prime, f64, 100, 10, 2); test_masking_and_aggregation_scalar!(prime_f64_b4, Prime, f64, 10_000, 10, 2); test_masking_and_aggregation_scalar!(prime_f64_b6, Prime, f64, 1_000_000, 10, 2); test_masking_and_aggregation_scalar!(prime_f64_bmax, Prime, f64, 10, 2); test_masking_and_aggregation_scalar!(pow_f64_b0, Power2, f64, 1, 10, 2); test_masking_and_aggregation_scalar!(pow_f64_b2, Power2, f64, 100, 10, 2); test_masking_and_aggregation_scalar!(pow_f64_b4, Power2, f64, 10_000, 10, 2); test_masking_and_aggregation_scalar!(pow_f64_b6, Power2, f64, 1_000_000, 10, 2); test_masking_and_aggregation_scalar!(pow_f64_bmax, Power2, f64, 10, 2); } ================================================ FILE: rust/xaynet-core/src/mask/mod.rs ================================================ //! Masking, aggregation and unmasking of models. //! //! # Models //! A [`Model`] is a collection of weights/parameters which are represented as finite numerical //! values (i.e. rational numbers) of arbitrary precision. As such, a model in itself is not bound //! to any particular primitive data type, but it can be created from those and converted back into //! them. //! //! Currently, the primitive data types [`f32`], [`f64`], [`i32`] and [`i64`] are supported and //! this might be extended in the future. //! //! ``` //! # use xaynet_core::mask::{FromPrimitives, IntoPrimitives, Model}; //! let weights = vec![0_f32; 10]; //! let model = Model::from_primitives_bounded(weights.into_iter()); //! assert_eq!( //! model.into_primitives_unchecked().collect::>(), //! vec![0_f32; 10], //! ); //! ``` //! //! # Masking configurations //! The masking, aggregation and unmasking of models requires certain information about the models //! to guarantee that no information is lost during the process, which is configured via the //! [`MaskConfig`]. Each masking configuration consists of the group type, data type, bound type and //! model type. Usually, a masking configuration is decided on and configured depending on the //! specific machine learning use case as part of the setup for the XayNet federated learning //! platform. //! //! Currently, those choices are catalogued for certain fixed variants for each type, but we aim //! to generalize this in the future to more flexible masking configurations to allow for a more //! fine-grained tradeoff between representability and performance. //! //! ## Group type //! The [`GroupType`] describes the order of the finite group in which the masked model weights are //! embedded. The smaller the gap between the maximum possible embedded weights and the group order //! is, the less theoretically possible information flow about the masks may be observed. Specific //! group orders provide potentially higher performance on the other hand, which always makes this //! a tradeoff between security and performance. The group type variants are: //! - Integer: no gap but potentially slowest performance. //! - Prime: usually small gap with higher performance. //! - Power2: usually higher gap with potentially highest performance. //! //! ## Data type //! The [`DataType`] describes the original primitive data type of the model weights. This in //! combination with the bound type influences the preserved decimal places of the model weights //! during the masking, aggregation and unmasking process, which are: //! - F32: 10 decimal places for bounded model weights and 45 decimal places for unbounded. //! - F64: 20 decimal places for bounded model weights and 324 decimal places for unbounded. //! - I32 and I64: 10 decimal places (required for scaled aggregation). //! //! Currently the primitive data types [`f32`], [`f64`], [`i32`] and [`i64`] are supported via the //! data type variants. //! //! ## Bound type //! The [`BoundType`] describes the absolute bounds on all model weights. The smaller the bounds of //! the model weights, the less bytes are required to represent the masked model weights. These //! bounds are enforced on the model weights before masking them to prevent information loss during //! the masking, aggregation and unmasking process. The bound type variants are: //! - B0: all model weights are absolutely bounded by 1. //! - B2: all model weights are absolutely bounded by 100. //! - B4: all model weights are absolutely bounded by 10,000. //! - B6: all model weights are absolutely bounded by 1,000,000. //! - Bmax: all model weights are absolutely bounded by their primitive data type's absolute //! maximum value. //! //! ## Model type //! The [`ModelType`] describes the maximum number of masked models that can be aggregated without //! information loss. The smaller the number of masked models, the less bytes are required to //! represent masked model weights. The model type variants are: //! - M3: at most 1,000 masked models may be aggregated. //! - M6: at most 1,000,000 masked models may be aggregated. //! - M9: at most 1,000,000,000 masked models may be aggregated. //! - M12: at most 1,000,000,000,000 masked models may be aggregated. //! //! # Masking, aggregation and unmasking //! Local models should be masked (i.e. encrypted) before they are communicated somewhere else to //! protect the possibly sensitive information learned from local data. The masking should allow //! for masked models to be aggregated while they are still masked (i.e. homomorphic encryption). //! Then the aggregated masked model can safely be unmasked without jeopardizing the secrecy of //! personal information if the model is generalized enough. //! //! ## Masking //! A [`Model`] can be masked with a [`Masker`], which requires a [`MaskConfig`]. During the //! masking, the model weights are scaled, then embedded as elements of the chosen finite group and //! finally masked by randomly generated elements from that very same finite group. The scalar //! provides the necessary means to perform different aggregation strategies, for example federated //! averaging. The masked model is returned as a [`MaskObject`] and the mask used to mask the model //! can be generated via the additionally returned [`MaskSeed`]. //! //! ``` //! # use xaynet_core::mask::{BoundType, DataType, FromPrimitives, GroupType, MaskConfig, Masker, Model, ModelType, Scalar}; //! // create local models and a fitting masking configuration //! let number_weights = 10; //! let scalar = Scalar::new(1, 2_u8); //! let local_model_1 = Model::from_primitives_bounded(vec![0_f32; number_weights].into_iter()); //! let local_model_2 = Model::from_primitives_bounded(vec![1_f32; number_weights].into_iter()); //! let config = MaskConfig { //! group_type: GroupType::Prime, //! data_type: DataType::F32, //! bound_type: BoundType::B0, //! model_type: ModelType::M3, //! }; //! //! // mask the local models //! let (local_mask_seed_1, masked_local_model_1) = Masker::new(config.into()).mask(scalar.clone(), &local_model_1); //! let (local_mask_seed_2, masked_local_model_2) = Masker::new(config.into()).mask(scalar, &local_model_2); //! //! // derive the masks of the local masked models //! let local_mask_1 = local_mask_seed_1.derive_mask(number_weights, config.into()); //! let local_mask_2 = local_mask_seed_2.derive_mask(number_weights, config.into()); //! ``` //! //! ## Aggregation //! Masked models can be aggregated via an [`Aggregation`]. Masks themselves can be aggregated via //! an [`Aggregation`] as well. An aggregated masked model can only be unmasked by the aggregation //! of masks for each model. Aggregation should always be validated beforehand so that it may be //! safely performed wrt the chosen masking configuration without possible loss of information. //! //! ``` //! # use xaynet_core::mask::{Aggregation, BoundType, DataType, FromPrimitives, GroupType, MaskConfig, Masker, MaskObject, Model, ModelType, Scalar}; //! # let number_weights = 10; //! # let scalar = Scalar::new(1, 2_u8); //! # let local_model_1 = Model::from_primitives_bounded(vec![0_f32; number_weights].into_iter()); //! # let local_model_2 = Model::from_primitives_bounded(vec![1_f32; number_weights].into_iter()); //! # let config = MaskConfig { group_type: GroupType::Prime, data_type: DataType::F32, bound_type: BoundType::B0, model_type: ModelType::M3}; //! # let (local_mask_seed_1, masked_local_model_1) = Masker::new(config.into()).mask(scalar.clone(), &local_model_1); //! # let (local_mask_seed_2, masked_local_model_2) = Masker::new(config.into()).mask(scalar, &local_model_2); //! # let local_model_mask_1 = local_mask_seed_1.derive_mask(number_weights, config.into()); //! # let local_model_mask_2 = local_mask_seed_2.derive_mask(number_weights, config.into()); //! // aggregate the local model masks (similarly for local scalar masks) //! let mut mask_aggregator = Aggregation::new(config.into(), number_weights); //! if let Ok(_) = mask_aggregator.validate_aggregation(&local_model_mask_1) { //! mask_aggregator.aggregate(local_model_mask_1); //! }; //! if let Ok(_) = mask_aggregator.validate_aggregation(&local_model_mask_2) { //! mask_aggregator.aggregate(local_model_mask_2); //! }; //! let global_mask: MaskObject = mask_aggregator.into(); //! //! // aggregate the local masked models //! let mut model_aggregator = Aggregation::new(config.into(), number_weights); //! if let Ok(_) = model_aggregator.validate_aggregation(&masked_local_model_1) { //! model_aggregator.aggregate(masked_local_model_1); //! }; //! if let Ok(_) = model_aggregator.validate_aggregation(&masked_local_model_2) { //! model_aggregator.aggregate(masked_local_model_2); //! }; //! ``` //! //! ## Unmasking //! A masked model can be unmasked by the corresponding mask via an [`Aggregation`]. Unmasking //! should always be validated beforehand so that it may be safely performed wrt the chosen mask //! configuration without possible loss of information. //! //! ```no_run //! # use xaynet_core::mask::{Aggregation, BoundType, DataType, FromPrimitives, GroupType, MaskConfig, Masker, MaskObject, Model, ModelType, Scalar}; //! # let number_weights = 10; //! # let scalar = Scalar::new(1, 2_u8); //! # let local_model_1 = Model::from_primitives_bounded(vec![0_f32; number_weights].into_iter()); //! # let local_model_2 = Model::from_primitives_bounded(vec![1_f32; number_weights].into_iter()); //! # let config = MaskConfig { group_type: GroupType::Prime, data_type: DataType::F32, bound_type: BoundType::B0, model_type: ModelType::M3}; //! # let (local_mask_seed_1, masked_local_model_1) = Masker::new(config.into()).mask(scalar.clone(), &local_model_1); //! # let (local_mask_seed_2, masked_local_model_2) = Masker::new(config.into()).mask(scalar, &local_model_2); //! # let local_model_mask_1 = local_mask_seed_1.derive_mask(number_weights, config.into()); //! # let local_model_mask_2 = local_mask_seed_2.derive_mask(number_weights, config.into()); //! # let mut mask_aggregator = Aggregation::new(config.into(), number_weights); //! # if let Ok(_) = mask_aggregator.validate_aggregation(&local_model_mask_1) { mask_aggregator.aggregate(local_model_mask_1); }; //! # if let Ok(_) = mask_aggregator.validate_aggregation(&local_model_mask_2) { mask_aggregator.aggregate(local_model_mask_2); }; //! # let global_mask: MaskObject = mask_aggregator.into(); //! # let mut model_aggregator = Aggregation::new(config.into(), number_weights); //! # if let Ok(_) = model_aggregator.validate_aggregation(&masked_local_model_1) { model_aggregator.aggregate(masked_local_model_1); }; //! # if let Ok(_) = model_aggregator.validate_aggregation(&masked_local_model_2) { model_aggregator.aggregate(masked_local_model_2); }; //! // unmask the aggregated masked model with the aggregated mask //! if let Ok(_) = model_aggregator.validate_unmasking(&global_mask) { //! let global_model = model_aggregator.unmask(global_mask); //! assert_eq!( //! global_model, //! Model::from_primitives_bounded(vec![0.5_f32; number_weights].into_iter()), //! ); //! }; //! ``` pub(crate) mod config; pub(crate) mod masking; pub(crate) mod model; pub(crate) mod object; pub(crate) mod scalar; pub(crate) mod seed; pub use self::{ config::{ serialization::MaskConfigBuffer, BoundType, DataType, GroupType, InvalidMaskConfigError, MaskConfig, MaskConfigPair, ModelType, }, masking::{Aggregation, AggregationError, Masker, UnmaskingError}, model::{FromPrimitives, IntoPrimitives, Model, ModelCastError, PrimitiveCastError}, object::{ serialization::vect::MaskVectBuffer, InvalidMaskObjectError, MaskObject, MaskUnit, MaskVect, }, scalar::{FromPrimitive, IntoPrimitive, Scalar, ScalarCastError}, seed::{EncryptedMaskSeed, MaskSeed}, }; ================================================ FILE: rust/xaynet-core/src/mask/model.rs ================================================ //! Model representation and conversion. //! //! See the [mask module] documentation since this is a private module anyways. //! //! [mask module]: crate::mask use std::{ fmt::Debug, iter::{FromIterator, IntoIterator}, slice::{Iter, IterMut}, }; use derive_more::{Display, From, Index, IndexMut, Into}; use num::{ bigint::BigInt, clamp, rational::Ratio, traits::{float::FloatCore, identities::Zero, ToPrimitive}, }; use serde::{Deserialize, Serialize}; use thiserror::Error; #[derive(Debug, Clone, PartialEq, Hash, From, Index, IndexMut, Into, Serialize, Deserialize)] /// A numerical representation of a machine learning model. pub struct Model(Vec>); impl std::convert::AsRef for Model { fn as_ref(&self) -> &Model { self } } #[allow(clippy::len_without_is_empty)] impl Model { /// Gets the number of weights/parameters of this model. pub fn len(&self) -> usize { self.0.len() } /// Creates an iterator that yields references to the weights/parameters of this model. pub fn iter(&self) -> Iter> { self.0.iter() } /// Creates an iterator that yields mutable references to the weights/parameters of this model. pub fn iter_mut(&mut self) -> IterMut> { self.0.iter_mut() } } impl FromIterator> for Model { fn from_iter>>(iter: I) -> Self { let data: Vec> = iter.into_iter().collect(); Model(data) } } impl IntoIterator for Model { type Item = Ratio; type IntoIter = std::vec::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() } } #[derive(Debug, Display)] /// A primitive data type as a target for model conversion. pub(crate) enum PrimitiveType { F32, F64, I32, I64, } #[derive(Error, Debug)] #[error("Could not convert weight {weight} to primitive type {target}")] /// Errors related to model conversion into primitives. pub struct ModelCastError { weight: Ratio, target: PrimitiveType, } #[derive(Clone, Error, Debug)] #[error("Could not convert primitive type {0:?} to weight")] /// Errors related to weight conversion from primitives. pub struct PrimitiveCastError(pub(crate) P); /// An interface to convert a collection of numerical values into an iterator of primitive values. /// /// This trait is used to convert a [`Model`], which has its own internal representation of the /// weights, into primitive types ([`f32`], [`f64`], [`i32`], [`i64`]). The opposite trait is /// [`FromPrimitives`]. pub trait IntoPrimitives: Sized { /// Creates an iterator from numerical values that yields converted primitive values. /// /// # Errors /// Yields an error for each numerical value that can't be converted into a primitive value. fn into_primitives(self) -> Box>>; /// Creates an iterator from numerical values that yields converted primitive values. /// /// # Errors /// Yields an error for each numerical value that can't be converted into a primitive value. fn to_primitives(&self) -> Box>>; /// Consume this model and into an iterator that yields `P` values. /// /// # Panics /// Panics if a numerical value can't be converted into a primitive value. fn into_primitives_unchecked(self) -> Box> { Box::new( self.into_primitives() .map(|res| res.expect("conversion to primitive type failed")), ) } } /// An interface to convert a collection of primitive values into an iterator of numerical values. /// /// This trait is used to convert primitive types ([`f32`], [`f64`], [`i32`], [`i64`]) into a /// [`Model`], which has its own internal representation of the weights. The opposite trait is /// [`IntoPrimitives`]. pub trait FromPrimitives: Sized { /// Creates an iterator from primitive values that yields converted numerical values. /// /// # Errors /// Yields an error for the first encountered primitive value that can't be converted into a /// numerical value due to not being finite. fn from_primitives>(iter: I) -> Result>; /// Creates an iterator from primitive values that yields converted numerical values. /// /// If a primitive value cannot be directly converted into a numerical value due to not being /// finite, it is clamped. fn from_primitives_bounded>(iter: I) -> Self; } impl IntoPrimitives for Model { fn into_primitives(self) -> Box>> { Box::new(self.0.into_iter().map(|i| { i.to_integer().to_i32().ok_or(ModelCastError { weight: i, target: PrimitiveType::I32, }) })) } fn to_primitives(&self) -> Box>> { let vec = self.0.clone(); Box::new(vec.into_iter().map(|i| { i.to_integer().to_i32().ok_or(ModelCastError { weight: i, target: PrimitiveType::I32, }) })) } } impl FromPrimitives for Model { fn from_primitives>(iter: I) -> Result> { Ok(iter.map(|p| Ratio::from_integer(BigInt::from(p))).collect()) } fn from_primitives_bounded>(iter: I) -> Self { Self::from_primitives(iter).unwrap() } } impl IntoPrimitives for Model { fn into_primitives(self) -> Box>> { Box::new(self.0.into_iter().map(|i| { i.to_integer().to_i64().ok_or(ModelCastError { weight: i, target: PrimitiveType::I64, }) })) } fn to_primitives(&self) -> Box>> { let vec = self.0.clone(); Box::new(vec.into_iter().map(|i| { i.to_integer().to_i64().ok_or(ModelCastError { weight: i, target: PrimitiveType::I64, }) })) } } impl FromPrimitives for Model { fn from_primitives>(iter: I) -> Result> { Ok(iter.map(|p| Ratio::from_integer(BigInt::from(p))).collect()) } fn from_primitives_bounded>(iter: I) -> Self { Self::from_primitives(iter).unwrap() } } impl IntoPrimitives for Model { fn into_primitives(self) -> Box>> { let iter = self.0.into_iter().map(|r| { ratio_to_float::(&r).ok_or(ModelCastError { weight: r, target: PrimitiveType::F32, }) }); Box::new(iter) } fn to_primitives(&self) -> Box>> { let vec = self.0.clone(); let iter = vec.into_iter().map(|r| { ratio_to_float::(&r).ok_or(ModelCastError { weight: r, target: PrimitiveType::F32, }) }); Box::new(iter) } } impl FromPrimitives for Model { fn from_primitives>(iter: I) -> Result> { iter.map(|f| Ratio::from_float(f).ok_or(PrimitiveCastError(f))) .collect() } fn from_primitives_bounded>(iter: I) -> Self { iter.map(float_to_ratio_bounded::).collect() } } impl IntoPrimitives for Model { fn into_primitives(self) -> Box>> { let iter = self.0.into_iter().map(|r| { ratio_to_float::(&r).ok_or(ModelCastError { weight: r, target: PrimitiveType::F64, }) }); Box::new(iter) } fn to_primitives(&self) -> Box>> { let vec = self.0.clone(); let iter = vec.into_iter().map(|r| { ratio_to_float::(&r).ok_or(ModelCastError { weight: r, target: PrimitiveType::F64, }) }); Box::new(iter) } } impl FromPrimitives for Model { fn from_primitives>(iter: I) -> Result> { iter.map(|f| Ratio::from_float(f).ok_or(PrimitiveCastError(f))) .collect() } fn from_primitives_bounded>(iter: I) -> Self { iter.map(float_to_ratio_bounded::).collect() } } /// Converts a numerical value into a primitive floating point value. /// /// # Errors /// Fails if the numerical value is not representable in the primitive data type. pub(crate) fn ratio_to_float(ratio: &Ratio) -> Option { let min_value = Ratio::from_float(F::min_value()).unwrap(); let max_value = Ratio::from_float(F::max_value()).unwrap(); if ratio < &min_value || ratio > &max_value { return None; } let mut numer = ratio.numer().clone(); let mut denom = ratio.denom().clone(); // safe loop: terminates after at most bit-length of ratio iterations loop { if let (Some(n), Some(d)) = (F::from(numer.clone()), F::from(denom.clone())) { if n == F::zero() || d == F::zero() { break Some(F::zero()); } else { let float = n / d; if float.is_finite() { break Some(float); } } } else { numer >>= 1_usize; denom >>= 1_usize; } } } /// Converts the primitive floating point value into a numerical value. /// /// Maps positive/negative infinity to max/min of the primitive data type and NaN to zero. pub(crate) fn float_to_ratio_bounded(f: F) -> Ratio { if f.is_nan() { Ratio::::zero() } else { let finite_f = clamp(f, F::min_value(), F::max_value()); // safe unwrap: clamped weight is guaranteed to be finite Ratio::::from_float(finite_f).unwrap() } } #[cfg(test)] mod tests { use super::*; use std::iter; type R = Ratio; #[test] fn test_model_f32() { let expected_primitives = vec![-1_f32, 0_f32, 1_f32]; let expected_model = Model::from(vec![ R::from_float(-1_f32).unwrap(), R::zero(), R::from_float(1_f32).unwrap(), ]); let actual_model = Model::from_primitives(expected_primitives.iter().cloned()).unwrap(); assert_eq!(actual_model, expected_model); let actual_model = Model::from_primitives_bounded(expected_primitives.iter().cloned()); assert_eq!(actual_model, expected_model); let actual_primitives: Vec = expected_model.into_primitives_unchecked().collect(); assert_eq!(actual_primitives, expected_primitives); } #[test] fn test_model_f64() { let expected_primitives = vec![-1_f64, 0_f64, 1_f64]; let expected_model = Model::from(vec![ R::from_float(-1_f64).unwrap(), R::zero(), R::from_float(1_f64).unwrap(), ]); let actual_model = Model::from_primitives(expected_primitives.iter().cloned()).unwrap(); assert_eq!(actual_model, expected_model); let actual_model = Model::from_primitives_bounded(expected_primitives.iter().cloned()); assert_eq!(actual_model, expected_model); let actual_primitives: Vec = expected_model.into_primitives_unchecked().collect(); assert_eq!(actual_primitives, expected_primitives); } #[test] fn test_model_f32_from_weird_primitives() { // +infinity assert!(Model::from_primitives(iter::once(f32::INFINITY)).is_err()); assert_eq!( Model::from_primitives_bounded(iter::once(f32::INFINITY)), vec![R::from_float(f32::MAX).unwrap()].into() ); // -infinity assert!(Model::from_primitives(iter::once(f32::NEG_INFINITY)).is_err()); assert_eq!( Model::from_primitives_bounded(iter::once(f32::NEG_INFINITY)), vec![R::from_float(f32::MIN).unwrap()].into() ); // NaN assert!(Model::from_primitives(iter::once(f32::NAN)).is_err()); assert_eq!( Model::from_primitives_bounded(iter::once(f32::NAN)), vec![R::zero()].into() ); } #[test] fn test_model_f64_from_weird_primitives() { // +infinity assert!(Model::from_primitives(iter::once(f64::INFINITY)).is_err()); assert_eq!( Model::from_primitives_bounded(iter::once(f64::INFINITY)), vec![R::from_float(f64::MAX).unwrap()].into() ); // -infinity assert!(Model::from_primitives(iter::once(f64::NEG_INFINITY)).is_err()); assert_eq!( Model::from_primitives_bounded(iter::once(f64::NEG_INFINITY)), vec![R::from_float(f64::MIN).unwrap()].into() ); // NaN assert!(Model::from_primitives(iter::once(f64::NAN)).is_err()); assert_eq!( Model::from_primitives_bounded(iter::once(f64::NAN)), vec![R::zero()].into() ); } #[test] fn test_model_i32() { let expected_primitives = vec![-1_i32, 0_i32, 1_i32]; let expected_model = Model::from(vec![ R::from_integer(BigInt::from(-1_i32)), R::zero(), R::from_integer(BigInt::from(1_i32)), ]); let actual_model = Model::from_primitives(expected_primitives.iter().cloned()).unwrap(); assert_eq!(actual_model, expected_model); let actual_model = Model::from_primitives_bounded(expected_primitives.iter().cloned()); assert_eq!(actual_model, expected_model); let actual_primitives: Vec = expected_model.into_primitives_unchecked().collect(); assert_eq!(actual_primitives, expected_primitives); } #[test] fn test_model_i64() { let expected_primitives = vec![-1_i64, 0_i64, 1_i64]; let expected_model = Model::from(vec![ R::from_integer(BigInt::from(-1_i64)), R::zero(), R::from_integer(BigInt::from(1_i64)), ]); let actual_model = Model::from_primitives(expected_primitives.iter().cloned()).unwrap(); assert_eq!(actual_model, expected_model); let actual_model = Model::from_primitives_bounded(expected_primitives.iter().cloned()); assert_eq!(actual_model, expected_model); let actual_primitives: Vec = expected_model.into_primitives_unchecked().collect(); assert_eq!(actual_primitives, expected_primitives); } #[test] #[allow(clippy::float_cmp)] fn test_ratio_to_float() { let ratio = R::from_float(0_f32).unwrap(); assert_eq!(ratio_to_float::(&ratio).unwrap(), 0_f32); let ratio = R::from_float(0_f64).unwrap(); assert_eq!(ratio_to_float::(&ratio).unwrap(), 0_f64); let ratio = R::from_float(0.1_f32).unwrap(); assert_eq!(ratio_to_float::(&ratio).unwrap(), 0.1_f32); let ratio = R::from_float(0.1_f64).unwrap(); assert_eq!(ratio_to_float::(&ratio).unwrap(), 0.1_f64); let f32_max = R::from_float(f32::max_value()).unwrap(); let ratio = &f32_max * BigInt::from(10_usize) / (f32_max * BigInt::from(100_usize)); assert_eq!(ratio_to_float::(&ratio).unwrap(), 0.1_f32); let f64_max = R::from_float(f64::max_value()).unwrap(); let ratio = &f64_max * BigInt::from(10_usize) / (f64_max * BigInt::from(100_usize)); assert_eq!(ratio_to_float::(&ratio).unwrap(), 0.1_f64); } } ================================================ FILE: rust/xaynet-core/src/mask/object/mod.rs ================================================ //! Masked objects. //! //! See the [mask module] documentation since this is a private module anyways. //! //! [mask module]: crate::mask pub mod serialization; use std::iter::Iterator; use num::bigint::BigUint; use serde::{Deserialize, Serialize}; use thiserror::Error; use crate::mask::config::{MaskConfig, MaskConfigPair}; #[derive(Error, Debug)] #[error("the mask object is invalid: data is incompatible with the masking configuration")] /// Errors related to invalid mask objects. pub struct InvalidMaskObjectError; #[derive(Debug, Hash, PartialEq, Eq, Clone, Serialize, Deserialize)] /// A *mask vector* which represents a masked model or its corresponding mask. pub struct MaskVect { pub data: Vec, pub config: MaskConfig, } impl MaskVect { /// Creates a new mask vector from the given data and masking configuration. pub fn new_unchecked(config: MaskConfig, data: Vec) -> Self { Self { data, config } } /// Creates a new mask vector from the given data and masking configuration. /// /// # Errors /// Fails if the elements of the mask object don't conform to the given masking configuration. pub fn new(config: MaskConfig, data: Vec) -> Result { let obj = Self::new_unchecked(config, data); if obj.is_valid() { Ok(obj) } else { Err(InvalidMaskObjectError) } } /// Creates a new empty mask vector of given size and masking configuration. pub fn empty(config: MaskConfig, size: usize) -> Self { Self { data: Vec::with_capacity(size), config, } } /// Checks if the elements of this mask vector conform to the masking configuration. pub fn is_valid(&self) -> bool { let order = self.config.order(); self.data.iter().all(|i| i < &order) } } #[derive(Debug, Hash, PartialEq, Eq, Clone, Serialize, Deserialize)] /// A *mask unit* which represents a masked scalar or its corresponding mask. pub struct MaskUnit { pub data: BigUint, pub config: MaskConfig, } impl From<&MaskUnit> for MaskVect { fn from(mask_unit: &MaskUnit) -> Self { Self::new_unchecked(mask_unit.config, vec![mask_unit.data.clone()]) } } impl From for MaskVect { fn from(mask_unit: MaskUnit) -> Self { Self::new_unchecked(mask_unit.config, vec![mask_unit.data]) } } impl MaskUnit { /// Creates a new mask unit from the given mask and masking configuration. pub fn new_unchecked(config: MaskConfig, data: BigUint) -> Self { Self { data, config } } /// Creates a new mask unit from the given mask and masking configuration. /// /// # Errors /// Fails if the mask unit doesn't conform to the given masking configuration. pub fn new(config: MaskConfig, data: BigUint) -> Result { let obj = Self::new_unchecked(config, data); if obj.is_valid() { Ok(obj) } else { Err(InvalidMaskObjectError) } } /// Creates a new mask unit of given masking configuration with default value `1`. pub fn default(config: MaskConfig) -> Self { Self { data: BigUint::from(1_u8), config, } } /// Checks if the data value conforms to the masking configuration. pub fn is_valid(&self) -> bool { self.data < self.config.order() } } #[derive(Debug, Hash, PartialEq, Eq, Clone, Serialize, Deserialize)] /// A mask object consisting of a vector part and unit part. pub struct MaskObject { pub vect: MaskVect, pub unit: MaskUnit, } impl MaskObject { /// Creates a new mask object from the given vector and unit. pub fn new_unchecked(vect: MaskVect, unit: MaskUnit) -> Self { Self { vect, unit } } /// Creates a new mask object from the given vector, unit and masking configurations. pub fn new( config: MaskConfigPair, data_vect: Vec, data_unit: BigUint, ) -> Result { let vect = MaskVect::new(config.vect, data_vect)?; let unit = MaskUnit::new(config.unit, data_unit)?; Ok(Self { vect, unit }) } /// Creates a new empty mask object of given size and masking configurations. pub fn empty(config: MaskConfigPair, size: usize) -> Self { Self { vect: MaskVect::empty(config.vect, size), unit: MaskUnit::default(config.unit), } } /// Checks if this mask object conforms to the masking configurations. pub fn is_valid(&self) -> bool { self.vect.is_valid() && self.unit.is_valid() } } ================================================ FILE: rust/xaynet-core/src/mask/object/serialization/mod.rs ================================================ //! Serialization of masked objects. //! //! See the [mask module] documentation since this is a private module anyways. //! //! [mask module]: crate::mask pub(crate) mod unit; pub(crate) mod vect; use anyhow::Context; use crate::{ mask::object::{ serialization::{unit::MaskUnitBuffer, vect::MaskVectBuffer}, MaskObject, MaskUnit, MaskVect, }, message::{ traits::{FromBytes, ToBytes}, DecodeError, }, }; // target dependent maximum number of mask object elements #[cfg(target_pointer_width = "16")] const MAX_NB: u32 = u16::MAX as u32; /// A buffer for serialized mask objects. pub struct MaskObjectBuffer { inner: T, } impl> MaskObjectBuffer { /// Creates a new buffer from `bytes`. /// /// # Errors /// Fails if the `bytes` don't conform to the required buffer length for mask objects. pub fn new(bytes: T) -> Result { let buffer = Self { inner: bytes }; buffer .check_buffer_length() .context("not a valid mask object")?; Ok(buffer) } /// Creates a new buffer from `bytes`. pub fn new_unchecked(bytes: T) -> Self { Self { inner: bytes } } /// Checks if this buffer conforms to the required buffer length for mask objects. /// /// # Errors /// Fails if the buffer is too small. pub fn check_buffer_length(&self) -> Result<(), DecodeError> { let inner = self.inner.as_ref(); // check length of vect field MaskVectBuffer::new(inner).context("invalid vector field")?; // check length of unit field MaskUnitBuffer::new(&inner[self.unit_offset()..]).context("invalid unit field")?; Ok(()) } /// Gets the vector part. /// /// # Panics /// May panic if this buffer is unchecked. pub fn vect(&self) -> &[u8] { let len = self.unit_offset(); &self.inner.as_ref()[0..len] } /// Gets the offset of the unit field. pub fn unit_offset(&self) -> usize { let vect_buf = MaskVectBuffer::new_unchecked(self.inner.as_ref()); vect_buf.len() } /// Gets the unit part. /// /// # Panics /// May panic if this buffer is unchecked. pub fn unit(&self) -> &[u8] { let offset = self.unit_offset(); &self.inner.as_ref()[offset..] } /// Gets the expected number of bytes of this buffer. /// /// # Panics /// May panic if this buffer is unchecked. pub fn len(&self) -> usize { let unit_offset = self.unit_offset(); let unit_buf = MaskUnitBuffer::new_unchecked(&self.inner.as_ref()[unit_offset..]); unit_offset + unit_buf.len() } } impl + AsMut<[u8]>> MaskObjectBuffer { /// Gets the vector part. /// /// # Panics /// May panic if this buffer is unchecked. pub fn vect_mut(&mut self) -> &mut [u8] { self.inner.as_mut() } /// Gets the unit part. /// /// # Panics /// May panic if this buffer is unchecked. pub fn unit_mut(&mut self) -> &mut [u8] { let offset = self.unit_offset(); &mut self.inner.as_mut()[offset..] } } impl ToBytes for MaskObject { fn buffer_length(&self) -> usize { self.vect.buffer_length() + self.unit.buffer_length() } fn to_bytes + AsRef<[u8]>>(&self, buffer: &mut T) { let mut writer = MaskObjectBuffer::new_unchecked(buffer.as_mut()); self.vect.to_bytes(&mut writer.vect_mut()); self.unit.to_bytes(&mut writer.unit_mut()); } } impl FromBytes for MaskObject { fn from_byte_slice>(buffer: &T) -> Result { let reader = MaskObjectBuffer::new(buffer.as_ref())?; let vect = MaskVect::from_byte_slice(&reader.vect()).context("invalid vector part")?; let unit = MaskUnit::from_byte_slice(&reader.unit()).context("invalid unit part")?; Ok(Self { vect, unit }) } fn from_byte_stream + ExactSizeIterator>( iter: &mut I, ) -> Result { let vect = MaskVect::from_byte_stream(iter).context("invalid vector part")?; let unit = MaskUnit::from_byte_stream(iter).context("invalid unit part")?; Ok(Self { vect, unit }) } } #[cfg(test)] pub(crate) mod tests { use super::*; use crate::mask::{ config::{BoundType, DataType, GroupType, MaskConfig, ModelType}, object::serialization::{unit::tests::mask_unit, vect::tests::mask_vect}, MaskObject, }; pub fn mask_config() -> (MaskConfig, Vec) { // config.order() = 20_000_000_000_001 with this config, so the data // should be stored on 6 bytes. let config = MaskConfig { group_type: GroupType::Integer, data_type: DataType::I32, bound_type: BoundType::B0, model_type: ModelType::M3, }; let bytes = vec![0x00, 0x02, 0x00, 0x03]; (config, bytes) } pub fn mask_object() -> (MaskObject, Vec) { let (mask_vect, mask_vect_bytes) = mask_vect(); let (mask_unit, mask_unit_bytes) = mask_unit(); let obj = MaskObject::new_unchecked(mask_vect, mask_unit); let bytes = [mask_vect_bytes.as_slice(), mask_unit_bytes.as_slice()].concat(); (obj, bytes) } #[test] fn serialize_mask_object() { let (mask_object, expected) = mask_object(); let mut buf = vec![0xff; 42]; mask_object.to_bytes(&mut buf); assert_eq!(buf, expected); } #[test] fn deserialize_mask_object() { let (expected, bytes) = mask_object(); assert_eq!(MaskObject::from_byte_slice(&&bytes[..]).unwrap(), expected); } #[test] fn deserialize_mask_object_from_stream() { let (expected, bytes) = mask_object(); assert_eq!( MaskObject::from_byte_stream(&mut bytes.into_iter()).unwrap(), expected ); } } ================================================ FILE: rust/xaynet-core/src/mask/object/serialization/unit.rs ================================================ //! Serialization of masked units. //! //! See the [mask module] documentation since this is a private module anyways. //! //! [mask module]: crate::mask use std::ops::Range; use anyhow::{anyhow, Context}; use num::bigint::BigUint; use crate::{ mask::{ config::{serialization::MASK_CONFIG_BUFFER_LEN, MaskConfig}, object::MaskUnit, }, message::{ traits::{FromBytes, ToBytes}, utils::range, DecodeError, }, }; const MASK_CONFIG_FIELD: Range = range(0, MASK_CONFIG_BUFFER_LEN); /// A buffer for serialized mask units. pub struct MaskUnitBuffer { inner: T, } impl> MaskUnitBuffer { /// Creates a new buffer from `bytes`. /// /// # Errors /// Fails if the `bytes` don't conform to the required buffer length for mask units. pub fn new(bytes: T) -> Result { let buffer = Self { inner: bytes }; buffer .check_buffer_length() .context("not a valid mask unit")?; Ok(buffer) } /// Creates a new buffer from `bytes`. pub fn new_unchecked(bytes: T) -> Self { Self { inner: bytes } } /// Checks if this buffer conforms to the required buffer length for mask units. /// /// # Errors /// Fails if the buffer is too small. pub fn check_buffer_length(&self) -> Result<(), DecodeError> { let len = self.inner.as_ref().len(); if len < MASK_CONFIG_FIELD.end { return Err(anyhow!( "invalid buffer length: {} < {}", len, MASK_CONFIG_FIELD.end )); } let total_expected_length = self.try_len()?; if len < total_expected_length { return Err(anyhow!( "invalid buffer length: expected {} bytes but buffer has only {} bytes", total_expected_length, len )); } Ok(()) } /// Return the expected length of the underlying byte buffer, /// based on the masking config field of numbers field. This is /// similar to [`len()`] but cannot panic. /// /// [`len()`]: MaskUnitBuffer::len pub fn try_len(&self) -> Result { let config = MaskConfig::from_byte_slice(&self.config()).context("invalid mask unit buffer")?; let data_length = config.bytes_per_number(); Ok(MASK_CONFIG_FIELD.end + data_length) } /// Gets the expected number of bytes of this buffer wrt to the masking configuration. /// /// # Panics /// Panics if the serialized masking configuration is invalid. pub fn len(&self) -> usize { let config = MaskConfig::from_byte_slice(&self.config()).unwrap(); let data_length = config.bytes_per_number(); MASK_CONFIG_FIELD.end + data_length } /// Gets the serialized masking configuration. /// /// # Panics /// May panic if this buffer is unchecked. pub fn config(&self) -> &[u8] { &self.inner.as_ref()[MASK_CONFIG_FIELD] } /// Gets the serialized mask unit element. /// /// # Panics /// May panic if this buffer is unchecked. pub fn data(&self) -> &[u8] { &self.inner.as_ref()[MASK_CONFIG_FIELD.end..self.len()] } } impl + AsMut<[u8]>> MaskUnitBuffer { /// Gets the serialized masking configuration. /// /// # Panics /// May panic if this buffer is unchecked. pub fn config_mut(&mut self) -> &mut [u8] { &mut self.inner.as_mut()[MASK_CONFIG_FIELD] } /// Gets the serialized mask unit element. /// /// # Panics /// May panic if this buffer is unchecked. pub fn data_mut(&mut self) -> &mut [u8] { let end = self.len(); &mut self.inner.as_mut()[MASK_CONFIG_FIELD.end..end] } } impl ToBytes for MaskUnit { fn buffer_length(&self) -> usize { MASK_CONFIG_FIELD.end + self.config.bytes_per_number() } fn to_bytes + AsRef<[u8]>>(&self, buffer: &mut T) { let mut writer = MaskUnitBuffer::new_unchecked(buffer.as_mut()); self.config.to_bytes(&mut writer.config_mut()); let data = writer.data_mut(); // FIXME: this allocates a vec which is sub-optimal. See // https://github.com/rust-num/num-bigint/issues/152 let bytes = self.data.to_bytes_le(); // This may panic if the data is invalid and is an // integer that is bigger than what is expected by the // configuration. data[..bytes.len()].copy_from_slice(&bytes[..]); // padding for b in data .iter_mut() .take(self.config.bytes_per_number()) .skip(bytes.len()) { *b = 0; } } } impl FromBytes for MaskUnit { fn from_byte_slice>(buffer: &T) -> Result { let reader = MaskUnitBuffer::new(buffer.as_ref())?; let config = MaskConfig::from_byte_slice(&reader.config())?; let data = BigUint::from_bytes_le(reader.data()); Ok(MaskUnit { data, config }) } fn from_byte_stream + ExactSizeIterator>( iter: &mut I, ) -> Result { let config = MaskConfig::from_byte_stream(iter)?; if iter.len() < 4 { return Err(anyhow!("byte stream exhausted")); } let data_len = config.bytes_per_number(); if iter.len() < data_len { return Err(anyhow!( "mask unit is {} bytes long but byte stream only has {} bytes", data_len, iter.len() )); } let mut buf = vec![0; data_len]; for (i, b) in iter.take(data_len).enumerate() { buf[i] = b; } let data = BigUint::from_bytes_le(buf.as_slice()); Ok(MaskUnit { data, config }) } } #[cfg(test)] pub(crate) mod tests { use super::*; use crate::mask::object::serialization::tests::mask_config; pub fn mask_unit() -> (MaskUnit, Vec) { let (config, mut bytes) = mask_config(); let data = BigUint::from(1_u8); let mask_unit = MaskUnit::new_unchecked(config, data); bytes.extend(vec![ // data (6 bytes with this config) 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, // 1 ]); (mask_unit, bytes) } #[test] fn serialize_mask_unit() { let (mask_unit, expected) = mask_unit(); let mut buf = vec![0xff; expected.len()]; mask_unit.to_bytes(&mut buf); assert_eq!(buf, expected); } #[test] fn deserialize_mask_unit() { let (expected, bytes) = mask_unit(); assert_eq!(MaskUnit::from_byte_slice(&&bytes[..]).unwrap(), expected); } #[test] fn deserialize_mask_unit_from_stream() { let (expected, bytes) = mask_unit(); assert_eq!( MaskUnit::from_byte_stream(&mut bytes.into_iter()).unwrap(), expected ); } } ================================================ FILE: rust/xaynet-core/src/mask/object/serialization/vect.rs ================================================ //! Serialization of masked vectors. //! //! See the [mask module] documentation since this is a private module anyways. //! //! [mask module]: crate::mask use std::{convert::TryInto, ops::Range}; use anyhow::{anyhow, Context}; use num::bigint::BigUint; use crate::{ mask::{ config::{serialization::MASK_CONFIG_BUFFER_LEN, MaskConfig}, object::MaskVect, }, message::{ traits::{FromBytes, ToBytes}, utils::{range, ChunkableIterator}, DecodeError, }, }; const MASK_CONFIG_FIELD: Range = range(0, MASK_CONFIG_BUFFER_LEN); const NUMBERS_FIELD: Range = range(MASK_CONFIG_FIELD.end, 4); // target dependent maximum number of mask object elements #[cfg(target_pointer_width = "16")] const MAX_NB: u32 = u16::MAX as u32; /// A buffer for serialized mask vectors. pub struct MaskVectBuffer { inner: T, } #[allow(clippy::len_without_is_empty)] impl> MaskVectBuffer { /// Creates a new buffer from `bytes`. /// /// # Errors /// Fails if the `bytes` don't conform to the required buffer length for mask vectors. pub fn new(bytes: T) -> Result { let buffer = Self { inner: bytes }; buffer .check_buffer_length() .context("not a valid mask vector")?; Ok(buffer) } /// Creates a new buffer from `bytes`. pub fn new_unchecked(bytes: T) -> Self { Self { inner: bytes } } /// Checks if this buffer conforms to the required buffer length for mask vectors. /// /// # Errors /// Fails if the buffer is too small. pub fn check_buffer_length(&self) -> Result<(), DecodeError> { let len = self.inner.as_ref().len(); if len < NUMBERS_FIELD.end { return Err(anyhow!( "invalid buffer length: {} < {}", len, NUMBERS_FIELD.end )); } let total_expected_length = self.try_len()?; if len < total_expected_length { return Err(anyhow!( "invalid buffer length: expected {} bytes but buffer has only {} bytes", total_expected_length, len )); } Ok(()) } /// Return the expected length of the underlying byte buffer, /// based on the masking config field of numbers field. This is /// similar to [`len()`] but cannot panic. /// /// [`len()`]: MaskVectBuffer::len fn try_len(&self) -> Result { let config = MaskConfig::from_byte_slice(&self.config()).context("invalid mask vector buffer")?; let bytes_per_number = config.bytes_per_number(); let (data_length, overflows) = self.numbers().overflowing_mul(bytes_per_number); if overflows { return Err(anyhow!( "invalid MaskObject buffer: invalid masking config or numbers field" )); } Ok(NUMBERS_FIELD.end + data_length) } /// Gets the expected number of bytes of this buffer wrt to the masking configuration. /// /// # Panics /// Panics if the serialized masking configuration is invalid. pub fn len(&self) -> usize { let config = MaskConfig::from_byte_slice(&self.config()).unwrap(); let bytes_per_number = config.bytes_per_number(); let data_length = self.numbers() * bytes_per_number; NUMBERS_FIELD.end + data_length } /// Gets the number of serialized mask object elements. /// /// # Panics /// May panic if this buffer is unchecked. /// /// Panics if the number can't be represented as usize on targets smaller than 32 bits. pub fn numbers(&self) -> usize { // UNWRAP SAFE: the slice is exactly 4 bytes long let nb = u32::from_be_bytes(self.inner.as_ref()[NUMBERS_FIELD].try_into().unwrap()); // smaller targets than 32 bits are currently not of interest #[cfg(target_pointer_width = "16")] if nb > MAX_NB { panic!("16 bit targets or smaller are currently not fully supported") } nb as usize } /// Gets the serialized masking configuration. /// /// # Panics /// May panic if this buffer is unchecked. pub fn config(&self) -> &[u8] { &self.inner.as_ref()[MASK_CONFIG_FIELD] } /// Gets the serialized mask vector elements. /// /// # Panics /// May panic if this buffer is unchecked. pub fn data(&self) -> &[u8] { &self.inner.as_ref()[NUMBERS_FIELD.end..self.len()] } } impl + AsMut<[u8]>> MaskVectBuffer { /// Sets the number of serialized mask vector elements. /// /// # Panics /// May panic if this buffer is unchecked. pub fn set_numbers(&mut self, value: u32) { self.inner.as_mut()[NUMBERS_FIELD].copy_from_slice(&value.to_be_bytes()); } /// Gets the serialized masking configuration. /// /// # Panics /// May panic if this buffer is unchecked. pub fn config_mut(&mut self) -> &mut [u8] { &mut self.inner.as_mut()[MASK_CONFIG_FIELD] } /// Gets the serialized mask vector elements. /// /// # Panics /// May panic if this buffer is unchecked. pub fn data_mut(&mut self) -> &mut [u8] { let end = self.len(); &mut self.inner.as_mut()[NUMBERS_FIELD.end..end] } } impl ToBytes for MaskVect { fn buffer_length(&self) -> usize { NUMBERS_FIELD.end + self.config.bytes_per_number() * self.data.len() } fn to_bytes>(&self, buffer: &mut T) { let mut writer = MaskVectBuffer::new_unchecked(buffer.as_mut()); self.config.to_bytes(&mut writer.config_mut()); writer.set_numbers(self.data.len() as u32); let mut data = writer.data_mut(); let bytes_per_number = self.config.bytes_per_number(); for int in self.data.iter() { // FIXME: this allocates a vec which is sub-optimal. See // https://github.com/rust-num/num-bigint/issues/152 let bytes = int.to_bytes_le(); // This may panic if the data is invalid and contains // integers that are bigger than what is expected by the // configuration. data[..bytes.len()].copy_from_slice(&bytes[..]); // padding for b in data.iter_mut().take(bytes_per_number).skip(bytes.len()) { *b = 0; } data = &mut data[bytes_per_number..]; } } } impl FromBytes for MaskVect { fn from_byte_slice>(buffer: &T) -> Result { let reader = MaskVectBuffer::new(buffer.as_ref())?; let config = MaskConfig::from_byte_slice(&reader.config())?; let mut data = Vec::with_capacity(reader.numbers()); let bytes_per_number = config.bytes_per_number(); for chunk in reader.data().chunks(bytes_per_number) { data.push(BigUint::from_bytes_le(chunk)); } Ok(MaskVect { data, config }) } fn from_byte_stream + ExactSizeIterator>( iter: &mut I, ) -> Result { let config = MaskConfig::from_byte_stream(iter)?; if iter.len() < 4 { return Err(anyhow!("byte stream exhausted")); } let numbers = u32::from_byte_stream(iter) .context("failed to parse the number of items in mask vector")?; let bytes_per_number = config.bytes_per_number(); let data_len = numbers as usize * bytes_per_number; if iter.len() < data_len { return Err(anyhow!( "mask vector is {} bytes long but byte stream only has {} bytes", data_len, iter.len() )); } let mut data = Vec::with_capacity(numbers as usize); let mut buf = vec![0; bytes_per_number]; for chunk in iter.take(data_len).chunks(bytes_per_number).into_iter() { for (i, b) in chunk.enumerate() { buf[i] = b; } data.push(BigUint::from_bytes_le(buf.as_slice())); } Ok(MaskVect { data, config }) } } #[cfg(test)] pub(crate) mod tests { use super::*; use crate::mask::object::serialization::tests::mask_config; pub fn mask_vect() -> (MaskVect, Vec) { let (config, mut bytes) = mask_config(); let data = vec![ BigUint::from(1_u8), BigUint::from(2_u8), BigUint::from(3_u8), BigUint::from(4_u8), ]; let mask_vect = MaskVect::new_unchecked(config, data); bytes.extend(vec![ // number of elements 0x00, 0x00, 0x00, 0x04, // data (1 weight => 6 bytes with this config) 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, // 1 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, // 2 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, // 3 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, // 4 ]); (mask_vect, bytes) } #[test] fn serialize_mask_vect() { let (mask_vect, expected) = mask_vect(); let mut buf = vec![0xff; expected.len()]; mask_vect.to_bytes(&mut buf); assert_eq!(buf, expected); } #[test] fn deserialize_mask_vect() { let (expected, bytes) = mask_vect(); assert_eq!(MaskVect::from_byte_slice(&&bytes[..]).unwrap(), expected); } #[test] fn deserialize_mask_vect_from_stream() { let (expected, bytes) = mask_vect(); assert_eq!( MaskVect::from_byte_stream(&mut bytes.into_iter()).unwrap(), expected ); } } ================================================ FILE: rust/xaynet-core/src/mask/scalar.rs ================================================ //! Scalar representation and conversion. //! //! See the [mask module] documentation since this is a private module anyways. //! //! [mask module]: crate::mask use crate::mask::{ model::{ratio_to_float, PrimitiveType}, PrimitiveCastError, }; use derive_more::{From, Into}; use num::{ clamp, rational::Ratio, traits::{float::FloatCore, ToPrimitive}, BigInt, BigUint, One, Unsigned, Zero, }; use serde::{Deserialize, Serialize}; use std::{ convert::{TryFrom, TryInto}, fmt::Debug, }; use thiserror::Error; #[derive(Debug, Clone, PartialEq, Hash, From, Into, Serialize, Deserialize)] /// A numerical representation of a machine learning scalar. pub struct Scalar(Ratio); impl From for Ratio { fn from(scalar: Scalar) -> Self { let (numer, denom) = scalar.0.into(); Ratio::new(numer.into(), denom.into()) } } impl TryFrom> for Scalar { type Error = >::Error; fn try_from(ratio: Ratio) -> Result { let (numer, denom) = ratio.into(); Ok(Self(Ratio::new(numer.try_into()?, denom.try_into()?))) } } impl Scalar { /// Constructs a new `Scalar` from the given numerator and denominator. pub fn new(numer: U, denom: U) -> Self where U: Unsigned + Into, { Self(Ratio::new(numer.into(), denom.into())) } /// Constructs a `Scalar` representing the given integer. pub fn from_integer(u: U) -> Self where U: Unsigned + Into, { Self(Ratio::from_integer(u.into())) } /// Constructs a `Scalar` of unit value. pub fn unit() -> Self { Self(Ratio::one()) } /// Convenience method for conversion to a non-negative ratio of `BigInt`. pub(crate) fn to_ratio(&self) -> Ratio { self.clone().into() } /// Constructs a `Scalar` from a primitive floating point value, clamped where necessary. /// /// Maps positive infinity to max of the primitive data type, negatives and NaN to zero. pub(crate) fn from_float_bounded(f: F) -> Self { if f.is_nan() { Self(Ratio::zero()) } else { let finite_f = clamp(f, F::zero(), F::max_value()); // safe unwrap: clamped weight is guaranteed to be finite let r = Ratio::from_float(finite_f).unwrap(); // safe unwrap: bounded non-negative ratio r r.try_into().unwrap() } } } #[derive(Error, Debug)] #[error("Could not convert weight {weight} to primitive type {target}")] /// Errors related to scalar conversion into primitives. pub struct ScalarCastError { weight: Ratio, target: PrimitiveType, } /// An interface for conversion into a primitive value. /// /// This trait is used to convert a [`Scalar`], which has its own internal /// representation, into a primitive type ([`f32`], [`f64`], [`i32`], [`i64`]). /// The opposite trait is [`FromPrimitive`]. pub trait IntoPrimitive

: Sized { /// Consumes into a converted primitive value. /// /// # Errors /// Returns an error if the conversion fails. fn into_primitive(self) -> Result; /// Converts to a primitive value. /// /// # Errors /// Returns an error if the conversion fails. fn to_primitive(&self) -> Result; /// Consumes into a converted primitive value. /// /// # Panics /// Panics if the conversion fails. fn into_primitive_unchecked(self) -> P { self.into_primitive() .expect("conversion to primitive type failed") } } /// An interface for conversion from a primitive value. /// /// This trait is used to obtain a [`Scalar`], which has its own representation, /// from a primitive type ([`f32`], [`f64`], [`i32`], [`i64`]). The opposite /// trait is [`IntoPrimitive`]. pub trait FromPrimitive: Sized { /// Converts from a primitive value. /// /// # Errors /// Returns an error if the conversion fails. fn from_primitive(prim: P) -> Result>; /// Converts from a primitive value. /// /// If a direct conversion cannot be obtained from the primitive value, it is clamped. fn from_primitive_bounded(prim: P) -> Self; } impl IntoPrimitive for Scalar { fn into_primitive(self) -> Result { let r = self.0; r.to_integer().to_i32().ok_or(ScalarCastError { weight: r, target: PrimitiveType::I32, }) } fn to_primitive(&self) -> Result { self.clone().into_primitive() } } impl FromPrimitive for Scalar { fn from_primitive(prim: i32) -> Result> { let i = BigUint::try_from(prim).map_err(|_| PrimitiveCastError(prim))?; Ok(Self(Ratio::from_integer(i))) } fn from_primitive_bounded(prim: i32) -> Self { Self::from_primitive(prim).unwrap_or_else(|_| Self(Ratio::zero())) } } impl IntoPrimitive for Scalar { fn into_primitive(self) -> Result { let i = self.0; i.to_integer().to_i64().ok_or(ScalarCastError { weight: i, target: PrimitiveType::I64, }) } fn to_primitive(&self) -> Result { self.clone().into_primitive() } } impl FromPrimitive for Scalar { fn from_primitive(prim: i64) -> Result> { let i = BigUint::try_from(prim).map_err(|_| PrimitiveCastError(prim))?; Ok(Self(Ratio::from_integer(i))) } fn from_primitive_bounded(prim: i64) -> Self { Self::from_primitive(prim).unwrap_or_else(|_| Self(Ratio::zero())) } } impl IntoPrimitive for Scalar { fn into_primitive(self) -> Result { let r = self.to_ratio(); ratio_to_float(&r).ok_or(ScalarCastError { weight: self.0, target: PrimitiveType::F32, }) } fn to_primitive(&self) -> Result { self.clone().into_primitive() } } impl FromPrimitive for Scalar { fn from_primitive(prim: f32) -> Result> { let r = Ratio::from_float(prim).ok_or(PrimitiveCastError(prim))?; r.try_into().map_err(|_| PrimitiveCastError(prim)) } fn from_primitive_bounded(prim: f32) -> Self { Self::from_float_bounded(prim) } } impl IntoPrimitive for Scalar { fn into_primitive(self) -> Result { let r = self.to_ratio(); ratio_to_float(&r).ok_or(ScalarCastError { weight: self.0, target: PrimitiveType::F64, }) } fn to_primitive(&self) -> Result { self.clone().into_primitive() } } impl FromPrimitive for Scalar { fn from_primitive(prim: f64) -> Result> { let r = Ratio::from_float(prim).ok_or(PrimitiveCastError(prim))?; r.try_into().map_err(|_| PrimitiveCastError(prim)) } fn from_primitive_bounded(prim: f64) -> Self { Self::from_float_bounded(prim) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_ratio_conversion() { let (numer, denom) = (1_u8, 2_u8); let expected_ratio = Ratio::new(BigInt::from(numer), BigInt::from(denom)); let actual_ratio = Scalar::new(numer, denom).into(); assert_eq!(expected_ratio, actual_ratio); } #[test] fn test_ratio_conversion_ok() { let (numer, denom) = (1_u8, 2_u8); let ratio = Ratio::new(BigInt::from(numer), BigInt::from(denom)); let sc_res = Scalar::try_from(ratio); assert!(sc_res.is_ok()); assert_eq!(sc_res.unwrap(), Scalar::new(numer, denom)); } #[test] fn test_ratio_conversion_err() { let neg_ratio = Ratio::new(BigInt::from(-1), BigInt::from(2)); let sc_res = Scalar::try_from(neg_ratio); assert!(sc_res.is_err()); } #[test] #[allow(clippy::float_cmp)] fn test_scalar_f32() { let prim_sc_pairs = vec![ (0_f32, Scalar::from_integer(0_u8)), (2_f32, Scalar::from_integer(2_u8)), (0.5_f32, Scalar::new(1_u8, 2_u8)), ]; for (prim, sc) in prim_sc_pairs { let converted_sc = Scalar::from_primitive(prim); assert!(converted_sc.is_ok()); assert_eq!(converted_sc.unwrap(), sc); let converted_sc = Scalar::from_primitive_bounded(prim); assert_eq!(converted_sc, sc); let converted_prim: f32 = sc.into_primitive_unchecked(); assert_eq!(converted_prim, prim); } } #[test] fn test_scalar_f32_from_weird_prims() { let prim_pairs = vec![ (f32::INFINITY, f32::MAX), (-1_f32, 0_f32), (f32::NAN, 0_f32), ]; for (weird, fine) in prim_pairs { let weird_res = Scalar::from_primitive(weird); assert!(weird_res.is_err()); let bounded = Scalar::from_primitive_bounded(weird); let fine_res = Scalar::try_from(Ratio::from_float(fine).unwrap()); assert!(fine_res.is_ok()); assert_eq!(bounded, fine_res.unwrap()); } } #[test] #[allow(clippy::float_cmp)] fn test_scalar_f64() { let prim_sc_pairs = vec![ (0_f64, Scalar::from_integer(0_u8)), (2_f64, Scalar::from_integer(2_u8)), (0.5_f64, Scalar::new(1_u8, 2_u8)), ]; for (prim, sc) in prim_sc_pairs { let converted_sc = Scalar::from_primitive(prim); assert!(converted_sc.is_ok()); assert_eq!(converted_sc.unwrap(), sc); let converted_sc = Scalar::from_primitive_bounded(prim); assert_eq!(converted_sc, sc); let converted_prim: f64 = sc.into_primitive_unchecked(); assert_eq!(converted_prim, prim); } } #[test] fn test_scalar_f64_from_weird_prims() { let prim_pairs = vec![ (f64::INFINITY, f64::MAX), (-1_f64, 0_f64), (f64::NAN, 0_f64), ]; for (weird, fine) in prim_pairs { let weird_res = Scalar::from_primitive(weird); assert!(weird_res.is_err()); let bounded = Scalar::from_primitive_bounded(weird); let fine_res = Scalar::try_from(Ratio::from_float(fine).unwrap()); assert!(fine_res.is_ok()); assert_eq!(bounded, fine_res.unwrap()); } } #[test] fn test_scalar_i32() { let prim_sc_pairs = vec![ (0_i32, Scalar::from_integer(0_u8)), (2_i32, Scalar::from_integer(2_u8)), ]; for (prim, sc) in prim_sc_pairs { let converted_sc = Scalar::from_primitive(prim); assert!(converted_sc.is_ok()); assert_eq!(converted_sc.unwrap(), sc); let converted_sc = Scalar::from_primitive_bounded(prim); assert_eq!(converted_sc, sc); let converted_prim: i32 = sc.into_primitive_unchecked(); assert_eq!(converted_prim, prim); } } #[test] fn test_scalar_i64() { let prim_sc_pairs = vec![ (0_i64, Scalar::from_integer(0_u8)), (2_i64, Scalar::from_integer(2_u8)), ]; for (prim, sc) in prim_sc_pairs { let converted_sc = Scalar::from_primitive(prim); assert!(converted_sc.is_ok()); assert_eq!(converted_sc.unwrap(), sc); let converted_sc = Scalar::from_primitive_bounded(prim); assert_eq!(converted_sc, sc); let converted_prim: i64 = sc.into_primitive_unchecked(); assert_eq!(converted_prim, prim); } } } ================================================ FILE: rust/xaynet-core/src/mask/seed.rs ================================================ //! Mask seed and mask generation. //! //! See the [mask module] documentation since this is a private module anyways. //! //! [mask module]: crate::mask use std::iter; use derive_more::{AsMut, AsRef}; use rand::SeedableRng; use rand_chacha::ChaCha20Rng; use serde::{Deserialize, Serialize}; use sodiumoxide::crypto::box_; use thiserror::Error; use crate::{ crypto::{encrypt::SEALBYTES, prng::generate_integer, ByteObject}, mask::{ object::{MaskObject, MaskUnit, MaskVect}, MaskConfigPair, }, SumParticipantEphemeralPublicKey, SumParticipantEphemeralSecretKey, }; #[derive(AsRef, AsMut, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] /// A seed to generate a mask. /// /// When this goes out of scope, its contents will be zeroed out. pub struct MaskSeed(box_::Seed); impl ByteObject for MaskSeed { const LENGTH: usize = box_::SEEDBYTES; fn from_slice(bytes: &[u8]) -> Option { box_::Seed::from_slice(bytes).map(Self) } fn zeroed() -> Self { Self(box_::Seed([0_u8; Self::LENGTH])) } fn as_slice(&self) -> &[u8] { self.0.as_ref() } } impl MaskSeed { /// Gets this seed as an array. pub fn as_array(&self) -> [u8; Self::LENGTH] { (self.0).0 } /// Encrypts this seed with the given public key as an [`EncryptedMaskSeed`]. pub fn encrypt(&self, pk: &SumParticipantEphemeralPublicKey) -> EncryptedMaskSeed { // safe unwrap: length of slice is guaranteed by constants EncryptedMaskSeed::from_slice_unchecked(pk.encrypt(self.as_slice()).as_slice()) } /// Derives a mask of given length from this seed wrt the masking configurations. pub fn derive_mask(&self, len: usize, config: MaskConfigPair) -> MaskObject { let MaskConfigPair { vect: config_n, unit: config_1, } = config; let mut prng = ChaCha20Rng::from_seed(self.as_array()); let rand_int = generate_integer(&mut prng, &config_1.order()); let scalar_mask = MaskUnit::new_unchecked(config_1, rand_int); let order_n = config_n.order(); let rand_ints = iter::repeat_with(|| generate_integer(&mut prng, &order_n)) .take(len) .collect(); let model_mask = MaskVect::new_unchecked(config_n, rand_ints); MaskObject::new_unchecked(model_mask, scalar_mask) } } #[derive(AsRef, AsMut, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] /// An encrypted mask seed. pub struct EncryptedMaskSeed(Vec); impl From> for EncryptedMaskSeed { fn from(value: Vec) -> Self { Self(value) } } impl ByteObject for EncryptedMaskSeed { const LENGTH: usize = SEALBYTES + MaskSeed::LENGTH; fn from_slice(bytes: &[u8]) -> Option { if bytes.len() == Self::LENGTH { Some(Self(bytes.to_vec())) } else { None } } fn zeroed() -> Self { Self(vec![0_u8; Self::LENGTH]) } fn as_slice(&self) -> &[u8] { self.0.as_slice() } } #[derive(Debug, Error)] pub enum InvalidMaskSeed { #[error("the encrypted mask seed could not be decrypted")] DecryptionFailed, #[error("the mask seed has an invalid length")] InvalidLength, } impl EncryptedMaskSeed { /// Decrypts this seed as a [`MaskSeed`]. /// /// # Errors /// Fails if the decryption fails. pub fn decrypt( &self, pk: &SumParticipantEphemeralPublicKey, sk: &SumParticipantEphemeralSecretKey, ) -> Result { MaskSeed::from_slice( sk.decrypt(self.as_slice(), pk) .or(Err(InvalidMaskSeed::DecryptionFailed))? .as_slice(), ) .ok_or(InvalidMaskSeed::InvalidLength) } } #[cfg(test)] mod tests { use super::*; use crate::{ crypto::encrypt::EncryptKeyPair, mask::config::{BoundType, DataType, GroupType, MaskConfig, ModelType}, }; #[test] fn test_constants() { assert_eq!(MaskSeed::LENGTH, 32); assert_eq!( MaskSeed::zeroed().as_slice(), [0_u8; 32].to_vec().as_slice(), ); assert_eq!(EncryptedMaskSeed::LENGTH, 80); assert_eq!( EncryptedMaskSeed::zeroed().as_slice(), [0_u8; 80].to_vec().as_slice(), ); } #[test] fn test_derive_mask() { let config = MaskConfig { group_type: GroupType::Prime, data_type: DataType::F32, bound_type: BoundType::B0, model_type: ModelType::M3, }; let seed = MaskSeed::generate(); let mask = seed.derive_mask(10, config.into()); assert_eq!(mask.vect.data.len(), 10); assert!(mask .vect .data .iter() .all(|integer| integer < &config.order())); } #[test] fn test_encryption() { let seed = MaskSeed::generate(); assert_eq!(seed.as_slice().len(), 32); assert_ne!(seed, MaskSeed::zeroed()); let EncryptKeyPair { public, secret } = EncryptKeyPair::generate(); let encr_seed = seed.encrypt(&public); assert_eq!(encr_seed.as_slice().len(), 80); let decr_seed = encr_seed.decrypt(&public, &secret).unwrap(); assert_eq!(seed, decr_seed); } } ================================================ FILE: rust/xaynet-core/src/message/message.rs ================================================ //! Message buffers. //! //! See the [message module] documentation since this is a private module anyways. //! //! [message module]: crate::mask use std::convert::{TryFrom, TryInto}; use anyhow::{anyhow, Context}; use serde::{Deserialize, Serialize}; use crate::{ crypto::{ByteObject, PublicEncryptKey, PublicSigningKey, SecretSigningKey, Signature}, message::{Chunk, DecodeError, FromBytes, Payload, Sum, Sum2, ToBytes, Update}, }; /// The minimum number of accepted `sum`/`sum2` messages for the PET protocol to function correctly. pub const SUM_COUNT_MIN: u64 = 1; /// The minimum number of accepted `update` messages for the PET protocol to function correctly. pub const UPDATE_COUNT_MIN: u64 = 3; pub(crate) mod ranges { use std::ops::Range; use super::*; use crate::message::utils::range; /// Byte range corresponding to the signature in a message in a /// message header pub const SIGNATURE: Range = range(0, Signature::LENGTH); /// Byte range corresponding to the participant public key in a /// message header pub const PARTICIPANT_PK: Range = range(SIGNATURE.end, PublicSigningKey::LENGTH); /// Byte range corresponding to the coordinator public key in a /// message header pub const COORDINATOR_PK: Range = range(PARTICIPANT_PK.end, PublicEncryptKey::LENGTH); /// Byte range corresponding to the length field in a message header pub const LENGTH: Range = range(COORDINATOR_PK.end, 4); /// Byte range corresponding to the tag in a message header pub const TAG: usize = LENGTH.end; /// Byte range corresponding to the flags in a message header pub const FLAGS: usize = TAG + 1; /// Byte range reserved for future use pub const RESERVED: Range = range(FLAGS + 1, 2); } /// Length in bytes of a message header pub const HEADER_LENGTH: usize = ranges::RESERVED.end; /// A wrapper around a buffer that contains a [`Message`]. /// /// It provides getters and setters to access the different fields of /// the message safely. A message is made of a header and a payload: /// /// ```no_rust /// 0 1 2 3 /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ /// | | /// + + /// | | /// + + /// | | /// + + /// | | /// + + /// | | /// + + /// | | /// + + /// | | /// + + /// | | /// + signature + /// | | /// + + /// | | /// + + /// | | /// + + /// | | /// + + /// | | /// + + /// | | /// + + /// | | /// + + /// | | /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ /// | | /// + + /// | | /// + + /// | | /// + + /// | | /// + participant_pk + /// | | /// + + /// | | /// + + /// | | /// + + /// | | /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ /// | | /// + + /// | | /// + + /// | | /// + + /// | | /// + coordinator_pk + /// | | /// + + /// | | /// + + /// | | /// + + /// | | /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ /// | length | /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ /// | tag | flags | reserved | /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ /// | | /// + payload (variable length) + /// | | /// ``` /// /// - `signature` contains the signature of the entire message /// - `participant_pk` contains the public key for verifying the /// signature /// - `coordinator_pk` is the coordinator public encryption key. It is /// embedded in the message for security reasons. See [_Donald /// T. Davis, "Defective Sign & Encrypt in S/MIME, PKCS#7, MOSS, /// PEM, PGP, and XML.", Proc. Usenix Tech. Conf. 2001 (Boston, /// Mass., June 25-30, /// 2001)_](http://world.std.com/~dtd/sign_encrypt/sign_encrypt7.html) /// - `length` is the length in bytes of the _full_ message, _i.e._ /// including the header. This is a 32 bits field so in theory, /// messages can be as big as 2^32 = 4,294,967,296 bytes. /// - `tag` indicates the type of message (sum, update, sum2 or /// multipart message) /// - the `flags` field currently supports a single flag, that /// indicates whether this is a multipart message /// /// # Examples /// ## Reading a sum message /// /// ```rust /// use std::convert::TryFrom; /// use xaynet_core::message::{Flags, MessageBuffer, Tag}; /// /// let mut bytes = vec![0x11; 64]; // message signature /// bytes.extend(vec![0x22; 32]); // participant public signing key /// bytes.extend(vec![0x33; 32]); // coordinator public encrypt key /// bytes.extend(&200_u32.to_be_bytes()); // Length field /// bytes.push(0x01); // tag (sum message) /// bytes.push(0x00); // flags (not a multipart message) /// bytes.extend(vec![0x00, 0x00]); // reserved /// /// // Payload: a sum message contains a signature and an ephemeral public key /// bytes.extend(vec![0xaa; 32]); // signature /// bytes.extend(vec![0xbb; 32]); // public key /// /// let buffer = MessageBuffer::new(&bytes).unwrap(); /// assert_eq!(buffer.signature(), vec![0x11; 64].as_slice()); /// assert_eq!(buffer.participant_pk(), vec![0x22; 32].as_slice()); /// assert_eq!(buffer.coordinator_pk(), vec![0x33; 32].as_slice()); /// assert_eq!(Tag::try_from(buffer.tag()).unwrap(), Tag::Sum); /// assert_eq!(Flags::try_from(buffer.flags()).unwrap(), Flags::empty()); /// assert_eq!( /// buffer.payload(), /// [vec![0xaa; 32], vec![0xbb; 32]].concat().as_slice() /// ); /// ``` /// /// ## Writing a sum message /// /// ```rust /// use std::convert::TryFrom; /// use xaynet_core::message::{Flags, MessageBuffer, Tag}; /// /// let mut expected = vec![0x11; 64]; // message signature /// expected.extend(vec![0x22; 32]); // participant public signing key /// expected.extend(vec![0x33; 32]); // coordinator public signing key /// expected.extend(&200_u32.to_be_bytes()); // length field /// expected.push(0x01); // tag (sum message) /// expected.push(0x00); // flags (not a multipart message) /// expected.extend(vec![0x00, 0x00]); // reserved /// /// // Payload: a sum message contains a signature and an ephemeral public key /// expected.extend(vec![0xaa; 32]); // signature /// expected.extend(vec![0xbb; 32]); // public key /// /// let mut bytes = vec![0; expected.len()]; /// let mut buffer = MessageBuffer::new_unchecked(&mut bytes); /// buffer /// .signature_mut() /// .copy_from_slice(vec![0x11; 64].as_slice()); /// buffer /// .participant_pk_mut() /// .copy_from_slice(vec![0x22; 32].as_slice()); /// buffer /// .coordinator_pk_mut() /// .copy_from_slice(vec![0x33; 32].as_slice()); /// buffer.set_length(200 as u32); /// buffer.set_tag(Tag::Sum.into()); /// buffer.set_flags(Flags::empty()); /// buffer /// .payload_mut() /// .copy_from_slice([vec![0xaa; 32], vec![0xbb; 32]].concat().as_slice()); /// assert_eq!(expected, bytes); /// ``` pub struct MessageBuffer { inner: T, } impl> MessageBuffer { pub fn inner(&self) -> &T { &self.inner } pub fn as_ref(&self) -> MessageBuffer<&T> { MessageBuffer::new_unchecked(self.inner()) } /// Performs bound checks for the various message fields on `bytes` and returns a new /// [`MessageBuffer`]. /// /// # Errors /// Fails if the `bytes` are smaller than a minimal-sized message buffer. pub fn new(bytes: T) -> Result { let buffer = Self { inner: bytes }; buffer .check_buffer_length() .context("not a valid MessageBuffer")?; Ok(buffer) } /// Returns a [`MessageBuffer`] without performing any bound checks. /// /// This means accessing the various fields may panic if the data /// is invalid. pub fn new_unchecked(bytes: T) -> Self { Self { inner: bytes } } /// Performs bound checks to ensure the fields can be accessed /// without panicking. pub fn check_buffer_length(&self) -> Result<(), DecodeError> { let len = self.inner.as_ref().len(); if len < HEADER_LENGTH { return Err(anyhow!( "invalid buffer length: {} < {}", len, HEADER_LENGTH )); } let expected_len = self.length() as usize; let actual_len = self.inner.as_ref().len(); if actual_len < expected_len { return Err(anyhow!( "invalid message length: length field says {}, but buffer is {} bytes long", expected_len, actual_len )); } Ok(()) } /// Gets the tag field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn tag(&self) -> u8 { self.inner.as_ref()[ranges::TAG] } /// Gets the flags field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn flags(&self) -> Flags { Flags::from_bits_truncate(self.inner.as_ref()[ranges::FLAGS]) } /// Gets the length field /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn length(&self) -> u32 { // Unwrapping is OK, as the slice is guaranteed to be 4 bytes // long u32::from_be_bytes(self.inner.as_ref()[ranges::LENGTH].try_into().unwrap()) } } impl<'a, T: AsRef<[u8]> + ?Sized> MessageBuffer<&'a T> { /// Gets the message signature field /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn signature(&self) -> &'a [u8] { &self.inner.as_ref()[ranges::SIGNATURE] } /// Gets the participant public key field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn participant_pk(&self) -> &'a [u8] { &self.inner.as_ref()[ranges::PARTICIPANT_PK] } /// Gets the coordinator public key field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn coordinator_pk(&self) -> &'a [u8] { &self.inner.as_ref()[ranges::COORDINATOR_PK] } /// Gets the rest of the message. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn payload(&self) -> &'a [u8] { &self.inner.as_ref()[HEADER_LENGTH..] } /// Parse the signature and public signing key, and check the /// message signature. pub fn check_signature(&self) -> Result<(), DecodeError> { let signature = Signature::from_byte_slice(&self.signature()) .context("cannot parse the signature field")?; let participant_pk = PublicSigningKey::from_byte_slice(&self.participant_pk()) .context("cannot part the public key field")?; if participant_pk.verify_detached(&signature, self.signed_data()) { Ok(()) } else { Err(anyhow!("invalid message signature")) } } /// Return the portion of the message used to compute the /// signature, ie the entire message except the signature field /// itself. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn signed_data(&self) -> &'a [u8] { let signed_data_range = ranges::SIGNATURE.end..self.length() as usize; &self.inner.as_ref()[signed_data_range] } } impl + AsRef<[u8]>> MessageBuffer { /// Sets the tag field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn set_tag(&mut self, value: u8) { self.inner.as_mut()[ranges::TAG] = value; } /// Sets the flags field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn set_flags(&mut self, value: Flags) { self.inner.as_mut()[ranges::FLAGS] = value.bits(); } /// Sets the length field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn set_length(&mut self, value: u32) { let bytes = value.to_be_bytes(); self.inner.as_mut()[ranges::LENGTH].copy_from_slice(&bytes[..]); } /// Gets a mutable reference to the message signature field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn signature_mut(&mut self) -> &mut [u8] { &mut self.inner.as_mut()[ranges::SIGNATURE] } /// Gets a mutable reference to the participant public key field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn participant_pk_mut(&mut self) -> &mut [u8] { &mut self.inner.as_mut()[ranges::PARTICIPANT_PK] } /// Gets a mutable reference to the coordinator public key field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn coordinator_pk_mut(&mut self) -> &mut [u8] { &mut self.inner.as_mut()[ranges::COORDINATOR_PK] } /// Gets a mutable reference to the rest of the message. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn payload_mut(&mut self) -> &mut [u8] { &mut self.inner.as_mut()[HEADER_LENGTH..] } /// Gets a mutable reference to the portion of the message used to /// compute the signature, ie the entire message except the /// signature field itself. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn signed_data_mut(&mut self) -> &mut [u8] { let signed_data_range = ranges::SIGNATURE.end..self.length() as usize; &mut self.inner.as_mut()[signed_data_range] } } bitflags::bitflags! { /// A bitmask that defines flags for a [`Message`]. pub struct Flags: u8 { /// Indicates whether this message is a multipart message const MULTIPART = 1 << 0; } } #[derive(Copy, Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] /// A tag that indicates the type of the [`Message`]. pub enum Tag { /// A tag for [`Sum`] messages Sum, /// A tag for [`Update`] messages Update, /// A tag for [`Sum2`] messages Sum2, } impl TryFrom for Tag { type Error = DecodeError; fn try_from(value: u8) -> Result { Ok(match value { 1 => Tag::Sum, 2 => Tag::Update, 3 => Tag::Sum2, _ => return Err(anyhow!("invalid tag {}", value)), }) } } impl From for u8 { fn from(tag: Tag) -> Self { match tag { Tag::Sum => 1, Tag::Update => 2, Tag::Sum2 => 3, } } } #[derive(Debug, Eq, PartialEq, Clone)] /// A header common to all messages. pub struct Message { /// Message signature. This can be `None` if it hasn't been /// computed yet. pub signature: Option, /// The participant public key, used to verify the message /// signature. pub participant_pk: PublicSigningKey, /// The coordinator public key pub coordinator_pk: PublicEncryptKey, /// Wether this is a multipart message pub is_multipart: bool, /// The type of message. This information is partially redundant /// with the `payload` field. So when serializing the message, /// this field is ignored if the payload is a [`Payload::Sum`], /// [`Payload::Update`], or [`Payload::Sum2`]. However, it is /// taken as is for [`Payload::Chunk`]. pub tag: Tag, /// Message payload pub payload: Payload, } impl Message { /// Create a new sum message with the given participant and /// coordinator public keys. pub fn new_sum( participant_pk: PublicSigningKey, coordinator_pk: PublicEncryptKey, message: Sum, ) -> Self { Self { signature: None, participant_pk, coordinator_pk, is_multipart: false, tag: Tag::Sum, payload: message.into(), } } /// Create a new sum2 message with the given participant and /// coordinator public keys. pub fn new_sum2( participant_pk: PublicSigningKey, coordinator_pk: PublicEncryptKey, message: Sum2, ) -> Self { Self { signature: None, participant_pk, coordinator_pk, is_multipart: false, tag: Tag::Sum2, payload: message.into(), } } /// Create a new update message with the given participant and /// coordinator public keys. pub fn new_update( participant_pk: PublicSigningKey, coordinator_pk: PublicEncryptKey, message: Update, ) -> Self { Self { signature: None, participant_pk, coordinator_pk, is_multipart: false, tag: Tag::Update, payload: message.into(), } } /// Create a new multipart message with the given participant and /// coordinator public keys. pub fn new_multipart( participant_pk: PublicSigningKey, coordinator_pk: PublicEncryptKey, message: Chunk, tag: Tag, ) -> Self { Self { signature: None, participant_pk, coordinator_pk, is_multipart: true, tag, payload: message.into(), } } /// Parse the given message **without** verifying the /// signature. If you need to check the signature, call /// [`MessageBuffer.verify_signature`] before parsing the message. pub fn from_byte_slice>(buffer: &T) -> Result { let reader = MessageBuffer::new(buffer.as_ref())?; let signature = Signature::from_byte_slice(&reader.signature()).context("failed to parse signature")?; let participant_pk = PublicSigningKey::from_byte_slice(&reader.participant_pk()) .context("failed to parse public key")?; let coordinator_pk = PublicEncryptKey::from_byte_slice(&reader.coordinator_pk()) .context("failed to parse public key")?; let tag = reader.tag().try_into()?; let is_multipart = reader.flags().contains(Flags::MULTIPART); let payload = if is_multipart { Chunk::from_byte_slice(&reader.payload()).map(Into::into) } else { match tag { Tag::Sum => Sum::from_byte_slice(&reader.payload()).map(Into::into), Tag::Update => Update::from_byte_slice(&reader.payload()).map(Into::into), Tag::Sum2 => Sum2::from_byte_slice(&reader.payload()).map(Into::into), } } .context("failed to parse message payload")?; Ok(Self { participant_pk, coordinator_pk, signature: Some(signature), payload, is_multipart, tag, }) } /// Serialize this message. If the `signature` attribute is /// `Some`, the signature will be directly inserted in the message /// header. Otherwise it will be computed. /// /// # Panic /// /// This method panics if the given buffer is too small for the /// message to fit. pub fn to_bytes + AsRef<[u8]> + ?Sized>( &self, buffer: &mut T, sk: &SecretSigningKey, ) { let mut writer = MessageBuffer::new(buffer.as_mut()).unwrap(); self.participant_pk .to_bytes(&mut writer.participant_pk_mut()); self.coordinator_pk .to_bytes(&mut writer.coordinator_pk_mut()); let flags = if self.is_multipart { Flags::MULTIPART } else { Flags::empty() }; writer.set_flags(flags); self.payload.to_bytes(&mut writer.payload_mut()); // Determine the tag from the payload type if // possible. Otherwise, use the self.tag field. let tag = match self.payload { Payload::Sum(_) => Tag::Sum, Payload::Update(_) => Tag::Update, Payload::Sum2(_) => Tag::Sum2, Payload::Chunk(_) => self.tag, }; writer.set_tag(tag.into()); writer.set_length(self.buffer_length() as u32); // insert the signature last. If the message contains one, use // it. Otherwise compute it. let signature = match self.signature { Some(signature) => signature, None => sk.sign_detached(writer.signed_data_mut()), }; signature.to_bytes(&mut writer.signature_mut()); } pub fn buffer_length(&self) -> usize { self.payload.buffer_length() + HEADER_LENGTH } } #[cfg(test)] mod tests { use std::convert::TryFrom; use super::*; use crate::{ message::{Message, Tag}, testutils::messages as helpers, }; fn sum_message() -> (Message, Vec) { helpers::message(helpers::sum::payload) } #[test] fn buffer_read() { let bytes = sum_message().1; let buffer = MessageBuffer::new(&bytes).unwrap(); assert_eq!(Tag::try_from(buffer.tag()).unwrap(), Tag::Sum); assert_eq!(buffer.signature(), helpers::signature().1.as_slice()); assert_eq!( buffer.participant_pk(), helpers::participant_pk().1.as_slice() ); assert_eq!( buffer.coordinator_pk(), helpers::coordinator_pk().1.as_slice() ); assert_eq!(buffer.length() as usize, bytes.len()); assert_eq!(buffer.payload(), helpers::sum::payload().1.as_slice()); } #[test] fn buffer_write() { let expected = sum_message().1; let mut bytes = vec![0; expected.len()]; let mut buffer = MessageBuffer::new_unchecked(&mut bytes); buffer .signature_mut() .copy_from_slice(helpers::signature().1.as_slice()); buffer .participant_pk_mut() .copy_from_slice(helpers::participant_pk().1.as_slice()); buffer .coordinator_pk_mut() .copy_from_slice(helpers::coordinator_pk().1.as_slice()); buffer.set_tag(Tag::Sum.into()); buffer.set_length(expected.len() as u32); buffer .payload_mut() .copy_from_slice(helpers::sum::payload().1.as_slice()); assert_eq!(bytes, expected); } } ================================================ FILE: rust/xaynet-core/src/message/mod.rs ================================================ //! The messages of the PET protocol. //! //! # The sum message //! The [`Sum`] message is an abstraction for the values which a sum participant communicates to //! XayNet during the sum phase of the PET protocol. It contains the following values: //! - The sum signature proves the eligibility of the participant for the sum task. //! - The ephemeral public key is used by update participants to encrypt mask seeds in the update //! phase for the process of mask aggregation in the sum2 phase. //! //! # The update message //! The [`Update`] message is an abstraction for the values which an update participant communicates //! to XayNet during the update phase of the PET protocol. It contains the following values: //! - The sum signature proves the ineligibility of the participant for the sum task. //! - The update signature proves the eligibility of the participant for the update task. //! - The masked model is the encrypted local update to the global model, which is trained on the //! local data of the update participant. //! - The local seed dictionary stores the encrypted mask seed, which generates the local mask for //! the local model, which is encrypted by the ephemeral public keys of the sum participants. //! //! # The sum2 message //! The [`Sum2`] message is an abstraction for the values which a sum participant communicates to //! XayNet during the sum2 phase of the PET protocol. It contains the following values: //! - The sum signature proves the eligibility of the participant for the sum task. //! - The global mask is used by XayNet to unmask the aggregated global model. #[allow(clippy::module_inception)] pub(crate) mod message; pub(crate) mod payload; pub(crate) mod traits; pub(crate) mod utils; pub use self::{ message::{ Flags, Message, MessageBuffer, Tag, HEADER_LENGTH as MESSAGE_HEADER_LENGTH, SUM_COUNT_MIN, UPDATE_COUNT_MIN, }, payload::{ chunk::{Chunk, ChunkBuffer}, sum::{Sum, SumBuffer}, sum2::{Sum2, Sum2Buffer}, update::{Update, UpdateBuffer}, Payload, }, traits::{FromBytes, LengthValueBuffer, ToBytes}, }; /// An error that signals a failure when trying to decrypt and parse a message. /// /// This is kept generic on purpose to not reveal to the sender what specifically failed during /// decryption or parsing. pub type DecodeError = anyhow::Error; ================================================ FILE: rust/xaynet-core/src/message/payload/chunk.rs ================================================ use std::convert::TryInto; use anyhow::{anyhow, Context}; use crate::message::{ traits::{FromBytes, ToBytes}, DecodeError, }; pub(crate) mod ranges { use crate::message::utils::range; use std::ops::Range; /// Byte range corresponding to the chunk ID in a chunk message pub const ID: Range = range(0, 2); /// Byte range corresponding to the message ID in a chunk message pub const MESSAGE_ID: Range = range(ID.end, 2); /// Byte range corresponding to the flags in a chunk message pub const FLAGS: usize = MESSAGE_ID.end; /// Byte range reserved for future use pub const RESERVED: Range = range(FLAGS + 1, 3); } /// Length in bytes of a chunk message header const HEADER_LENGTH: usize = ranges::RESERVED.end; /// A message chunk. #[derive(Eq, PartialEq, Debug, Clone)] pub struct Chunk { /// Chunk ID pub id: u16, /// ID of the message this chunk belongs to pub message_id: u16, /// `true` if this is the last chunk of the message, `false` otherwise pub last: bool, /// Data contained in this chunk. pub data: Vec, } bitflags::bitflags! { /// A bitmask that defines flags for a [`Chunk`]. pub struct Flags: u8 { /// Indicates whether this message is the last chunk of a /// multipart message const LAST_CHUNK = 1 << 0; } } /// ```no_rust /// 0 1 2 3 /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ /// | id | message_id | /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ /// | flags | reserved | /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ /// | | /// + data (variable length) + /// | | /// ``` /// /// - `id`: ID of the chunk /// - `message_id`: ID of the message this chunk belong to /// - `flags`: currently the only supported flag indicates whether /// this is the last chunk or not pub struct ChunkBuffer { inner: T, } impl> ChunkBuffer { /// Performs bound checks for the various message fields on `bytes` and returns a new /// [`ChunkBuffer`]. /// /// # Errors /// Fails if the `bytes` are smaller than a minimal-sized message buffer. pub fn new(bytes: T) -> Result { let buffer = Self { inner: bytes }; buffer .check_buffer_length() .context("not a valid ChunkBuffer")?; Ok(buffer) } /// Returns a [`ChunkBuffer`] without performing any bound checks. /// /// This means accessing the various fields may panic if the data /// is invalid. pub fn new_unchecked(bytes: T) -> Self { Self { inner: bytes } } /// Performs bound checks to ensure the fields can be accessed /// without panicking. pub fn check_buffer_length(&self) -> Result<(), DecodeError> { let len = self.inner.as_ref().len(); if len < HEADER_LENGTH { return Err(anyhow!( "invalid buffer length: {} < {}", len, HEADER_LENGTH )); } Ok(()) } /// Gets the flags field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn flags(&self) -> Flags { Flags::from_bits_truncate(self.inner.as_ref()[ranges::FLAGS]) } /// Gets the chunk ID field /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn id(&self) -> u16 { // Unwrapping is OK, as the slice is guaranteed to be 4 bytes // long u16::from_be_bytes(self.inner.as_ref()[ranges::ID].try_into().unwrap()) } /// Gets the message ID field /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn message_id(&self) -> u16 { // Unwrapping is OK, as the slice is guaranteed to be 4 bytes // long u16::from_be_bytes(self.inner.as_ref()[ranges::MESSAGE_ID].try_into().unwrap()) } } impl<'a, T: AsRef<[u8]> + ?Sized> ChunkBuffer<&'a T> { /// Gets the rest of the message. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn payload(&self) -> &'a [u8] { &self.inner.as_ref()[HEADER_LENGTH..] } } impl + AsRef<[u8]>> ChunkBuffer { /// Sets the flags field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn set_flags(&mut self, value: Flags) { self.inner.as_mut()[ranges::FLAGS] = value.bits(); } /// Sets the chunk ID field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn set_id(&mut self, value: u16) { let bytes = value.to_be_bytes(); self.inner.as_mut()[ranges::ID].copy_from_slice(&bytes[..]); } /// Sets the message ID field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn set_message_id(&mut self, value: u16) { let bytes = value.to_be_bytes(); self.inner.as_mut()[ranges::MESSAGE_ID].copy_from_slice(&bytes[..]); } /// Gets a mutable reference to the rest of the message. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn payload_mut(&mut self) -> &mut [u8] { &mut self.inner.as_mut()[HEADER_LENGTH..] } } impl FromBytes for Chunk { fn from_byte_slice>(buffer: &T) -> Result { let reader = ChunkBuffer::new(buffer.as_ref()).context("Invalid chunk buffer")?; Ok(Self { last: reader.flags().contains(Flags::LAST_CHUNK), id: reader.id(), message_id: reader.message_id(), data: reader.payload().to_vec(), }) } fn from_byte_stream + ExactSizeIterator>( iter: &mut I, ) -> Result { if iter.len() < HEADER_LENGTH { return Err(anyhow!("byte stream exhausted")); } let id = u16::from_byte_stream(iter).context("cannot parse id")?; let message_id = u16::from_byte_stream(iter).context("cannot parse message id")?; let flags = Flags::from_bits_truncate(iter.next().unwrap()); let data: Vec = iter.skip(3).collect(); Ok(Self { id, message_id, data, last: flags.contains(Flags::LAST_CHUNK), }) } } impl ToBytes for Chunk { fn buffer_length(&self) -> usize { HEADER_LENGTH + self.data.len() } fn to_bytes + AsRef<[u8]>>(&self, buffer: &mut T) { let mut writer = ChunkBuffer::new(buffer.as_mut()).unwrap(); let flags = if self.last { Flags::LAST_CHUNK } else { Flags::empty() }; writer.set_flags(flags); writer.set_id(self.id); writer.set_message_id(self.message_id); writer.payload_mut()[..self.data.len()].copy_from_slice(self.data.as_slice()); } } #[cfg(test)] mod tests { use super::*; fn flags() -> (u8, Flags) { let flags = Flags::LAST_CHUNK; (flags.bits(), flags) } fn id() -> (Vec, u16) { let value = 0xdddd_u16; (value.to_be_bytes().to_vec(), value) } fn message_id() -> (Vec, u16) { let value = 0xeeee_u16; (value.to_be_bytes().to_vec(), value) } fn data() -> Vec { vec![0xff; 10] } fn chunk() -> (Vec, Chunk) { let mut bytes = vec![]; bytes.extend(id().0); bytes.extend(message_id().0); bytes.push(flags().0); bytes.extend(vec![0x00, 0x00, 0x00]); bytes.extend(data()); let message = Chunk { id: id().1, message_id: message_id().1, last: flags().1.contains(Flags::LAST_CHUNK), data: data(), }; (bytes, message) } #[test] fn buffer_read() { let bytes = chunk().0; let buffer = ChunkBuffer::new(&bytes).unwrap(); assert_eq!(buffer.id(), id().1); assert_eq!(buffer.message_id(), message_id().1); assert_eq!(buffer.flags(), flags().1); assert_eq!(buffer.payload(), &data()[..]); } #[test] fn stream_parse() { let (bytes, expected) = chunk(); let actual = Chunk::from_byte_stream(&mut bytes.into_iter()).unwrap(); assert_eq!(actual, expected); } #[test] fn buffer_write() { let expected = chunk().0; let mut bytes = vec![0; expected.len()]; let mut buffer = ChunkBuffer::new_unchecked(&mut bytes); buffer.set_id(id().1); buffer.set_message_id(message_id().1); buffer.set_flags(flags().1); buffer.payload_mut().copy_from_slice(data().as_slice()); assert_eq!(bytes, expected); } } ================================================ FILE: rust/xaynet-core/src/message/payload/mod.rs ================================================ //! Message payloads. //! //! See the [message module] documentation since this is a private module anyways. //! //! [message module]: crate::message pub(crate) mod chunk; pub(crate) mod sum; pub(crate) mod sum2; pub(crate) mod update; use derive_more::From; use crate::message::{ payload::{chunk::Chunk, sum::Sum, sum2::Sum2, update::Update}, traits::ToBytes, }; /// The payload of a [`Message`]. /// /// [`Message`]: crate::message::Message #[derive(From, Eq, PartialEq, Debug, Clone)] pub enum Payload { /// The payload of a [`Sum`] message. Sum(Sum), /// The payload of an [`Update`] message. Update(Update), /// The payload of a [`Sum2`] message. Sum2(Sum2), /// The payload of a [`Chunk`] message. Chunk(Chunk), } impl Payload { pub fn is_sum(&self) -> bool { matches!(self, Self::Sum(_)) } pub fn is_update(&self) -> bool { matches!(self, Self::Update(_)) } pub fn is_sum2(&self) -> bool { matches!(self, Self::Sum2(_)) } pub fn is_chunk(&self) -> bool { matches!(self, Self::Chunk(_)) } } impl ToBytes for Payload { fn buffer_length(&self) -> usize { match self { Payload::Sum(m) => m.buffer_length(), Payload::Sum2(m) => m.buffer_length(), Payload::Update(m) => m.buffer_length(), Payload::Chunk(m) => m.buffer_length(), } } fn to_bytes + AsRef<[u8]>>(&self, buffer: &mut T) { match self { Payload::Sum(m) => m.to_bytes(buffer), Payload::Sum2(m) => m.to_bytes(buffer), Payload::Update(m) => m.to_bytes(buffer), Payload::Chunk(m) => m.to_bytes(buffer), } } } ================================================ FILE: rust/xaynet-core/src/message/payload/sum.rs ================================================ //! Sum message payloads. //! //! See the [message module] documentation since this is a private module anyways. //! //! [message module]: crate::message use std::ops::Range; use anyhow::{anyhow, Context}; use crate::{ crypto::ByteObject, message::{ traits::{FromBytes, ToBytes}, utils::range, DecodeError, }, ParticipantTaskSignature, SumParticipantEphemeralPublicKey, }; const SUM_SIGNATURE_RANGE: Range = range(0, ParticipantTaskSignature::LENGTH); const EPHM_PK_RANGE: Range = range( SUM_SIGNATURE_RANGE.end, SumParticipantEphemeralPublicKey::LENGTH, ); #[derive(Clone, Debug, Eq, PartialEq, Hash)] /// A wrapper around a buffer that contains a [`Sum`] message. /// /// It provides getters and setters to access the different fields of the message safely. /// /// # Examples /// ## Decoding a sum message /// /// ```rust /// # use xaynet_core::message::SumBuffer; /// let sum_signature = vec![0x11; 64]; /// let ephm_pk = vec![0x22; 32]; /// let bytes = [sum_signature.as_slice(), ephm_pk.as_slice()].concat(); /// let buffer = SumBuffer::new(&bytes).unwrap(); /// assert_eq!(buffer.sum_signature(), sum_signature.as_slice()); /// assert_eq!(buffer.ephm_pk(), ephm_pk.as_slice()); /// ``` /// /// ## Encoding a sum message /// /// ```rust /// # use xaynet_core::message::SumBuffer; /// let sum_signature = vec![0x11; 64]; /// let ephm_pk = vec![0x22; 32]; /// let mut storage = vec![0xff; 96]; /// let mut buffer = SumBuffer::new_unchecked(&mut storage); /// buffer /// .sum_signature_mut() /// .copy_from_slice(&sum_signature[..]); /// buffer.ephm_pk_mut().copy_from_slice(&ephm_pk[..]); /// assert_eq!(&storage[..64], sum_signature.as_slice()); /// assert_eq!(&storage[64..], ephm_pk.as_slice()); /// ``` pub struct SumBuffer { inner: T, } impl> SumBuffer { /// Performs bound checks for the various message fields on `bytes` and returns a new /// [`SumBuffer`]. /// /// # Errors /// Fails if the `bytes` are smaller than a minimal-sized sum message buffer. pub fn new(bytes: T) -> Result { let buffer = Self { inner: bytes }; buffer .check_buffer_length() .context("not a valid SumBuffer")?; Ok(buffer) } /// Returns a [`SumBuffer`] without performing any bound checks. /// /// This means accessing the various fields may panic if the data is invalid. pub fn new_unchecked(bytes: T) -> Self { Self { inner: bytes } } /// Performs bound checks to ensure the fields can be accessed without panicking. pub fn check_buffer_length(&self) -> Result<(), DecodeError> { let len = self.inner.as_ref().len(); if len < EPHM_PK_RANGE.end { Err(anyhow!( "invalid buffer length: {} < {}", len, EPHM_PK_RANGE.end )) } else { Ok(()) } } } impl> SumBuffer { /// Gets a mutable reference to the sum participant ephemeral public key field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn ephm_pk_mut(&mut self) -> &mut [u8] { &mut self.inner.as_mut()[EPHM_PK_RANGE] } /// Gets a mutable reference to the sum signature field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn sum_signature_mut(&mut self) -> &mut [u8] { &mut self.inner.as_mut()[SUM_SIGNATURE_RANGE] } } impl<'a, T: AsRef<[u8]> + ?Sized> SumBuffer<&'a T> { /// Gets a reference to the sum participant ephemeral public key field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn ephm_pk(&self) -> &'a [u8] { &self.inner.as_ref()[EPHM_PK_RANGE] } /// Gets a reference to the sum signature field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn sum_signature(&self) -> &'a [u8] { &self.inner.as_ref()[SUM_SIGNATURE_RANGE] } } #[derive(Debug, Eq, PartialEq, Clone)] /// A high level representation of a sum message. /// /// These messages are sent by sum participants during the sum phase. /// /// # Examples /// ## Decoding a message /// /// ```rust /// # use xaynet_core::{crypto::ByteObject, message::{FromBytes, Sum}, ParticipantTaskSignature, SumParticipantEphemeralPublicKey}; /// let signature = vec![0x11; 64]; /// let ephm_pk = vec![0x22; 32]; /// let bytes = [signature.as_slice(), ephm_pk.as_slice()].concat(); /// let parsed = Sum::from_byte_slice(&bytes).unwrap(); /// let expected = Sum{ /// sum_signature: ParticipantTaskSignature::from_slice(&signature[..]).unwrap(), /// ephm_pk: SumParticipantEphemeralPublicKey::from_slice(&ephm_pk[..]).unwrap(), /// }; /// assert_eq!(parsed, expected); /// ``` /// /// ## Encoding a message /// /// ```rust /// # use xaynet_core::{crypto::ByteObject, message::{ToBytes, Sum}, ParticipantTaskSignature, SumParticipantEphemeralPublicKey}; /// let sum_signature = ParticipantTaskSignature::from_slice(vec![0x11; 64].as_slice()).unwrap(); /// let ephm_pk = SumParticipantEphemeralPublicKey::from_slice(vec![0x22; 32].as_slice()).unwrap(); /// let msg = Sum { /// sum_signature, /// ephm_pk, /// }; /// // we need a 96 bytes long buffer to serialize that message /// assert_eq!(msg.buffer_length(), 96); /// // create a buffer with enough space and encode the message /// let mut buf = vec![0xff; 96]; /// msg.to_bytes(&mut buf); /// /// assert_eq!(buf, [vec![0x11; 64].as_slice(), vec![0x22; 32].as_slice()].concat()); /// ``` pub struct Sum { /// The signature of the round seed and the word "sum". /// /// This is used to determine whether a participant is selected for the sum task. pub sum_signature: ParticipantTaskSignature, /// An ephemeral public key generated by a sum participant for the current round. pub ephm_pk: SumParticipantEphemeralPublicKey, } impl ToBytes for Sum { fn buffer_length(&self) -> usize { EPHM_PK_RANGE.end } fn to_bytes + AsRef<[u8]>>(&self, buffer: &mut T) { let mut writer = SumBuffer::new(buffer.as_mut()).unwrap(); self.sum_signature.to_bytes(&mut writer.sum_signature_mut()); self.ephm_pk.to_bytes(&mut writer.ephm_pk_mut()); } } impl FromBytes for Sum { fn from_byte_slice>(buffer: &T) -> Result { let reader = SumBuffer::new(buffer.as_ref())?; let sum_signature = ParticipantTaskSignature::from_byte_slice(&reader.sum_signature()) .context("invalid sum signature")?; let ephm_pk = SumParticipantEphemeralPublicKey::from_byte_slice(&reader.ephm_pk()) .context("invalid ephemeral public key")?; Ok(Self { sum_signature, ephm_pk, }) } fn from_byte_stream + ExactSizeIterator>( iter: &mut I, ) -> Result { let sum_signature = ParticipantTaskSignature::from_byte_stream(iter).context("invalid sum signature")?; let ephm_pk = SumParticipantEphemeralPublicKey::from_byte_stream(iter) .context("invalid ephemeral public key")?; Ok(Self { sum_signature, ephm_pk, }) } } #[cfg(test)] pub(in crate::message) mod tests { use super::*; use crate::testutils::messages::sum as helpers; #[test] fn buffer_read() { let bytes = helpers::payload().1; let buffer = SumBuffer::new(&bytes).unwrap(); assert_eq!(buffer.sum_signature(), &helpers::sum_task_signature().1[..]); assert_eq!(buffer.ephm_pk(), &helpers::ephm_pk().1[..]); } #[test] fn buffer_read_invalid() { assert!(SumBuffer::new(&helpers::payload().1[1..]).is_err()); } #[test] fn buffer_write() { let mut buffer = vec![0xff; EPHM_PK_RANGE.end]; let mut writer = SumBuffer::new_unchecked(&mut buffer); writer .sum_signature_mut() .copy_from_slice(helpers::sum_task_signature().1.as_slice()); writer .ephm_pk_mut() .copy_from_slice(helpers::ephm_pk().1.as_slice()); } #[test] fn encode() { let (sum, bytes) = helpers::payload(); assert_eq!(sum.buffer_length(), bytes.len()); let mut buf = vec![0xff; sum.buffer_length()]; sum.to_bytes(&mut buf); assert_eq!(buf, bytes); } #[test] fn decode() { let (expected, bytes) = helpers::payload(); let parsed = Sum::from_byte_slice(&bytes).unwrap(); assert_eq!(parsed, expected); } #[test] fn stream_parse() { let (expected, bytes) = helpers::payload(); let parsed = Sum::from_byte_stream(&mut bytes.into_iter()).unwrap(); assert_eq!(parsed, expected); } } ================================================ FILE: rust/xaynet-core/src/message/payload/sum2.rs ================================================ //! Sum2 message payloads. //! //! See the [message module] documentation since this is a private module anyways. //! //! [message module]: crate::message use std::ops::Range; use anyhow::{anyhow, Context}; use crate::{ crypto::ByteObject, mask::object::{serialization::MaskObjectBuffer, MaskObject}, message::{ traits::{FromBytes, ToBytes}, utils::range, DecodeError, }, ParticipantTaskSignature, }; const SUM_SIGNATURE_RANGE: Range = range(0, ParticipantTaskSignature::LENGTH); #[derive(Clone, Debug, Eq, PartialEq, Hash)] /// A wrapper around a buffer that contains a [`Sum2`] message. /// /// It provides getters and setters to access the different fields of the message safely. pub struct Sum2Buffer { inner: T, } impl> Sum2Buffer { /// Performs bound checks for the various message fields on `bytes` and returns a new /// [`Sum2Buffer`]. /// /// # Errors /// Fails if the `bytes` are smaller than a minimal-sized sum2 message buffer. pub fn new(bytes: T) -> Result { let buffer = Self { inner: bytes }; buffer .check_buffer_length() .context("not a valid Sum2Buffer")?; Ok(buffer) } /// Returns a `Sum2Buffer` with the given `bytes` without performing bound checks. /// /// This means that accessing the message fields may panic. pub fn new_unchecked(bytes: T) -> Self { Self { inner: bytes } } /// Performs bound checks for the various message fields on this buffer. pub fn check_buffer_length(&self) -> Result<(), DecodeError> { let len = self.inner.as_ref().len(); if len < SUM_SIGNATURE_RANGE.end { return Err(anyhow!( "invalid buffer length: {} < {}", len, SUM_SIGNATURE_RANGE.end )); } // check the length of the mask field MaskObjectBuffer::new(&self.inner.as_ref()[self.model_mask_offset()..]) .context("invalid mask field")?; Ok(()) } /// Gets the offset of the model mask field. fn model_mask_offset(&self) -> usize { SUM_SIGNATURE_RANGE.end } } impl + AsMut<[u8]>> Sum2Buffer { /// Gets a mutable reference to the sum signature field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn sum_signature_mut(&mut self) -> &mut [u8] { &mut self.inner.as_mut()[SUM_SIGNATURE_RANGE] } /// Gets a mutable reference to the model mask field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn model_mask_mut(&mut self) -> &mut [u8] { let offset = self.model_mask_offset(); &mut self.inner.as_mut()[offset..] } } impl<'a, T: AsRef<[u8]> + ?Sized> Sum2Buffer<&'a T> { /// Gets a reference to the sum signature field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn sum_signature(&self) -> &'a [u8] { &self.inner.as_ref()[SUM_SIGNATURE_RANGE] } /// Gets a reference to the model mask field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn model_mask(&self) -> &'a [u8] { let offset = self.model_mask_offset(); &self.inner.as_ref()[offset..] } } #[derive(Eq, PartialEq, Clone, Debug)] /// A high level representation of a sum2 message. /// /// These messages are sent by sum participants during the sum2 phase. pub struct Sum2 { /// The signature of the round seed and the word "sum". /// /// This is used to determine whether a participant is selected for the sum task. pub sum_signature: ParticipantTaskSignature, /// A model mask computed by the participant. pub model_mask: MaskObject, } impl ToBytes for Sum2 { fn buffer_length(&self) -> usize { SUM_SIGNATURE_RANGE.end + self.model_mask.buffer_length() } fn to_bytes + AsRef<[u8]>>(&self, buffer: &mut T) { let mut writer = Sum2Buffer::new_unchecked(buffer.as_mut()); self.sum_signature.to_bytes(&mut writer.sum_signature_mut()); self.model_mask.to_bytes(&mut writer.model_mask_mut()); } } impl FromBytes for Sum2 { fn from_byte_slice>(buffer: &T) -> Result { let reader = Sum2Buffer::new(buffer.as_ref())?; Ok(Self { sum_signature: ParticipantTaskSignature::from_byte_slice(&reader.sum_signature()) .context("invalid sum signature")?, model_mask: MaskObject::from_byte_slice(&reader.model_mask()) .context("invalid mask")?, }) } fn from_byte_stream + ExactSizeIterator>( iter: &mut I, ) -> Result { Ok(Self { sum_signature: ParticipantTaskSignature::from_byte_stream(iter) .context("invalid sum signature")?, model_mask: MaskObject::from_byte_stream(iter).context("invalid mask object")?, }) } } #[cfg(test)] pub mod tests { use crate::testutils::messages::sum2 as helpers; use super::*; #[test] fn buffer_read() { let bytes = helpers::payload().1; let buffer = Sum2Buffer::new(&bytes).unwrap(); assert_eq!(buffer.sum_signature(), &helpers::sum_task_signature().1[..]); let expected_mask = helpers::mask_object().1; let expected_length = expected_mask.len(); let actual_mask = &buffer.model_mask()[..expected_length]; assert_eq!(actual_mask, expected_mask); } #[test] fn buffer_write() { // length = 64 (signature) + 42 (mask) = 106 let mut bytes = vec![0xff; 106]; { let mut buffer = Sum2Buffer::new_unchecked(&mut bytes); buffer .sum_signature_mut() .copy_from_slice(&helpers::sum_task_signature().1[..]); let mask = helpers::mask_object().1; buffer.model_mask_mut()[..mask.len()].copy_from_slice(&mask[..]); } assert_eq!(&bytes[..], &helpers::payload().1[..]); } #[test] fn encode() { let (sum2, bytes) = helpers::payload(); assert_eq!(sum2.buffer_length(), bytes.len()); let mut buf = vec![0xff; sum2.buffer_length()]; sum2.to_bytes(&mut buf); assert_eq!(buf, bytes); } #[test] fn decode() { let (sum2, bytes) = helpers::payload(); let parsed = Sum2::from_byte_slice(&bytes).unwrap(); assert_eq!(parsed, sum2); } #[test] fn stream_parse() { let (sum2, bytes) = helpers::payload(); let parsed = Sum2::from_byte_stream(&mut bytes.into_iter()).unwrap(); assert_eq!(parsed, sum2); } } ================================================ FILE: rust/xaynet-core/src/message/payload/update.rs ================================================ //! Update message payloads. //! //! See the [message module] documentation since this is a private module anyways. //! //! [message module]: crate::message use std::ops::Range; use anyhow::{anyhow, Context}; use crate::{ crypto::ByteObject, mask::object::{serialization::MaskObjectBuffer, MaskObject}, message::{ traits::{FromBytes, LengthValueBuffer, ToBytes}, utils::range, DecodeError, }, LocalSeedDict, ParticipantTaskSignature, }; const SUM_SIGNATURE_RANGE: Range = range(0, ParticipantTaskSignature::LENGTH); const UPDATE_SIGNATURE_RANGE: Range = range(SUM_SIGNATURE_RANGE.end, ParticipantTaskSignature::LENGTH); #[derive(Clone, Debug)] /// A wrapper around a buffer that contains an [`Update`] message. /// /// It provides getters and setters to access the different fields of the message safely. pub struct UpdateBuffer { inner: T, } impl> UpdateBuffer { /// Performs bound checks for the various message fields on `bytes` and returns a new /// [`UpdateBuffer`]. /// /// # Errors /// Fails if the `bytes` are smaller than a minimal-sized update message buffer. pub fn new(bytes: T) -> Result { let buffer = Self { inner: bytes }; buffer .check_buffer_length() .context("invalid UpdateBuffer")?; Ok(buffer) } /// Returns an [`UpdateBuffer`] without performing any bound checks. /// /// This means accessing the various fields may panic if the data is invalid. pub fn new_unchecked(bytes: T) -> Self { Self { inner: bytes } } /// Performs bound checks to ensure the fields can be accessed without panicking. pub fn check_buffer_length(&self) -> Result<(), DecodeError> { let len = self.inner.as_ref().len(); // First, check the fixed size portion of the // header. UPDATE_SIGNATURE_RANGE is the last field if len < UPDATE_SIGNATURE_RANGE.end { return Err(anyhow!( "invalid buffer length: {} < {}", len, UPDATE_SIGNATURE_RANGE.end )); } // Check length of the masked object field MaskObjectBuffer::new(&self.inner.as_ref()[self.masked_model_offset()..]) .context("invalid masked object field")?; // Check the length of the local seed dictionary field let _ = LengthValueBuffer::new(&self.inner.as_ref()[self.local_seed_dict_offset()..]) .context("invalid local seed dictionary length")?; Ok(()) } /// Gets the offset of the masked model field. fn masked_model_offset(&self) -> usize { UPDATE_SIGNATURE_RANGE.end } /// Gets the offset of the local seed dictionary field. /// /// # Panics /// Computing the offset may panic if the buffer has not been checked before. fn local_seed_dict_offset(&self) -> usize { let masked_model = MaskObjectBuffer::new_unchecked(&self.inner.as_ref()[self.masked_model_offset()..]); self.masked_model_offset() + masked_model.len() } } impl<'a, T: AsRef<[u8]> + ?Sized> UpdateBuffer<&'a T> { /// Gets the sum signature field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn sum_signature(&self) -> &'a [u8] { &self.inner.as_ref()[SUM_SIGNATURE_RANGE] } /// Gets the update signature field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn update_signature(&self) -> &'a [u8] { &self.inner.as_ref()[UPDATE_SIGNATURE_RANGE] } /// Gets a slice that starts at the beginning of the masked model field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn masked_model(&self) -> &'a [u8] { let offset = self.masked_model_offset(); &self.inner.as_ref()[offset..] } /// Gets a slice that starts at the beginning og the local seed dictionary field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn local_seed_dict(&self) -> &'a [u8] { let offset = self.local_seed_dict_offset(); &self.inner.as_ref()[offset..] } } impl + AsMut<[u8]>> UpdateBuffer { /// Gets a mutable reference to the sum signature field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn sum_signature_mut(&mut self) -> &mut [u8] { &mut self.inner.as_mut()[SUM_SIGNATURE_RANGE] } /// Gets a mutable reference to the update signature field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn update_signature_mut(&mut self) -> &mut [u8] { &mut self.inner.as_mut()[UPDATE_SIGNATURE_RANGE] } /// Gets a mutable slice that starts at the beginning of the masked model field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn masked_model_mut(&mut self) -> &mut [u8] { let offset = self.masked_model_offset(); &mut self.inner.as_mut()[offset..] } /// Gets a mutable slice that starts at the beginning of the local seed dictionary field. /// /// # Panics /// Accessing the field may panic if the buffer has not been checked before. pub fn local_seed_dict_mut(&mut self) -> &mut [u8] { let offset = self.local_seed_dict_offset(); &mut self.inner.as_mut()[offset..] } } #[derive(Debug, Eq, PartialEq, Clone)] /// A high level representation of an update message. /// /// These messages are sent by update participants during the update phase. pub struct Update { /// The signature of the round seed and the word "sum". /// /// This is used to determine whether a participant is selected for the sum task. pub sum_signature: ParticipantTaskSignature, /// Signature of the round seed and the word "update". /// /// This is used to determine whether a participant is selected for the update task. pub update_signature: ParticipantTaskSignature, /// A model trained by an update participant. /// /// The model is masked with randomness derived from the participant seed. pub masked_model: MaskObject, /// A dictionary that contains the seed used to mask `masked_model`. /// /// The seed is encrypted with the ephemeral public key of each sum participant. pub local_seed_dict: LocalSeedDict, } impl ToBytes for Update { fn buffer_length(&self) -> usize { UPDATE_SIGNATURE_RANGE.end + self.masked_model.buffer_length() + self.local_seed_dict.buffer_length() } fn to_bytes + AsRef<[u8]>>(&self, buffer: &mut T) { let mut writer = UpdateBuffer::new_unchecked(buffer.as_mut()); self.sum_signature.to_bytes(&mut writer.sum_signature_mut()); self.update_signature .to_bytes(&mut writer.update_signature_mut()); self.masked_model.to_bytes(&mut writer.masked_model_mut()); self.local_seed_dict .to_bytes(&mut writer.local_seed_dict_mut()); } } impl FromBytes for Update { fn from_byte_slice>(buffer: &T) -> Result { let reader = UpdateBuffer::new(buffer.as_ref())?; Ok(Self { sum_signature: ParticipantTaskSignature::from_byte_slice(&reader.sum_signature()) .context("invalid sum signature")?, update_signature: ParticipantTaskSignature::from_byte_slice(&reader.update_signature()) .context("invalid update signature")?, masked_model: MaskObject::from_byte_slice(&reader.masked_model()) .context("invalid masked model")?, local_seed_dict: LocalSeedDict::from_byte_slice(&reader.local_seed_dict()) .context("invalid local seed dictionary")?, }) } fn from_byte_stream + ExactSizeIterator>( iter: &mut I, ) -> Result { Ok(Self { sum_signature: ParticipantTaskSignature::from_byte_stream(iter) .context("invalid sum signature")?, update_signature: ParticipantTaskSignature::from_byte_stream(iter) .context("invalid update signature")?, masked_model: MaskObject::from_byte_stream(iter).context("invalid masked model")?, local_seed_dict: LocalSeedDict::from_byte_stream(iter) .context("invalid local seed dictionary")?, }) } } #[cfg(test)] pub mod tests { use super::*; use crate::testutils::messages::update as helpers; #[test] fn buffer_read() { let bytes = helpers::payload().1; let buffer = UpdateBuffer::new(&bytes).unwrap(); assert_eq!( buffer.sum_signature(), helpers::sum_task_signature().1.as_slice() ); assert_eq!( buffer.update_signature(), helpers::update_task_signature().1.as_slice() ); let expected = helpers::mask_object().1; assert_eq!(&buffer.masked_model()[..expected.len()], &expected[..]); assert_eq!(buffer.local_seed_dict(), &helpers::local_seed_dict().1[..]); } #[test] fn decode_invalid_seed_dict() { let mut invalid = helpers::local_seed_dict().1; // This truncates the last entry of the seed dictionary invalid[3] = 0xe3; let mut bytes = vec![]; bytes.extend(helpers::sum_task_signature().1); bytes.extend(helpers::update_task_signature().1); bytes.extend(helpers::mask_object().1); bytes.extend(invalid); let e = Update::from_byte_slice(&bytes).unwrap_err(); let cause = e.source().unwrap().to_string(); assert_eq!( cause, "invalid local seed dictionary: trailing bytes".to_string() ); } #[test] fn decode() { let (update, bytes) = helpers::payload(); let parsed = Update::from_byte_slice(&bytes).unwrap(); assert_eq!(parsed, update); } #[test] fn stream_parse() { let (update, bytes) = helpers::payload(); let parsed = Update::from_byte_stream(&mut bytes.into_iter()).unwrap(); assert_eq!(parsed, update); } #[test] fn encode() { let (update, bytes) = helpers::payload(); assert_eq!(update.buffer_length(), bytes.len()); let mut buf = vec![0xff; update.buffer_length()]; update.to_bytes(&mut buf); // The order in which the hashmap is serialized is not // guaranteed, but we chose our key/values such that they are // sorted. // // First compute the offset at which the local seed dict value // starts: two signature (64 bytes), the masked model (32 // bytes), the length field (4 bytes), the masked scalar (10 bytes) let offset = 64 * 2 + 32 + 4 + 10; // Sort the end of the buffer (&mut buf[offset..]).sort_unstable(); assert_eq!(buf, bytes); } } ================================================ FILE: rust/xaynet-core/src/message/traits.rs ================================================ //! Message traits. //! //! See the [message module] documentation since this is a private module anyways. //! //! [message module]: crate::message use std::{ convert::TryInto, io::{Cursor, Write}, iter::{ExactSizeIterator, Iterator}, ops::Range, }; use anyhow::{anyhow, Context}; use crate::{ crypto::ByteObject, mask::seed::EncryptedMaskSeed, message::{utils::ChunkableIterator, DecodeError}, LocalSeedDict, SumParticipantPublicKey, }; /// An interface for serializable message types. /// /// See also [`FromBytes`] for deserialization. pub trait ToBytes { /// The length of the buffer for encoding the type. fn buffer_length(&self) -> usize; /// Serialize the type in the given buffer. /// /// # Panics /// This method may panic if the given buffer is too small. Thus, [`buffer_length()`] must be /// called prior to calling this, and a large enough buffer must be provided. /// /// [`buffer_length()`]: ToBytes::buffer_length fn to_bytes + AsRef<[u8]>>(&self, buffer: &mut T); } /// An interface for deserializable message types. /// /// See also [`ToBytes`] for serialization. pub trait FromBytes: Sized { /// Deserialize the type from the given buffer. /// /// # Errors /// May fail if certain parts of the deserialized buffer don't pass message validity checks. fn from_byte_slice>(buffer: &T) -> Result; fn from_byte_stream + ExactSizeIterator>( iter: &mut I, ) -> Result; } impl FromBytes for T where T: ByteObject, { fn from_byte_slice>(buffer: &U) -> Result { Self::from_slice(buffer.as_ref()) .ok_or_else(|| anyhow!("failed to deserialize byte object")) } fn from_byte_stream + ExactSizeIterator>( iter: &mut I, ) -> Result { let buf: Vec = iter.take(Self::LENGTH).collect(); Self::from_byte_slice(&buf) } } impl ToBytes for T where T: ByteObject, { fn buffer_length(&self) -> usize { self.as_slice().len() } fn to_bytes + AsRef<[u8]>>(&self, buffer: &mut U) { buffer.as_mut().copy_from_slice(self.as_slice()) } } /// A helper for encoding and decoding Length-Value (LV) fields. /// /// Note that the 4 bytes [`length()`] field gives the length of the *total* Length-Value field, /// _i.e._ the length of the value, plus the 4 extra bytes of the length field itself. /// /// # Examples /// ## Decoding a LV field /// /// ```rust /// # use xaynet_core::message::LengthValueBuffer; /// let bytes = vec![ /// 0x00, 0x00, 0x00, 0x05, // Length = 5 /// 0xff, // Value = 0xff /// 0x11, 0x22, // Extra bytes /// ]; /// let buffer = LengthValueBuffer::new(&bytes).unwrap(); /// assert_eq!(buffer.length(), 5); /// assert_eq!(buffer.value_length(), 1); /// assert_eq!(buffer.value(), &[0xff][..]); /// ``` /// /// ## Encoding a LV field /// /// ```rust /// # use xaynet_core::message::LengthValueBuffer; /// let mut bytes = vec![0xff; 9]; /// let mut buffer = LengthValueBuffer::new_unchecked(&mut bytes); /// // It is important to set the length field before setting the value, otherwise, `value_mut()` will panic. /// buffer.set_length(8); /// buffer.value_mut().copy_from_slice(&[0, 1, 2, 3][..]); /// let expected = vec![ /// 0x00, 0x00, 0x00, 0x08, // Length = 8 /// 0x00, 0x01, 0x02, 0x03, // Value /// 0xff, // unchanged /// ]; /// /// assert_eq!(bytes, expected); /// ``` /// /// [`length()`]: LengthValueBuffer::length pub struct LengthValueBuffer { inner: T, } /// The size of the length field for encoding a Length-Value item. const LENGTH_FIELD: Range = 0..4; impl> LengthValueBuffer { /// Returns a new [`LengthValueBuffer`]. /// /// # Errors /// This method performs bound checks and returns an error if the given buffer is not a valid /// Length-Value item. /// /// # Examples /// /// ```rust /// # use xaynet_core::message::LengthValueBuffer; /// // truncated length: /// assert!(LengthValueBuffer::new(&vec![0x00, 0x00, 0x00]).is_err()); /// /// // truncated value: /// let bytes = vec![ /// 0x00, 0x00, 0x00, 0x08, // length: 8 /// 0x11, 0x22, 0x33, // value /// ]; /// assert!(LengthValueBuffer::new(&bytes).is_err()); /// /// // valid Length-Value item /// let bytes = vec![ /// 0x00, 0x00, 0x00, 0x08, // length: 8 /// 0x11, 0x22, 0x33, 0x44, // value /// 0xaa, 0xbb, // extra bytes are ignored /// ]; /// let buf = LengthValueBuffer::new(&bytes).unwrap(); /// assert_eq!(buf.length(), 8); /// assert_eq!(buf.value(), &[0x11, 0x22, 0x33, 0x44][..]); /// ``` pub fn new(bytes: T) -> Result { let buffer = Self { inner: bytes }; buffer .check_buffer_length() .context("not a valid LengthValueBuffer")?; Ok(buffer) } /// Create a new [`LengthValueBuffer`] without any bound checks. pub fn new_unchecked(bytes: T) -> Self { Self { inner: bytes } } /// Check that the buffer is a valid Length-Value item. pub fn check_buffer_length(&self) -> Result<(), DecodeError> { let len = self.inner.as_ref().len(); if len < LENGTH_FIELD.end { return Err(anyhow!( "invalid buffer length: {} < {}", len, LENGTH_FIELD.end )); } if (self.length() as usize) < LENGTH_FIELD.end { return Err(anyhow!( "invalid length value: {} (should be >= {})", len, LENGTH_FIELD.end )); } if len < self.length() as usize { return Err(anyhow!( "invalid buffer length: {} < {}", len, self.length(), )); } Ok(()) } /// Returns the length field. Note that the value of the length /// field includes the length of the field itself (4 bytes). /// /// # Panics /// This method may panic if buffer is not a valid Length-Value item. pub fn length(&self) -> u32 { // unwrap safe: the slice is exactly 4 bytes long u32::from_be_bytes(self.inner.as_ref()[LENGTH_FIELD].try_into().unwrap()) } /// Returns the length of the value. pub fn value_length(&self) -> usize { self.length() as usize - LENGTH_FIELD.end } /// Returns the range corresponding to the value. fn value_range(&self) -> Range { let offset = LENGTH_FIELD.end; let value_length = self.value_length(); offset..offset + value_length } } impl> LengthValueBuffer { /// Sets the length field to the given value. /// /// # Panics /// This method may panic if buffer is not a valid Length-Value item. pub fn set_length(&mut self, value: u32) { self.inner.as_mut()[LENGTH_FIELD].copy_from_slice(&value.to_be_bytes()); } } impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> LengthValueBuffer<&'a mut T> { /// Gets a mutable reference to the value field. /// /// # Panics /// This method may panic if buffer is not a valid Length-Value item. pub fn value_mut(&mut self) -> &mut [u8] { let range = self.value_range(); &mut self.inner.as_mut()[range] } /// Gets a mutable reference to the underlying buffer. /// /// # Panics /// This method may panic if buffer is not a valid Length-Value item. pub fn bytes_mut(&mut self) -> &mut [u8] { self.inner.as_mut() } } impl<'a, T: AsRef<[u8]> + ?Sized> LengthValueBuffer<&'a T> { /// Gets a reference to the value field. /// /// # Panics /// This method may panic if buffer is not a valid Length-Value item. pub fn value(&self) -> &'a [u8] { &self.inner.as_ref()[self.value_range()] } /// Gets a reference to the underlying buffer. /// /// # Panics /// This method may panic if buffer is not a valid Length-Value item. pub fn bytes(self) -> &'a [u8] { let range = self.value_range(); &self.inner.as_ref()[..range.end] } } const ENTRY_LENGTH: usize = SumParticipantPublicKey::LENGTH + EncryptedMaskSeed::LENGTH; impl ToBytes for LocalSeedDict { fn buffer_length(&self) -> usize { LENGTH_FIELD.end + self.len() * ENTRY_LENGTH } fn to_bytes + AsRef<[u8]>>(&self, buffer: &mut T) { let mut writer = Cursor::new(buffer.as_mut()); let length = self.buffer_length() as u32; let _ = writer.write(&length.to_be_bytes()).unwrap(); for (key, value) in self { let _ = writer.write(key.as_slice()).unwrap(); let _ = writer.write(value.as_ref()).unwrap(); } } } impl FromBytes for LocalSeedDict { fn from_byte_slice>(buffer: &T) -> Result { let reader = LengthValueBuffer::new(buffer.as_ref())?; let mut dict = LocalSeedDict::new(); let key_length = SumParticipantPublicKey::LENGTH; let mut entries = reader.value().chunks_exact(ENTRY_LENGTH); for chunk in &mut entries { // safe unwraps: lengths of slices are guaranteed // by constants. let key = SumParticipantPublicKey::from_slice(&chunk[..key_length]).unwrap(); let value = EncryptedMaskSeed::from_slice(&chunk[key_length..]).unwrap(); if dict.insert(key, value).is_some() { return Err(anyhow!("invalid local seed dictionary: duplicated key")); } } if !entries.remainder().is_empty() { return Err(anyhow!("invalid local seed dictionary: trailing bytes")); } Ok(dict) } fn from_byte_stream + ExactSizeIterator>( iter: &mut I, ) -> Result { let len = u32::from_byte_stream(iter).context("cannot parse length field")? as usize; if len < 4 { return Err(anyhow!("invalid length field")); } if iter.len() < len - 4 { return Err(anyhow!( "expected {} bytes, but only {} left", len - 4, iter.len() )); } let mut dict = LocalSeedDict::new(); let entries = iter.take(len - 4).chunks(ENTRY_LENGTH); for mut chunk in entries.into_iter() { let key = SumParticipantPublicKey::from_byte_stream(&mut chunk) .context("invalid entry: cannot parse public key")?; let value = EncryptedMaskSeed::from_byte_stream(&mut chunk) .context("invalid entry: cannot parse encrypted mask seed")?; // This should really not happen, but it's worth checking // because our chunkable iterator panics if the chunks are // not fully consumed. if chunk.len() > 0 { return Err(anyhow!( "unknown error while parsing seed dict entry: entry buffer not fully consumed" )); } if dict.insert(key, value).is_some() { return Err(anyhow!("duplicated key")); } } Ok(dict) } } impl FromBytes for u16 { fn from_byte_slice>(buffer: &T) -> Result { Ok(u16::from_be_bytes( buffer .as_ref() .try_into() .context("failed to parse u16: invalid length")?, )) } fn from_byte_stream + ExactSizeIterator>( iter: &mut I, ) -> Result { fn err() -> DecodeError { anyhow!("cannot read u16: byte stream exhausted") } let b1 = (iter.next().ok_or_else(err)? as u16) << 8; let b2 = iter.next().ok_or_else(err)? as u16; Ok(b1 | b2) } } impl FromBytes for u32 { fn from_byte_slice>(buffer: &T) -> Result { Ok(u32::from_be_bytes( buffer .as_ref() .try_into() .context("failed to parse u32: invalid length")?, )) } fn from_byte_stream + ExactSizeIterator>( iter: &mut I, ) -> Result { fn err() -> DecodeError { anyhow!("cannot read u32: byte stream exhausted") } let b1 = (iter.next().ok_or_else(err)? as u32) << 24; let b2 = (iter.next().ok_or_else(err)? as u32) << 16; let b3 = (iter.next().ok_or_else(err)? as u32) << 8; let b4 = iter.next().ok_or_else(err)? as u32; Ok(b1 | b2 | b3 | b4) } } #[cfg(test)] mod tests { use super::*; #[test] fn decode_length_value_buffer() { let bytes = vec![ 0x00, 0x00, 0x00, 0x05, // Length = 1 0xff, // Value = 0xff 0x11, 0x22, // Extra bytes ]; let buffer = LengthValueBuffer::new(&bytes).unwrap(); assert_eq!(buffer.length(), 5); assert_eq!(buffer.value_length(), 1); assert_eq!(buffer.value(), &[0xff][..]); } #[test] fn decode_empty_value() { let bytes = vec![0x00, 0x00, 0x00, 0x04]; let buffer = LengthValueBuffer::new(&bytes).unwrap(); assert_eq!(buffer.length(), 4); assert_eq!(buffer.value_length(), 0); } #[test] fn decode_length_value_buffer_buffer_exhausted() { let bytes = vec![ 0x00, 0x00, 0x00, 0x08, // Length = 6 0x11, 0x22, // Only 2 bytes ]; assert!(LengthValueBuffer::new(bytes).is_err()); } #[test] fn decode_length_value_buffer_invalid_length() { // Missing bytes let bytes = vec![0x00, 0x00, 0x00]; assert!(LengthValueBuffer::new(bytes).is_err()); // Length field invalid let bytes = vec![0x00, 0x00, 0x00, 0x03]; assert!(LengthValueBuffer::new(bytes).is_err()); } #[test] fn encode_length_value_buffer() { let mut bytes = vec![0xff; 7]; let mut buffer = LengthValueBuffer::new_unchecked(&mut bytes); buffer.set_length(6); buffer.value_mut().copy_from_slice(&[0x11, 0x22][..]); let expected = vec![ 0x00, 0x00, 0x00, 0x06, // Length = 6 0x11, 0x22, // Value 0xff, // unchanged ]; assert_eq!(bytes, expected); } #[test] fn encode_length_value_buffer_emty() { let mut bytes = vec![0xff; 5]; let mut buffer = LengthValueBuffer::new_unchecked(&mut bytes); buffer.set_length(4); buffer.value_mut().copy_from_slice(&[][..]); let expected = vec![ 0x00, 0x00, 0x00, 0x04, // Length = 0 0xff, // unchanged ]; assert_eq!(bytes, expected); } #[test] fn parse_u16() { let buf = vec![0x12, 0x34]; assert_eq!(u16::from_byte_slice(&buf.as_slice()).unwrap(), 0x1234); assert_eq!(u16::from_byte_stream(&mut buf.into_iter()).unwrap(), 0x1234); } } ================================================ FILE: rust/xaynet-core/src/message/utils/chunkable_iterator.rs ================================================ //! This module provides an extension to the [`Iterator`] trait that allows iterating by chunks. One //! important property of our chunks, is that they implement [`ExactSizeIterator`], which is //! required by the [`FromBytes`] trait. //! //! [`Iterator`]: std::iter::Iterator //! [`ExactSizeIterator`]: std::iter::ExactSizeIterator //! [`FromBytes`]: crate::message::FromBytes use std::{ cell::RefCell, cmp, fmt, iter::{ExactSizeIterator, Iterator}, ops::Range, }; pub trait ChunkableIterator: Iterator + Sized { /// Return an _iterable_ that can chunk the iterator. /// /// Yield subiterators (chunks) that each yield a fixed number of /// elements, determined by `size`. The last chunk will be shorter /// if there aren't enough elements. /// /// Note that the chunks *must* be fully consumed in the order /// they are yielded. Otherwise, they will panic. /// /// # Examples /// /// ```compile_fail /// # // private items can't be tested with doc tests /// let chunks = vec![0, 1, 2, 3, 4].into_iter().chunks(2); /// let mut chunks_iter = chunks.into_iter(); /// /// let mut chunk_1 = chunks_iter.next().unwrap(); /// assert_eq!(chunk_1.next().unwrap(), 0); /// assert_eq!(chunk_1.next().unwrap(), 1); /// assert!(chunk_1.next().is_none()); /// /// let mut chunk_2 = chunks_iter.next().unwrap(); /// assert_eq!(chunk_2.next().unwrap(), 2); /// assert_eq!(chunk_2.next().unwrap(), 3); /// assert!(chunk_2.next().is_none()); /// /// let mut chunk_3 = chunks_iter.next().unwrap(); /// assert_eq!(chunk_3.next().unwrap(), 4); /// assert!(chunk_3.next().is_none()); /// /// assert!(chunks_iter.next().is_none()); /// ``` /// /// Attempting to consume chunks out of order fails: /// /// ```compile_fail /// # // private items can't be tested with doc tests /// let chunks = vec![0, 1, 2, 3, 4].into_iter().chunks(2); /// let mut chunks_iter = chunks.into_iter(); /// /// let mut chunk_1 = chunks_iter.next().unwrap(); /// let mut chunk_2 = chunks_iter.next().unwrap(); /// /// chunk_2.next(); // panics because chunk_1 was not consumed /// ``` /// /// Similarly, not _fully_ consuming the chunks fails: /// /// ```compile_fail /// # // private items can't be tested with doc tests /// let chunks = vec![0, 1, 2, 3, 4].into_iter().chunks(2); /// let mut chunks_iter = chunks.into_iter(); /// /// let mut chunk_1 = chunks_iter.next().unwrap(); /// let _ = chunk_1.next().unwrap(); /// let mut chunk_2 = chunks_iter.next().unwrap(); /// /// chunk_2.next(); // panics because chunk_1 was not fully consumed /// ``` /// /// # Panics /// /// Panics if size is 0. fn chunks(self, size: usize) -> IntoChunks; } impl ChunkableIterator for I where I: Iterator, { fn chunks(self, size: usize) -> IntoChunks { IntoChunks::new(self, size) } } struct Inner where I: Iterator, { /// The iterator we're chunking iter: I, /// Size of each chunk. Note that the last chunk may be smaller chunk_size: usize, /// Number of chunks that have been yielded nb_chunks: usize, /// Next item from `iter`. By buffering it, we can know when `iter` /// is exhausted. next: Option<(usize, I::Item)>, } impl fmt::Debug for Inner where I: Iterator + fmt::Debug, I::Item: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Inner") .field("iter", &self.iter) .field("chunk_size", &self.chunk_size) .field("nb_chunks", &self.nb_chunks) .field("next", &self.next) .finish() } } impl Inner where I: ExactSizeIterator, { /// Number of items left in `self.iter` fn remaining(&self) -> usize { self.next.as_ref().map(|_| 1).unwrap_or(0) + self.iter.len() } } impl Inner where I: Iterator, { /// Return a new `Inner` with the given iterator and chunk size fn new(mut iter: I, chunk_size: usize) -> Self { if chunk_size == 0 { panic!("invalid chunk size (must be > 0)") } let next = iter.next().map(|elt| (0, elt)); Self { iter, chunk_size, nb_chunks: 0, next, } } /// Get the `index`-th item from the underlying iterator. See /// [`IntoChunks::get`]. fn get(&mut self, index: usize) -> Option { self.next.as_ref()?; let current_index = self.next.as_ref().unwrap().0; if index < current_index { return None; } if index == current_index { let res = Some(self.next.take().unwrap().1); // Buffer the next element self.next = self.iter.next().map(|elt| (index + 1, elt)); res } else { panic!("previous chunks must be consumed"); } } } /// A type that can be turned into an `Iterator>`. pub struct IntoChunks where I: Iterator, { /// `inner` is just a mutable `Inner`. inner: RefCell>, } impl fmt::Debug for IntoChunks where I: Iterator + fmt::Debug, I::Item: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("IntoChunks") .field("inner", &self.inner) .finish() } } impl IntoChunks where I: Iterator, { /// Return a new `Chunk` pub fn new(iter: I, chunk_size: usize) -> Self { Self { inner: RefCell::new(Inner::new(iter, chunk_size)), } } /// Get the range of the next chunk fn next_chunk_range(&self) -> Range { let start = self.inner.borrow().nb_chunks * self.inner.borrow().chunk_size; let end = start + self.inner.borrow().chunk_size; start..end } /// Return `true` if the iterator we're chunking is exhausted fn exhausted(&self) -> bool { self.inner.borrow().next.is_none() } /// Get the `index`-th item from the underlying iterator. If the /// iterator already advanced beyond `index`, `None` is /// returned. If the requested `index` hasn't been reached yet, /// this method panics. This is to enforce the invariant that all /// chunks must be consumed in order. /// /// # Examples /// /// ```compile_fail /// # // private items can't be tested with doc tests /// let iter = vec![0, 1, 2, 3, 4, 5].into_iter(); /// let chunk_size = 2; /// let chunks = IntoChunks::new(iter, chunk_size); /// assert_eq!(chunks.get(0), Some(0)); /// assert_eq!(chunks.get(1), Some(1)); /// // calling `get` for an index that have been consumed already /// assert_eq!(chunks.get(1), None); /// // this panics, because the expected index is `2` /// chunks.get(3); /// ``` pub fn get(&self, index: usize) -> Option { self.inner.borrow_mut().get(index) } } impl IntoChunks where I: ExactSizeIterator, { /// Number of items left in the iterator we're chunking fn remaining(&self) -> usize { self.inner.borrow().remaining() } } impl<'a, I> IntoIterator for &'a IntoChunks where I: Iterator, { type Item = Chunk<'a, I>; type IntoIter = Chunks<'a, I>; fn into_iter(self) -> Self::IntoIter { Chunks { parent: self } } } /// An iterator that yields chunks pub struct Chunks<'a, I> where I: Iterator, { parent: &'a IntoChunks, } impl<'a, I> fmt::Debug for Chunks<'a, I> where I: Iterator + fmt::Debug, I::Item: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Chunks") .field("parent", &self.parent) .finish() } } impl<'a, I> Iterator for Chunks<'a, I> where I: Iterator, { type Item = Chunk<'a, I>; fn next(&mut self) -> Option> { if self.parent.exhausted() { return None; } let chunk = Chunk { range: self.parent.next_chunk_range(), chunks: self.parent, }; self.parent.inner.borrow_mut().nb_chunks += 1; Some(chunk) } } /// A chunk pub struct Chunk<'a, I> where I: Iterator, { range: Range, chunks: &'a IntoChunks, } impl<'a, I> fmt::Debug for Chunk<'a, I> where I: Iterator + fmt::Debug, I::Item: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Chunk") .field("range", &self.range) .field("chunks", &self.chunks) .finish() } } impl<'a, I> Iterator for Chunk<'a, I> where I: Iterator, { type Item = I::Item; fn next(&mut self) -> Option { if self.range.start >= self.range.end { return None; } match self.chunks.get(self.range.start) { Some(elt) => { self.range.start += 1; Some(elt) } None => { self.range.start = self.range.end; None } } } } impl<'a, I> ExactSizeIterator for Chunk<'a, I> where I: Iterator + ExactSizeIterator, { fn len(&self) -> usize { cmp::min(self.chunks.remaining(), self.range.end - self.range.start) } } #[cfg(test)] mod tests { use super::*; #[test] fn full_chunks_1() { let iter = vec![0, 1, 2].into_iter(); let chunks = IntoChunks::new(iter, 1); let mut chunks_iter = chunks.into_iter(); let mut c = chunks_iter.next().unwrap(); assert_eq!(c.len(), 1); assert_eq!(c.next().unwrap(), 0); assert_eq!(c.len(), 0); assert!(c.next().is_none()); let mut c = chunks_iter.next().unwrap(); assert_eq!(c.len(), 1); assert_eq!(c.next().unwrap(), 1); assert_eq!(c.len(), 0); assert!(c.next().is_none()); let mut c = chunks_iter.next().unwrap(); assert_eq!(c.len(), 1); assert_eq!(c.next().unwrap(), 2); assert_eq!(c.len(), 0); assert!(c.next().is_none()); assert!(chunks_iter.next().is_none()); } #[test] fn full_chunks_2() { let iter = vec![0, 1, 2, 3, 4, 5].into_iter(); let chunks = IntoChunks::new(iter, 2); let mut chunks_iter = chunks.into_iter(); let mut c = chunks_iter.next().unwrap(); assert_eq!(c.len(), 2); assert_eq!(c.next().unwrap(), 0); assert_eq!(c.len(), 1); assert_eq!(c.next().unwrap(), 1); assert_eq!(c.len(), 0); assert!(c.next().is_none()); let mut c = chunks_iter.next().unwrap(); assert_eq!(c.len(), 2); assert_eq!(c.next().unwrap(), 2); assert_eq!(c.len(), 1); assert_eq!(c.next().unwrap(), 3); assert_eq!(c.len(), 0); assert!(c.next().is_none()); let mut c = chunks_iter.next().unwrap(); assert_eq!(c.len(), 2); assert_eq!(c.next().unwrap(), 4); assert_eq!(c.len(), 1); assert_eq!(c.next().unwrap(), 5); assert_eq!(c.len(), 0); assert!(c.next().is_none()); assert!(chunks_iter.next().is_none()); } #[test] fn partial_chunk() { let iter = vec![0, 1, 2].into_iter(); let chunks = IntoChunks::new(iter, 2); let mut chunks_iter = chunks.into_iter(); let mut c = chunks_iter.next().unwrap(); assert_eq!(c.len(), 2); assert_eq!(c.next().unwrap(), 0); assert_eq!(c.len(), 1); assert_eq!(c.next().unwrap(), 1); assert_eq!(c.len(), 0); assert!(c.next().is_none()); let mut c = chunks_iter.next().unwrap(); assert_eq!(c.len(), 1); assert_eq!(c.next().unwrap(), 2); assert_eq!(c.len(), 0); assert!(c.next().is_none()); } #[test] #[should_panic(expected = "previous chunks must be consumed")] fn chunks_consumed_out_of_order() { let iter = vec![0, 1, 2, 3, 4, 5].into_iter(); let chunks = IntoChunks::new(iter, 2); let mut chunks_iter = chunks.into_iter(); let mut c1 = chunks_iter.next().unwrap(); assert_eq!(c1.next().unwrap(), 0); assert_eq!(c1.next().unwrap(), 1); assert!(c1.next().is_none()); let _c2 = chunks_iter.next().unwrap(); let mut c3 = chunks_iter.next().unwrap(); assert_eq!(c3.next().unwrap(), 4); } // This test case illustrates a weird behavior of our iterator: // everything being lazy, we can create chunks that start *beyond* // what our main iterator can provide in theory. Attempting to // consume such iterators should panic #[test] #[should_panic(expected = "previous chunks must be consumed")] fn weird() { let iter = vec![0, 1, 2].into_iter(); let chunks = IntoChunks::new(iter, 1); let mut chunks_iter = chunks.into_iter(); let mut c1 = chunks_iter.next().unwrap(); let mut c2 = chunks_iter.next().unwrap(); let mut c3 = chunks_iter.next().unwrap(); // This chunks starts at index 3, which we don't even have let mut c4 = chunks_iter.next().unwrap(); assert!(c4.next().is_none()); assert!(c1.next().is_none()); assert!(c2.next().is_none()); assert!(c3.next().is_none()); } } ================================================ FILE: rust/xaynet-core/src/message/utils/mod.rs ================================================ //! Message utilities. //! //! See the [message module] documentation since this is a private module anyways. //! //! [message module]: crate::message mod chunkable_iterator; pub use chunkable_iterator::{Chunk, ChunkableIterator, Chunks, IntoChunks}; use std::ops::Range; /// Creates a range from `start` to `start + length`. pub(crate) const fn range(start: usize, length: usize) -> Range { start..(start + length) } ================================================ FILE: rust/xaynet-core/src/testutils/messages.rs ================================================ //! This module provides helpers for generating messages or messages //! parts such as signatures, cryptographic keys, or mask objects. use std::convert::TryFrom; use num::BigUint; use crate::{ crypto::{ByteObject, PublicEncryptKey, PublicSigningKey, Signature}, mask::EncryptedMaskSeed, message::{Message, Payload, Sum, Sum2, Tag, Update}, LocalSeedDict, }; // A message adds 136 bytes of overhead: // // - a signature (64 bytes) // - the participant pk (32 bytes) // - the coordinator pk (32 bytes) // - a length field (4 bytes) // - a tag (1 byte) // - flags (1 byte) // - a reserved field (2 bytes) pub const HEADER_LENGTH: usize = 136; pub fn signature() -> (Signature, Vec) { let bytes = vec![0x1a; 64]; let signature = Signature::from_slice(bytes.as_slice()).unwrap(); (signature, bytes) } pub fn participant_pk() -> (PublicSigningKey, Vec) { let bytes = vec![0xbb; 32]; let pk = PublicSigningKey::from_slice(&bytes).unwrap(); (pk, bytes) } pub fn coordinator_pk() -> (PublicEncryptKey, Vec) { let bytes = vec![0xcc; 32]; let pk = PublicEncryptKey::from_slice(&bytes).unwrap(); (pk, bytes) } pub fn message(f: F) -> (Message, Vec) where F: Fn() -> (P, Vec), P: Into, { let (payload, payload_bytes) = f(); let payload: Payload = payload.into(); let tag = match payload { Payload::Sum(_) => Tag::Sum, Payload::Update(_) => Tag::Update, Payload::Sum2(_) => Tag::Sum2, _ => panic!("chunks not supported"), }; let message = Message { signature: Some(signature().0), participant_pk: participant_pk().0, coordinator_pk: coordinator_pk().0, payload, is_multipart: false, tag, }; let mut buf = signature().1; buf.extend(participant_pk().1); buf.extend(coordinator_pk().1); let length = payload_bytes.len() + HEADER_LENGTH; buf.extend(&(length as u32).to_be_bytes()); buf.push(tag.into()); buf.extend(vec![0, 0, 0]); buf.extend(payload_bytes); (message, buf) } pub mod sum { //! This module provides helpers for generating sum payloads use super::*; /// Return a fake sum task signature and its serialized version pub fn sum_task_signature() -> (Signature, Vec) { let bytes = vec![0x11; 64]; let signature = Signature::from_slice(&bytes[..]).unwrap(); (signature, bytes) } /// Return a fake ephemeral public key and its serialized version pub fn ephm_pk() -> (PublicEncryptKey, Vec) { let bytes = vec![0x22; 32]; let pk = PublicEncryptKey::from_slice(&bytes[..]).unwrap(); (pk, bytes) } /// Return an sum payload with its serialized version pub fn payload() -> (Sum, Vec) { let mut bytes = sum_task_signature().1; bytes.extend(ephm_pk().1); let sum = Sum { sum_signature: sum_task_signature().0, ephm_pk: ephm_pk().0, }; (sum, bytes) } } pub mod update { //! This module provides helpers for generating update payloads pub use mask::{mask_object, mask_unit, mask_vect}; pub use sum::sum_task_signature; use super::*; /// Return a fake update task signature and its serialized version pub fn update_task_signature() -> (Signature, Vec) { let bytes = vec![0x14; 64]; let signature = Signature::from_slice(&bytes[..]).unwrap(); (signature, bytes) } /// Return a local seed dictionary with two entries with its /// expected serialized version pub fn local_seed_dict() -> (LocalSeedDict, Vec) { let mut local_seed_dict = LocalSeedDict::new(); let mut bytes = vec![]; // Length (32+80) * 2 + 4 = 228 bytes.extend(vec![0x00, 0x00, 0x00, 0xe4]); bytes.extend(vec![0x55; PublicSigningKey::LENGTH]); bytes.extend(vec![0x66; EncryptedMaskSeed::LENGTH]); local_seed_dict.insert( PublicSigningKey::from_slice(vec![0x55; 32].as_slice()).unwrap(), EncryptedMaskSeed::try_from(vec![0x66; EncryptedMaskSeed::LENGTH]).unwrap(), ); // Second entry bytes.extend(vec![0x77; PublicSigningKey::LENGTH]); bytes.extend(vec![0x88; EncryptedMaskSeed::LENGTH]); local_seed_dict.insert( PublicSigningKey::from_slice(vec![0x77; 32].as_slice()).unwrap(), EncryptedMaskSeed::try_from(vec![0x88; EncryptedMaskSeed::LENGTH]).unwrap(), ); (local_seed_dict, bytes) } /// Return an update payload with its serialized version pub fn payload() -> (Update, Vec) { let mut bytes = sum_task_signature().1; bytes.extend(update_task_signature().1); bytes.extend(mask_object().1); bytes.extend(local_seed_dict().1); let update = Update { sum_signature: sum_task_signature().0, update_signature: update_task_signature().0, masked_model: mask_object().0, local_seed_dict: local_seed_dict().0, }; (update, bytes) } } pub mod sum2 { //! This module provides helpers for generating update payloads pub use mask::{mask_object, mask_unit, mask_vect}; pub use sum::sum_task_signature; use super::*; /// Return a sum2 message and its serialized version pub fn payload() -> (Sum2, Vec) { let (sum_signature, sum_signature_bytes) = sum_task_signature(); let (model_mask, model_mask_bytes) = mask_object(); let bytes = [sum_signature_bytes.as_slice(), model_mask_bytes.as_slice()].concat(); let sum2 = Sum2 { sum_signature, model_mask, }; (sum2, bytes) } } pub mod mask { //! This module provides helpers for generating mask objects use crate::mask::{ BoundType, DataType, GroupType, MaskConfig, MaskObject, MaskUnit, MaskVect, ModelType, }; use super::*; /// Return a mask config and its serialized version pub fn mask_config() -> (MaskConfig, Vec) { // config.order() = 20_000_000_000_001 with this config, so the data // should be stored on 6 bytes. let config = MaskConfig { group_type: GroupType::Integer, data_type: DataType::I32, bound_type: BoundType::B0, model_type: ModelType::M3, }; let bytes = vec![0x00, 0x02, 0x00, 0x03]; (config, bytes) } /// Return a masked vector and its serialized version pub fn mask_vect() -> (MaskVect, Vec) { let (config, mut bytes) = mask_config(); let data = vec![ BigUint::from(1_u8), BigUint::from(2_u8), BigUint::from(3_u8), BigUint::from(4_u8), ]; let mask_vect = MaskVect::new(config, data).unwrap(); bytes.extend(vec![ // number of elements 0x00, 0x00, 0x00, 0x04, // data (1 weight => 6 bytes with this config) 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, // 1 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, // 2 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, // 3 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, // 4 ]); (mask_vect, bytes) } /// Return a masked scalar and its serialized version pub fn mask_unit() -> (MaskUnit, Vec) { let (config, mut bytes) = mask_config(); let data = BigUint::from(1_u8); let mask_unit = MaskUnit::new(config, data).unwrap(); bytes.extend(vec![ 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, // data: 1 ]); (mask_unit, bytes) } /// Return a mask object, containing a masked vector and a masked /// scalar, and its serialized version pub fn mask_object() -> (MaskObject, Vec) { let (mask_vect, mask_vect_bytes) = mask_vect(); let (mask_unit, mask_unit_bytes) = mask_unit(); let obj = MaskObject::new_unchecked(mask_vect, mask_unit); let bytes = [mask_vect_bytes.as_slice(), mask_unit_bytes.as_slice()].concat(); (obj, bytes) } } #[cfg(test)] mod tests { use super::*; // This tests is just so that if something changes, we catch it // and can update the helpers accordingly #[test] fn check_object_lengths() { assert_eq!(Signature::LENGTH, 64); assert_eq!(PublicEncryptKey::LENGTH, 32); assert_eq!(PublicSigningKey::LENGTH, 32); assert_eq!(EncryptedMaskSeed::LENGTH, 80); } } ================================================ FILE: rust/xaynet-core/src/testutils/mod.rs ================================================ pub mod messages; pub mod multipart; ================================================ FILE: rust/xaynet-core/src/testutils/multipart.rs ================================================ use num::BigUint; use crate::{ crypto::{ByteObject, PublicSigningKey, Signature}, mask::{ BoundType, DataType, EncryptedMaskSeed, GroupType, MaskConfig, MaskObject, MaskUnit, MaskVect, ModelType, }, message::{Message, ToBytes, Update}, testutils::messages, LocalSeedDict, }; /// Return a seed dict that has the given length `len` once /// serialized. `len - 4` must be multiple of 112. pub fn local_seed_dict(len: usize) -> LocalSeedDict { // a public key is 32 bytes and an encrypted mask seed 80. let entry_len = 32 + 80; if ((len - 4) % entry_len) != 0 { panic!("invalid length for seed dict"); } let nb_entries = (len - 4) / entry_len; let mut dict = LocalSeedDict::new(); for i in 0..nb_entries { let bytes = (i as u64).to_be_bytes(); let pk_bytes = bytes.iter().cycle().take(32).copied().collect::>(); let seed_bytes = bytes.iter().cycle().take(80).copied().collect::>(); let pk = PublicSigningKey::from_slice(pk_bytes.as_slice()).unwrap(); let mask_seed = EncryptedMaskSeed::from_slice(seed_bytes.as_slice()).unwrap(); dict.insert(pk, mask_seed); } // Check that our calculations are correct assert_eq!(dict.buffer_length(), len); dict } pub fn mask_object(len: usize) -> MaskObject { // The model contains 2 sub mask objects: // - the masked model, which has: // - 4 bytes for the config // - 4 bytes for the number of weights // - 6 bytes (with our config) for each weight // - the masked scalar: // - 4 bytes for the config // - 6 bytes (with our config) for the scalar // // The only parameter we control to make the length vary is // the number of weights. The lengths is then: // // len = (4 + 4 + n_weights * 6) + (4 + 6) = 18 + 6 * n_weights // // So we must have: (len - 18) % 6 = 0 if (len - 18) % 6 != 0 { panic!("invalid masked model length") } let n_weights = (len - 18) / 6; // Let's not be too crazy, it makes no sense to test with too // many weights assert!(n_weights < u32::MAX as usize); let mut weights = vec![]; for i in 0..n_weights { weights.push(BigUint::from(i)); } let masked_model = MaskVect::new(mask_config(), weights).unwrap(); let masked_scalar = MaskUnit::new(mask_config(), BigUint::from(0_u32)).unwrap(); let obj = MaskObject::new_unchecked(masked_model, masked_scalar); // Check that our calculations are correct assert_eq!(obj.buffer_length(), len); obj } pub fn mask_config() -> MaskConfig { // config.order() = 20_000_000_000_001 with this config, so the data // should be stored on 6 bytes. MaskConfig { group_type: GroupType::Integer, data_type: DataType::I32, bound_type: BoundType::B0, model_type: ModelType::M3, } } pub fn task_signatures() -> (Signature, Signature) { ( messages::sum::sum_task_signature().0, messages::update::update_task_signature().0, ) } /// Create an update payload with a seed dictionary of length /// `dict_len` and a mask object of length `mask_len`. For a payload /// of size `S`, the following must hold true: /// /// ```no_rust /// (mask_len - 22) % 6 = 0 /// (dict_len - 4) % 112 = 0 /// S = dict_len + mask_len + 64*2 /// ``` pub fn update(dict_len: usize, mask_obj_len: usize) -> Update { // An update message is made of: // - 2 signatures of 64 bytes each // - a mask object of variable length // - a seed dictionary of variable length // // The `Message` overhead is 136 bytes (see // crate::messages::HEADER_LEN). So a message with // `dict_len` = 100 and `mask_obj_len` = 100 will be: // // 100 + 100 + 64*2 + 136 = 464 bytes let (sum_signature, update_signature) = task_signatures(); let payload = Update { sum_signature, update_signature, masked_model: mask_object(mask_obj_len), local_seed_dict: local_seed_dict(dict_len), }; assert_eq!(payload.buffer_length(), mask_obj_len + dict_len + 64 * 2); payload } /// Create an update message with a seed dictionary of length /// `dict_len` and a mask object of length `mask_len`. For a message /// of size `S`, the following must hold true: /// /// ```no_rust /// (mask_len - 22) % 6 = 0 /// (dict_len - 4) % 112 = 0 /// S = dict_len + mask_len + 64*2 + 136 /// ``` pub fn message(dict_len: usize, mask_obj_len: usize) -> Message { let (message, _) = messages::message(|| { let payload = update(dict_len, mask_obj_len); let dummy_buf = vec![]; (payload, dummy_buf) }); message } ================================================ FILE: rust/xaynet-mobile/.cargo/config.toml ================================================ # These reduces the size of the libraries a lot! # See: https://github.com/johnthagen/min-sized-rust [profile.release] lto = true codegen-units = 1 opt-level = 'z' ================================================ FILE: rust/xaynet-mobile/.gitignore ================================================ ffi_test.o.dSYM ffi_test.o test_participant_save_and_restore.txt ================================================ FILE: rust/xaynet-mobile/Cargo.toml ================================================ [package] name = "xaynet-mobile" version = "0.1.0" authors = ["Xayn Engineering "] edition = "2018" description = "The Xayn Network project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant." readme = "README.md" homepage = "https://xaynet.dev/" repository = "https://github.com/xaynetwork/xaynet/" license-file = "../../LICENSE" keywords = ["federated-learning", "fl", "ai", "machine-learning"] categories = ["science", "cryptography"] [dependencies] async-trait = "0.1.57" bincode = "1.3.3" ffi-support = "0.4.4" futures = "0.3.24" reqwest = { version = "0.11.10", default-features = false, features = ["rustls-tls"]} sodiumoxide = "0.2.7" thiserror = "1.0.32" tracing = "0.1.36" tokio = { version = "1.20.1", default-features = false, features = ["rt"] } xaynet-core = { path = "../xaynet-core", version = "0.2.0" } xaynet-sdk = { path = "../xaynet-sdk", default-features = false, version = "0.1.0", features = ["reqwest-client"]} zeroize = "1.5.7" [build-dependencies] cbindgen = "=0.17.0" [lib] name = "xaynet_mobile" crate-type = ["staticlib", "cdylib", "rlib"] [features] default = [] ================================================ FILE: rust/xaynet-mobile/README.md ================================================ # Xaynet FFI ## Generate C-Header File To generate the header files, run `cargo build`. ## Run tests ### macOS ``` cc -o tests/ffi_test.o -Wl,-dead_strip -I. tests/ffi_test.c ../target/debug/libxaynet_mobile.a -framework Security -framework Foundation ./tests/ffi_test.o ``` ### Linux ``` gcc \ tests/ffi_test.c -Wall \ -I. \ -pthread -Wl,--no-as-needed -lm -ldl \ ../target/debug/libxaynet_mobile.a \ -o tests/ffi_test.o ./tests/ffi_test.o ``` To check for memory leaks, you can use Valgrind: ``` valgrind --tool=memcheck --leak-check=full --show-leak-kinds=all -s ./tests/ffi_test.o ``` ================================================ FILE: rust/xaynet-mobile/build.rs ================================================ use std::{ env, fs::read_dir, path::{Path, PathBuf}, }; use cbindgen::{generate_with_config, Config}; // cargo doesn't check directories recursively so we have to do it by hand, also emitting a // rerun-if line cancels the default rerun for changes in the crate directory fn cargo_rerun_if_changed(entry: impl AsRef) { let entry = entry.as_ref(); if entry.is_dir() { for entry in read_dir(entry).expect("Failed to read dir.") { cargo_rerun_if_changed(entry.expect("Failed to read entry.").path()); } } else { println!("cargo:rerun-if-changed={}", entry.display()); } } fn main() { let crate_dir = PathBuf::from( env::var("CARGO_MANIFEST_DIR").expect("Failed to read CARGO_MANIFEST_DIR env."), ); let bind_config = crate_dir.join("cbindgen.toml"); let bind_file = crate_dir.join("xaynet_ffi.h"); cargo_rerun_if_changed(crate_dir.join("src")); cargo_rerun_if_changed(crate_dir.join("Cargo.toml")); cargo_rerun_if_changed(bind_config.as_path()); let config = Config::from_file(bind_config).expect("Failed to read config."); generate_with_config(crate_dir, config) .expect("Failed to generate bindings.") .write_to_file(bind_file); } ================================================ FILE: rust/xaynet-mobile/cbindgen.toml ================================================ language = "C" autogen_warning = "/* Warning, this file is autogenerated by cbindgen. Don't modify this manually. */" include_version = true [export] exclude = ["_xaynet_ffi_settings_destroy", "_xaynet_ffi_participant_destroy", "_xaynet_ffi_local_model_config_destroy"] [parse] parse_deps = true include = ["ffi-support"] [enum] rename_variants = "ScreamingSnakeCase" prefix_with_name = true ================================================ FILE: rust/xaynet-mobile/src/ffi/config.rs ================================================ use crate::ffi::{ERR_NULLPTR, OK}; use std::os::raw::c_int; use xaynet_core::mask::DataType; mod pv { use super::LocalModelConfig; ffi_support::define_box_destructor!(LocalModelConfig, _xaynet_ffi_local_model_config_destroy); } /// Destroy the model configuration created by [`xaynet_ffi_participant_local_model_config()`]. /// /// # Return value /// /// - [`OK`] on success /// - [`ERR_NULLPTR`] if `local_model_config` is NULL /// /// # Safety /// /// 1. When calling this method, you have to ensure that *either* the pointer is NULL /// *or* all of the following is true: /// - The pointer must be properly [aligned]. /// - It must be "dereferencable" in the sense defined in the [`std::ptr`] module /// documentation. /// 2. After destroying the `LocalModelConfig`, the pointer becomes invalid and must not be /// used. /// 3. This function should only be called on a pointer that has been created by /// [`xaynet_ffi_participant_local_model_config()`]. /// /// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety /// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment /// [`xaynet_ffi_participant_local_model_config()`]: crate::ffi::xaynet_ffi_participant_local_model_config #[no_mangle] pub unsafe extern "C" fn xaynet_ffi_local_model_config_destroy( local_model_config: *mut LocalModelConfig, ) -> c_int { if local_model_config.is_null() { return ERR_NULLPTR; } pv::_xaynet_ffi_local_model_config_destroy(local_model_config); OK } #[repr(C)] /// The model configuration of the model that is expected in [`xaynet_ffi_participant_set_model()`]. /// /// [`xaynet_ffi_participant_set_model()`]: crate::ffi::xaynet_ffi_participant_set_model pub struct LocalModelConfig { /// The expected data type of the model. pub data_type: ModelDataType, /// the expected length of the model. pub len: u64, } impl From for LocalModelConfig { fn from(lmc: xaynet_sdk::LocalModelConfig) -> Self { LocalModelConfig { data_type: lmc.data_type.into(), len: lmc.len as u64, } } } #[repr(u8)] /// The original primitive data type of the numerical values to be masked. pub enum ModelDataType { /// Numbers of type f32. F32 = 0, /// Numbers of type f64. F64 = 1, /// Numbers of type i32. I32 = 2, /// Numbers of type i64. I64 = 3, } impl From for ModelDataType { fn from(dt: DataType) -> Self { match dt { DataType::F32 => ModelDataType::F32, DataType::F64 => ModelDataType::F64, DataType::I32 => ModelDataType::I32, DataType::I64 => ModelDataType::I64, } } } ================================================ FILE: rust/xaynet-mobile/src/ffi/mod.rs ================================================ #![allow(unused_unsafe)] mod participant; pub use participant::*; mod settings; pub use settings::*; mod config; pub use config::*; pub use ffi_support::{ByteBuffer, FfiStr}; use std::os::raw::c_int; /// Destroy the given `ByteBuffer` and free its memory. This function must only be /// called on `ByteBuffer`s that have been created on the Rust side of the FFI. If you /// have created a `ByteBuffer` on the other side of the FFI, do not use this function, /// use `free()` instead. /// /// # Return value /// /// - [`OK`] on success /// - [`ERR_NULLPTR`] if `buf` is NULL /// /// # Safety /// /// 1. When calling this method, you have to ensure that *either* the pointer is NULL /// *or* all of the following is true: /// - The pointer must be properly [aligned]. /// - It must be "dereferencable" in the sense defined in the [`std::ptr`] module /// documentation. /// 2. After destroying the `ByteBuffer` the pointer becomes invalid and must not be /// used. /// 3. Calling this function on a `ByteBuffer` that has not been created on the Rust /// side of the FFI is UB. /// /// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety /// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment #[no_mangle] pub unsafe extern "C" fn xaynet_ffi_byte_buffer_destroy( // Note that we use a *const instead of a *mut here. The reason is // that the functions that create byte buffers return *const // pointers. Taking a *mut here would trigger a // -Wdiscarded-qualifiers warning from C. Forcing users to use // *const pointers brings some safety, and casting back to *mut // here is no big deal since the pointer becomes invalid afterward // anyway. buf: *const ByteBuffer, ) -> c_int { if buf.is_null() { return ERR_NULLPTR; } Box::from_raw(buf as *mut ByteBuffer).destroy(); OK } /// Initialize the crypto library. This method must be called before instantiating a /// participant with [`xaynet_ffi_participant_new()`] or before generating new keys with /// [`xaynet_ffi_generate_key_pair()`]. /// /// # Return value /// /// - [`OK`] if the initialization succeeded /// - -[`ERR_CRYPTO_INIT`] if the initialization failed /// /// # Safety /// /// This function is safe to call #[no_mangle] pub unsafe extern "C" fn xaynet_ffi_crypto_init() -> c_int { if sodiumoxide::init().is_err() { ERR_CRYPTO_INIT } else { OK } } /// Return value upon success pub const OK: c_int = 0; /// NULL pointer argument pub const ERR_NULLPTR: c_int = 1; /// Invalid coordinator URL pub const ERR_INVALID_URL: c_int = 2; /// Invalid settings: coordinator URL is not set pub const ERR_SETTINGS_URL: c_int = 3; /// Invalid settings: signing keys are not set pub const ERR_SETTINGS_KEYS: c_int = 4; /// Invalid settings: scalar is out of bounds pub const ERR_SETTINGS_SCALAR: c_int = 5; /// Failed to set the local model: invalid model pub const ERR_SETMODEL_MODEL: c_int = 6; /// Failed to set the local model: invalid data type pub const ERR_SETMODEL_DATATYPE: c_int = 7; /// Failed to initialized the crypto library pub const ERR_CRYPTO_INIT: c_int = 8; /// Invalid secret signing key pub const ERR_CRYPTO_SECRET_KEY: c_int = 9; /// Invalid public signing key pub const ERR_CRYPTO_PUBLIC_KEY: c_int = 10; /// No global model is currently available pub const GLOBALMODEL_NONE: c_int = 11; /// Failed to get the global model: communication with the coordinator failed pub const ERR_GLOBALMODEL_IO: c_int = 12; /// Failed to get the global model: invalid data type pub const ERR_GLOBALMODEL_DATATYPE: c_int = 13; /// Failed to get the global model: invalid buffer length pub const ERR_GLOBALMODEL_LEN: c_int = 14; /// Failed to get the global model: invalid model pub const ERR_GLOBALMODEL_CONVERT: c_int = 15; ================================================ FILE: rust/xaynet-mobile/src/ffi/participant.rs ================================================ use std::{ convert::TryFrom, os::raw::{c_int, c_uchar, c_uint, c_void}, ptr, slice, }; use ffi_support::{ByteBuffer, FfiStr}; use xaynet_core::mask::{DataType, FromPrimitives, IntoPrimitives, Model}; use super::{ LocalModelConfig, ERR_GLOBALMODEL_CONVERT, ERR_GLOBALMODEL_DATATYPE, ERR_GLOBALMODEL_IO, ERR_GLOBALMODEL_LEN, ERR_NULLPTR, ERR_SETMODEL_DATATYPE, ERR_SETMODEL_MODEL, GLOBALMODEL_NONE, OK, }; use crate::{into_primitives, Participant, Settings, Task}; mod pv { use super::Participant; ffi_support::define_box_destructor!(Participant, _xaynet_ffi_participant_destroy); } /// Destroy the participant created by [`xaynet_ffi_participant_new()`] or /// [`xaynet_ffi_participant_restore()`]. /// /// # Return value /// /// - [`OK`] on success /// - [`ERR_NULLPTR`] if `participant` is NULL /// /// # Safety /// /// 1. When calling this method, you have to ensure that *either* the pointer is NULL /// *or* all of the following is true: /// - The pointer must be properly [aligned]. /// - It must be "dereferencable" in the sense defined in the [`std::ptr`] module /// documentation. /// 2. After destroying the `Participant`, the pointer becomes invalid and must not be /// used. /// 3. This function should only be called on a pointer that has been created by /// [`xaynet_ffi_participant_new()`] or [`xaynet_ffi_participant_restore()`] /// /// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety /// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment #[no_mangle] pub unsafe extern "C" fn xaynet_ffi_participant_destroy(participant: *mut Participant) -> c_int { if participant.is_null() { return ERR_NULLPTR; } pv::_xaynet_ffi_participant_destroy(participant); OK } /// The participant is not taking part in the sum or update task pub const PARTICIPANT_TASK_NONE: c_int = 1; /// The participant is not taking part in the sum task pub const PARTICIPANT_TASK_SUM: c_int = 1 << 1; /// The participant is not taking part in the update task pub const PARTICIPANT_TASK_UPDATE: c_int = 1 << 2; /// The participant is expected to set the model it trained pub const PARTICIPANT_SHOULD_SET_MODEL: c_int = 1 << 3; /// The participant is expected to set the model it trained pub const PARTICIPANT_MADE_PROGRESS: c_int = 1 << 4; /// A new global model is available pub const PARTICIPANT_NEW_GLOBALMODEL: c_int = 1 << 5; /// Instantiate a new participant with the given settings. The participant must be /// destroyed with [`xaynet_ffi_participant_destroy`]. /// /// # Return value /// /// - a NULL pointer if `settings` is NULL or if the participant creation failed /// - a valid pointer to a [`Participant`] otherwise /// /// # Safety /// /// When calling this method, you have to ensure that *either* the pointer is NULL *or* /// all of the following is true: /// /// - The pointer must be properly [aligned]. /// - It must be "dereferencable" in the sense defined in the [`std::ptr`] module /// documentation. /// /// After destroying the participant with [`xaynet_ffi_participant_destroy`] becomes /// invalid and must not be used. /// /// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety /// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment #[no_mangle] pub unsafe extern "C" fn xaynet_ffi_participant_new(settings: *const Settings) -> *mut Participant { let settings = match unsafe { settings.as_ref() } { Some(settings) => settings.clone(), None => return std::ptr::null_mut(), }; match Participant::new(settings) { Ok(participant) => Box::into_raw(Box::new(participant)), Err(_) => std::ptr::null_mut(), } } /// Drive the participant internal state machine. Every tick, the state machine /// attempts to perform a small work unit. /// /// # Return value /// /// - [`ERR_NULLPTR`] is `participant` is NULL /// - a bitflag otherwise, with the following flags: /// - [`PARTICIPANT_MADE_PROGRESS`]: if set, this flag indicates that the participant /// internal state machine was able to make some progress, and that the participant /// state changed. This information can be used as an indication for saving the /// participant state for instance. If the flag is not set, the state machine was /// not able to make progress. There are many potential causes for this, including: /// - the participant is not taking part to the current training round and is just /// waiting for a new one to start /// - the Xaynet coordinator is not reachable or has not published some /// information the participant is waiting for /// - the state machine is waiting for the model to be set (see /// [`xaynet_ffi_participant_set_model()`]) /// - [`PARTICIPANT_TASK_NONE`], [`PARTICIPANT_TASK_SUM`] and /// [`PARTICIPANT_TASK_UPDATE`]: these flags are mutually exclusive, and indicate /// which task the participant has been selected for, for the current round. If /// [`PARTICIPANT_TASK_NONE`] is set, then the participant will just wait for a new /// round to start. If [`PARTICIPANT_TASK_UPDATE`] is set, then the participant has /// been selected to update the global model, and should prepare to provide a new /// model once the [`PARTICIPANT_SHOULD_SET_MODEL`] flag is set. /// - [`PARTICIPANT_SHOULD_SET_MODEL`]: if set, then the participant should set its /// model, by calling [`xaynet_ffi_participant_set_model()`] /// - [`PARTICIPANT_NEW_GLOBALMODEL`]: if set, the participant can fetch the new global /// model, by calling [`xaynet_ffi_participant_global_model()`] /// /// # Safety /// /// When calling this method, you have to ensure that *either* the pointer is NULL *or* /// all of the following is true: /// /// - The pointer must be properly [aligned]. /// - It must be "dereferencable" in the sense defined in the [`std::ptr`] module /// documentation. /// /// After destroying the participant with [`xaynet_ffi_participant_destroy`] becomes /// invalid and must not be used. /// /// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety /// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment #[no_mangle] pub unsafe extern "C" fn xaynet_ffi_participant_tick(participant: *mut Participant) -> c_int { let participant = match unsafe { participant.as_mut() } { Some(participant) => participant, None => return ERR_NULLPTR, }; participant.tick(); let mut flags: c_int = 0; match participant.task() { Task::None => flags |= PARTICIPANT_TASK_NONE, Task::Sum => flags |= PARTICIPANT_TASK_SUM, Task::Update => flags |= PARTICIPANT_TASK_UPDATE, }; if participant.should_set_model() { flags |= PARTICIPANT_SHOULD_SET_MODEL; } if participant.made_progress() { flags |= PARTICIPANT_MADE_PROGRESS; } if participant.new_global_model() { flags |= PARTICIPANT_NEW_GLOBALMODEL; } flags } /// Serialize the participant state and return a buffer that contains the serialized /// participant. /// /// # Safety /// /// 1. When calling this method, you have to ensure that *either* the pointer is NULL /// *or* all of the following is true: /// - The pointer must be properly [aligned]. /// - It must be "dereferencable" in the sense defined in the [`std::ptr`] module /// documentation. /// 2. the `ByteBuffer` created by this function must be destroyed with /// [`xaynet_ffi_participant_destroy`]. Attempting to free the memory from the other /// side of the FFI is UB. /// 3. This function destroys the participant. Therefore, **the pointer becomes invalid /// and must not be used anymore**. Instead, a new participant should be created, /// either with [`xaynet_ffi_participant_new()`] or /// [`xaynet_ffi_participant_restore()`] /// /// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety /// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment /// /// # Example /// /// To save the participant into a file: /// /// ```c /// const ByteBuffer *save_buf = xaynet_ffi_participant_save(participant); /// assert(save_buf); /// /// char *path = "./participant.bin"; /// FILE *f = fopen(path, "w"); /// fwrite(save_buf->data, 1, save_buf->len, f); /// fclose(f); /// ``` #[no_mangle] pub unsafe extern "C" fn xaynet_ffi_participant_save( participant: *mut Participant, ) -> *const ByteBuffer { let participant: Participant = match unsafe { participant.as_mut() } { Some(ptr) => unsafe { *Box::from_raw(ptr) }, None => return std::ptr::null(), }; Box::into_raw(Box::new(ByteBuffer::from_vec(participant.save()))) } /// Restore the participant from a buffer that contained its serialized state. /// /// # Return value /// /// - a NULL pointer on failure /// - a pointer to the restored participant on success /// /// # Safety /// /// When calling this method, you have to ensure that *either* the pointers are NULL /// *or* all of the following is true: /// - The pointers must be properly [aligned]. /// - They must be "dereferencable" in the sense defined in the [`std::ptr`] module /// documentation. /// /// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety /// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment /// /// # Example /// /// To restore a participant from a file: /// /// ```c /// f = fopen("./participant.bin", "r"); /// fseek(f, 0L, SEEK_END); /// int fsize = ftell(f); /// fseek(f, 0L, SEEK_SET); /// ByteBuffer buf = { /// .len = fsize, /// .data = (uint8_t *)malloc(fsize), /// }; /// int n_read = fread(buf.data, 1, fsize, f); /// assert(n_read == fsize); /// fclose(f); /// Participant *restored = /// xaynet_ffi_participant_restore("http://localhost:8081", &buf); /// free(buf.data); /// ``` #[no_mangle] pub unsafe extern "C" fn xaynet_ffi_participant_restore( url: FfiStr, buffer: *const ByteBuffer, ) -> *mut Participant { let url = match url.as_opt_str() { Some(url) => url, None => return ptr::null_mut(), }; let buffer: &ByteBuffer = match unsafe { buffer.as_ref() } { Some(ptr) => ptr, None => return ptr::null_mut(), }; if let Ok(participant) = Participant::restore(buffer.as_slice(), url) { Box::into_raw(Box::new(participant)) } else { ptr::null_mut() } } /// Set the participant's model. Usually this should be called when the value returned /// by [`xaynet_ffi_participant_tick()`] contains the [`PARTICIPANT_SHOULD_SET_MODEL`] /// flag, but it can be called anytime. The model just won't be sent to the coordinator /// until it's time. /// /// - `buffer` should be a pointer to a buffer that contains the model /// - `data_type` specifies the type of the model weights (see [`DataType`]). The C header /// file generated by this crate provides an enum corresponding to the parameters: `DataType`. /// - `len` is the number of weights the model has /// /// # Return value /// /// - [`OK`] if the model is set successfully /// - [`ERR_NULLPTR`] if `participant` is NULL /// - [`ERR_SETMODEL_DATATYPE`] if the datatype is invalid /// - [`ERR_SETMODEL_MODEL`] if the model is invalid /// /// # Safety /// /// 1. When calling this method, you have to ensure that *either* the pointer is NULL /// *or* all of the following is true: /// - The pointer must be properly [aligned]. /// - It must be "dereferencable" in the sense defined in the [`std::ptr`] module /// documentation. /// 2. If `len` or `data_type` do not match the model in `buffer`, this method will /// result in a buffer over-read. /// /// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety /// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment #[no_mangle] pub unsafe extern "C" fn xaynet_ffi_participant_set_model( participant: *mut Participant, buffer: *const c_void, data_type: c_uchar, len: c_uint, ) -> c_int { let participant = match unsafe { participant.as_mut() } { Some(participant) => participant, None => return ERR_NULLPTR, }; if buffer.is_null() { return ERR_NULLPTR; } let data_type = match DataType::try_from(data_type) { Ok(data_type) => data_type, Err(_) => return ERR_SETMODEL_DATATYPE, }; let len = len as usize; let model = match data_type { DataType::F32 => { let buffer = unsafe { slice::from_raw_parts(buffer as *const f32, len) }; // we map the error so that we get an uniform error type Model::from_primitives(buffer.iter().copied()).map_err(|_| ()) } DataType::F64 => { let buffer = unsafe { slice::from_raw_parts(buffer as *const f64, len) }; Model::from_primitives(buffer.iter().copied()).map_err(|_| ()) } DataType::I32 => { let buffer = unsafe { slice::from_raw_parts(buffer as *const i32, len) }; Model::from_primitives(buffer.iter().copied()).map_err(|_| ()) } DataType::I64 => { let buffer = unsafe { slice::from_raw_parts(buffer as *const i64, len) }; Model::from_primitives(buffer.iter().copied()).map_err(|_| ()) } }; if let Ok(m) = model { participant.set_model(m); OK } else { ERR_SETMODEL_MODEL } } /// Return the latest global model from the coordinator. /// /// - `buffer` is the array in which the global model should be copied. /// - `data_type` specifies the type of the model weights (see [`DataType`]). The C header /// file generated by this crate provides an enum corresponding to the parameters: `DataType`. /// - `len` is the number of weights the model has /// /// # Return Value /// /// - [`OK`] if the model is set successfully /// - [`ERR_NULLPTR`] if `participant` or the `buffer` is NULL /// - [`GLOBALMODEL_NONE`] if no model exists /// - [`ERR_GLOBALMODEL_IO`] if the communication with the coordinator failed /// - [`ERR_GLOBALMODEL_DATATYPE`] if the datatype is invalid /// - [`ERR_GLOBALMODEL_LEN`] if the length of the buffer does not match the length of the model /// - [`ERR_GLOBALMODEL_CONVERT`] if the conversion of the model failed /// /// # Note /// /// It is **not** guaranteed, that the model configuration returned by /// [`xaynet_ffi_participant_local_model_config`] corresponds to the configuration of /// the global model. This means that the global model can have a different length / data type /// than it is defined in model configuration. That both model configurations are the same is /// only guaranteed if the model config **never** changes on the coordinator side. /// /// # Safety /// /// 1. When calling this method, you have to ensure that *either* the pointer is NULL /// *or* all of the following is true: /// - The pointer must be properly [aligned]. /// - It must be "dereferencable" in the sense defined in the [`std::ptr`] module /// documentation. /// 2. If `len` or `data_type` do not match the model in `buffer`, this method will /// result in a buffer over-read. /// /// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment #[no_mangle] pub unsafe extern "C" fn xaynet_ffi_participant_global_model( participant: *mut Participant, buffer: *mut c_void, data_type: c_uchar, len: c_uint, ) -> c_int { let participant = match unsafe { participant.as_mut() } { Some(participant) => participant, None => return ERR_NULLPTR, }; if buffer.is_null() { return ERR_NULLPTR; } let global_model = match participant.global_model() { Ok(Some(model)) => model, Ok(None) => return GLOBALMODEL_NONE, Err(_) => return ERR_GLOBALMODEL_IO, }; let data_type = match DataType::try_from(data_type) { Ok(data_type) => data_type, Err(_) => return ERR_GLOBALMODEL_DATATYPE, }; let len = len as usize; if len != global_model.len() { return ERR_GLOBALMODEL_LEN; } match data_type { DataType::F32 => into_primitives!(global_model, buffer, f32, len), DataType::F64 => into_primitives!(global_model, buffer, f64, len), DataType::I32 => into_primitives!(global_model, buffer, i32, len), DataType::I64 => into_primitives!(global_model, buffer, i64, len), } } #[macro_export] macro_rules! into_primitives { ($global_model:expr, $buffer:expr, $data_type:ty, $len:expr) => {{ if let Ok(global_model) = $global_model .into_primitives() .collect::, _>>() { let buffer = unsafe { slice::from_raw_parts_mut($buffer as *mut $data_type, $len) }; buffer.copy_from_slice(global_model.as_slice()); OK } else { ERR_GLOBALMODEL_CONVERT } }}; } /// Return the local model configuration of the model that is expected in the /// [`xaynet_ffi_participant_set_model()`] function. /// /// # Safety /// /// 1. When calling this method, you have to ensure that *either* the pointer is NULL /// *or* all of the following is true: /// - The pointer must be properly [aligned]. /// - It must be "dereferencable" in the sense defined in the [`std::ptr`] module /// documentation. /// /// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment #[no_mangle] pub unsafe extern "C" fn xaynet_ffi_participant_local_model_config( participant: *const Participant, ) -> *mut LocalModelConfig { let participant = match unsafe { participant.as_ref() } { Some(ptr) => ptr, None => return std::ptr::null_mut(), }; Box::into_raw(Box::new(participant.local_model_config().into())) } ================================================ FILE: rust/xaynet-mobile/src/ffi/settings.rs ================================================ use std::os::raw::{c_double, c_int}; use ffi_support::{ByteBuffer, FfiStr}; use xaynet_core::crypto::{ByteObject, PublicSigningKey, SecretSigningKey, SigningKeyPair}; use zeroize::Zeroize; use super::{ ERR_CRYPTO_PUBLIC_KEY, ERR_CRYPTO_SECRET_KEY, ERR_INVALID_URL, ERR_NULLPTR, ERR_SETTINGS_KEYS, ERR_SETTINGS_SCALAR, ERR_SETTINGS_URL, OK, }; use crate::{Settings, SettingsError}; mod pv { use super::Settings; ffi_support::define_box_destructor!(Settings, _xaynet_ffi_settings_destroy); } /// Destroy the settings created by [`xaynet_ffi_settings_new()`]. /// /// # Return value /// /// - [`OK`] on success /// - [`ERR_NULLPTR`] if `buf` is NULL /// /// # Safety /// /// 1. When calling this method, you have to ensure that *either* the pointer is NULL /// *or* all of the following is true: /// - The pointer must be properly [aligned]. /// - It must be "dereferencable" in the sense defined in the [`std::ptr`] module /// documentation. /// 2. After destroying the `Settings`, the pointer becomes invalid and must not be /// used. /// 3. This function should only be called on a pointer that has been created by /// [`xaynet_ffi_settings_new`]. /// /// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety /// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment #[no_mangle] pub unsafe extern "C" fn xaynet_ffi_settings_destroy(settings: *mut Settings) -> c_int { if settings.is_null() { return ERR_NULLPTR; } pv::_xaynet_ffi_settings_destroy(settings); OK } /// Create new [`Settings`] and return a pointer to it. /// /// # Safety /// /// The `Settings` created by this function must be destroyed with /// [`xaynet_ffi_settings_destroy()`]. Attempting to free the memory from the other side /// of the FFI is UB. #[no_mangle] pub unsafe extern "C" fn xaynet_ffi_settings_new() -> *mut Settings { Box::into_raw(Box::new(Settings::new())) } /// Set scalar setting. /// /// # Return value /// /// - [`OK`] if successful /// - [`ERR_NULLPTR`] if `settings` is `NULL` /// /// # Safety /// /// When calling this method, you have to ensure that *either* the pointer is NULL *or* /// all of the following is true: /// - The pointer must be properly [aligned]. /// - It must be "dereferencable" in the sense defined in the [`std::ptr`] module /// documentation. /// /// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety /// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment #[no_mangle] pub unsafe extern "C" fn xaynet_ffi_settings_set_scalar( settings: *mut Settings, scalar: c_double, ) -> c_int { match unsafe { settings.as_mut() } { Some(settings) => { settings.set_scalar(scalar); OK } None => ERR_NULLPTR, } } /// Set coordinator URL. /// /// # Return value /// /// - [`OK`] if successful /// - [`ERR_INVALID_URL`] if `url` is not a valid string /// - [`ERR_NULLPTR`] if `settings` is `NULL` /// /// # Safety /// /// When calling this method, you have to ensure that *either* the pointers are NULL /// *or* all of the following is true: /// - The pointers must be properly [aligned]. /// - They must be "dereferencable" in the sense defined in the [`std::ptr`] module /// documentation. /// /// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety /// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment #[no_mangle] pub unsafe extern "C" fn xaynet_ffi_settings_set_url( settings: *mut Settings, url: FfiStr, ) -> c_int { let url = match url.as_opt_str() { Some(url) => url, None => return ERR_INVALID_URL, }; match unsafe { settings.as_mut() } { Some(settings) => { settings.set_url(url.to_string()); OK } None => ERR_NULLPTR, } } // TODO: add a way to save the key pair /// A signing key pair pub struct KeyPair { public: ByteBuffer, secret: ByteBuffer, } // TODO: document that crypto must be initialized. /// Generate a new signing key pair that can be used in the [`Settings`]. **Before /// calling this function you must initialize the crypto library with /// [`xaynet_ffi_crypto_init()`]**. /// /// The returned value contains a pointer to the secret key. For security reasons, you /// must make sure that this buffer life is a short as possible, and call /// [`xaynet_ffi_forget_key_pair`] to destroy it. /// /// [`xaynet_ffi_crypto_init()`]: crate::ffi::xaynet_ffi_crypto_init /// /// # Safety /// /// This function is safe to call #[no_mangle] pub unsafe extern "C" fn xaynet_ffi_generate_key_pair() -> *const KeyPair { let SigningKeyPair { public, secret } = SigningKeyPair::generate(); let public_vec = public.as_slice().to_vec(); let secret_vec = secret.as_slice().to_vec(); let keys = KeyPair { public: ByteBuffer::from_vec(public_vec), // under the hood, ByteBuffer takes ownership of the memory // without copying/leaking anything. There's no need to zero // out anything yet secret: ByteBuffer::from_vec(secret_vec), }; Box::into_raw(Box::new(keys)) } /// De-allocate the buffers that contain the signing keys, and zero out the content of /// the buffer that contains the secret key. /// /// # Return value /// /// - [`ERR_NULLPTR`] is `key_pair` is NULL /// - [`OK`] otherwise /// /// # Safety /// /// When calling this method, you have to ensure that *either* the pointer is NULL *or* /// all of the following is true: /// - The pointer must be properly [aligned]. /// - It must be "dereferencable" in the sense defined in the [`std::ptr`] module /// documentation. /// /// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety /// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment #[no_mangle] pub unsafe extern "C" fn xaynet_ffi_forget_key_pair(key_pair: *const KeyPair) -> c_int { if key_pair.is_null() { return ERR_NULLPTR; } let key_pair = unsafe { Box::from_raw(key_pair as *mut KeyPair) }; // IMPORTANT: we need to free the ByteBuffer memory, since it does // not implement drop. We also take care of zero-ing the memory // for the secret key. key_pair.secret.destroy_into_vec().zeroize(); key_pair.public.destroy_into_vec(); OK } /// Set participant signing keys. /// /// # Return value /// /// - [`OK`] if successful /// - [`ERR_NULLPTR`] if `settings` or `key_pair` is `NULL` /// - [`ERR_CRYPTO_PUBLIC_KEY`] if the given `key_pair` contains an invalid public key /// - [`ERR_CRYPTO_SECRET_KEY`] if the given `key_pair` contains an invalid secret key /// /// # Safety /// /// When calling this method, you have to ensure that *either* the pointers are NULL /// *or* all of the following is true: /// - The pointers must be properly [aligned]. /// - They must be "dereferencable" in the sense defined in the [`std::ptr`] module /// documentation. /// /// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety /// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment #[no_mangle] pub unsafe extern "C" fn xaynet_ffi_settings_set_keys( settings: *mut Settings, key_pair: *const KeyPair, ) -> c_int { let key_pair = match unsafe { key_pair.as_ref() } { Some(key_pair) => key_pair, None => return ERR_NULLPTR, }; let secret_slice = key_pair.secret.as_slice(); if secret_slice.len() != SecretSigningKey::LENGTH { return ERR_CRYPTO_SECRET_KEY; } let secret = SecretSigningKey::from_slice_unchecked(secret_slice); let public_slice = key_pair.public.as_slice(); if public_slice.len() != PublicSigningKey::LENGTH { return ERR_CRYPTO_PUBLIC_KEY; } let public = PublicSigningKey::from_slice_unchecked(public_slice); match unsafe { settings.as_mut() } { Some(settings) => { settings.set_keys(SigningKeyPair { public, secret }); OK } None => ERR_NULLPTR, } } /// Check whether the given settings are valid and can be used to instantiate a /// participant (see [`xaynet_ffi_participant_new()`]). /// /// # Return value /// /// - [`OK`] on success /// - [`ERR_SETTINGS_URL`] if the URL has not been set /// - [`ERR_SETTINGS_KEYS`] if the signing keys have not been set /// - [`ERR_SETTINGS_SCALAR`] if the scalar is out of bounds /// /// # Safety /// /// When calling this method, you have to ensure that *either* the pointer is NULL *or* /// all of the following is true: /// /// - The pointer must be properly [aligned]. /// - It must be "dereferencable" in the sense defined in the [`std::ptr`] module /// documentation. /// /// [`xaynet_ffi_participant_new()`]: crate::ffi::xaynet_ffi_participant_new /// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety /// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment #[no_mangle] pub unsafe extern "C" fn xaynet_ffi_check_settings(settings: *const Settings) -> c_int { match unsafe { settings.as_ref() } { Some(settings) => match settings.check() { Ok(()) => OK, Err(SettingsError::MissingUrl) => ERR_SETTINGS_URL, Err(SettingsError::MissingKeys) => ERR_SETTINGS_KEYS, Err(SettingsError::OutOfScalarRange(_)) => ERR_SETTINGS_SCALAR, }, None => ERR_NULLPTR, } } ================================================ FILE: rust/xaynet-mobile/src/lib.rs ================================================ #![cfg_attr( doc, forbid(rustdoc::broken_intra_doc_links, rustdoc::private_intra_doc_links) )] #![doc( html_logo_url = "https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/xaynet_banner.png", html_favicon_url = "https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/favicon.png", issue_tracker_base_url = "https://github.com/xaynetwork/xaynet/issues" )] //! This crates provides a mobile friendly implementation of a Xaynet Federated Learning //! participant, along with FFI C bindings for building applications in languages that //! can use C bindings. //! //! The [`Participant`] provided by this crate is mobile friendly because the caller has //! a lot of control on how to drive the participant execution. You can regularly pause //! the execution of the participant, save it, and later restore it and continue the //! execution. When running on a device that is low on battery or does not have access //! to Wi-Fi for instance, it can be useful to be able to pause the participant. //! //! This control comes at a complexity cost though. Usually, a participant is split two: //! - a task that executes a state machine that implements the PET protocol and emit //! notifications. //! - a task that react to these events, for instance by downloading the latest global //! model at the end of a round, or trains a new model when the participant has been //! selected for the update task. //! //! The task that executes the PET protocol usually runs in background and we have //! little control over it. This is a problem on mobile environment: //! - first, the app may be killed at any moment and we'd lose the participant state //! - second we don't really want a background task to potentially perform CPU heavy or //! network heavy operations without having a say since it may drain the battery or //! consume too much data. //! //! To solve this problem, the [`Participant`] provided in this crate embeds the PET //! state machine, and it's the caller responsibility to drive its execution (see //! [`Participant::tick()`]) #[macro_use] extern crate ffi_support; #[macro_use] extern crate async_trait; #[macro_use] extern crate tracing; mod participant; mod settings; pub use self::{ participant::{Event, Events, InitError, Notifier, Participant, Task}, settings::{Settings, SettingsError}, }; pub mod ffi; mod reqwest_client; pub(crate) use reqwest_client::new_client; pub use reqwest_client::ClientError; ================================================ FILE: rust/xaynet-mobile/src/participant.rs ================================================ //! Participant implementation use std::{convert::TryInto, sync::Arc}; use futures::future::FutureExt; use thiserror::Error; use tokio::{ runtime::Runtime, sync::{mpsc, Mutex}, }; use xaynet_core::mask::Model; use xaynet_sdk::{ client::Client, LocalModelConfig, ModelStore, Notify, SerializableState, StateMachine, TransitionOutcome, XaynetClient, }; use crate::{ new_client, settings::{Settings, SettingsError}, ClientError, }; /// Event emitted by the participant internal state machine as it advances through the /// PET protocol pub enum Event { /// Event emitted when the participant is selected for the update task Update, /// Event emitted when the participant is selected for the sum task Sum, /// Event emitted when the participant is done with its task Idle, /// Event emitted when a new round starts NewRound, /// Event emitted when the participant should load its model. This only happens if /// the participant has been selected for the update task LoadModel, } /// Event sender that is passed to the participant internal state machine for emitting /// notification pub struct Notifier(mpsc::Sender); impl Notifier { fn notify(&mut self, event: Event) { if let Err(e) = self.0.try_send(event) { warn!("failed to notify participant: {}", e); } } } /// A receiver for events emitted by the participant internal state machine pub struct Events(mpsc::Receiver); impl Events { /// Create a new event sender and receiver. fn new() -> (Self, Notifier) { let (tx, rx) = mpsc::channel(10); (Self(rx), Notifier(tx)) } /// Pop the next event. If no event has been received, return `None`. fn next(&mut self) -> Option { // Note `try_recv` (tokio 0.2.x) or `recv().now_or_never()` (tokio 1.x) // has an implementation bug where previously sent messages may not be // available immediately. // Related issue: https://github.com/tokio-rs/tokio/issues/3350 // However, that should not be an issue for us. let next = self.0.recv().now_or_never()?; if next.is_none() { // if next is `none`, the channel is closed // This can happen if: // 1. the state machine crashed. In that case it's OK to crash. // 2. `next` was called whereas the state machine was // dropped, which is an error. So crashing is OK as // well. panic!("notifier dropped") } next } } impl Notify for Notifier { fn new_round(&mut self) { self.notify(Event::NewRound) } fn sum(&mut self) { self.notify(Event::Sum) } fn update(&mut self) { self.notify(Event::Update) } fn load_model(&mut self) { self.notify(Event::LoadModel) } fn idle(&mut self) { self.notify(Event::Idle) } } /// A store shared between by the participant and its internal state machine. When the /// state machine emits a [`Event::LoadModel`] event, the participant is expected to /// load its model into the store. See [`Participant::set_model()`]. #[derive(Clone)] struct Store(Arc>>); impl Store { /// Create a new model store. fn new() -> Self { Self(Arc::new(Mutex::new(None))) } } #[async_trait] impl ModelStore for Store { type Model = Model; type Error = std::convert::Infallible; async fn load_model(&mut self) -> Result, Self::Error> { Ok(self.0.lock().await.take()) } } /// Represent the participant current task #[derive(Clone, Debug, Copy)] pub enum Task { /// The participant is taking part in the sum task Sum, /// The participant is taking part in the update task Update, /// The participant is not taking part in any task None, } /// A participant. It embeds an internal state machine that executes the PET /// protocol. However, it is the caller's responsibility to drive this state machine by /// calling [`Participant::tick()`], and to take action when the participant state /// changes. pub struct Participant { /// Internal state machine state_machine: Option, /// Receiver for the events emitted by the state machine events: Events, /// Model store where the participant should load its model, when /// `self.should_set_model` is `true`. store: Store, /// Async runtime to execute the state machine runtime: Runtime, /// Xaynet client client: Client, /// Whether the participant state changed after the last call to /// [`Participant::tick()`] made_progress: bool, /// Whether the participant should load its model into the store. should_set_model: bool, /// Whether a new global model is available. new_global_model: bool, /// The participant current task task: Task, } /// Error that can occur when instantiating a new [`Participant`], either with /// [`Participant::new()`] or [`Participant::restore()`] #[derive(Error, Debug)] pub enum InitError { #[error("failed to deserialize the participant state {:?}", _0)] Deserialization(#[from] Box), #[error("failed to initialize the participant runtime {:?}", _0)] Runtime(std::io::Error), #[error("failed to initialize HTTP client {:?}", _0)] Client(#[from] ClientError), #[error("invalid participant settings {:?}", _0)] InvalidSettings(#[from] SettingsError), } #[derive(Error, Debug)] #[error("failed to fetch global model: {}", self.0)] pub struct GetGlobalModelError(xaynet_sdk::client::ClientError); impl Participant { /// Create a new participant with the given settings pub fn new(settings: Settings) -> Result { let (url, pet_settings) = settings.try_into()?; let client = new_client(url.as_str(), None, None)?; let (events, notifier) = Events::new(); let store = Store::new(); let state_machine = StateMachine::new(pet_settings, client.clone(), store.clone(), notifier); Self::init(state_machine, client, events, store) } /// Restore a participant from it's serialized state. The coordinator client that /// the participant uses internally is not part of the participant state, so the /// `url` is used to instantiate a new one. pub fn restore(state: &[u8], url: &str) -> Result { let state: SerializableState = bincode::deserialize(state)?; let (events, notifier) = Events::new(); let store = Store::new(); let client = new_client(url, None, None)?; let state_machine = StateMachine::restore(state, client.clone(), store.clone(), notifier); Self::init(state_machine, client, events, store) } fn init( state_machine: StateMachine, client: Client, events: Events, store: Store, ) -> Result { let mut participant = Self { runtime: Self::runtime()?, state_machine: Some(state_machine), events, store, client, task: Task::None, made_progress: true, should_set_model: false, new_global_model: false, }; participant.process_events(); Ok(participant) } fn runtime() -> Result { tokio::runtime::Builder::new_current_thread() .enable_all() .build() .map_err(InitError::Runtime) } /// Serialize the participant state and return the corresponding buffer. pub fn save(self) -> Vec { // UNWRAP_SAFE: the state machine is always set. let state_machine = self.state_machine.unwrap().save(); bincode::serialize(&state_machine).unwrap() } /// Drive the participant internal state machine. /// /// After calling this method, the caller should check whether the participant state /// changed, by calling [`Participant::made_progress()`]. If the state changed, the /// caller should perform the following checks and react appropriately: /// /// - whether the participant is taking part to any task by calling /// [`Participant::task()`] /// - whether the participant should load its model into the store by calling /// [`Participant::should_set_model()`] pub fn tick(&mut self) { // UNWRAP_SAFE: the state machine is always set. let state_machine = self.state_machine.take().unwrap(); let outcome = self .runtime .block_on(async { state_machine.transition().await }); match outcome { TransitionOutcome::Pending(new_state_machine) => { self.made_progress = false; self.state_machine = Some(new_state_machine); } TransitionOutcome::Complete(new_state_machine) => { self.made_progress = true; self.state_machine = Some(new_state_machine) } }; self.process_events(); } fn process_events(&mut self) { loop { match self.events.next() { Some(Event::Idle) => { self.task = Task::None; } Some(Event::Update) => { self.task = Task::Update; } Some(Event::Sum) => { self.task = Task::Sum; } Some(Event::NewRound) => { self.should_set_model = false; self.new_global_model = true; } Some(Event::LoadModel) => { self.should_set_model = true; } None => break, } } } /// Check whether the participant internal state machine made progress while /// executing the PET protocol. If so, the participant state likely changed. pub fn made_progress(&self) -> bool { self.made_progress } /// Check whether the participant internal state machine is waiting for the /// participant to load its model into the store. If this method returns `true`, the /// caller should make sure to call [`Participant::set_model()`] at some point. pub fn should_set_model(&self) -> bool { self.should_set_model } /// Check whether a new global model is available. If this method returns `true`, the /// caller can call [`Participant::global_model()`] to fetch the new global model. pub fn new_global_model(&self) -> bool { self.new_global_model } /// Return the participant current task pub fn task(&self) -> Task { self.task } /// Load the given model into the store, so that the participant internal state /// machine can process it. pub fn set_model(&mut self, model: Model) { let Self { ref mut runtime, ref store, .. } = self; runtime.block_on(async { let mut stored_model = store.0.lock().await; *stored_model = Some(model) }); self.should_set_model = false; } /// Retrieve the current global model, if available. pub fn global_model(&mut self) -> Result, GetGlobalModelError> { let Self { ref mut runtime, ref mut client, .. } = self; let global_model = runtime.block_on(async { client.get_model().await.map_err(GetGlobalModelError) }); if global_model.is_ok() { self.new_global_model = false; } global_model } /// Return the local model configuration of the model that is expected in the /// [`Participant::set_model`] method. pub fn local_model_config(&self) -> LocalModelConfig { // UNWRAP_SAFE: the state machine is always set. let state_machine = self.state_machine.as_ref().unwrap(); state_machine.local_model_config() } } ================================================ FILE: rust/xaynet-mobile/src/reqwest_client.rs ================================================ use std::{fs::File, io::Read}; use thiserror::Error; use xaynet_sdk::client::Client; /// Error returned upon failing to instantiate a new [`xaynet_sdk::client::Client`] #[derive(Debug, Error)] pub enum ClientError { #[error("invalid URL: {0}")] InvalidUrl(String), #[error("failed to read trust anchor {0}: {1}")] TrustAnchor(String, String), #[error("failed to read client certificate {0}: {1}")] ClientCert(String, String), #[error("{0}")] Other(String), } impl ClientError { fn trust_anchor(path: String, e: E) -> Self { Self::TrustAnchor(path, format!("{}", e)) } fn client_cert(path: String, e: E) -> Self { Self::ClientCert(path, format!("{}", e)) } fn other(e: E) -> Self { Self::Other(format!("{}", e)) } } /// Build a new [`xaynet_sdk::client::Client`] /// /// # Args /// /// - `address`: URL of the Xaynet coordinator to connect to /// - `trust_anchor_path`: path the to root certificate for TLS server authentication. The /// certificate must be PEM encoded. /// - `client_cert_path`: path to the client certificate to use for TLS client authentication. The /// certificate must be PEM encoded. pub fn new_client( address: &str, trust_anchor_path: Option, client_cert_path: Option, ) -> Result, ClientError> { let builder = reqwest::ClientBuilder::new(); let builder = if let Some(path) = trust_anchor_path { let mut buf = Vec::new(); File::open(&path) .map_err(|e| ClientError::trust_anchor(path.clone(), e))? .read_to_end(&mut buf) .map_err(|e| ClientError::trust_anchor(path.clone(), e))?; let root_cert = reqwest::Certificate::from_pem(&buf).map_err(|e| ClientError::trust_anchor(path, e))?; builder.use_rustls_tls().add_root_certificate(root_cert) } else { builder }; let builder = if let Some(path) = client_cert_path { let mut buf = Vec::new(); File::open(&path) .map_err(|e| ClientError::client_cert(path.clone(), e))? .read_to_end(&mut buf) .map_err(|e| ClientError::client_cert(path.clone(), e))?; let identity = reqwest::Identity::from_pem(&buf).map_err(|e| ClientError::client_cert(path, e))?; builder.use_rustls_tls().identity(identity) } else { builder }; let reqwest_client = builder.build().map_err(ClientError::other)?; let xaynet_client = Client::new(reqwest_client, address) .map_err(|_| ClientError::InvalidUrl(address.to_string()))?; Ok(xaynet_client) } ================================================ FILE: rust/xaynet-mobile/src/settings.rs ================================================ //! This module provides utilities to configure a [`Participant`]. //! //! [`Participant`]: crate::Participant use std::convert::TryInto; use thiserror::Error; use xaynet_core::{ crypto::SigningKeyPair, mask::{FromPrimitive, PrimitiveCastError, Scalar}, }; use xaynet_sdk::settings::{MaxMessageSize, PetSettings}; /// A participant settings #[derive(Clone, Debug)] pub struct Settings { /// The Xaynet coordinator URL. url: Option, /// The participant signing keys. keys: Option, /// The scalar used for masking. scalar: Result>, /// The maximum possible size of a message. max_message_size: MaxMessageSize, } impl Default for Settings { fn default() -> Self { Self::new() } } impl Settings { /// Create new empty settings. pub fn new() -> Self { Self { url: None, keys: None, scalar: Ok(Scalar::unit()), max_message_size: MaxMessageSize::default(), } } /// Set the participant signing keys pub fn set_keys(&mut self, keys: SigningKeyPair) { self.keys = Some(keys); } /// Set the scalar to use for masking pub fn set_scalar(&mut self, scalar: f64) { self.scalar = Scalar::from_primitive(scalar) } /// Set the Xaynet coordinator address pub fn set_url(&mut self, url: String) { self.url = Some(url); } /// Sets the maximum possible size of a message. pub fn set_max_message_size(&mut self, size: MaxMessageSize) { self.max_message_size = size; } /// Check whether the settings are complete and valid pub fn check(&self) -> Result<(), SettingsError> { if self.url.is_none() { Err(SettingsError::MissingUrl) } else if self.keys.is_none() { Err(SettingsError::MissingKeys) } else if let Err(e) = &self.scalar { Err(e.clone().into()) } else { Ok(()) } } } /// Error returned when the settings are invalid #[derive(Debug, Error)] pub enum SettingsError { #[error("the Xaynet coordinator URL must be specified")] MissingUrl, #[error("the participant signing key pair must be specified")] MissingKeys, #[error("float not within range of scalar: {0}")] OutOfScalarRange(#[from] PrimitiveCastError), } impl TryInto<(String, PetSettings)> for Settings { type Error = SettingsError; fn try_into(self) -> Result<(String, PetSettings), Self::Error> { let Settings { keys, url, scalar, max_message_size, } = self; let url = url.ok_or(SettingsError::MissingUrl)?; let keys = keys.ok_or(SettingsError::MissingKeys)?; let scalar = scalar.map_err(SettingsError::OutOfScalarRange)?; let pet_settings = PetSettings { keys, scalar, max_message_size, }; Ok((url, pet_settings)) } } ================================================ FILE: rust/xaynet-mobile/tests/ffi_test.c ================================================ #include #include #include #include #include #include "minunit.h" #include "xaynet_ffi.h" static char *test_settings_new() { Settings *settings = xaynet_ffi_settings_new(); xaynet_ffi_settings_destroy(settings); return 0; } static char *test_settings_set_keys() { mu_assert("failed to init crypto", xaynet_ffi_crypto_init() == OK); Settings *settings = xaynet_ffi_settings_new(); const KeyPair *keys = xaynet_ffi_generate_key_pair(); int err = xaynet_ffi_settings_set_keys(settings, keys); mu_assert("failed to set keys", !err); xaynet_ffi_forget_key_pair(keys); xaynet_ffi_settings_destroy(settings); return 0; } static char *test_settings_set_url() { Settings *settings = xaynet_ffi_settings_new(); int err = xaynet_ffi_settings_set_url(settings, NULL); mu_assert("settings invalid URL should fail", err == ERR_INVALID_URL); char *url = "http://localhost:1234"; err = xaynet_ffi_settings_set_url(settings, url); mu_assert("failed to set url", !err); char *url2 = strdup(url); err = xaynet_ffi_settings_set_url(settings, url2); mu_assert("failed to set url from allocated string", !err); // cleanup free(url2); xaynet_ffi_settings_destroy(settings); return 0; } void with_keys(Settings *settings) { const KeyPair *keys = xaynet_ffi_generate_key_pair(); int err = xaynet_ffi_settings_set_keys(settings, keys); assert(!err); xaynet_ffi_forget_key_pair(keys); } void with_url(Settings *settings) { int err = xaynet_ffi_settings_set_url(settings, "http://localhost:1234"); assert(!err); } static char *test_settings() { Settings *settings = xaynet_ffi_settings_new(); with_keys(settings); int err = xaynet_ffi_check_settings(settings); mu_assert("expected missing url error", err == ERR_SETTINGS_URL); xaynet_ffi_settings_destroy(settings); settings = xaynet_ffi_settings_new(); with_url(settings); err = xaynet_ffi_check_settings(settings); mu_assert("expected missing keys error", err == ERR_SETTINGS_KEYS); xaynet_ffi_settings_destroy(settings); return 0; } static char *test_global_model() { Settings *settings = xaynet_ffi_settings_new(); with_keys(settings); with_url(settings); xaynet_ffi_settings_set_url(settings, "http://localhost:8081"); Participant *participant = xaynet_ffi_participant_new(settings); mu_assert("failed to create participant", participant != NULL); LocalModelConfig *local_model_config = xaynet_ffi_participant_local_model_config(participant); float* buffer = (float *)malloc(sizeof(float) * local_model_config->len); int err = xaynet_ffi_participant_global_model(NULL, buffer, local_model_config->data_type, local_model_config->len); mu_assert("expected participant is null error", err == ERR_NULLPTR); err = xaynet_ffi_participant_global_model(participant, NULL, local_model_config->data_type, local_model_config->len); mu_assert("expected buffer is null error", err == ERR_NULLPTR); err = xaynet_ffi_participant_global_model(participant, buffer, local_model_config->data_type, local_model_config->len); mu_assert("expected io error (cannot connect to coordinator)", err == ERR_GLOBALMODEL_IO); free(buffer); xaynet_ffi_local_model_config_destroy(local_model_config); xaynet_ffi_participant_destroy(participant); xaynet_ffi_settings_destroy(settings); return 0; } static char *test_participant_save_and_restore() { Settings *settings = xaynet_ffi_settings_new(); with_keys(settings); with_url(settings); Participant *participant = xaynet_ffi_participant_new(settings); mu_assert("failed to create participant", participant != NULL); xaynet_ffi_settings_destroy(settings); // save the participant const ByteBuffer *save_buf = xaynet_ffi_participant_save(participant); mu_assert("failed to save participant", save_buf != NULL); // write the serialized participant to a file char *path = "./test_participant_save_and_restore.txt"; FILE *f = fopen(path, "w"); fwrite(save_buf->data, 1, save_buf->len, f); fclose(f); int err = xaynet_ffi_byte_buffer_destroy(save_buf); assert(!err); // read the serialized participant from the file f = fopen(path, "r"); fseek(f, 0L, SEEK_END); int fsize = ftell(f); fseek(f, 0L, SEEK_SET); ByteBuffer restore_buf = { .len = fsize, .data = (uint8_t *)malloc(fsize), }; int n_read = fread(restore_buf.data, 1, fsize, f); mu_assert("failed to read serialized participant", n_read == fsize); fclose(f); // restore the participant Participant *restored = xaynet_ffi_participant_restore("http://localhost:8081", &restore_buf); mu_assert("failed to restore participant", restored != NULL); // free memory free(restore_buf.data); xaynet_ffi_participant_destroy(restored); return 0; } static char *test_participant_tick() { Settings *settings = xaynet_ffi_settings_new(); with_keys(settings); with_url(settings); Participant *participant = xaynet_ffi_participant_new(settings); mu_assert("failed to create participant", participant != NULL); int status = xaynet_ffi_participant_tick(participant); mu_assert("missing no task flag", (status & PARTICIPANT_TASK_NONE)); mu_assert("unexpected sum task flag", !(status & PARTICIPANT_TASK_SUM)); mu_assert("unexpected update task flag", !(status & PARTICIPANT_TASK_UPDATE)); mu_assert("unexpected set model flag", !(status & PARTICIPANT_SHOULD_SET_MODEL)); mu_assert("unexpected made progress flag", !(status & PARTICIPANT_MADE_PROGRESS)); // free memory xaynet_ffi_settings_destroy(settings); xaynet_ffi_participant_destroy(participant); return 0; } static char *all_tests() { mu_run_test(test_settings_new); mu_run_test(test_settings_set_keys); mu_run_test(test_settings_set_url); mu_run_test(test_settings); mu_run_test(test_global_model); mu_run_test(test_participant_save_and_restore); mu_run_test(test_participant_tick); return 0; } int tests_run = 0; int main(int argc, char **argv) { assert(xaynet_ffi_crypto_init() == OK); char *result = all_tests(); if (result != 0) { fprintf(stderr, RED "ERROR: %s\n" RESET, result); } else { printf(GREEN "ALL TESTS PASSED\n" RESET); } printf("Tests run: %d\n", tests_run); return result != 0; } ================================================ FILE: rust/xaynet-mobile/tests/minunit.h ================================================ #define RESET "\033[0m" #define BLACK "\033[30m" /* Black */ #define RED "\033[31m" /* Red */ #define GREEN "\033[32m" /* Green */ #define mu_assert(message, test) \ do \ { \ if (!(test)) \ return message; \ } while (0) #define mu_run_test(test) \ do \ { \ char *message = test(); \ tests_run++; \ if (message) \ return message; \ } while (0) extern int tests_run; ================================================ FILE: rust/xaynet-mobile/xaynet_ffi.h ================================================ /* Generated with cbindgen:0.17.0 */ /* Warning, this file is autogenerated by cbindgen. Don't modify this manually. */ #include #include #include #include /** * Return value upon success */ #define OK 0 /** * NULL pointer argument */ #define ERR_NULLPTR 1 /** * Invalid coordinator URL */ #define ERR_INVALID_URL 2 /** * Invalid settings: coordinator URL is not set */ #define ERR_SETTINGS_URL 3 /** * Invalid settings: signing keys are not set */ #define ERR_SETTINGS_KEYS 4 /** * Invalid settings: scalar is out of bounds */ #define ERR_SETTINGS_SCALAR 5 /** * Failed to set the local model: invalid model */ #define ERR_SETMODEL_MODEL 6 /** * Failed to set the local model: invalid data type */ #define ERR_SETMODEL_DATATYPE 7 /** * Failed to initialized the crypto library */ #define ERR_CRYPTO_INIT 8 /** * Invalid secret signing key */ #define ERR_CRYPTO_SECRET_KEY 9 /** * Invalid public signing key */ #define ERR_CRYPTO_PUBLIC_KEY 10 /** * No global model is currently available */ #define GLOBALMODEL_NONE 11 /** * Failed to get the global model: communication with the coordinator failed */ #define ERR_GLOBALMODEL_IO 12 /** * Failed to get the global model: invalid data type */ #define ERR_GLOBALMODEL_DATATYPE 13 /** * Failed to get the global model: invalid buffer length */ #define ERR_GLOBALMODEL_LEN 14 /** * Failed to get the global model: invalid model */ #define ERR_GLOBALMODEL_CONVERT 15 /** * The participant is not taking part in the sum or update task */ #define PARTICIPANT_TASK_NONE 1 /** * The participant is not taking part in the sum task */ #define PARTICIPANT_TASK_SUM (1 << 1) /** * The participant is not taking part in the update task */ #define PARTICIPANT_TASK_UPDATE (1 << 2) /** * The participant is expected to set the model it trained */ #define PARTICIPANT_SHOULD_SET_MODEL (1 << 3) /** * The participant is expected to set the model it trained */ #define PARTICIPANT_MADE_PROGRESS (1 << 4) /** * A new global model is available */ #define PARTICIPANT_NEW_GLOBALMODEL (1 << 5) /** * The original primitive data type of the numerical values to be masked. */ enum ModelDataType { /** * Numbers of type f32. */ MODEL_DATA_TYPE_F32 = 0, /** * Numbers of type f64. */ MODEL_DATA_TYPE_F64 = 1, /** * Numbers of type i32. */ MODEL_DATA_TYPE_I32 = 2, /** * Numbers of type i64. */ MODEL_DATA_TYPE_I64 = 3, }; typedef uint8_t ModelDataType; /** * A signing key pair */ typedef struct KeyPair KeyPair; /** * A participant. It embeds an internal state machine that executes the PET * protocol. However, it is the caller's responsibility to drive this state machine by * calling [`Participant::tick()`], and to take action when the participant state * changes. */ typedef struct Participant Participant; /** * A participant settings */ typedef struct Settings Settings; /** * ByteBuffer is a struct that represents an array of bytes to be sent over the FFI boundaries. * There are several cases when you might want to use this, but the primary one for us * is for returning protobuf-encoded data to Swift and Java. The type is currently rather * limited (implementing almost no functionality), however in the future it may be * more expanded. * * ## Caveats * * Note that the order of the fields is `len` (an i64) then `data` (a `*mut u8`), getting * this wrong on the other side of the FFI will cause memory corruption and crashes. * `i64` is used for the length instead of `u64` and `usize` because JNA has interop * issues with both these types. * * ### `Drop` is not implemented * * ByteBuffer does not implement Drop. This is intentional. Memory passed into it will * be leaked if it is not explicitly destroyed by calling [`ByteBuffer::destroy`], or * [`ByteBuffer::destroy_into_vec`]. This is for two reasons: * * 1. In the future, we may allow it to be used for data that is not managed by * the Rust allocator\*, and `ByteBuffer` assuming it's okay to automatically * deallocate this data with the Rust allocator. * * 2. Automatically running destructors in unsafe code is a * [frequent footgun](https://without.boats/blog/two-memory-bugs-from-ringbahn/) * (among many similar issues across many crates). * * Note that calling `destroy` manually is often not needed, as usually you should * be passing these to the function defined by [`define_bytebuffer_destructor!`] from * the other side of the FFI. * * Because this type is essentially *only* useful in unsafe or FFI code (and because * the most common usage pattern does not require manually managing the memory), it * does not implement `Drop`. * * \* Note: in the case of multiple Rust shared libraries loaded at the same time, * there may be multiple instances of "the Rust allocator" (one per shared library), * in which case we're referring to whichever instance is active for the code using * the `ByteBuffer`. Note that this doesn't occur on all platforms or build * configurations, but treating allocators in different shared libraries as fully * independent is always safe. * * ## Layout/fields * * This struct's field are not `pub` (mostly so that we can soundly implement `Send`, but also so * that we can verify rust users are constructing them appropriately), the fields, their types, and * their order are *very much* a part of the public API of this type. Consumers on the other side * of the FFI will need to know its layout. * * If this were a C struct, it would look like * * ```c,no_run * struct ByteBuffer { * // Note: This should never be negative, but values above * // INT64_MAX / i64::MAX are not allowed. * int64_t len; * // Note: nullable! * uint8_t *data; * }; * ``` * * In rust, there are two fields, in this order: `len: i64`, and `data: *mut u8`. * * For clarity, the fact that the data pointer is nullable means that `Option` is not * the same size as ByteBuffer, and additionally is not FFI-safe (the latter point is not * currently guaranteed anyway as of the time of writing this comment). * * ### Description of fields * * `data` is a pointer to an array of `len` bytes. Note that data can be a null pointer and therefore * should be checked. * * The bytes array is allocated on the heap and must be freed on it as well. Critically, if there * are multiple rust shared libraries using being used in the same application, it *must be freed * on the same heap that allocated it*, or you will corrupt both heaps. * * Typically, this object is managed on the other side of the FFI (on the "FFI consumer"), which * means you must expose a function to release the resources of `data` which can be done easily * using the [`define_bytebuffer_destructor!`] macro provided by this crate. */ typedef struct ByteBuffer { int64_t len; uint8_t *data; } ByteBuffer; /** * `FfiStr<'a>` is a safe (`#[repr(transparent)]`) wrapper around a * nul-terminated `*const c_char` (e.g. a C string). Conceptually, it is * similar to [`std::ffi::CStr`], except that it may be used in the signatures * of extern "C" functions. * * Functions accepting strings should use this instead of accepting a C string * directly. This allows us to write those functions using safe code without * allowing safe Rust to cause memory unsafety. * * A single function for constructing these from Rust ([`FfiStr::from_raw`]) * has been provided. Most of the time, this should not be necessary, and users * should accept `FfiStr` in the parameter list directly. * * ## Caveats * * An effort has been made to make this struct hard to misuse, however it is * still possible, if the `'static` lifetime is manually specified in the * struct. E.g. * * ```rust,no_run * # use ffi_support::FfiStr; * // NEVER DO THIS * #[no_mangle] * extern "C" fn never_do_this(s: FfiStr<'static>) { * // save `s` somewhere, and access it after this * // function returns. * } * ``` * * Instead, one of the following patterns should be used: * * ``` * # use ffi_support::FfiStr; * #[no_mangle] * extern "C" fn valid_use_1(s: FfiStr<'_>) { * // Use of `s` after this function returns is impossible * } * // Alternative: * #[no_mangle] * extern "C" fn valid_use_2(s: FfiStr) { * // Use of `s` after this function returns is impossible * } * ``` */ typedef const char *FfiStr; /** * The model configuration of the model that is expected in [`xaynet_ffi_participant_set_model()`]. * * [`xaynet_ffi_participant_set_model()`]: crate::ffi::xaynet_ffi_participant_set_model */ typedef struct LocalModelConfig { /** * The expected data type of the model. */ ModelDataType data_type; /** * the expected length of the model. */ uint64_t len; } LocalModelConfig; /** * Destroy the given `ByteBuffer` and free its memory. This function must only be * called on `ByteBuffer`s that have been created on the Rust side of the FFI. If you * have created a `ByteBuffer` on the other side of the FFI, do not use this function, * use `free()` instead. * * # Return value * * - [`OK`] on success * - [`ERR_NULLPTR`] if `buf` is NULL * * # Safety * * 1. When calling this method, you have to ensure that *either* the pointer is NULL * *or* all of the following is true: * - The pointer must be properly [aligned]. * - It must be "dereferencable" in the sense defined in the [`std::ptr`] module * documentation. * 2. After destroying the `ByteBuffer` the pointer becomes invalid and must not be * used. * 3. Calling this function on a `ByteBuffer` that has not been created on the Rust * side of the FFI is UB. * * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment */ int xaynet_ffi_byte_buffer_destroy(const struct ByteBuffer *buf); /** * Initialize the crypto library. This method must be called before instantiating a * participant with [`xaynet_ffi_participant_new()`] or before generating new keys with * [`xaynet_ffi_generate_key_pair()`]. * * # Return value * * - [`OK`] if the initialization succeeded * - -[`ERR_CRYPTO_INIT`] if the initialization failed * * # Safety * * This function is safe to call */ int xaynet_ffi_crypto_init(void); /** * Destroy the participant created by [`xaynet_ffi_participant_new()`] or * [`xaynet_ffi_participant_restore()`]. * * # Return value * * - [`OK`] on success * - [`ERR_NULLPTR`] if `participant` is NULL * * # Safety * * 1. When calling this method, you have to ensure that *either* the pointer is NULL * *or* all of the following is true: * - The pointer must be properly [aligned]. * - It must be "dereferencable" in the sense defined in the [`std::ptr`] module * documentation. * 2. After destroying the `Participant`, the pointer becomes invalid and must not be * used. * 3. This function should only be called on a pointer that has been created by * [`xaynet_ffi_participant_new()`] or [`xaynet_ffi_participant_restore()`] * * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment */ int xaynet_ffi_participant_destroy(struct Participant *participant); /** * Instantiate a new participant with the given settings. The participant must be * destroyed with [`xaynet_ffi_participant_destroy`]. * * # Return value * * - a NULL pointer if `settings` is NULL or if the participant creation failed * - a valid pointer to a [`Participant`] otherwise * * # Safety * * When calling this method, you have to ensure that *either* the pointer is NULL *or* * all of the following is true: * * - The pointer must be properly [aligned]. * - It must be "dereferencable" in the sense defined in the [`std::ptr`] module * documentation. * * After destroying the participant with [`xaynet_ffi_participant_destroy`] becomes * invalid and must not be used. * * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment */ struct Participant *xaynet_ffi_participant_new(const struct Settings *settings); /** * Drive the participant internal state machine. Every tick, the state machine * attempts to perform a small work unit. * * # Return value * * - [`ERR_NULLPTR`] is `participant` is NULL * - a bitflag otherwise, with the following flags: * - [`PARTICIPANT_MADE_PROGRESS`]: if set, this flag indicates that the participant * internal state machine was able to make some progress, and that the participant * state changed. This information can be used as an indication for saving the * participant state for instance. If the flag is not set, the state machine was * not able to make progress. There are many potential causes for this, including: * - the participant is not taking part to the current training round and is just * waiting for a new one to start * - the Xaynet coordinator is not reachable or has not published some * information the participant is waiting for * - the state machine is waiting for the model to be set (see * [`xaynet_ffi_participant_set_model()`]) * - [`PARTICIPANT_TASK_NONE`], [`PARTICIPANT_TASK_SUM`] and * [`PARTICIPANT_TASK_UPDATE`]: these flags are mutually exclusive, and indicate * which task the participant has been selected for, for the current round. If * [`PARTICIPANT_TASK_NONE`] is set, then the participant will just wait for a new * round to start. If [`PARTICIPANT_TASK_UPDATE`] is set, then the participant has * been selected to update the global model, and should prepare to provide a new * model once the [`PARTICIPANT_SHOULD_SET_MODEL`] flag is set. * - [`PARTICIPANT_SHOULD_SET_MODEL`]: if set, then the participant should set its * model, by calling [`xaynet_ffi_participant_set_model()`] * - [`PARTICIPANT_NEW_GLOBALMODEL`]: if set, the participant can fetch the new global * model, by calling [`xaynet_ffi_participant_global_model()`] * * # Safety * * When calling this method, you have to ensure that *either* the pointer is NULL *or* * all of the following is true: * * - The pointer must be properly [aligned]. * - It must be "dereferencable" in the sense defined in the [`std::ptr`] module * documentation. * * After destroying the participant with [`xaynet_ffi_participant_destroy`] becomes * invalid and must not be used. * * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment */ int xaynet_ffi_participant_tick(struct Participant *participant); /** * Serialize the participant state and return a buffer that contains the serialized * participant. * * # Safety * * 1. When calling this method, you have to ensure that *either* the pointer is NULL * *or* all of the following is true: * - The pointer must be properly [aligned]. * - It must be "dereferencable" in the sense defined in the [`std::ptr`] module * documentation. * 2. the `ByteBuffer` created by this function must be destroyed with * [`xaynet_ffi_participant_destroy`]. Attempting to free the memory from the other * side of the FFI is UB. * 3. This function destroys the participant. Therefore, **the pointer becomes invalid * and must not be used anymore**. Instead, a new participant should be created, * either with [`xaynet_ffi_participant_new()`] or * [`xaynet_ffi_participant_restore()`] * * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment * * # Example * * To save the participant into a file: * * ```c * const ByteBuffer *save_buf = xaynet_ffi_participant_save(participant); * assert(save_buf); * * char *path = "./participant.bin"; * FILE *f = fopen(path, "w"); * fwrite(save_buf->data, 1, save_buf->len, f); * fclose(f); * ``` */ const struct ByteBuffer *xaynet_ffi_participant_save(struct Participant *participant); /** * Restore the participant from a buffer that contained its serialized state. * * # Return value * * - a NULL pointer on failure * - a pointer to the restored participant on success * * # Safety * * When calling this method, you have to ensure that *either* the pointers are NULL * *or* all of the following is true: * - The pointers must be properly [aligned]. * - They must be "dereferencable" in the sense defined in the [`std::ptr`] module * documentation. * * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment * * # Example * * To restore a participant from a file: * * ```c * f = fopen("./participant.bin", "r"); * fseek(f, 0L, SEEK_END); * int fsize = ftell(f); * fseek(f, 0L, SEEK_SET); * ByteBuffer buf = { * .len = fsize, * .data = (uint8_t *)malloc(fsize), * }; * int n_read = fread(buf.data, 1, fsize, f); * assert(n_read == fsize); * fclose(f); * Participant *restored = * xaynet_ffi_participant_restore("http://localhost:8081", &buf); * free(buf.data); * ``` */ struct Participant *xaynet_ffi_participant_restore(FfiStr url, const struct ByteBuffer *buffer); /** * Set the participant's model. Usually this should be called when the value returned * by [`xaynet_ffi_participant_tick()`] contains the [`PARTICIPANT_SHOULD_SET_MODEL`] * flag, but it can be called anytime. The model just won't be sent to the coordinator * until it's time. * * - `buffer` should be a pointer to a buffer that contains the model * - `data_type` specifies the type of the model weights (see [`DataType`]). The C header * file generated by this crate provides an enum corresponding to the parameters: `DataType`. * - `len` is the number of weights the model has * * # Return value * * - [`OK`] if the model is set successfully * - [`ERR_NULLPTR`] if `participant` is NULL * - [`ERR_SETMODEL_DATATYPE`] if the datatype is invalid * - [`ERR_SETMODEL_MODEL`] if the model is invalid * * # Safety * * 1. When calling this method, you have to ensure that *either* the pointer is NULL * *or* all of the following is true: * - The pointer must be properly [aligned]. * - It must be "dereferencable" in the sense defined in the [`std::ptr`] module * documentation. * 2. If `len` or `data_type` do not match the model in `buffer`, this method will * result in a buffer over-read. * * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment */ int xaynet_ffi_participant_set_model(struct Participant *participant, const void *buffer, unsigned char data_type, unsigned int len); /** * Return the latest global model from the coordinator. * * - `buffer` is the array in which the global model should be copied. * - `data_type` specifies the type of the model weights (see [`DataType`]). The C header * file generated by this crate provides an enum corresponding to the parameters: `DataType`. * - `len` is the number of weights the model has * * # Return Value * * - [`OK`] if the model is set successfully * - [`ERR_NULLPTR`] if `participant` or the `buffer` is NULL * - [`GLOBALMODEL_NONE`] if no model exists * - [`ERR_GLOBALMODEL_IO`] if the communication with the coordinator failed * - [`ERR_GLOBALMODEL_DATATYPE`] if the datatype is invalid * - [`ERR_GLOBALMODEL_LEN`] if the length of the buffer does not match the length of the model * - [`ERR_GLOBALMODEL_CONVERT`] if the conversion of the model failed * * # Note * * It is **not** guaranteed, that the model configuration returned by * [`xaynet_ffi_participant_local_model_config`] corresponds to the configuration of * the global model. This means that the global model can have a different length / data type * than it is defined in model configuration. That both model configurations are the same is * only guaranteed if the model config **never** changes on the coordinator side. * * # Safety * * 1. When calling this method, you have to ensure that *either* the pointer is NULL * *or* all of the following is true: * - The pointer must be properly [aligned]. * - It must be "dereferencable" in the sense defined in the [`std::ptr`] module * documentation. * 2. If `len` or `data_type` do not match the model in `buffer`, this method will * result in a buffer over-read. * * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment */ int xaynet_ffi_participant_global_model(struct Participant *participant, void *buffer, unsigned char data_type, unsigned int len); /** * Return the local model configuration of the model that is expected in the * [`xaynet_ffi_participant_set_model()`] function. * * # Safety * * 1. When calling this method, you have to ensure that *either* the pointer is NULL * *or* all of the following is true: * - The pointer must be properly [aligned]. * - It must be "dereferencable" in the sense defined in the [`std::ptr`] module * documentation. * * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment */ struct LocalModelConfig *xaynet_ffi_participant_local_model_config(const struct Participant *participant); /** * Destroy the settings created by [`xaynet_ffi_settings_new()`]. * * # Return value * * - [`OK`] on success * - [`ERR_NULLPTR`] if `buf` is NULL * * # Safety * * 1. When calling this method, you have to ensure that *either* the pointer is NULL * *or* all of the following is true: * - The pointer must be properly [aligned]. * - It must be "dereferencable" in the sense defined in the [`std::ptr`] module * documentation. * 2. After destroying the `Settings`, the pointer becomes invalid and must not be * used. * 3. This function should only be called on a pointer that has been created by * [`xaynet_ffi_settings_new`]. * * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment */ int xaynet_ffi_settings_destroy(struct Settings *settings); /** * Create new [`Settings`] and return a pointer to it. * * # Safety * * The `Settings` created by this function must be destroyed with * [`xaynet_ffi_settings_destroy()`]. Attempting to free the memory from the other side * of the FFI is UB. */ struct Settings *xaynet_ffi_settings_new(void); /** * Set scalar setting. * * # Return value * * - [`OK`] if successful * - [`ERR_NULLPTR`] if `settings` is `NULL` * * # Safety * * When calling this method, you have to ensure that *either* the pointer is NULL *or* * all of the following is true: * - The pointer must be properly [aligned]. * - It must be "dereferencable" in the sense defined in the [`std::ptr`] module * documentation. * * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment */ int xaynet_ffi_settings_set_scalar(struct Settings *settings, double scalar); /** * Set coordinator URL. * * # Return value * * - [`OK`] if successful * - [`ERR_INVALID_URL`] if `url` is not a valid string * - [`ERR_NULLPTR`] if `settings` is `NULL` * * # Safety * * When calling this method, you have to ensure that *either* the pointers are NULL * *or* all of the following is true: * - The pointers must be properly [aligned]. * - They must be "dereferencable" in the sense defined in the [`std::ptr`] module * documentation. * * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment */ int xaynet_ffi_settings_set_url(struct Settings *settings, FfiStr url); /** * Generate a new signing key pair that can be used in the [`Settings`]. **Before * calling this function you must initialize the crypto library with * [`xaynet_ffi_crypto_init()`]**. * * The returned value contains a pointer to the secret key. For security reasons, you * must make sure that this buffer life is a short as possible, and call * [`xaynet_ffi_forget_key_pair`] to destroy it. * * [`xaynet_ffi_crypto_init()`]: crate::ffi::xaynet_ffi_crypto_init * * # Safety * * This function is safe to call */ const struct KeyPair *xaynet_ffi_generate_key_pair(void); /** * De-allocate the buffers that contain the signing keys, and zero out the content of * the buffer that contains the secret key. * * # Return value * * - [`ERR_NULLPTR`] is `key_pair` is NULL * - [`OK`] otherwise * * # Safety * * When calling this method, you have to ensure that *either* the pointer is NULL *or* * all of the following is true: * - The pointer must be properly [aligned]. * - It must be "dereferencable" in the sense defined in the [`std::ptr`] module * documentation. * * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment */ int xaynet_ffi_forget_key_pair(const struct KeyPair *key_pair); /** * Set participant signing keys. * * # Return value * * - [`OK`] if successful * - [`ERR_NULLPTR`] if `settings` or `key_pair` is `NULL` * - [`ERR_CRYPTO_PUBLIC_KEY`] if the given `key_pair` contains an invalid public key * - [`ERR_CRYPTO_SECRET_KEY`] if the given `key_pair` contains an invalid secret key * * # Safety * * When calling this method, you have to ensure that *either* the pointers are NULL * *or* all of the following is true: * - The pointers must be properly [aligned]. * - They must be "dereferencable" in the sense defined in the [`std::ptr`] module * documentation. * * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment */ int xaynet_ffi_settings_set_keys(struct Settings *settings, const struct KeyPair *key_pair); /** * Check whether the given settings are valid and can be used to instantiate a * participant (see [`xaynet_ffi_participant_new()`]). * * # Return value * * - [`OK`] on success * - [`ERR_SETTINGS_URL`] if the URL has not been set * - [`ERR_SETTINGS_KEYS`] if the signing keys have not been set * - [`ERR_SETTINGS_SCALAR`] if the scalar is out of bounds * * # Safety * * When calling this method, you have to ensure that *either* the pointer is NULL *or* * all of the following is true: * * - The pointer must be properly [aligned]. * - It must be "dereferencable" in the sense defined in the [`std::ptr`] module * documentation. * * [`xaynet_ffi_participant_new()`]: crate::ffi::xaynet_ffi_participant_new * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment */ int xaynet_ffi_check_settings(const struct Settings *settings); /** * Destroy the model configuration created by [`xaynet_ffi_participant_local_model_config()`]. * * # Return value * * - [`OK`] on success * - [`ERR_NULLPTR`] if `local_model_config` is NULL * * # Safety * * 1. When calling this method, you have to ensure that *either* the pointer is NULL * *or* all of the following is true: * - The pointer must be properly [aligned]. * - It must be "dereferencable" in the sense defined in the [`std::ptr`] module * documentation. * 2. After destroying the `LocalModelConfig`, the pointer becomes invalid and must not be * used. * 3. This function should only be called on a pointer that has been created by * [`xaynet_ffi_participant_local_model_config()`]. * * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment * [`xaynet_ffi_participant_local_model_config()`]: crate::ffi::xaynet_ffi_participant_local_model_config */ int xaynet_ffi_local_model_config_destroy(struct LocalModelConfig *local_model_config); ================================================ FILE: rust/xaynet-sdk/Cargo.toml ================================================ [package] name = "xaynet-sdk" version = "0.1.0" authors = ["Xayn Engineering "] edition = "2018" description = "The Xayn Network project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant." readme = "../../README.md" homepage = "https://xaynet.dev/" repository = "https://github.com/xaynetwork/xaynet/" license-file = "../../LICENSE" keywords = ["federated-learning", "fl", "ai", "machine-learning"] categories = ["science", "cryptography"] [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] [dependencies] async-trait = "0.1.57" base64 = "0.13.0" bincode = "1.3.3" derive_more = { version = "0.99.17", default-features = false, features = ["from"] } # TODO: remove once concurrent_futures.rs was moved to the e2e package futures = "0.3.24" paste = "1.0.8" serde = { version = "1.0.144", features = ["derive"] } sodiumoxide = "0.2.7" thiserror = "1.0.32" # TODO: move to dev-dependencies once concurrent_futures.rs was moved to the e2e package tokio = { version = "1.20.1", features = ["rt", "macros"] } tracing = "0.1.36" url = "2.2.2" xaynet-core = { path = "../xaynet-core", version = "0.2.0" } # feature: reqwest client reqwest = { version = "0.11.10", default-features = false, optional = true } # This has to match the version used by reqwest. It would be nice if # reqwest just re-exported it bytes = { version = "1.0.1", optional = true } rand = "0.8.5" [dev-dependencies] mockall = "0.11.2" num = { version = "0.4.0", features = ["serde"] } serde_json = "1.0.85" tokio-test = "0.4.1" xaynet-core = { path = "../xaynet-core", features = ["testutils"] } [features] default = [] reqwest-client = ["reqwest", "bytes"] ================================================ FILE: rust/xaynet-sdk/src/client.rs ================================================ use async_trait::async_trait; use thiserror::Error; use url::Url; use crate::XaynetClient; use xaynet_core::{ common::RoundParameters, crypto::{ByteObject, PublicSigningKey}, mask::Model, SumDict, UpdateSeedDict, }; /// Error returned upon failing to build a new [`Client`] #[derive(Debug, Error)] pub enum ClientError { #[error("failed to deserialize data: {0}")] Deserialize(String), #[error("HTTP request failed: {0}")] Http(String), #[error("{0}")] Other(String), #[error("Reading from file failed: {0}")] Io(#[from] std::io::Error), #[error("Unexpected response")] UnexpectedResponse(u16), #[error("Unexpected certificate extension")] UnexpectedCertificate, #[error("No certificate found")] NoCertificate, } #[cfg_attr(not(feature = "reqwest-client"), allow(dead_code))] impl ClientError { fn http_error(e: E) -> Self { Self::Http(format!("{}", e)) } } impl From for ClientError { fn from(e: bincode::Error) -> Self { Self::Deserialize(format!("{}", e)) } } impl From for ClientError { fn from(e: std::num::ParseIntError) -> Self { Self::Deserialize(format!("{}", e)) } } /// A basic HTTP interface that [`Client`] HTTP backends must implement. #[async_trait] pub trait XaynetHttpClient { /// Error type for all the trait's methods type Error: std::error::Error; /// Reponse type for `GET` requests type GetResponse: AsRef<[u8]>; /// Perform an HTTP `GET` on the given URL. /// /// If the response is `NO_CONTENT`, the implementor must return `Ok(None)`. Otherwise, the /// response body must be returned async fn get(&mut self, url: &str) -> Result, ClientError>; /// Perform an HTTP `POST` on the given URL, with the given body. async fn post(&mut self, url: &str, body: Vec) -> Result<(), ClientError>; } #[derive(Debug, Clone)] /// A client that communicates with the coordinator's API via HTTP(S). pub struct Client { /// HTTP(S) client client: C, /// Coordinator URL base_url: Url, } /// Error returned when trying to client a [`Client`] with an invalid /// address for the Xaynet coordinator. #[derive(Debug, Error)] #[error("Invalid base URL: {}", .0)] pub struct InvalidBaseUrl(String); impl Client where C: XaynetHttpClient, { /// Create a new client. /// /// # Args /// /// - `client` is the HTTP client that will be used to perform the HTTP requests. Any HTTP /// client can be used, as long as it implements the [`XaynetHttpClient`] trait. /// - `base_url` is the URL to the Xaynet coordinator /// /// # Errors /// /// An error is returned if `base_url` is not a valid URL pub fn new(http_client: C, base_url: &str) -> Result { let base_url = Url::parse(base_url).map_err(|e| InvalidBaseUrl(format!("{}", e)))?; if base_url.cannot_be_a_base() { return Err(InvalidBaseUrl(String::from("cannot be a base URL"))); } Ok(Self { client: http_client, base_url, }) } /// Append the given segment to the client base URL fn url(&self, segment: &str) -> Url { let mut url = self.base_url.clone(); url.path_segments_mut().unwrap().push(segment); url } async fn get(&mut self, url: &Url) -> Result, ClientError> where T: for<'a> serde::Deserialize<'a>, { Ok(match self.client.get(url.as_str()).await? { Some(data) => Some(bincode::deserialize::(data.as_ref())?), None => None, }) } async fn post(&mut self, url: &Url, data: Vec) -> Result<(), ClientError> { self.client.post(url.as_str(), data).await } } #[async_trait] impl XaynetClient for Client where C: XaynetHttpClient + Send, { type Error = ClientError; async fn get_round_params(&mut self) -> Result { let url = self.url("params"); let round_params: Option = self.get(&url).await?; round_params.ok_or_else(|| { ClientError::Other("failed to fetch round parameters: empty response".to_string()) }) } async fn get_sums(&mut self) -> Result, Self::Error> { let url = self.url("sums"); Ok(self.get(&url).await?) } async fn get_seeds( &mut self, pk: PublicSigningKey, ) -> Result, Self::Error> { let mut url = self.url("seeds"); url.query_pairs_mut() .append_pair("pk", &base64::encode(pk.as_slice())); self.get(&url).await } async fn get_model(&mut self) -> Result, Self::Error> { let url = self.url("model"); Ok(self.get(&url).await?) } async fn send_message(&mut self, msg: Vec) -> Result<(), Self::Error> { let url = self.url("message"); self.post(&url, msg).await } } #[cfg(feature = "reqwest-client")] #[cfg_attr(docsrs, doc(cfg(feature = "reqwest-client")))] #[async_trait] impl XaynetHttpClient for reqwest::Client { type Error = reqwest::Error; type GetResponse = bytes::Bytes; async fn get(&mut self, url: &str) -> Result, ClientError> { let resp = reqwest::Client::get(self, url) .send() .await .map_err(ClientError::http_error)? .error_for_status() .map_err(ClientError::http_error)?; match resp.status() { reqwest::StatusCode::OK => { Ok(Some(resp.bytes().await.map_err(ClientError::http_error)?)) } reqwest::StatusCode::NO_CONTENT => Ok(None), status => Err(ClientError::UnexpectedResponse(status.as_u16())), } } async fn post(&mut self, url: &str, body: Vec) -> Result<(), ClientError> { let _resp = reqwest::Client::post(self, url) .body(body) .send() .await .map_err(ClientError::http_error)? .error_for_status() .map_err(ClientError::http_error)?; Ok(()) } } ================================================ FILE: rust/xaynet-sdk/src/lib.rs ================================================ #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr( doc, forbid(rustdoc::broken_intra_doc_links, rustdoc::private_intra_doc_links) )] #![doc( html_logo_url = "https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/xaynet_banner.png", html_favicon_url = "https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/favicon.png", issue_tracker_base_url = "https://github.com/xaynetwork/xaynet/issues" )] //! This crate provides building blocks for implementing participants for //! the [Xaynet Federated Learning platform](https://www.xaynet.dev/). //! //! The PET protocol states that in any given round of federated learning, //! each participant of the protocol may be selected to carry out one of //! two tasks: //! //! - **update**: participants selected for the update task //! (a.k.a. _update participants_) are responsible for sending a machine //! learning model they trained //! - **sum**: participants selected for the sum task (a.k.a. _sum //! participants_) are responsible for computing a global mask from local mask seeds sent by //! the update participants //! //! Participants may also not be selected for any of these tasks, in which //! case they simply wait for the next round. //! //! # Running a participant //! //! The communication with the Xaynet coordinator is managed by a //! background task that runs the PET protocol. We call it the PET //! agent. In practice, the agent is a simple wrapper around the //! [`StateMachine`]. //! //! To run a participant, you need to start an agent, and //! interact with it. There are two types of interactions: //! //! - reacting to notifications for the agents, which include: //! - start of a new round of training //! - selection for the sum task //! - selection for the update task //! - end of a task //! - providing the agent with a Machine Learning model and a corresponding //! scalar for aggregation when the participant takes part the update task //! //! ## Implementing an agent //! //! A simple agent can be implemented as a function. //! //! ``` //! use std::time::Duration; //! //! use tokio::time::sleep; //! use xaynet_sdk::{StateMachine, TransitionOutcome}; //! //! async fn run_agent(mut state_machine: StateMachine, tick: Duration) { //! loop { //! state_machine = match state_machine.transition().await { //! // The state machine is stuck waiting for some data, //! // either from the coordinator or from the //! // participant. Let's wait a little and try again //! TransitionOutcome::Pending(state_machine) => { //! sleep(tick).await; //! state_machine //! } //! // The state machine moved forward in the PET protocol. //! // We simply continue looping, trying to make more progress. //! TransitionOutcome::Complete(state_machine) => state_machine, //! }; //! } //! } //! ``` //! //! This agent needs to be fed a [`StateMachine`] in order to run. A //! state machine requires found components: //! //! - a cryptographic key identifying the participant, see [`PetSettings`] //! - a store from which it can load a model when the participant is //! selected for the update task. This can be any type that //! implements the [`ModelStore`] trait. In our case, we'll use a //! dummy in-memory store that always returns the same model. //! - a client to talk with the Xaynet coordinator. This can be any //! type that implements the [`XaynetClient`] trait, like the [`Client`]. //! For this we're going to use the trait implementations on the `reqwest` //! client that is available when compiling with `--features reqwest-client`. //! - a notifier that the state machine can use to send //! notifications. This can be any type that implements the //! [`Notify`] trait. We'll use channels for this. //! //! [`PetSettings`]: crate::settings::PetSettings //! [`Client`]: crate::client::Client //! //! Finally we can start our agent and log the events it emits. Here //! is the full code: //! //! ```no_run //! # #[cfg(all(feature = "reqwest-client", feature = "tokio/rt-muli-thread"))] //! # mod feature_reqwest_client { //! use std::{ //! sync::{mpsc, Arc}, //! time::Duration, //! }; //! //! use async_trait::async_trait; //! use reqwest::Client as ReqwestClient; //! use tokio::time::sleep; //! //! use xaynet_core::{ //! crypto::SigningKeyPair, //! mask::{BoundType, DataType, FromPrimitives, GroupType, MaskConfig, Model, ModelType}, //! }; //! use xaynet_sdk::{ //! client::Client, //! settings::PetSettings, //! ModelStore, //! Notify, //! StateMachine, //! TransitionOutcome, //! }; //! //! async fn run_agent(mut state_machine: StateMachine, tick: Duration) { //! loop { //! state_machine = match state_machine.transition().await { //! TransitionOutcome::Pending(state_machine) => { //! sleep(tick.clone()).await; //! state_machine //! } //! TransitionOutcome::Complete(state_machine) => state_machine, //! }; //! } //! } //! //! #[derive(Debug)] //! enum Event { //! // event sent by the state machine when the participant is //! // selected for the update task //! Update, //! // event sent by the state machine when the participant is //! // selected for the sum task //! Sum, //! // event sent by the state machine when a new round starts //! NewRound, //! // event sent by the state machine when the participant //! // becomes inactive (after finishing a task for instance) //! Idle, //! // event sent by the state machine when the participant //! // is supposed to populate the model store //! LoadModel, //! } //! //! // Our notifier is a simple wrapper around a channel. //! struct Notifier(mpsc::Sender); //! //! impl Notify for Notifier { //! fn new_round(&mut self) { //! self.0.send(Event::NewRound).unwrap(); //! } //! fn sum(&mut self) { //! self.0.send(Event::Sum).unwrap(); //! } //! fn update(&mut self) { //! self.0.send(Event::Update).unwrap(); //! } //! fn idle(&mut self) { //! self.0.send(Event::Idle).unwrap(); //! } //! fn load_model(&mut self) { //! self.0.send(Event::LoadModel).unwrap(); //! } //! } //! //! // Our store will always load the same model. //! // In practice the model should be updated with //! // the model the participant trains when it is selected //! // for the update task. //! struct LocalModel(Arc); //! //! #[async_trait] //! impl ModelStore for LocalModel { //! type Model = Arc; //! type Error = std::convert::Infallible; //! //! async fn load_model(&mut self) -> Result, Self::Error> { //! Ok(Some(self.0.clone())) //! } //! } //! //! #[tokio::main] //! async fn main() -> Result<(), std::convert::Infallible> { //! let keys = SigningKeyPair::generate(); //! let settings = PetSettings::new(keys); //! let xaynet_client = Client::new(ReqwestClient::new(), "http://localhost:8081").unwrap(); //! let (tx, rx) = mpsc::channel::(); //! let notifier = Notifier(tx); //! let model = Model::from_primitives(vec![0; 100].into_iter()).unwrap(); //! let model_store = LocalModel(Arc::new(model)); //! //! let mut state_machine = StateMachine::new(settings, xaynet_client, model_store, notifier); //! // Start the agent //! tokio::spawn(async move { //! run_agent(state_machine, Duration::from_secs(1)).await; //! }); //! //! loop { //! println!("{:?}", rx.recv().unwrap()); //! } //! } //! # } //! # fn main() {} // don't actually run anything, because the client never terminates //! ``` pub mod client; mod message_encoder; pub mod settings; mod state_machine; mod traits; pub(crate) mod utils; pub(crate) use self::message_encoder::MessageEncoder; pub use self::traits::{ModelStore, Notify, XaynetClient}; pub use state_machine::{LocalModelConfig, SerializableState, StateMachine, TransitionOutcome}; ================================================ FILE: rust/xaynet-sdk/src/message_encoder/chunker.rs ================================================ #![allow(dead_code)] use std::cmp; /// Default chunk size, for [`Chunker`] pub const DEFAULT_CHUNK_SIZE: usize = 4096; /// A struct that yields chunks of the given data. pub struct Chunker<'a, T: AsRef<[u8]>> { data: &'a T, max_chunk_size: usize, } impl<'a, T> Chunker<'a, T> where T: AsRef<[u8]>, { /// Create a new [`Chunker`] that yields chunks of `T` of size /// `max_chunk_size`. If `max_chunk_size` is `0`, then the max /// chunk size will be set to [`DEFAULT_CHUNK_SIZE`]. pub fn new(data: &'a T, max_chunk_size: usize) -> Self { let max_chunk_size = if max_chunk_size == 0 { DEFAULT_CHUNK_SIZE } else { max_chunk_size }; Self { data, max_chunk_size, } } /// Get the total number of chunks pub fn nb_chunks(&self) -> usize { let data_len = self.data.as_ref().len(); ceiling_div(data_len, self.max_chunk_size) } /// Get the chunk with the given ID. /// /// # Panics /// /// This method panics if the given `id` is bigger than `self.nb_chunks()`. pub fn get_chunk(&self, id: usize) -> &'a [u8] { if id >= self.nb_chunks() { panic!("no chunk with ID {}", id); } let start = id * self.max_chunk_size; let end = cmp::min(start + self.max_chunk_size, self.data.as_ref().len()); let range = start..end; &self.data.as_ref()[range] } } /// A helper that performs division with ceil. /// /// # Panic /// /// This function panic if `d` is 0. fn ceiling_div(n: usize, d: usize) -> usize { (n + d - 1) / d } #[cfg(test)] mod tests { use super::*; #[test] #[should_panic(expected = "no chunk with ID 0")] fn test_0() { let data = vec![]; let chunker = Chunker::new(&data, 0); assert_eq!(chunker.nb_chunks(), 0); chunker.get_chunk(0); } #[test] #[should_panic(expected = "no chunk with ID 5")] fn test_1() { let data = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; let chunker = Chunker::new(&data, 2); assert_eq!(chunker.nb_chunks(), 5); assert_eq!(chunker.get_chunk(0), &[0, 1]); assert_eq!(chunker.get_chunk(1), &[2, 3]); assert_eq!(chunker.get_chunk(2), &[4, 5]); assert_eq!(chunker.get_chunk(3), &[6, 7]); assert_eq!(chunker.get_chunk(4), &[8, 9]); chunker.get_chunk(5); } #[test] #[should_panic(expected = "no chunk with ID 5")] fn test_2() { let data = vec![0, 1, 2, 3, 4, 5, 6, 7, 8]; let chunker = Chunker::new(&data, 2); assert_eq!(chunker.nb_chunks(), 5); assert_eq!(chunker.get_chunk(0), &[0, 1]); assert_eq!(chunker.get_chunk(1), &[2, 3]); assert_eq!(chunker.get_chunk(2), &[4, 5]); assert_eq!(chunker.get_chunk(3), &[6, 7]); assert_eq!(chunker.get_chunk(4), &[8]); chunker.get_chunk(5); } #[test] #[should_panic(expected = "no chunk with ID 4")] fn test_3() { let data = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; let chunker = Chunker::new(&data, 3); assert_eq!(chunker.nb_chunks(), 4); assert_eq!(chunker.get_chunk(0), &[0, 1, 2]); assert_eq!(chunker.get_chunk(1), &[3, 4, 5]); assert_eq!(chunker.get_chunk(2), &[6, 7, 8]); assert_eq!(chunker.get_chunk(3), &[9]); chunker.get_chunk(4); } #[test] #[should_panic(expected = "no chunk with ID 1")] fn test_4() { let data = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; let chunker = Chunker::new(&data, 10); assert_eq!(chunker.nb_chunks(), 1); assert_eq!(chunker.get_chunk(0), data.as_slice()); chunker.get_chunk(1); } #[test] #[should_panic(expected = "no chunk with ID 1")] fn test_5() { let data = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; let chunker = Chunker::new(&data, 0); assert_eq!(chunker.max_chunk_size, DEFAULT_CHUNK_SIZE); assert_eq!(chunker.nb_chunks(), 1); assert_eq!(chunker.get_chunk(0), data.as_slice()); chunker.get_chunk(1); } } ================================================ FILE: rust/xaynet-sdk/src/message_encoder/encoder.rs ================================================ use serde::{Deserialize, Serialize}; use thiserror::Error; use super::Chunker; use xaynet_core::{ crypto::{PublicEncryptKey, SecretSigningKey, SigningKeyPair}, message::{Chunk, Message, Payload, Tag, ToBytes}, }; /// An encoder for multipart messages. It implements /// `Iterator>`, which yields message parts ready to be /// sent over the wire. #[derive(Serialize, Deserialize, Debug)] pub struct MultipartEncoder { keys: SigningKeyPair, /// The coordinator public key. It should be the key used to /// encrypt the message. coordinator_pk: PublicEncryptKey, /// Serialized message payload. data: Vec, /// Next chunk ID to be produced by the iterator id: u16, /// Message tag tag: Tag, /// The maximum size allowed for the payload. `self.data` is split /// in chunks of this size. payload_size: usize, /// A random ID common to all the message chunks. message_id: u16, } /// Overhead induced by wrapping the data in [`Payload::Chunk`] pub const CHUNK_OVERHEAD: usize = 8; pub const MIN_PAYLOAD_SIZE: usize = CHUNK_OVERHEAD + 1; impl Iterator for MultipartEncoder { type Item = Vec; fn next(&mut self) -> Option { let chunker = Chunker::new(&self.data, self.payload_size - CHUNK_OVERHEAD); if self.id as usize >= chunker.nb_chunks() { return None; } let chunk = Chunk { id: self.id, message_id: self.message_id, last: self.id as usize == chunker.nb_chunks() - 1, data: chunker.get_chunk(self.id as usize).to_vec(), }; self.id += 1; let message = Message { // The signature is computed when serializing the message signature: None, participant_pk: self.keys.public, is_multipart: true, tag: self.tag, payload: Payload::Chunk(chunk), coordinator_pk: self.coordinator_pk, }; let data = serialize_message(&message, &self.keys.secret); Some(data) } } /// An encoder for a [`Payload`] representing a sum, update or sum2 /// message. If the [`Payload`] is small enough, a [`Message`] header /// is added, and the message is serialized and signed. If /// the [`Payload`] is too large to fit in a single message, it is /// split in chunks which are also serialized and signed. #[derive(Serialize, Deserialize, Debug)] pub enum MessageEncoder { /// Encoder for a payload that fits in a single message. Simple(Option>), /// Encoder for a large payload that needs to be split in several /// parts. Multipart(MultipartEncoder), } impl Iterator for MessageEncoder { type Item = Vec; fn next(&mut self) -> Option { match self { MessageEncoder::Simple(ref mut data) => data.take(), MessageEncoder::Multipart(ref mut multipart_encoder) => multipart_encoder.next(), } } } #[derive(Error, Debug)] pub enum InvalidEncodingInput { #[error("only sum, update, and sum2 messages can be encoded")] Payload, #[error("the max payload size is too small")] PayloadSize, } impl MessageEncoder { // NOTE: the only reason we need to consume the payload is because creating the Message // consumes it. /// Create a new encoder for the given payload. The `participant` /// is used to sign the message(s). If the serialized payload is /// larger than `max_payload_size`, the message will we split in /// multiple chunks. If `max_payload_size` is `0`, the message /// will not be split. /// /// # Errors /// /// An [`InvalidEncodingInput`] error is returned when `payload` is of /// type [`Payload::Chunk`]. Only [`Payload::Sum`], /// [`Payload::Update`], [`Payload::Sum2`] are accepted. pub fn new( keys: SigningKeyPair, payload: Payload, coordinator_pk: PublicEncryptKey, max_payload_size: usize, ) -> Result { // Reject payloads of type Payload::Chunk. It is the job of the encoder to produce those if // the payload is deemed to big to be sent in a single message if payload.is_chunk() { return Err(InvalidEncodingInput::Payload); } if max_payload_size != 0 && max_payload_size <= MIN_PAYLOAD_SIZE { return Err(InvalidEncodingInput::PayloadSize); } if max_payload_size != 0 && payload.buffer_length() > max_payload_size { Ok(Self::new_multipart( keys, coordinator_pk, payload, max_payload_size, )) } else { Ok(Self::new_simple(keys, coordinator_pk, payload)) } } fn new_simple( keys: SigningKeyPair, coordinator_pk: PublicEncryptKey, payload: Payload, ) -> Self { let message = Message { // The signature is computed when serializing the message signature: None, participant_pk: keys.public, is_multipart: false, coordinator_pk, tag: Self::get_tag_from_payload(&payload), payload, }; let data = serialize_message(&message, &keys.secret); Self::Simple(Some(data)) } fn new_multipart( keys: SigningKeyPair, coordinator_pk: PublicEncryptKey, payload: Payload, payload_size: usize, ) -> Self { let tag = Self::get_tag_from_payload(&payload); let mut data = vec![0; payload.buffer_length()]; payload.to_bytes(&mut data); Self::Multipart(MultipartEncoder { keys, data, id: 0, tag, coordinator_pk, payload_size, message_id: rand::random::(), }) } fn get_tag_from_payload(payload: &Payload) -> Tag { match payload { Payload::Sum(_) => Tag::Sum, Payload::Update(_) => Tag::Update, Payload::Sum2(_) => Tag::Sum2, Payload::Chunk(_) => panic!("no tag associated to Payload::Chunk"), } } } #[cfg(test)] mod tests { use xaynet_core::{ crypto::{ByteObject, EncryptKeyPair, EncryptKeySeed, SigningKeyPair, SigningKeySeed}, message::{FromBytes, Update}, testutils::multipart as helpers, }; use super::*; fn participant_keys() -> SigningKeyPair { let seed = SigningKeySeed::from_slice(vec![0; 32].as_slice()).unwrap(); SigningKeyPair::derive_from_seed(&seed) } fn coordinator_keys() -> EncryptKeyPair { let seed = EncryptKeySeed::from_slice(vec![0; 32].as_slice()).unwrap(); EncryptKeyPair::derive_from_seed(&seed) } fn message(dict_len: usize, mask_obj_len: usize) -> Message { let payload = helpers::update(dict_len, mask_obj_len).into(); Message { signature: None, participant_pk: participant_keys().public, is_multipart: false, tag: Tag::Update, payload, coordinator_pk: coordinator_keys().public, } } fn small_message() -> Message { let dict_len = 80 + 32 + 4; // 116 => dict with a single entry let model_len = 6 + 18; // 24 => masked model with single weight let message = message(dict_len, model_len); let payload_len = dict_len + model_len + 64 * 2; // 268 let message_len = payload_len + 136; // 404 assert_eq!(message.payload.buffer_length(), payload_len); assert_eq!(message.buffer_length(), message_len); message } #[test] fn no_chunk() { let msg = small_message(); let mut enc = MessageEncoder::new( participant_keys(), msg.clone().payload, msg.coordinator_pk, 272, ) .unwrap(); let data = enc.next().unwrap(); let parsed = Message::from_byte_slice(&data.as_slice()).unwrap(); assert!(!parsed.is_multipart); assert_eq!(parsed.payload, msg.payload); assert!(enc.next().is_none()); } #[test] fn two_chunks() { let msg = small_message(); let mut enc = MessageEncoder::new( participant_keys(), msg.clone().payload, msg.coordinator_pk, 200, ) .unwrap(); let data = enc.next().unwrap(); // The payload should be 200 bytes + 136 bytes for the // message header. // // 8 of these 200 payload bytes are for the Chunk payload // header. So this chunk actually only contains 192 bytes (out // of 268) from the Update payload. So 76 bytes remain. assert_eq!(data.len(), 200 + 136); let parsed = Message::from_byte_slice(&data.as_slice()).unwrap(); assert!(parsed.is_multipart); let chunk1 = extract_chunk(parsed); assert!(!chunk1.last); assert_eq!(chunk1.id, 0); assert_eq!(chunk1.data.len(), 192); let data = enc.next().unwrap(); // The payload should be 76 bytes + 8 bytes of CHUNK_OVERHEAD, // plus 136 byte for the message header assert_eq!(data.len(), 84 + 136); let parsed = Message::from_byte_slice(&data.as_slice()).unwrap(); assert!(parsed.is_multipart); let chunk2 = extract_chunk(parsed); assert!(chunk2.last); assert_eq!(chunk2.id, 1); assert_eq!(chunk2.data.len(), 76); let payload_data: Vec = [chunk1.data, chunk2.data].concat(); let update = Update::from_byte_slice(&payload_data).unwrap(); assert_eq!(update, extract_update(msg)); } fn extract_chunk(message: Message) -> Chunk { if let Payload::Chunk(c) = message.payload { c } else { panic!("not a chunk message"); } } fn extract_update(message: Message) -> Update { if let Payload::Update(u) = message.payload { u } else { panic!("not an update message"); } } } fn serialize_message(message: &Message, sk: &SecretSigningKey) -> Vec { let mut buf = vec![0; message.buffer_length()]; message.to_bytes(&mut buf, sk); buf } ================================================ FILE: rust/xaynet-sdk/src/message_encoder/mod.rs ================================================ mod chunker; mod encoder; use chunker::Chunker; pub use encoder::MessageEncoder; ================================================ FILE: rust/xaynet-sdk/src/settings/max_message_size.rs ================================================ use serde::{de::Error as SerdeError, Deserialize, Deserializer, Serialize}; use thiserror::Error; pub use xaynet_core::message::MESSAGE_HEADER_LENGTH; /// The minimum message payload size pub const MINIMUM_PAYLOAD_SIZE: usize = 1; /// Length of the encryption header in encrypted messages pub const ENCRYPTION_HEADER_LENGTH: usize = xaynet_core::crypto::SEALBYTES; /// The minimum size a message can have pub const MIN_MESSAGE_SIZE: usize = MESSAGE_HEADER_LENGTH + ENCRYPTION_HEADER_LENGTH + MINIMUM_PAYLOAD_SIZE; /// Invalid [`MaxMessageSize`] value #[derive(Debug, Error)] #[error("max message size must be at least {}", MIN_MESSAGE_SIZE)] pub struct InvalidMaxMessageSize; /// Represent the maximum size messages sent by a participant can /// have. If a larger message needs to be sent, it will be chunked and /// sent in several parts. Note that messages have a minimal size of /// [`MIN_MESSAGE_SIZE`]. #[derive(Serialize, Deserialize, Clone, Copy, Debug)] pub struct MaxMessageSize(#[serde(deserialize_with = "deserialize")] Option); impl Default for MaxMessageSize { fn default() -> Self { MaxMessageSize(Some( 4096 - MESSAGE_HEADER_LENGTH - ENCRYPTION_HEADER_LENGTH, )) } } impl MaxMessageSize { /// An arbitrary large maximum message size. With this setting, /// messages will never be split. pub fn unlimited() -> Self { MaxMessageSize(None) } /// Create a max message size of `size`. /// /// # Errors /// /// This method returns an [`InvalidMaxMessageSize`] error if /// `size` is smaller than [`MIN_MESSAGE_SIZE`]; pub fn capped(size: usize) -> Result { if size >= MIN_MESSAGE_SIZE { Ok(MaxMessageSize(Some(size))) } else { Err(InvalidMaxMessageSize) } } /// Get the maximum payload size corresponding to the maximum /// message size. `None` means that the payload size is unlimited. pub fn max_payload_size(&self) -> Option { self.0 .map(|size| size - MESSAGE_HEADER_LENGTH - ENCRYPTION_HEADER_LENGTH) } } fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { let value: Option = Option::deserialize(deserializer)?; match value { Some(size) => { if size >= MIN_MESSAGE_SIZE { Ok(Some(size)) } else { Err(SerdeError::custom(format!( "max_message_size must be at least {} (got {})", MIN_MESSAGE_SIZE, size ))) } } None => Ok(None), } } #[cfg(test)] mod tests { use serde_json::json; use super::*; #[test] fn max_message_size_deserialization_capped() { let input = r#"{"some":1000}"#; let expected = json!({"some": MaxMessageSize::capped(1000).unwrap()}); let actual: serde_json::Value = serde_json::from_str(input).unwrap(); assert_eq!(expected, actual); } #[test] fn max_message_size_deserialization_unlimited() { let input = r#"{"none":null}"#; let expected = json!({ "none": MaxMessageSize::unlimited() }); let actual: serde_json::Value = serde_json::from_str(input).unwrap(); assert_eq!(expected, actual); } #[test] fn max_message_size_deserialization_err() { // Use a dummy struct, otherwise, serde deserializes the value // as an integer. #[derive(Deserialize, Serialize, Debug)] struct Dummy { mms: MaxMessageSize, } let input = r#"{"mms":123}"#; let expected = "max_message_size must be at least 185 (got 123) at line 1 column 11".to_string(); let actual = serde_json::from_str::(input).unwrap_err(); assert_eq!(expected, format!("{}", actual)); } #[test] fn max_message_size_serialization_capped() { let input = json!({"some": MaxMessageSize::capped(1000).unwrap()}); let expected = r#"{"some":1000}"#; let actual = serde_json::to_string(&input).unwrap(); assert_eq!(expected, actual); } #[test] fn max_message_size_serialization_unlimited() { let input = json!({ "none": MaxMessageSize::unlimited() }); let expected = r#"{"none":null}"#; let actual = serde_json::to_string(&input).unwrap(); assert_eq!(expected, actual); } } ================================================ FILE: rust/xaynet-sdk/src/settings/mod.rs ================================================ mod max_message_size; use serde::{Deserialize, Serialize}; pub use max_message_size::{InvalidMaxMessageSize, MaxMessageSize, MIN_MESSAGE_SIZE}; use xaynet_core::{crypto::SigningKeyPair, mask::Scalar}; #[derive(Serialize, Deserialize, Debug)] pub struct PetSettings { pub keys: SigningKeyPair, pub scalar: Scalar, pub max_message_size: MaxMessageSize, } impl PetSettings { pub fn new(keys: SigningKeyPair) -> Self { PetSettings { keys, scalar: Scalar::unit(), max_message_size: MaxMessageSize::default(), } } } ================================================ FILE: rust/xaynet-sdk/src/state_machine/io.rs ================================================ use std::error::Error; use async_trait::async_trait; use xaynet_core::{ common::RoundParameters, mask::Model, SumDict, SumParticipantPublicKey, UpdateSeedDict, }; use crate::{ModelStore, Notify, XaynetClient}; /// Returned a dynamically dispatched [`IO`] object pub(crate) fn boxed_io( xaynet_client: X, model_store: M, notifier: N, ) -> Box + Send>>> where X: XaynetClient + Send + 'static, M: ModelStore + Send + 'static, N: Notify + Send + 'static, { Box::new(StateMachineIO::new(xaynet_client, model_store, notifier)) } #[cfg(test)] type DynModel = Box<(dyn std::convert::AsRef + Send)>; /// A trait that gathers all the [`Notify`], [`XaynetClient`] and [`ModelStore`] /// methods. /// /// This trait is intended not to be exposed. It is a convenience for avoiding the /// proliferation of generic parameters in the state machine: instead of three traits, /// we now have only one. /// /// Note that by having only one trait, we can also use dynamic dispatch and actually /// get rid of all the generic parameters in the state machine. /// /// ```compile_fail /// Box // allowed /// Box // not allowed /// ``` #[cfg_attr(test, mockall::automock(type Model=DynModel;))] #[async_trait] pub(crate) trait IO: Send + 'static { type Model; /// Attempt to load the model from the store. async fn load_model(&mut self) -> Result, Box>; /// Fetch the round parameters from the coordinator async fn get_round_params(&mut self) -> Result>; /// Fetch the sum dictionary from the coordinator async fn get_sums(&mut self) -> Result, Box>; /// Fetch the seed dictionary for the given sum participant from the coordinator async fn get_seeds( &mut self, pk: SumParticipantPublicKey, ) -> Result, Box>; /// Fetch the latest global model from the coordinator async fn get_model(&mut self) -> Result, Box>; /// Send the given signed and encrypted PET message to the coordinator async fn send_message(&mut self, msg: Vec) -> Result<(), Box>; /// Notify the participant that a new round started fn notify_new_round(&mut self); /// Notify the participant that they have been selected for the sum task for the current /// round fn notify_sum(&mut self); /// Notify the participant that it is selected for the update task for the current /// round fn notify_update(&mut self); /// Notify the participant that is done with its current task and it waiting for /// being selected for a task fn notify_idle(&mut self); /// Notify the participant that is is expected to provide a model to the state /// machine by loading it into the store fn notify_load_model(&mut self); } /// Internal struct that implements the [`IO`] trait. It is not used as is in the state /// machine. Instead, we box it and use it as a `dyn IO` object. struct StateMachineIO { xaynet_client: X, model_store: M, notifier: N, } impl StateMachineIO { /// Create a new `StateMachineIO` pub fn new(xaynet_client: X, model_store: M, notifier: N) -> Self { Self { xaynet_client, model_store, notifier, } } } #[async_trait] impl IO for StateMachineIO where X: XaynetClient + Send + 'static, M: ModelStore + Send + 'static, N: Notify + Send + 'static, { type Model = Box + Send>; async fn load_model(&mut self) -> Result, Box> { self.model_store .load_model() .await .map_err(|e| Box::new(e) as Box) .map(|opt| opt.map(|model| Box::new(model) as Box + Send>)) } async fn get_round_params(&mut self) -> Result> { self.xaynet_client .get_round_params() .await .map_err(|e| Box::new(e) as Box) } async fn get_sums(&mut self) -> Result, Box> { self.xaynet_client .get_sums() .await .map_err(|e| Box::new(e) as Box) } async fn get_seeds( &mut self, pk: SumParticipantPublicKey, ) -> Result, Box> { self.xaynet_client .get_seeds(pk) .await .map_err(|e| Box::new(e) as Box) } async fn get_model(&mut self) -> Result, Box> { self.xaynet_client .get_model() .await .map_err(|e| Box::new(e) as Box) } async fn send_message(&mut self, msg: Vec) -> Result<(), Box> { self.xaynet_client .send_message(msg) .await .map_err(|e| Box::new(e) as Box) } fn notify_new_round(&mut self) { self.notifier.new_round() } fn notify_sum(&mut self) { self.notifier.sum() } fn notify_update(&mut self) { self.notifier.update() } fn notify_idle(&mut self) { self.notifier.idle() } fn notify_load_model(&mut self) { self.notifier.load_model() } } #[async_trait] impl IO for Box + Send>>> { type Model = Box + Send>; async fn load_model(&mut self) -> Result, Box> { self.as_mut().load_model().await } async fn get_round_params(&mut self) -> Result> { self.as_mut().get_round_params().await } async fn get_sums(&mut self) -> Result, Box> { self.as_mut().get_sums().await } async fn get_seeds( &mut self, pk: SumParticipantPublicKey, ) -> Result, Box> { self.as_mut().get_seeds(pk).await } async fn get_model(&mut self) -> Result, Box> { self.as_mut().get_model().await } async fn send_message(&mut self, msg: Vec) -> Result<(), Box> { self.as_mut().send_message(msg).await } fn notify_new_round(&mut self) { self.as_mut().notify_new_round() } fn notify_sum(&mut self) { self.as_mut().notify_sum() } fn notify_update(&mut self) { self.as_mut().notify_update() } fn notify_idle(&mut self) { self.as_mut().notify_idle() } fn notify_load_model(&mut self) { self.as_mut().notify_load_model() } } ================================================ FILE: rust/xaynet-sdk/src/state_machine/mod.rs ================================================ // Important the macro_use modules must be declared first for the // macro to be used in the other modules (until declarative macros are stable) #[macro_use] mod phase; mod io; mod phases; #[allow(clippy::module_inception)] mod state_machine; // It is useful to re-export everything within this module because // there are lot of interdependencies between all the sub-modules #[cfg(test)] use self::io::MockIO; use self::{ io::{boxed_io, IO}, phase::{IntoPhase, Phase, PhaseIo, Progress, SharedState, State, Step}, phases::{Awaiting, NewRound, SendingSum, SendingSum2, SendingUpdate, Sum, Sum2, Update}, }; pub use self::{ phase::{LocalModelConfig, SerializableState}, state_machine::{StateMachine, TransitionOutcome}, }; #[cfg(test)] pub mod tests; ================================================ FILE: rust/xaynet-sdk/src/state_machine/phase.rs ================================================ use async_trait::async_trait; use derive_more::From; use serde::{Deserialize, Serialize}; use thiserror::Error; use tracing::{debug, error, info, warn}; use super::{Awaiting, NewRound, SendingSum, SendingSum2, SendingUpdate, Sum, Sum2, Update, IO}; use crate::{ settings::{MaxMessageSize, PetSettings}, state_machine::{StateMachine, TransitionOutcome}, MessageEncoder, }; use xaynet_core::{ common::{RoundParameters, RoundSeed}, crypto::{ByteObject, PublicEncryptKey, SigningKeyPair}, mask::{self, DataType, MaskConfig, Model, Scalar}, message::Payload, }; /// State of the state machine #[derive(Debug, Serialize, Deserialize)] pub struct State

{ /// data specific to the current phase pub private: Box

, /// data common to most of the phases pub shared: Box, } impl

State

{ /// Create a new state pub fn new(shared: Box, private: Box

) -> Self { Self { private, shared } } } /// A dynamically dispatched [`IO`] object. pub(crate) type PhaseIo = Box + Send>>>; /// Represent the state machine in a specific phase pub struct Phase

{ /// State of the phase. pub(super) state: State

, /// Opaque client for performing IO tasks: talking with the /// coordinator API, loading models, etc. pub(super) io: PhaseIo, } impl

std::fmt::Debug for Phase

where P: std::fmt::Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Phase") .field("state", &self.state) .field("io", &"PhaseIo") .finish() } } /// Store for all the data that are common to all the phases #[derive(Serialize, Deserialize, Debug)] pub struct SharedState { /// Keys that identify the participant. They are used to sign the /// PET message sent by the participant. pub keys: SigningKeyPair, /// Scalar used for masking pub scalar: Scalar, /// Maximum message size the participant can send. Messages larger /// than `message_size` are split in several parts. pub message_size: MaxMessageSize, /// Current round parameters pub round_params: RoundParameters, } /// Get arbitrary round parameters. These round parameters are never used, we just /// temporarily use them in the [`SharedState`] when creating a new state machine. The /// first thing the state machine does when it runs, is to fetch the real round /// parameters from the coordinator. fn dummy_round_parameters() -> RoundParameters { RoundParameters { pk: PublicEncryptKey::zeroed(), sum: 0.0, update: 0.0, seed: RoundSeed::zeroed(), mask_config: MaskConfig { group_type: mask::GroupType::Integer, data_type: mask::DataType::F32, bound_type: mask::BoundType::B0, model_type: mask::ModelType::M3, } .into(), model_length: 0, } } impl SharedState { pub fn new(settings: PetSettings) -> Self { Self { keys: settings.keys, scalar: settings.scalar, message_size: settings.max_message_size, round_params: dummy_round_parameters(), } } } /// A trait that each `Phase

` implements. When `Step::step` is called, the phase /// tries to do a small piece of work. #[async_trait] pub trait Step { /// Represent an attempt to make progress within a phase. If the step results in a /// change in the phase state, the updated state machine is returned as /// `TransitionOutcome::Complete`. If no progress can be made, the state machine is /// returned unchanged as `TransitionOutcome::Pending`. async fn step(mut self) -> TransitionOutcome; } #[macro_export] macro_rules! try_progress { ($progress:expr) => {{ use $crate::state_machine::{Progress, TransitionOutcome}; match $progress { // No progress can be made. Return the state machine as is Progress::Stuck(phase) => return TransitionOutcome::Pending(phase.into()), // Further progress can be made but require more work, so don't return Progress::Continue(phase) => phase, // Progress has been made, return the updated state machine Progress::Updated(state_machine) => return TransitionOutcome::Complete(state_machine), } }}; } /// Represent the presence or absence of progress being made during a phase. #[derive(Debug)] pub enum Progress

{ /// No progress can be made currently. Stuck(Phase

), /// More work needs to be done for progress to be made. Continue(Phase

), /// Progress has been made and resulted in this new state machine. Updated(StateMachine), } impl

Phase

where Phase

: Step + Into, { /// Try to make some progress in the execution of the PET protocol. There are three /// possible outcomes: /// /// 1. no progress can currently be made and the phase state is unchanged /// 2. progress is made but the state machine does not transition to a new /// phase. Internally, the phase state is changed though. /// 3. progress is made and the state machine transitions to a new phase. /// /// In case `1.`, the state machine is returned unchanged, wrapped in /// [`TransitionOutcome::Pending`] to indicate to the caller that the state machine /// wasn't updated. In case `2.` and `3.` the updated state machine is returned /// wrapped in [`TransitionOutcome::Complete`]. pub async fn step(mut self) -> TransitionOutcome { match self.check_round_freshness().await { RoundFreshness::Unknown => TransitionOutcome::Pending(self.into()), RoundFreshness::Outdated => { info!("a new round started: updating the round parameters and resetting the state machine"); self.io.notify_new_round(); TransitionOutcome::Complete( Phase::::new( State::new(self.state.shared, Box::new(NewRound)), self.io, ) .into(), ) } RoundFreshness::Fresh => { debug!("round is still fresh, continuing from where we left off"); ::step(self).await } } } /// Check whether the coordinator has published new round parameters. In other /// words, this checks whether a new round has started. async fn check_round_freshness(&mut self) -> RoundFreshness { match self.io.get_round_params().await { Err(e) => { warn!("failed to fetch round parameters {:?}", e); RoundFreshness::Unknown } Ok(params) => { if params == self.state.shared.round_params { debug!("round parameters didn't change"); RoundFreshness::Fresh } else { info!("fetched fresh round parameters"); self.state.shared.round_params = params; RoundFreshness::Outdated } } } } } /// Trait for building [`Phase

`] from a [`State

`]. /// /// Note that we could just use [`Phase::new`] for this. However we want to be able to /// customize the conversion for each phase. For instance, when building a /// `Phase` from an `Update`, we want to emit some events with the `io` /// object. It is cleaner to wrap this custom logic in a trait impl. pub(crate) trait IntoPhase

{ /// Build the phase with the given `io` object fn into_phase(self, io: PhaseIo) -> Phase

; } impl

Phase

{ /// Build a new phase with the given state and io object. This should not be called /// directly. Instead, use the [`IntoPhase`] trait to construct a phase. pub(crate) fn new(state: State

, io: PhaseIo) -> Self { Phase { state, io } } /// Instantiate a message encoder for the given payload. /// /// The encoder takes care of converting the given `payload` into one or several /// signed and encrypted PET messages. pub fn message_encoder(&self, payload: Payload) -> MessageEncoder { MessageEncoder::new( self.state.shared.keys.clone(), payload, self.state.shared.round_params.pk, self.state .shared .message_size .max_payload_size() .unwrap_or(0), ) // the encoder rejects Chunk payload, but in the state // machine, we never manually create such payloads so // unwrapping is fine .unwrap() } /// Return the local model configuration of the model that is expected in the update phase. pub fn local_model_config(&self) -> LocalModelConfig { LocalModelConfig { data_type: self.state.shared.round_params.mask_config.vect.data_type, len: self.state.shared.round_params.model_length, } } #[cfg(test)] pub(crate) fn with_io_mock(&mut self, f: F) where F: FnOnce(&mut super::MockIO), { let mut mock = super::MockIO::new(); f(&mut mock); self.io = Box::new(mock); } #[cfg(test)] pub(crate) fn check_io_mock(&mut self) { // dropping the mock forces the checks to run. We replace it // by an empty one, so that we detect if a method is called // un-expectedly afterwards let _ = std::mem::replace(&mut self.io, Box::new(super::MockIO::new())); } } #[derive(Debug)] /// The local model configuration of the model that is expected in the update phase. pub struct LocalModelConfig { /// The expected data type of the local model. // In the current state it is not possible to configure a coordinator in which // the scalar data type and the model data type are different. Therefore, we assume here // that the scalar data type is the same as the model data type. pub data_type: DataType, /// The expected length of the local model. pub len: usize, } #[derive(Error, Debug)] #[error("failed to send a PET message")] pub struct SendMessageError; /// Round freshness indicator pub enum RoundFreshness { /// A new round started. The current round is outdated Outdated, /// We were not able to check whether a new round started Unknown, /// The current round is still going Fresh, } /// A serializable representation of a phase state. /// /// We cannot serialize the state directly, even though it implements `Serialize`, because deserializing it would require knowing its type in advance: /// /// ```compile_fail /// // `buf` is a Vec that contains a serialized state that we want to deserialize /// let state: State = State::deserialize(&buf[..]).unwrap(); /// ``` #[derive(Serialize, Deserialize, From, Debug)] pub enum SerializableState { NewRound(State), Awaiting(State), Sum(State), Update(State), Sum2(State), SendingSum(State), SendingUpdate(State), SendingSum2(State), } impl

From> for SerializableState where State

: Into, { fn from(phase: Phase

) -> Self { phase.state.into() } } ================================================ FILE: rust/xaynet-sdk/src/state_machine/phases/awaiting.rs ================================================ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use tracing::info; use crate::state_machine::{IntoPhase, Phase, PhaseIo, State, Step, TransitionOutcome}; #[derive(Serialize, Deserialize, Debug)] pub struct Awaiting; #[async_trait] impl Step for Phase { async fn step(mut self) -> TransitionOutcome { info!("awaiting task"); TransitionOutcome::Pending(self.into()) } } impl IntoPhase for State { fn into_phase(self, mut io: PhaseIo) -> Phase { io.notify_idle(); Phase::<_>::new(self, io) } } ================================================ FILE: rust/xaynet-sdk/src/state_machine/phases/mod.rs ================================================ mod awaiting; mod new_round; mod sending; mod sum; mod sum2; mod update; pub use self::{ awaiting::Awaiting, new_round::NewRound, sending::{SendingSum, SendingSum2, SendingUpdate}, sum::Sum, sum2::Sum2, update::Update, }; ================================================ FILE: rust/xaynet-sdk/src/state_machine/phases/new_round.rs ================================================ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use tracing::info; use xaynet_core::crypto::{ByteObject, Signature}; use crate::state_machine::{ Awaiting, IntoPhase, Phase, PhaseIo, State, Step, Sum, TransitionOutcome, Update, }; #[derive(Serialize, Deserialize, Debug)] pub struct NewRound; impl IntoPhase for State { fn into_phase(self, mut io: PhaseIo) -> Phase { io.notify_new_round(); Phase::<_>::new(self, io) } } #[async_trait] impl Step for Phase { async fn step(mut self) -> TransitionOutcome { info!("new_round task"); info!("checking eligibility for sum task"); let sum_signature = self.sign(b"sum"); if sum_signature.is_eligible(self.state.shared.round_params.sum) { info!("eligible for sum task"); return TransitionOutcome::Complete(self.into_sum(sum_signature).into()); } info!("not eligible for sum task, checking eligibility for update task"); let update_signature = self.sign(b"update"); if update_signature.is_eligible(self.state.shared.round_params.update) { info!("eligible for update task"); return TransitionOutcome::Complete( self.into_update(sum_signature, update_signature).into(), ); } info!("not eligible for update task, going to sleep until next round"); let awaiting: Phase = self.into(); TransitionOutcome::Complete(awaiting.into()) } } impl From> for Phase { fn from(new_round: Phase) -> Self { State::new(new_round.state.shared, Box::new(Awaiting)).into_phase(new_round.io) } } impl Phase { fn sign(&self, data: &[u8]) -> Signature { let sk = &self.state.shared.keys.secret; let seed = self.state.shared.round_params.seed.as_slice(); sk.sign_detached(&[seed, data].concat()) } fn into_sum(self, sum_signature: Signature) -> Phase { let sum = Box::new(Sum::new(sum_signature)); let state = State::new(self.state.shared, sum); state.into_phase(self.io) } fn into_update(self, sum_signature: Signature, update_signature: Signature) -> Phase { let update = Box::new(Update::new(sum_signature, update_signature)); let state = State::new(self.state.shared, update); state.into_phase(self.io) } } ================================================ FILE: rust/xaynet-sdk/src/state_machine/phases/sending.rs ================================================ use async_trait::async_trait; use paste::paste; use serde::{Deserialize, Serialize}; use tracing::{debug, error, info}; use crate::{ state_machine::{ phases::Sum2, Awaiting, IntoPhase, Phase, PhaseIo, Progress, State, Step, TransitionOutcome, IO, }, MessageEncoder, }; /// Implements the `SendingSum`, `SendingUpdate` and `SendingSum2` phases and transitions. macro_rules! impl_sending { ($Phase: ty, $Next: ty, $phase: expr, $next: expr) => { paste! { #[doc = "The state of the " $phase " sending phase."] #[derive(Serialize, Deserialize, Debug)] pub struct [] { /// The message to send. message: MessageEncoder, /// Chunk that couldn't be sent and should be tried again. failed: Option>, /// State of the phase to transition to, after this one completes. next: $Next, } impl [] { #[doc = "Creates a new " $phase " sending state."] pub fn new(message: MessageEncoder, next: $Next) -> Self { Self { message, failed: None, next, } } } impl IntoPhase<[]> for State<[]> { fn into_phase(self, io: PhaseIo) -> Phase<[]> { Phase::<_>::new(self, io) } } #[async_trait] impl Step for Phase<[]> { async fn step(mut self) -> TransitionOutcome { info!("sending {} message", $phase); self = try_progress!(self.send_next().await); info!("done sending {} message, going to {} phase", $phase, $next); let phase: Phase<$Next> = self.into(); TransitionOutcome::Complete(phase.into()) } } impl From]>> for Phase<$Next> { fn from(sending: Phase<[]>) -> Self { State::new(sending.state.shared, Box::new(sending.state.private.next)) .into_phase(sending.io) } } impl Phase<[]> { #[doc = "Tries to send a " $phase " message and reports back on the progress made."] async fn try_send(mut self, data: Vec) -> Progress<[]> { info!("sending {} message (size = {})", $phase, data.len()); if let Err(e) = self.io.send_message(data.clone()).await { error!("failed to send {} message: {:?}", $phase, e); self.state.private.failed = Some(data); Progress::Stuck(self) } else { Progress::Updated(self.into()) } } #[doc = "Sends the next " $phase " message and reports back on the progress made.\n" "\n" "Retries to send a previously failed message. Otherwise, tries to send the " "next message." ] async fn send_next(mut self) -> Progress<[]> { if let Some(data) = self.state.private.failed.take() { debug!( "retrying to send {} message that couldn't be sent previously", $phase ); self.try_send(data).await } else { match self.state.private.message.next() { Some(data) => { let data = self.state.shared.round_params.pk.encrypt(data.as_slice()); self.try_send(data).await } None => { debug!("nothing left to send"); Progress::Continue(self) } } } } } } } } impl_sending!(Sum, Sum2, "sum", "sum2"); impl_sending!(Update, Awaiting, "update", "awaiting"); impl_sending!(Sum2, Awaiting, "sum2", "awaiting"); ================================================ FILE: rust/xaynet-sdk/src/state_machine/phases/sum.rs ================================================ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use tracing::{debug, info}; use crate::{ state_machine::{IntoPhase, Phase, PhaseIo, SendingSum, State, Step, Sum2, TransitionOutcome}, MessageEncoder, }; use xaynet_core::{ crypto::{EncryptKeyPair, Signature}, message::Sum as SumMessage, }; use super::Awaiting; /// The state of the sum phase. #[derive(Serialize, Deserialize, Debug)] pub struct Sum { /// The sum participant ephemeral keys. They are used to decrypt /// the encrypted mask seeds. pub ephm_keys: EncryptKeyPair, /// Signature that proves that the participant has been selected /// for the sum task. pub sum_signature: Signature, } impl Sum { /// Creates a new sum state. pub fn new(sum_signature: Signature) -> Self { Sum { ephm_keys: EncryptKeyPair::generate(), sum_signature, } } } impl IntoPhase for State { fn into_phase(self, mut io: PhaseIo) -> Phase { io.notify_sum(); Phase::<_>::new(self, io) } } #[async_trait] impl Step for Phase { async fn step(mut self) -> TransitionOutcome { info!("sum task"); let sending: Phase = self.into(); TransitionOutcome::Complete(sending.into()) } } impl From> for Phase { fn from(sum: Phase) -> Self { debug!("composing sum message"); let message = sum.compose_message(); debug!("going to sending phase"); let sum2 = Sum2::new(sum.state.private.ephm_keys, sum.state.private.sum_signature); let sending = Box::new(SendingSum::new(message, sum2)); let state = State::new(sum.state.shared, sending); state.into_phase(sum.io) } } impl From> for Phase { fn from(sum: Phase) -> Self { State::new(sum.state.shared, Box::new(Awaiting)).into_phase(sum.io) } } impl Phase { /// Creates and encodes the sum message from the sum state. pub fn compose_message(&self) -> MessageEncoder { let sum = SumMessage { sum_signature: self.state.private.sum_signature, ephm_pk: self.state.private.ephm_keys.public, }; self.message_encoder(sum.into()) } } ================================================ FILE: rust/xaynet-sdk/src/state_machine/phases/sum2.rs ================================================ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use tracing::{debug, error, info, warn}; use xaynet_core::{ crypto::{EncryptKeyPair, Signature}, mask::{Aggregation, MaskObject, MaskSeed}, message::Sum2 as Sum2Message, UpdateSeedDict, }; use crate::{ state_machine::{ IntoPhase, Phase, PhaseIo, Progress, SendingSum2, State, Step, TransitionOutcome, IO, }, MessageEncoder, }; use super::Awaiting; /// The state of the sum2 phase. #[derive(Serialize, Deserialize, Debug)] pub struct Sum2 { /// The sum participant ephemeral keys. They are used to decrypt /// the encrypted mask seeds. pub ephm_keys: EncryptKeyPair, /// Signature that proves that the participant has been selected /// for the sum task. pub sum_signature: Signature, /// Dictionary containing the encrypted mask seed of every update /// participants. pub seed_dict: Option, /// The decrypted mask seeds pub seeds: Option>, /// The global mask, obtained by aggregating the masks derived /// from the mask seeds. pub mask: Option, } impl Sum2 { /// Creates a new sum2 state. pub fn new(ephm_keys: EncryptKeyPair, sum_signature: Signature) -> Self { Self { ephm_keys, sum_signature, seed_dict: None, seeds: None, mask: None, } } /// Checks if the seed dict has already been fetched. fn has_fetched_seed_dict(&self) -> bool { self.seed_dict.is_some() || self.has_decrypted_seeds() } /// Checks if the seeds have already been decrypted. fn has_decrypted_seeds(&self) -> bool { self.seeds.is_some() || self.has_aggregated_masks() } /// Checks if the masks have already been aggregated. fn has_aggregated_masks(&self) -> bool { self.mask.is_some() } } impl IntoPhase for State { fn into_phase(self, io: PhaseIo) -> Phase { Phase::<_>::new(self, io) } } #[async_trait] impl Step for Phase { async fn step(mut self) -> TransitionOutcome { info!("sum2 task"); self = try_progress!(self.fetch_seed_dict().await); self = try_progress!(self.decrypt_seeds()); self = try_progress!(self.aggregate_masks()); let sending: Phase = self.into(); TransitionOutcome::Complete(sending.into()) } } impl From> for Phase { fn from(mut sum2: Phase) -> Self { debug!("composing sum2 message"); let message = sum2.compose_message(); debug!("going to sending phase"); let sending = Box::new(SendingSum2::new(message, Awaiting)); let state = State::new(sum2.state.shared, sending); state.into_phase(sum2.io) } } impl From> for Phase { fn from(sum2: Phase) -> Self { State::new(sum2.state.shared, Box::new(Awaiting)).into_phase(sum2.io) } } impl Phase { /// Retrieve the encrypted mask seeds. pub(crate) async fn fetch_seed_dict(mut self) -> Progress { if self.state.private.has_fetched_seed_dict() { return Progress::Continue(self); } debug!("polling for update seeds"); match self.io.get_seeds(self.state.shared.keys.public).await { Err(e) => { warn!("failed to fetch seeds: {}", e); Progress::Stuck(self) } Ok(None) => { debug!("seeds not available yet"); Progress::Stuck(self) } Ok(Some(seeds)) => { self.state.private.seed_dict = Some(seeds); Progress::Updated(self.into()) } } } /// Decrypt the mask seeds that the update participants generated. pub(crate) fn decrypt_seeds(mut self) -> Progress { if self.state.private.has_decrypted_seeds() { return Progress::Continue(self); } let keys = &self.state.private.ephm_keys; // UNWRAP_SAFE: the seed dict is set in // `self.fetch_seed_dict()` which is called before this method let seeds: Result, ()> = self .state .private .seed_dict .take() .unwrap() .into_iter() .map(|(_, seed)| seed.decrypt(&keys.public, &keys.secret).map_err(|_| ())) .collect(); match seeds { Ok(seeds) => { self.state.private.seeds = Some(seeds); Progress::Updated(self.into()) } Err(_) => { warn!("failed to decrypt mask seeds, going back to waiting phase"); self.io.notify_idle(); let awaiting: Phase = self.into(); Progress::Updated(awaiting.into()) } } } /// Derive the masks from the decrypted mask seeds, and aggregate /// them. The resulting mask will later be added to the sum2 /// message to be sent to the coordinator. pub(crate) fn aggregate_masks(mut self) -> Progress { if self.state.private.has_aggregated_masks() { return Progress::Continue(self); } info!("aggregating masks"); let config = self.state.shared.round_params.mask_config; let mask_len = self.state.shared.round_params.model_length; let mut mask_agg = Aggregation::new(config, mask_len as usize); // UNWRAP_SAFE: the seeds are set in `decrypt_seeds()` which is called before this method for seed in self.state.private.seeds.take().unwrap().into_iter() { let mask = seed.derive_mask(mask_len as usize, config); if let Err(e) = mask_agg.validate_aggregation(&mask) { error!("sum2 phase failed: cannot aggregate masks: {}", e); error!("going to awaiting phase"); let awaiting: Phase = self.into(); return Progress::Updated(awaiting.into()); } else { mask_agg.aggregate(mask); } } self.state.private.mask = Some(mask_agg.into()); Progress::Updated(self.into()) } /// Creates and encodes the sum2 message from the sum2 state. pub fn compose_message(&mut self) -> MessageEncoder { let sum2 = Sum2Message { sum_signature: self.state.private.sum_signature, // UNWRAP_SAFE: the mask set in `aggregate_masks()` which is called before this method model_mask: self.state.private.mask.take().unwrap(), }; self.message_encoder(sum2.into()) } } ================================================ FILE: rust/xaynet-sdk/src/state_machine/phases/update.rs ================================================ use std::ops::Deref; use async_trait::async_trait; use derive_more::From; use serde::{Deserialize, Serialize}; use tracing::{debug, info, warn}; use xaynet_core::{ crypto::Signature, mask::{MaskObject, MaskSeed, Masker, Model}, message::Update as UpdateMessage, LocalSeedDict, ParticipantTaskSignature, SumDict, }; use crate::{ state_machine::{ Awaiting, IntoPhase, Phase, PhaseIo, Progress, SendingUpdate, State, Step, TransitionOutcome, IO, }, MessageEncoder, }; #[derive(From)] pub enum LocalModel { Dyn(Box + Send>), Owned(Model), } impl std::fmt::Debug for LocalModel { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { match self { LocalModel::Dyn(_) => fmt.debug_tuple("LocalModel::Dyn"), LocalModel::Owned(_) => fmt.debug_tuple("LocalModel::Owned"), } .field(&"...") .finish() } } impl AsRef for LocalModel { fn as_ref(&self) -> &Model { match self { LocalModel::Dyn(model) => model.deref().as_ref(), LocalModel::Owned(model) => model, } } } impl serde::ser::Serialize for LocalModel { fn serialize(&self, serializer: S) -> Result where S: serde::ser::Serializer, { match self { LocalModel::Dyn(model) => model.as_ref().as_ref().serialize(serializer), LocalModel::Owned(model) => model.serialize(serializer), } } } impl<'de> serde::de::Deserialize<'de> for LocalModel { fn deserialize(deserializer: D) -> Result where D: serde::de::Deserializer<'de>, { let model = ::deserialize(deserializer)?; Ok(LocalModel::Owned(model)) } } /// The state of the update phase. #[derive(Serialize, Deserialize, Debug)] pub struct Update { pub sum_signature: ParticipantTaskSignature, pub update_signature: ParticipantTaskSignature, pub sum_dict: Option, pub seed_dict: Option, pub model: Option, pub mask: Option<(MaskSeed, MaskObject)>, } impl Update { /// Creates a new update state. pub fn new(sum_signature: Signature, update_signature: Signature) -> Self { Update { sum_signature, update_signature, sum_dict: None, seed_dict: None, model: None, mask: None, } } fn has_fetched_sum_dict(&self) -> bool { self.sum_dict.is_some() || self.has_loaded_model() } fn has_loaded_model(&self) -> bool { self.model.is_some() || self.has_masked_model() } fn has_masked_model(&self) -> bool { self.mask.is_some() || self.has_built_seed_dict() } fn has_built_seed_dict(&self) -> bool { self.seed_dict.is_some() } } impl IntoPhase for State { fn into_phase(self, mut io: PhaseIo) -> Phase { io.notify_update(); if !self.private.has_loaded_model() { io.notify_load_model(); } Phase::<_>::new(self, io) } } #[async_trait] impl Step for Phase { async fn step(mut self) -> TransitionOutcome { self = try_progress!(self.fetch_sum_dict().await); self = try_progress!(self.load_model().await); self = try_progress!(self.mask_model()); self = try_progress!(self.build_seed_dict()); let sending: Phase = self.into(); TransitionOutcome::Complete(sending.into()) } } impl From> for Phase { fn from(mut update: Phase) -> Self { debug!("composing update message"); let message = update.compose_message(); debug!("going to sending phase"); let sending = Box::new(SendingUpdate::new(message, Awaiting)); let state = State::new(update.state.shared, sending); state.into_phase(update.io) } } impl From> for Phase { fn from(update: Phase) -> Self { State::new(update.state.shared, Box::new(Awaiting)).into_phase(update.io) } } impl Phase { pub(crate) async fn fetch_sum_dict(mut self) -> Progress { if self.state.private.has_fetched_sum_dict() { debug!("already fetched the sum dictionary, continuing"); return Progress::Continue(self); } debug!("fetching sum dictionary"); match self.io.get_sums().await { Ok(Some(dict)) => { self.state.private.sum_dict = Some(dict); Progress::Updated(self.into()) } Ok(None) => { debug!("sum dictionary is not available yet"); Progress::Stuck(self) } Err(e) => { warn!("failed to fetch sum dictionary: {:?}", e); Progress::Stuck(self) } } } pub(crate) async fn load_model(mut self) -> Progress { if self.state.private.has_loaded_model() { debug!("already loaded the model, continuing"); return Progress::Continue(self); } debug!("loading local model"); match self.io.load_model().await { Ok(Some(model)) => { self.state.private.model = Some(model.into()); Progress::Updated(self.into()) } Ok(None) => { debug!("model is not ready"); Progress::Stuck(self) } Err(e) => { warn!("failed to load model: {:?}", e); Progress::Stuck(self) } } } /// Generate a mask seed and mask a local model. pub(crate) fn mask_model(mut self) -> Progress { if self.state.private.has_masked_model() { debug!("already computed the masked model, continuing"); return Progress::Continue(self); } info!("computing masked model"); let config = self.state.shared.round_params.mask_config; let masker = Masker::new(config); // UNWRAP_SAFE: the model is set, per the `has_masked_model()` check above let model = self.state.private.model.take().unwrap(); let scalar = self.state.shared.scalar.clone(); self.state.private.mask = Some(masker.mask(scalar, model.as_ref())); Progress::Updated(self.into()) } // Create a local seed dictionary from a sum dictionary. pub(crate) fn build_seed_dict(mut self) -> Progress { if self.state.private.has_built_seed_dict() { debug!("already built the seed dictionary, continuing"); return Progress::Continue(self); } // UNWRAP_SAFE: the mask is set in `mask_model()` which is called before this method let mask_seed = &self.state.private.mask.as_ref().unwrap().0; info!("building local seed dictionary"); let seeds = self .state .private .sum_dict .take() .unwrap() .into_iter() .map(|(pk, ephm_pk)| (pk, mask_seed.encrypt(&ephm_pk))) .collect(); self.state.private.seed_dict = Some(seeds); Progress::Updated(self.into()) } /// Creates and encodes the update message from the update state. pub fn compose_message(&mut self) -> MessageEncoder { let update = UpdateMessage { sum_signature: self.state.private.sum_signature, update_signature: self.state.private.update_signature, // UNWRAP_SAFE: the mask is set in `mask_model()` which is called before this method masked_model: self.state.private.mask.take().unwrap().1, // UNWRAP_SAFE: the dict is set in `build_seed_dict()` which is called before this method local_seed_dict: self.state.private.seed_dict.take().unwrap(), }; self.message_encoder(update.into()) } } ================================================ FILE: rust/xaynet-sdk/src/state_machine/state_machine.rs ================================================ use derive_more::From; use super::{ boxed_io, Awaiting, IntoPhase, LocalModelConfig, NewRound, Phase, SendingSum, SendingSum2, SendingUpdate, SerializableState, SharedState, State, Sum, Sum2, Update, }; use crate::{settings::PetSettings, ModelStore, Notify, XaynetClient}; /// Outcome of a state machine transition attempt. #[derive(Debug)] pub enum TransitionOutcome { /// Outcome when the state machine cannot make immediate progress. The state machine /// is returned unchanged. Pending(StateMachine), /// Outcome when a transition occured and the state machine was updated. Complete(StateMachine), } /// PET state machine. #[derive(From, Debug)] pub enum StateMachine { /// PET state machine in the "new round" phase NewRound(Phase), /// PET state machine in the "awaiting" phase Awaiting(Phase), /// PET state machine in the "sum" phase Sum(Phase), /// PET state machine in the "update" phase Update(Phase), /// PET state machine in the "sum2" phase Sum2(Phase), /// PET state machine in the "sending sum message" phase SendingSum(Phase), /// PET state machine in the "sending update message" phase SendingUpdate(Phase), /// PET state machine in the "sending sum2 message" phase SendingSum2(Phase), } impl StateMachine { /// Try to make progress in the PET protocol pub async fn transition(self) -> TransitionOutcome { match self { StateMachine::NewRound(phase) => phase.step().await, StateMachine::Awaiting(phase) => phase.step().await, StateMachine::Sum(phase) => phase.step().await, StateMachine::Update(phase) => phase.step().await, StateMachine::Sum2(phase) => phase.step().await, StateMachine::SendingSum(phase) => phase.step().await, StateMachine::SendingUpdate(phase) => phase.step().await, StateMachine::SendingSum2(phase) => phase.step().await, } } /// Convert the state machine into a serializable data structure so /// that it can be saved. pub fn save(self) -> SerializableState { match self { StateMachine::NewRound(phase) => phase.state.into(), StateMachine::Awaiting(phase) => phase.state.into(), StateMachine::Sum(phase) => phase.state.into(), StateMachine::Update(phase) => phase.state.into(), StateMachine::Sum2(phase) => phase.state.into(), StateMachine::SendingSum(phase) => phase.state.into(), StateMachine::SendingUpdate(phase) => phase.state.into(), StateMachine::SendingSum2(phase) => phase.state.into(), } } /// Return the local model configuration of the model that is expected in the update phase. pub fn local_model_config(&self) -> LocalModelConfig { match self { StateMachine::NewRound(ref phase) => phase.local_model_config(), StateMachine::Awaiting(ref phase) => phase.local_model_config(), StateMachine::Sum(ref phase) => phase.local_model_config(), StateMachine::Update(ref phase) => phase.local_model_config(), StateMachine::Sum2(ref phase) => phase.local_model_config(), StateMachine::SendingSum(ref phase) => phase.local_model_config(), StateMachine::SendingUpdate(ref phase) => phase.local_model_config(), StateMachine::SendingSum2(ref phase) => phase.local_model_config(), } } } impl StateMachine { /// Instantiate a new PET state machine. /// /// # Args /// /// - `settings`: PET settings /// - `xaynet_client`: a client for communicating with the Xaynet coordinator /// - `model_store`: a store from which the trained model can be /// loaded, when the participant is selected for the update task /// - `notifier`: a type that the state machine can use to emit notifications pub fn new( settings: PetSettings, xaynet_client: X, model_store: M, notifier: N, ) -> Self where X: XaynetClient + Send + 'static, M: ModelStore + Send + 'static, N: Notify + Send + 'static, { let io = boxed_io(xaynet_client, model_store, notifier); let state = State::new(Box::new(SharedState::new(settings)), Box::new(Awaiting)); state.into_phase(io).into() } /// Restore the PET state machine from the given `state`. pub fn restore( state: SerializableState, xaynet_client: X, model_store: M, notifier: N, ) -> Self where X: XaynetClient + Send + 'static, M: ModelStore + Send + 'static, N: Notify + Send + 'static, { let io = boxed_io(xaynet_client, model_store, notifier); match state { SerializableState::NewRound(state) => state.into_phase(io).into(), SerializableState::Awaiting(state) => state.into_phase(io).into(), SerializableState::Sum(state) => state.into_phase(io).into(), SerializableState::Sum2(state) => state.into_phase(io).into(), SerializableState::Update(state) => state.into_phase(io).into(), SerializableState::SendingSum(state) => state.into_phase(io).into(), SerializableState::SendingUpdate(state) => state.into_phase(io).into(), SerializableState::SendingSum2(state) => state.into_phase(io).into(), } } } ================================================ FILE: rust/xaynet-sdk/src/state_machine/tests/mod.rs ================================================ mod phases; pub mod utils; ================================================ FILE: rust/xaynet-sdk/src/state_machine/tests/phases/mod.rs ================================================ mod new_round; mod sum; mod sum2; mod update; ================================================ FILE: rust/xaynet-sdk/src/state_machine/tests/phases/new_round.rs ================================================ use crate::{ state_machine::{ tests::utils::{shared_state, SelectFor}, IntoPhase, MockIO, NewRound, Phase, State, }, unwrap_step, }; #[tokio::test] async fn test_selected_for_sum() { let mut io = MockIO::new(); io.expect_notify_sum().return_const(()); let phase = make_phase(SelectFor::Sum, io); unwrap_step!(phase, complete, sum); } #[tokio::test] async fn test_selected_for_update() { let mut io = MockIO::new(); io.expect_notify_update().times(1).return_const(()); io.expect_notify_load_model().times(1).return_const(()); let phase = make_phase(SelectFor::Update, io); unwrap_step!(phase, complete, update); } #[tokio::test] async fn test_not_selected() { let mut io = MockIO::new(); io.expect_notify_idle().times(1).return_const(()); let phase = make_phase(SelectFor::None, io); unwrap_step!(phase, complete, awaiting); } /// Instantiate a new round phase. /// /// - `task` is the task we want the simulated participant to be selected for. If you want a /// sum participant, pass `SelectedFor::Sum` for example. /// - `io` is the mock the test wants to use. It should contains all the test expectations. The /// reason for settings the mocked IO object in this helper is that once the phase is /// created, `phase.io` is a `Box`, not a `MockIO`. Therefore, it doesn't have any of /// the mock methods (`expect_xxx()`, `checkpoint()`, etc.) so we cannot set any expectation /// a posteriori fn make_phase(task: SelectFor, io: MockIO) -> Phase { let shared = shared_state(task); // Check IntoPhase implementation let mut mock = MockIO::new(); mock.expect_notify_new_round().times(1).return_const(()); let mut phase: Phase = State::new(shared, Box::new(NewRound)).into_phase(Box::new(mock)); // Set `phase.io` to the mock the test wants to use. Note that this drops the `mock` we // created above, so the expectations we set on `mock` run now. let _ = std::mem::replace(&mut phase.io, Box::new(io)); phase } ================================================ FILE: rust/xaynet-sdk/src/state_machine/tests/phases/sum.rs ================================================ use thiserror::Error; use xaynet_core::crypto::{ByteObject, EncryptKeyPair, EncryptKeySeed}; use crate::{ state_machine::{ tests::utils::{shared_state, SelectFor}, IntoPhase, MockIO, Phase, SharedState, State, Sum, }, unwrap_step, }; /// Instantiate a sum phase. fn make_phase(io: MockIO) -> Phase { let shared = shared_state(SelectFor::Sum); let sum = make_sum(&shared); // Check IntoPhase implementation let mut mock = MockIO::new(); mock.expect_notify_sum().times(1).return_const(()); let mut phase: Phase = State::new(shared, sum).into_phase(Box::new(mock)); // Set `phase.io` to the mock the test wants to use. Note that this drops the `mock` we // created above, so the expectations we set on `mock` run now. let _ = std::mem::replace(&mut phase.io, Box::new(io)); phase } fn make_sum(shared: &SharedState) -> Box { let ephm_keys = EncryptKeyPair::derive_from_seed(&EncryptKeySeed::zeroed()); let sk = &shared.keys.secret; let seed = shared.round_params.seed.as_slice(); let signature = sk.sign_detached(&[seed, b"sum"].concat()); Box::new(Sum { ephm_keys, sum_signature: signature, }) } #[tokio::test] async fn test_phase() { let io = MockIO::new(); let phase = make_phase(io); let _phase = unwrap_step!(phase, complete, sending_sum); } #[derive(Error, Debug)] #[error("error")] struct DummyErr; ================================================ FILE: rust/xaynet-sdk/src/state_machine/tests/phases/sum2.rs ================================================ use mockall::Sequence; use xaynet_core::{ crypto::{ByteObject, EncryptKeyPair, EncryptKeySeed, PublicEncryptKey}, mask::{FromPrimitives, MaskConfigPair, MaskObject, MaskSeed, Masker, Model, Scalar}, UpdateSeedDict, }; use crate::{ state_machine::{ tests::utils::{shared_state, SelectFor, SigningKeyGenerator}, IntoPhase, MockIO, Phase, SendingSum2, SharedState, State, Sum2, }, unwrap_progress_continue, unwrap_step, }; /// Instantiate a sum phase. fn make_phase() -> Phase { let shared = shared_state(SelectFor::Sum); let sum2 = make_sum2(&shared); // Check IntoPhase implementation let mock = MockIO::new(); let mut phase: Phase = State::new(shared, sum2).into_phase(Box::new(mock)); phase.check_io_mock(); phase } fn make_sum2(shared: &SharedState) -> Box { let ephm_keys = EncryptKeyPair::derive_from_seed(&EncryptKeySeed::zeroed()); let sk = &shared.keys.secret; let seed = shared.round_params.seed.as_slice(); let signature = sk.sign_detached(&[seed, b"sum"].concat()); Box::new(Sum2 { ephm_keys, sum_signature: signature, seed_dict: None, seeds: None, mask: None, }) } fn make_seed_dict(mask_config: MaskConfigPair, ephm_pk: PublicEncryptKey) -> UpdateSeedDict { let (seed, _mask) = make_masked_model(mask_config); let mut key_gen = SigningKeyGenerator::new(); let mut dict = UpdateSeedDict::new(); for _ in 0..4 { let pk = key_gen.next().public; dict.insert(pk, seed.encrypt(&ephm_pk)); } dict } fn make_model() -> Model { Model::from_primitives(vec![1.0, 2.0, 3.0, 4.0].into_iter()).unwrap() } fn make_masked_model(mask_config: MaskConfigPair) -> (MaskSeed, MaskObject) { let masker = Masker::new(mask_config); let scalar = Scalar::unit(); let model = make_model(); masker.mask(scalar, &model) } async fn step1_fetch_seed_dict(mut phase: Phase) -> Phase { let mask_config = phase.state.shared.round_params.mask_config; let ephm_pk = phase.state.private.ephm_keys.public; phase.with_io_mock(move |mock| { let mut seq = Sequence::new(); // The first time the state machine fetches the seed dict, // pretend it's not published yet mock.expect_get_seeds() .times(1) .in_sequence(&mut seq) .returning(|_| Ok(None)); // The second time, return it mock.expect_get_seeds() .times(1) .in_sequence(&mut seq) .returning(move |_| Ok(Some(make_seed_dict(mask_config, ephm_pk)))); }); // First time: no progress should be made, since we didn't // fetch any seed dict yet let phase = unwrap_step!(phase, pending, sum2); // Second time: now the state machine should have made progress let phase = unwrap_step!(phase, complete, sum2); // Calling `fetch_seed_dict` again should return Progress::Continue let mut phase = unwrap_progress_continue!(phase, fetch_seed_dict, async); phase.check_io_mock(); phase } async fn step2_decrypt_seeds(phase: Phase) -> Phase { let phase = unwrap_step!(phase, complete, sum2); assert!(phase.state.private.seeds.is_some()); // Make sure this steps consumes the seed dict. assert!(phase.state.private.seed_dict.is_none()); phase } async fn step3_aggregate_masks(phase: Phase) -> Phase { let phase = unwrap_step!(phase, complete, sum2); assert!(phase.state.private.mask.is_some()); // Make sure this steps consumes the seeds. assert!(phase.state.private.seeds.is_none()); phase } async fn step4_into_sending_phase(phase: Phase) -> Phase { let phase = unwrap_step!(phase, complete, sending_sum2); phase } #[tokio::test] async fn test_phase() { let phase = make_phase(); let phase = step1_fetch_seed_dict(phase).await; let phase = step2_decrypt_seeds(phase).await; let phase = step3_aggregate_masks(phase).await; let _phase = step4_into_sending_phase(phase).await; } ================================================ FILE: rust/xaynet-sdk/src/state_machine/tests/phases/update.rs ================================================ use mockall::Sequence; use xaynet_core::{ crypto::ByteObject, mask::{FromPrimitives, Model}, SumDict, }; use crate::{ save_and_restore, state_machine::{ tests::utils::{shared_state, EncryptKeyGenerator, SelectFor, SigningKeyGenerator}, IntoPhase, MockIO, Phase, SendingUpdate, SharedState, State, Update, }, unwrap_progress_continue, unwrap_step, }; /// Instantiate a sum phase. fn make_phase() -> Phase { let shared = shared_state(SelectFor::Update); let update = make_update(&shared); // Check IntoPhase implementation let mut mock = MockIO::new(); mock.expect_notify_update().times(1).return_const(()); mock.expect_notify_load_model().times(1).return_const(()); let mut phase: Phase = State::new(shared, update).into_phase(Box::new(mock)); phase.check_io_mock(); phase } fn make_update(shared: &SharedState) -> Box { let sk = &shared.keys.secret; let seed = shared.round_params.seed.as_slice(); let sum_signature = sk.sign_detached(&[seed, b"sum"].concat()); let update_signature = sk.sign_detached(&[seed, b"update"].concat()); Box::new(Update { sum_signature, update_signature, sum_dict: None, seed_dict: None, model: None, mask: None, }) } fn make_model() -> Model { let weights: Vec = vec![1.1, 2.2, 3.3, 4.4]; Model::from_primitives(weights.into_iter()).unwrap() } fn make_sum_dict() -> SumDict { let mut dict = SumDict::new(); let mut signing_keys = SigningKeyGenerator::new(); let mut encrypt_keys = EncryptKeyGenerator::new(); dict.insert(signing_keys.next().public, encrypt_keys.next().public); dict.insert(signing_keys.next().public, encrypt_keys.next().public); dict } async fn step1_fetch_sum_dict(mut phase: Phase) -> Phase { phase.with_io_mock(|mock| { let mut seq = Sequence::new(); // The first time the state machine fetches the sum dict, // pretend it's not published yet mock.expect_get_sums() .times(1) .in_sequence(&mut seq) .returning(|| Ok(None)); // The second time, return a sum dictionary. mock.expect_get_sums() .times(1) .in_sequence(&mut seq) .returning(|| Ok(Some(make_sum_dict()))); }); // First time: no progress should be made, since we didn't // fetch any sum dict yet let phase = unwrap_step!(phase, pending, update); // Second time: now the state machine should have made progress let phase = unwrap_step!(phase, complete, update); // Calling `fetch_sum_dict` again should return Progress::Continue let mut phase = unwrap_progress_continue!(phase, fetch_sum_dict, async); phase.check_io_mock(); phase } async fn step2_load_model(mut phase: Phase) -> Phase { phase.with_io_mock(|mock| { let mut seq = Sequence::new(); // The first time the state machine fetches the sum dict, // pretend it's not published yet mock.expect_load_model() .times(1) .in_sequence(&mut seq) .returning(|| Ok(None)); // The second time, return a sum dictionary. mock.expect_load_model() .times(1) .in_sequence(&mut seq) .returning(|| Ok(Some(Box::new(make_model())))); }); // First time: no progress should be made, since we didn't // load any model let phase = unwrap_step!(phase, pending, update); // Second time: now the state machine should have made progress let phase = unwrap_step!(phase, complete, update); // Calling `load_model` again should return Progress::Continue let mut phase = unwrap_progress_continue!(phase, load_model, async); phase.check_io_mock(); phase } async fn step3_mask_model(phase: Phase) -> Phase { let phase = unwrap_step!(phase, complete, update); let mut phase = unwrap_progress_continue!(phase, mask_model); phase.check_io_mock(); phase } async fn step4_build_seed_dict(phase: Phase) -> Phase { let phase = unwrap_step!(phase, complete, update); let mut phase = unwrap_progress_continue!(phase, build_seed_dict); phase.check_io_mock(); phase } async fn step5_into_sending_phase(phase: Phase) -> Phase { let phase = unwrap_step!(phase, complete, sending_update); phase } #[tokio::test] async fn test_update_phase() { let phase = make_phase(); let phase = step1_fetch_sum_dict(phase).await; let phase = step2_load_model(phase).await; let phase = step3_mask_model(phase).await; let phase = step4_build_seed_dict(phase).await; let _phase = step5_into_sending_phase(phase).await; } #[tokio::test] async fn test_save_and_restore() { let phase = make_phase(); let mut phase = step1_fetch_sum_dict(phase).await; phase.with_io_mock(|mock| { let mut seq = Sequence::new(); mock.expect_notify_update() .times(1) .in_sequence(&mut seq) .return_const(()); mock.expect_notify_load_model() .times(1) .in_sequence(&mut seq) .return_const(()); }); let phase = save_and_restore!(phase, Update); let mut phase = step2_load_model(phase).await; phase.with_io_mock(|mock| { mock.expect_notify_update().times(1).return_const(()); }); let phase = save_and_restore!(phase, Update); let mut phase = step3_mask_model(phase).await; phase.with_io_mock(|mock| { mock.expect_notify_update().times(1).return_const(()); }); let phase = save_and_restore!(phase, Update); let mut phase = step4_build_seed_dict(phase).await; phase.with_io_mock(|mock| { mock.expect_notify_update().times(1).return_const(()); }); let _phase = save_and_restore!(phase, Update); } ================================================ FILE: rust/xaynet-sdk/src/state_machine/tests/utils.rs ================================================ use xaynet_core::{ common::{RoundParameters, RoundSeed}, crypto::{ByteObject, EncryptKeyPair, EncryptKeySeed, SigningKeyPair, SigningKeySeed}, mask::{self, MaskConfig, Scalar}, }; use crate::{settings::MaxMessageSize, state_machine::SharedState}; #[macro_export] macro_rules! unwrap_as { ($e:expr, $p:path) => { match $e { $p(s) => s, x => panic!("Not a {}: {:?}", stringify!($p), x), } }; } #[macro_export] macro_rules! unwrap_step { ($phase:expr, complete, $state_machine:tt) => { unwrap_step!( $phase, $crate::state_machine::TransitionOutcome::Complete, $state_machine ) }; ($phase:expr, pending, $state_machine:tt) => { unwrap_step!( $phase, $crate::state_machine::TransitionOutcome::Pending, $state_machine ) }; ($phase:expr, $transition_outcome:path, awaiting) => { unwrap_step!( $phase, $transition_outcome, $crate::state_machine::StateMachine::Awaiting ) }; ($phase:expr, $transition_outcome:path, sum) => { unwrap_step!( $phase, $transition_outcome, $crate::state_machine::StateMachine::Sum ) }; ($phase:expr, $transition_outcome:path, sum2) => { unwrap_step!( $phase, $transition_outcome, $crate::state_machine::StateMachine::Sum2 ) }; ($phase:expr, $transition_outcome:path, update) => { unwrap_step!( $phase, $transition_outcome, $crate::state_machine::StateMachine::Update ) }; ($phase:expr, $transition_outcome:path, sending_sum) => { unwrap_step!( $phase, $transition_outcome, $crate::state_machine::StateMachine::SendingSum ) }; ($phase:expr, $transition_outcome:path, sending_update) => { unwrap_step!( $phase, $transition_outcome, $crate::state_machine::StateMachine::SendingUpdate ) }; ($phase:expr, $transition_outcome:path, sending_sum2) => { unwrap_step!( $phase, $transition_outcome, $crate::state_machine::StateMachine::SendingSum2 ) }; ($phase:expr, $transition_outcome:path, $state_machine:path) => {{ let x = $crate::unwrap_as!( $crate::state_machine::Step::step($phase).await, $transition_outcome ); $crate::unwrap_as!(x, $state_machine) }}; } #[macro_export] macro_rules! unwrap_progress_continue { ($expr:expr) => { $crate::unwrap_as!($expr, $crate::state_machine::Progress::Continue) }; ($phase:expr, $method:tt) => { unwrap_progress_continue!($phase.$method()) }; ($phase:expr, $method:tt, async) => { unwrap_progress_continue!($phase.$method().await) }; } #[macro_export] macro_rules! save_and_restore { ($phase:expr, $state:tt) => {{ let mut phase = $phase; let io_mock = std::mem::replace(&mut phase.io, Box::new(MockIO::new())); let serializable_state = Into::<$crate::state_machine::SerializableState>::into(phase); // TODO: actually serialize the state here let state = $crate::unwrap_as!( serializable_state, $crate::state_machine::SerializableState::$state ); let mut phase = $crate::state_machine::IntoPhase::<$state>::into_phase(state, io_mock); phase.check_io_mock(); phase }}; } /// Task for which the round parameters should be generated. #[derive(Debug, PartialEq, Eq)] pub enum SelectFor { /// Create round parameters that always select participants for the sum task Sum, /// Create round parameters that always select participants for the update task Update, /// Create round parameters that never select participants None, } pub fn mask_config() -> MaskConfig { MaskConfig { group_type: mask::GroupType::Prime, data_type: mask::DataType::F32, bound_type: mask::BoundType::B0, model_type: mask::ModelType::M3, } } pub fn round_params(task: SelectFor) -> RoundParameters { RoundParameters { pk: EncryptKeySeed::zeroed().derive_encrypt_key_pair().0, sum: if task == SelectFor::Sum { 1.0 } else { 0.0 }, update: if task == SelectFor::Update { 1.0 } else { 0.0 }, seed: RoundSeed::zeroed(), mask_config: mask_config().into(), model_length: 0, } } pub fn shared_state(task: SelectFor) -> Box { Box::new(SharedState { keys: SigningKeyPair::derive_from_seed(&SigningKeySeed::zeroed()), scalar: Scalar::unit(), message_size: MaxMessageSize::unlimited(), round_params: round_params(task), }) } pub struct EncryptKeyGenerator(EncryptKeySeed); impl EncryptKeyGenerator { pub fn new() -> Self { Self(EncryptKeySeed::zeroed()) } fn incr_seed(&mut self) { let mut raw = self.0.as_slice().to_vec(); for b in &mut raw { if *b < 0xff { *b += 1; return self.0 = EncryptKeySeed::from_slice(raw.as_slice()).unwrap(); } } panic!("max seed"); } pub fn next(&mut self) -> EncryptKeyPair { let keys = EncryptKeyPair::derive_from_seed(&self.0); self.incr_seed(); keys } } pub struct SigningKeyGenerator(SigningKeySeed); impl SigningKeyGenerator { pub fn new() -> Self { Self(SigningKeySeed::zeroed()) } fn incr_seed(&mut self) { let mut raw = self.0.as_slice().to_vec(); for b in &mut raw { if *b < 0xff { *b += 1; return self.0 = SigningKeySeed::from_slice(raw.as_slice()).unwrap(); } } panic!("max seed"); } pub fn next(&mut self) -> SigningKeyPair { let item = SigningKeyPair::derive_from_seed(&self.0); self.incr_seed(); item } } ================================================ FILE: rust/xaynet-sdk/src/traits.rs ================================================ use async_trait::async_trait; use xaynet_core::{ common::RoundParameters, mask::Model, SumDict, SumParticipantPublicKey, UpdateSeedDict, }; /// A trait used by the [`StateMachine`] to emit notifications upon /// certain events. /// /// [`StateMachine`]: crate::StateMachine pub trait Notify { /// Emit a notification when a new round of federated learning /// starts fn new_round(&mut self) {} /// Emit a notification when the participant has been selected for /// the sum task fn sum(&mut self) {} /// Emit a notification when the participant has been selected for /// the update task fn update(&mut self) {} /// Emit a notification when the participant is not selected for /// any task and is waiting for another round to start fn idle(&mut self) {} /// Emit a notification when the participant should populate the /// model store (see [`ModelStore`]). fn load_model(&mut self) {} } /// A trait used by the [`StateMachine`] to load the model trained by /// the participant, when it has been selected for the update task. /// /// [`StateMachine`]: crate::StateMachine #[async_trait] pub trait ModelStore { type Error: std::error::Error; type Model: AsRef + Send; /// Attempt to load the model. If the model is not yet available, /// `Ok(None)` should be returned. async fn load_model(&mut self) -> Result, Self::Error>; } /// A trait used by the [`StateMachine`] to communicate with the /// Xaynet coordinator. /// /// [`StateMachine`]: crate::StateMachine #[async_trait] pub trait XaynetClient { type Error: std::error::Error; /// Retrieve the current round parameters async fn get_round_params(&mut self) -> Result; /// Retrieve the current sum dictionary, if available. async fn get_sums(&mut self) -> Result, Self::Error>; /// Retrieve the current seed dictionary for the given sum /// participant, if available. async fn get_seeds( &mut self, pk: SumParticipantPublicKey, ) -> Result, Self::Error>; /// Retrieve the current global model, if available. async fn get_model(&mut self) -> Result, Self::Error>; /// Send an encrypted and signed PET message to the coordinator. async fn send_message(&mut self, msg: Vec) -> Result<(), Self::Error>; } ================================================ FILE: rust/xaynet-sdk/src/utils/concurrent_futures.rs ================================================ #![allow(dead_code)] use std::{ collections::VecDeque, pin::Pin, task::{Context, Poll}, }; use futures::{ stream::{FuturesUnordered, Stream}, Future, }; use tokio::task::{JoinError, JoinHandle}; /// `ConcurrentFutures` can keep a capped number of futures running concurrently, and yield their /// result as they finish. When the max number of concurrent futures is reached, new tasks are /// queued until some in-flight futures finish. pub struct ConcurrentFutures where T: Future + Send + 'static, T::Output: Send + 'static, { /// In-flight futures. running: FuturesUnordered>, /// Buffered tasks. queued: VecDeque, /// Max number of concurrent futures. max_in_flight: usize, } impl ConcurrentFutures where T: Future + Send + 'static, T::Output: Send + 'static, { pub fn new(max_in_flight: usize) -> Self { Self { running: FuturesUnordered::new(), queued: VecDeque::new(), max_in_flight, } } pub fn push(&mut self, task: T) { self.queued.push_back(task) } } impl Unpin for ConcurrentFutures where T: Future + Send + 'static, T::Output: Send + 'static, { } impl Stream for ConcurrentFutures where T: Future + Send + 'static, T::Output: Send + 'static, { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let this = self.get_mut(); while this.running.len() < this.max_in_flight { if let Some(queued) = this.queued.pop_front() { let handle = tokio::spawn(queued); this.running.push(handle); } else { break; } } Pin::new(&mut this.running).poll_next(cx) } } #[cfg(test)] mod tests { use std::time::Duration; use futures::stream::StreamExt; use tokio::time::sleep; use super::*; // this can fail in rare occasions because of polling delays #[tokio::test] async fn test() { let mut stream = ConcurrentFutures:: + Send + 'static>>>::new(2); stream.push(Box::pin(async { sleep(Duration::from_millis(10_u64)).await; 1_u8 })); stream.push(Box::pin(async { sleep(Duration::from_millis(28_u64)).await; 2_u8 })); stream.push(Box::pin(async { sleep(Duration::from_millis(8_u64)).await; 3_u8 })); stream.push(Box::pin(async { sleep(Duration::from_millis(2_u64)).await; 4_u8 })); // poll_next() hasn't been called yet so all futures are queued assert_eq!(stream.running.len(), 0); assert_eq!(stream.queued.len(), 4); // future 1 and 2 are spawned, then future 1 is ready assert_eq!(stream.next().await.unwrap().unwrap(), 1); // future 2 is pending, futures 3 and 4 are queued assert_eq!(stream.running.len(), 1); assert_eq!(stream.queued.len(), 2); // future 3 is spawned, then future 3 is ready assert_eq!(stream.next().await.unwrap().unwrap(), 3); // future 2 is pending, future 4 is queued assert_eq!(stream.running.len(), 1); assert_eq!(stream.queued.len(), 1); // future 4 is spawned, then future 4 is ready assert_eq!(stream.next().await.unwrap().unwrap(), 4); // future 2 is pending, then future 2 is ready assert_eq!(stream.next().await.unwrap().unwrap(), 2); // all futures have been resolved assert_eq!(stream.running.len(), 0); assert_eq!(stream.queued.len(), 0); } } ================================================ FILE: rust/xaynet-sdk/src/utils/mod.rs ================================================ // TODO: move to the e2e package pub mod concurrent_futures; ================================================ FILE: rust/xaynet-server/Cargo.toml ================================================ [package] name = "xaynet-server" version = "0.2.0" authors = ["Xayn Engineering "] edition = "2018" description = "The Xayn Network project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant." readme = "../../README.md" homepage = "https://xaynet.dev/" repository = "https://github.com/xaynetwork/xaynet/" license-file = "../../LICENSE" keywords = ["federated-learning", "fl", "ai", "machine-learning"] categories = ["science", "cryptography"] [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] [dependencies] anyhow = "1.0.62" async-trait = "0.1.57" base64 = "0.13.0" bincode = "1.3.3" bitflags = "1.3.2" bytes = "1.0.1" config = "0.12.0" chrono = "0.4.22" derive_more = { version = "0.99.17", default-features = false, features = [ "as_mut", "as_ref", "deref", "display", "from", "index", "index_mut", "into", ] } displaydoc = "0.2.3" futures = "0.3.24" hex = "0.4.3" http = "0.2.8" influxdb = "0.5.2" num = { version = "0.4.0", features = ["serde"] } num_enum = "0.5.7" once_cell = "1.13.1" paste = "1.0.8" rand = "0.8.5" rand_chacha = "0.3.1" serde = { version = "1.0.144", features = ["derive"] } rayon = "1.5.3" redis = { version = "0.21.6", default-features = false, features = [ "aio", "connection-manager", "script", "tokio-comp", ] } sodiumoxide = "0.2.7" structopt = "0.3.26" thiserror = "1.0.32" tokio = { version = "1.20.1", features = [ "macros", "rt-multi-thread", "signal", "sync", "net", "time", ] } tower = { version = "0.4.6", default-features = false, features = [ "buffer", "load-shed", "limit" ] } tracing = "0.1.36" tracing-futures = "0.2.5" tracing-subscriber = { version = "0.3.15", features = ["env-filter"] } validator = { version = "0.16.0", features = ["derive"] } warp = "0.3.1" xaynet-core = { path = "../xaynet-core", version = "0.2.0" } # feature: model-persistence fancy-regex = { version = "0.10.0", optional = true } rusoto_core = { version = "0.46.0", optional = true } rusoto_s3 = { version = "0.46.0", optional = true } [dev-dependencies] # We can't run tarpaulin with the flag `--test-threads=1` because it can trigger a segfault: # https://github.com/xd009642/tarpaulin/issues/317. A workaround is to use `serial_test`. mockall = "0.11.2" serial_test = "0.8.0" tokio-test = "0.4.1" tower-test = "0.4.0" [[bin]] name = "coordinator" path = "src/bin/main.rs" [features] default = [] full = ["metrics", "model-persistence", "tls"] metrics = [] model-persistence = ["fancy-regex", "rusoto_core", "rusoto_s3"] tls = ["warp/tls"] ================================================ FILE: rust/xaynet-server/src/bin/main.rs ================================================ use std::{path::PathBuf, process}; use structopt::StructOpt; use tokio::signal; use tracing::warn; use tracing_subscriber::*; #[cfg(feature = "metrics")] use xaynet_server::{metrics, settings::InfluxSettings}; use xaynet_server::{ rest::{serve, RestError}, services, settings::{LoggingSettings, RedisSettings, Settings}, state_machine::initializer::StateMachineInitializer, storage::{coordinator_storage::redis, Storage, Store}, }; #[cfg(feature = "model-persistence")] use xaynet_server::{settings::S3Settings, storage::model_storage::s3}; #[derive(Debug, StructOpt)] #[structopt(name = "Coordinator")] struct Opt { /// Path of the configuration file #[structopt(short, parse(from_os_str))] config_path: PathBuf, } #[tokio::main] async fn main() { let opt = Opt::from_args(); let settings = Settings::new(opt.config_path).unwrap_or_else(|err| { eprintln!("{}", err); process::exit(1); }); let Settings { pet: pet_settings, mask: mask_settings, api: api_settings, log: log_settings, model: model_settings, redis: redis_settings, .. } = settings; init_tracing(log_settings); // This should already called internally when instantiating the // state machine but it doesn't hurt making sure the crypto layer // is correctly initialized sodiumoxide::init().unwrap(); #[cfg(feature = "metrics")] init_metrics(settings.metrics.influxdb); let store = init_store( redis_settings, #[cfg(feature = "model-persistence")] settings.s3, ) .await; let (state_machine, requests_tx, event_subscriber) = StateMachineInitializer::new( pet_settings, mask_settings, model_settings, #[cfg(feature = "model-persistence")] settings.restore, store, ) .init() .await .expect("failed to initialize state machine"); let fetcher = services::fetchers::fetcher(&event_subscriber); let message_handler = services::messages::PetMessageHandler::new(&event_subscriber, requests_tx); tokio::select! { biased; _ = signal::ctrl_c() => {} _ = state_machine.run() => { warn!("shutting down: Service terminated"); } result = serve(api_settings, fetcher, message_handler) => { match result { Ok(()) => warn!("shutting down: REST server terminated"), Err(RestError::InvalidTlsConfig) => { warn!("shutting down: invalid TLS settings for REST server"); }, } } } } fn init_tracing(settings: LoggingSettings) { let _fmt_subscriber = FmtSubscriber::builder() .with_env_filter(settings.filter) .with_ansi(true) .init(); } #[cfg(feature = "metrics")] fn init_metrics(settings: InfluxSettings) { let recorder = metrics::Recorder::new(settings); if metrics::GlobalRecorder::install(recorder).is_err() { warn!("failed to install metrics recorder"); }; } async fn init_store( redis_settings: RedisSettings, #[cfg(feature = "model-persistence")] s3_settings: S3Settings, ) -> impl Storage { let coordinator_store = redis::Client::new(redis_settings.url) .await .expect("failed to establish a connection to Redis"); let model_store = { #[cfg(not(feature = "model-persistence"))] { xaynet_server::storage::model_storage::noop::NoOp } #[cfg(feature = "model-persistence")] { let s3 = s3::Client::new(s3_settings).expect("failed to create S3 client"); s3.create_global_models_bucket() .await .expect("failed to create bucket for global models"); s3 } }; Store::new(coordinator_store, model_store) } ================================================ FILE: rust/xaynet-server/src/examples.rs ================================================ /*! A guide to getting started with the XayNet examples. # Examples The XayNet examples code can be found under the `rust/examples` directory of the [`xaynet`](https://github.com/xaynetwork/xaynet/) repository. This Getting Started guide will cover only the general ideas around usage of the examples. Also see the source code of the individual examples themselves, which have plenty of comments. Running an example typically requires having a *coordinator* already running, which is the core component of XayNet. # Federated Learning A federated learning session over XayNet consists of two kinds of parties - a *coordinator* and (multiple) *participants*. The two parties engage in a protocol (called PET) over a series of rounds. The over-simplified idea is that in each round: 1. The coordinator makes available a *global* model, from which selected participants will train model updates (or, *local* models) to be sent back to the coordinator. 2. As a round progresses, the coordinator aggregates these updates into a new global model. From this description, it might appear that individual local models are plainly visible to the coordinator. What if sensitive data could be extracted from them? Would this not be a violation of participants' data privacy? In fact, a key point about this process is that the updates are **not** sent in the plain! Rather, they are sent encrypted (or *masked*) so that the coordinator (and by extension, XayNet) learns almost nothing about the individual updates. Yet, it is nevertheless able to aggregate them in such a way that the resulting global model is unmasked. This is essentially what is meant by federated learning that is *privacy-preserving*, and is a key feature enabled by the PET protocol. ## PET Protocol It is worth describing the protocol very briefly here, if only to better understand some of the configuration settings we will meet later. It is helpful to think of each round being divided up into several contiguous phases: **Start.** At the start of a round, the coordinator generates a collection of random *round parameters* for all participants. From these parameters, each participant is able to determine whether it is selected for the round and if so, which of the two roles it is: - *update* participants. - *sum* participants. **Sum.** In the Sum phase, sum participants send `sum` messages to the coordinator (the details of which are not so important here, but vital for computing `sum2` messages later). **Update.** In the Update phase, each update participant obtains the global model from the coordinator, trains a local model from it, masks it, and sends it to the coordinator in the form of `update` messages. The coordinator will internally aggregate these (masked) local models. **Sum2.** In the Sum2 phase, sum participants compute the sum of masks over all the local models, and sends it to the coordinator in the form of `sum2` messages. Equipped with the sum of masks, the coordinator is able to *unmask* the aggregated global model, for the next round. This short description of the protocol skips over many details, but is sufficient for the purposes of this guide. For a much more complete specification, see the [white paper](https://uploads-ssl.webflow.com/5f0c5c0bb18a279f0a62919e/5f157004da6585f299fa542b_XayNet%20Whitepaper%202.1.pdf). # Coordinator The coordinator is configurable via various settings. The project contains various ready-made configuration files that can be used, found under the `configs` directory of the repository. Typically they look something like the following (in TOML format): ```toml [api] bind_address = "127.0.0.1:8081" [pet.sum] prob = 0.1 count = { min = 1, max = 100 } time = { min = 5, max = 3600 } [pet.update] prob = 0.9 count = { min = 3, max = 10000 } time = { min = 10, max = 3600 } [pet.sum2] count = { min = 1, max = 100 } time = { min = 5, max = 3600 } [mask] group_type = "Prime" data_type = "F32" bound_type = "B0" model_type = "M3" [model] length = 4 ``` The actual files contain more settings than this, but we mention just the selection above because they will be the most relevant for this guide. ## Settings Going from the top, the [`ApiSettings`] include the address the coordinator should listen on for requests from participants. This address should be known to all participants. Optionally, it also contains configurations for TLS server and client authentication. The [`PetSettings`] specify various parameters of the PET protocol: - The most important are [`sum.prob`] and [`update.prob`], which are the probabilities assigned to the selection of sum and update participants, respectively (note that if a participant is selected for both roles, the *sum* role takes precedence). - The settings [`sum.count.min`], [`update.count.min`] and [`sum2.count.min`] specify, respectively, the minimum number of `sum`, `update` and `sum2` messages the coordinator should accept. Similarly, the [`sum.count.max`], [`update.count.max`] and [`sum2.count.max`] specify the maximum number of `sum`, `update` and `sum2` messages the coordinator should accept. - To complement, the settings [`sum.time.min`], [`update.time.min`] and [`sum2.time.min`] specify, respectively, the minimum amount of time (in seconds) the coordinator should wait for `sum`, `update` and `sum2` messages. To allow for more messages to be processed, increase these times. Similarly, the [`sum.time.max`], [`update.time.max`] and [`sum2.time.max`] specify the maximum amount of time (in seconds) the coordinator should wait for `sum`, `update` and `sum2` messages. The [`MaskSettings`] determines the masking configuration, consisting of the group type, data type, bound type and model type. The [`ModelSettings`] specify the length of the model used. Both of these settings should be decided in advance with participants, and agreed upon by both. ## Running The coordinator can be run as follows: ```text $ git clone git://github.com/xaynetwork/xaynet $ cd xaynet/rust $ cargo run --bin coordinator -- -c ../configs/config.toml ``` ## Running participants You can run the example from the xaynet repository: ```text $ git clone https://github.com/xaynetwork/xaynet $ cf xaynet/rust/examples $ RUST_LOG=info cargo run --example test-drive -- -n 10 ``` [`ApiSettings`]: crate::settings::ApiSettings [`PetSettings`]: crate::settings::PetSettings [`sum.prob`]: crate::settings::PetSettingsSum::prob [`update.prob`]: crate::settings::PetSettingsUpdate::prob [`sum.count.min`]: crate::settings::PetSettingsSum::count [`update.count.min`]: crate::settings::PetSettingsUpdate::count [`sum2.count.min`]: crate::settings::PetSettingsSum2::count [`sum.count.max`]: crate::settings::PetSettingsSum::count [`update.count.max`]: crate::settings::PetSettingsUpdate::count [`sum2.count.max`]: crate::settings::PetSettingsSum2::count [`sum.time.min`]: crate::settings::PetSettingsSum::time [`update.time.min`]: crate::settings::PetSettingsUpdate::time [`sum2.time.min`]: crate::settings::PetSettingsSum2::time [`sum.time.max`]: crate::settings::PetSettingsSum::time [`update.time.max`]: crate::settings::PetSettingsUpdate::time [`sum2.time.max`]: crate::settings::PetSettingsSum2::time [`MaskSettings`]: crate::settings::MaskSettings [`ModelSettings`]: crate::settings::ModelSettings */ ================================================ FILE: rust/xaynet-server/src/lib.rs ================================================ #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr( doc, forbid(rustdoc::broken_intra_doc_links, rustdoc::private_intra_doc_links) )] #![doc( html_logo_url = "https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/xaynet_banner.png", html_favicon_url = "https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/favicon.png", issue_tracker_base_url = "https://github.com/xaynetwork/xaynet/issues" )] //! `xaynet_server` is a backend for federated machine learning. It //! ensures the users privacy using the _Privacy-Enhancing Technology_ //! (PET). Download the [whitepaper] for an introduction to the //! protocol. //! //! [whitepaper]: https://uploads-ssl.webflow.com/5f0c5c0bb18a279f0a62919e/5f157004da6585f299fa542b_XayNet%20Whitepaper%202.1.pdf pub mod examples; pub mod metrics; pub mod rest; pub mod services; pub mod settings; pub mod state_machine; pub mod storage; ================================================ FILE: rust/xaynet-server/src/metrics/mod.rs ================================================ //! Utils to record metrics. pub mod recorders; use once_cell::sync::OnceCell; pub use self::recorders::influxdb::{Measurement, Recorder, Tags}; static RECORDER: OnceCell = OnceCell::new(); /// A wrapper around a static global metrics/events recorder. pub struct GlobalRecorder; impl GlobalRecorder { /// Gets the reference to the global recorder. /// /// Returns `None` if no recorder is set or is currently being initialized. /// This method never blocks. pub fn global() -> Option<&'static Recorder> { RECORDER.get() } /// Installs a new global recorder. /// /// Returns Err(Recorder) if a recorder has already been set. pub fn install(recorder: Recorder) -> Result<(), Recorder> { RECORDER.set(recorder) } } /// Records an event. /// /// # Example /// /// ```compile_fail /// // An event with just a title: /// event!("Error"); /// /// // An event with a title and a description: /// event!("Error", "something went wrong"); /// /// // An event with a title, a description and tags: /// event!( /// "Error", /// "something went wrong", /// ["phase error", "coordinator"], /// ); /// ``` #[macro_export] macro_rules! event { ($title: expr $(,)?) => { if let Some(recorder) = crate::metrics::GlobalRecorder::global() { recorder.event::<_, _, &str, _, &[_], &str>($title, None, None); } }; ($title: expr, $description: expr $(,)?) => { if let Some(recorder) = crate::metrics::GlobalRecorder::global() { recorder.event::<_, _, _, _, &[_], &str>($title, $description, None); } }; ($title: expr, $description: expr, [$($tags: expr),+] $(,)?) => { if let Some(recorder) = crate::metrics::GlobalRecorder::global() { recorder.event($title, $description, [$($tags),+]) } }; } /// Records a metric. /// /// # Example /// /// ```compile_fail /// // A basic metric: /// metric!(Measurement::RoundTotalNumber, 1); /// /// // A metric with one tag: /// metric!(Measurement::RoundParamSum, 0.7, ("round_id", 1)); /// /// // A metric with multiple tags: /// metric!( /// Measurement::RoundParamSum, /// 0.7, /// ("round_id", 1), /// ("phase", 2), /// ); /// ``` #[macro_export] macro_rules! metric { ($measurement: expr, $value: expr $(,)?) => { if let Some(recorder) = crate::metrics::GlobalRecorder::global() { recorder.metric::<_, _, crate::metrics::Tags>($measurement, $value, None); } }; ($measurement: expr, $value: expr, $(($tag: expr, $val: expr)),+ $(,)?) => { if let Some(recorder) = crate::metrics::GlobalRecorder::global() { let mut tags = crate::metrics::Tags::new(); $( tags.add($tag, $val); )+ recorder.metric($measurement, $value, tags); } }; } ================================================ FILE: rust/xaynet-server/src/metrics/recorders/influxdb/dispatcher.rs ================================================ use super::models::{Event, Metric}; use derive_more::From; use futures::future::BoxFuture; use influxdb::{Client as InfluxClient, WriteQuery}; use std::task::{Context, Poll}; use tower::Service; use tracing::debug; #[derive(From)] pub(in crate::metrics) enum Request { Metric(Metric), Event(Event), } impl From for WriteQuery { fn from(req: Request) -> Self { match req { Request::Metric(metric) => metric.into(), Request::Event(event) => event.into(), } } } #[derive(Clone)] pub(in crate::metrics) struct Dispatcher { client: InfluxClient, } impl Dispatcher { pub fn new(url: impl Into, database: impl Into) -> Self { let client = InfluxClient::new(url, database); Self { client } } } impl Service for Dispatcher { type Response = (); type Error = anyhow::Error; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, req: Request) -> Self::Future { let client = self.client.clone(); let fut = async move { debug!("dispatch metric"); client .query(&WriteQuery::from(req)) .await .map_err(|err| anyhow::anyhow!("failed to dispatch metric {}", err))?; Ok(()) }; Box::pin(fut) } } #[cfg(test)] mod tests { use tokio_test::assert_ready; use tower_test::mock::Spawn; use super::*; use crate::{ metrics::{ recorders::influxdb::models::{Event, Metric}, Measurement, }, settings::InfluxSettings, }; fn influx_settings() -> InfluxSettings { InfluxSettings { url: "http://127.0.0.1:8086".to_string(), db: "metrics".to_string(), } } #[tokio::test] #[ignore] async fn integration_dispatch_metric() { let settings = influx_settings(); let mut task = Spawn::new(Dispatcher::new(settings.url, settings.db)); let metric = Metric::new(Measurement::Phase, 1); assert_ready!(task.poll_ready()).unwrap(); let resp = task.call(metric.into()).await; assert!(resp.is_ok()); } #[tokio::test] #[ignore] async fn integration_dispatch_event() { let settings = influx_settings(); let mut task = Spawn::new(Dispatcher::new(settings.url, settings.db)); let event = Event::new("event"); assert_ready!(task.poll_ready()).unwrap(); let resp = task.call(event.into()).await; assert!(resp.is_ok()); } #[tokio::test] #[ignore] async fn integration_wrong_url() { let settings = influx_settings(); let mut task = Spawn::new(Dispatcher::new("http://127.0.0.1:9998", settings.db)); let event = Event::new("event"); assert_ready!(task.poll_ready()).unwrap(); let resp = task.call(event.into()).await; assert!(resp.is_err()); } } ================================================ FILE: rust/xaynet-server/src/metrics/recorders/influxdb/mod.rs ================================================ mod dispatcher; mod models; mod recorder; mod service; pub(in crate::metrics) use self::{ dispatcher::{Dispatcher, Request}, models::{Event, Metric}, service::InfluxDbService, }; pub use self::{ models::{Measurement, Tags}, recorder::Recorder, }; ================================================ FILE: rust/xaynet-server/src/metrics/recorders/influxdb/models.rs ================================================ use std::{borrow::Borrow, iter::IntoIterator}; use chrono::{DateTime, Utc}; use influxdb::{InfluxDbWriteable, Timestamp, Type, WriteQuery}; /// An enum that contains all supported measurements. pub enum Measurement { RoundParamSum, RoundParamUpdate, Phase, MasksTotalNumber, RoundTotalNumber, MessageAccepted, MessageDiscarded, MessageRejected, } impl From for &'static str { fn from(measurement: Measurement) -> &'static str { match measurement { Measurement::RoundParamSum => "round_param_sum", Measurement::RoundParamUpdate => "round_param_update", Measurement::Phase => "phase", Measurement::MasksTotalNumber => "masks_total_number", Measurement::RoundTotalNumber => "round_total_number", Measurement::MessageAccepted => "message_accepted", Measurement::MessageDiscarded => "message_discarded", Measurement::MessageRejected => "message_rejected", } } } impl From for String { fn from(measurement: Measurement) -> Self { <&str>::from(measurement).into() } } /// A container that contains the tags of a metric. pub struct Tags(Vec<(String, Type)>); impl Tags { /// Creates a new empty container for tags. pub fn new() -> Self { Self(Vec::new()) } /// Adds a tag to the metric. pub fn add(&mut self, tag: impl Into, value: impl Into) { self.0.push((tag.into(), value.into())) } } impl Default for Tags { fn default() -> Self { Self::new() } } impl IntoIterator for Tags { type Item = as IntoIterator>::Item; type IntoIter = as IntoIterator>::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() } } /// A metrics data point. pub(in crate::metrics) struct Metric { name: Measurement, time: DateTime, value: Type, tags: Option, } impl Metric { pub(in crate::metrics) fn new(measurement: Measurement, value: impl Into) -> Self { Self { name: measurement, time: Utc::now(), value: value.into(), tags: None, } } pub(in crate::metrics) fn with_tags(mut self, tags: T) -> Self where T: Into>, I: Into, { // It is by design that this function should only be called once. // see `Recorder::metric` // Therefore, we don't cover the case where we would extend `self.tags` // when `self.tags` already contains tags. self.tags = tags.into().map(Into::into); self } } impl From for WriteQuery { fn from(metric: Metric) -> Self { let mut query = Timestamp::from(metric.time).into_query(metric.name); query = query.add_field("value", metric.value); if let Some(tags) = metric.tags { for (tag, value) in tags { query = query.add_tag(tag, value); } } query } } /// An event data point. pub(in crate::metrics) struct Event { name: &'static str, time: DateTime, title: String, description: Option, tags: Option, } impl Event { pub(in crate::metrics) fn new(title: impl Into) -> Self { Self { name: "event", time: Utc::now(), title: title.into(), description: None, tags: None, } } pub(in crate::metrics) fn with_description(mut self, description: D) -> Self where D: Into>, S: Into, { self.description = description.into().map(Into::into); self } pub(in crate::metrics) fn with_tags(mut self, tags: T) -> Self where T: Into>, A: AsRef<[B]>, B: Borrow, { // It is by design that this function should only be called once. // see `Recorder::metric` // Therefore, we don't cover the case where we would extend `self.tags` // when `self.tags` already contains tags. self.tags = tags.into().map(|tags| tags.as_ref().join(",")); self } } impl From for WriteQuery { fn from(event: Event) -> Self { let mut query = Timestamp::from(event.time).into_query(event.name); query = query.add_field("title", event.title); if let Some(description) = event.description { query = query.add_field("description", description); } if let Some(tags) = event.tags { query = query.add_field("tags", tags); } query } } #[cfg(test)] mod tests { use influxdb::Query; use super::*; /// Creates key-value tags for metrics. macro_rules! tags { ($(($tag: expr, $val: expr)),+ $(,)?) => { { let mut tags = crate::metrics::Tags::new(); $( tags.add($tag, $val); )+ tags } }; } #[test] fn test_basic_metric() { let metric = Metric::new(Measurement::Phase, 1); assert!(WriteQuery::from(metric) .build() .unwrap() .get() .starts_with("phase value=1i ")) } #[test] fn test_metric_with_tag() { let metric = Metric::new(Measurement::Phase, 1).with_tags(tags![("key", 42)]); assert!(WriteQuery::from(metric) .build() .unwrap() .get() .starts_with("phase,key=42 value=1i ")) } #[test] fn test_metric_with_tags() { let metric = Metric::new(Measurement::Phase, 1).with_tags(tags![ ("key_1", 42), ("key_2", "42"), ("key_3", 1.0f32), ]); assert!(WriteQuery::from(metric) .build() .unwrap() .get() .starts_with("phase,key_1=42,key_2=42,key_3=1 value=1i ")) } #[test] fn test_basic_event() { let event = Event::new("error"); assert!(WriteQuery::from(event) .build() .unwrap() .get() .starts_with("event title=\"error\" ")) } #[test] fn test_event_with_description() { let event = Event::new("error").with_description("description"); assert!(WriteQuery::from(event) .build() .unwrap() .get() .starts_with("event title=\"error\",description=\"description\" ")) } #[test] fn test_event_with_description_and_tag() { let event = Event::new("error") .with_description("description") .with_tags(["tag"]); assert!(WriteQuery::from(event) .build() .unwrap() .get() .starts_with("event title=\"error\",description=\"description\",tags=\"tag\" ")) } #[test] fn test_event_with_description_and_tags() { let event = Event::new("error") .with_description("description") .with_tags(["tag_1", "tag_2"]); assert!(WriteQuery::from(event) .build() .unwrap() .get() .starts_with("event title=\"error\",description=\"description\",tags=\"tag_1,tag_2\" ")) } #[test] fn test_event_with_tag() { let event = Event::new("error").with_tags(["tag"]); assert!(WriteQuery::from(event) .build() .unwrap() .get() .starts_with("event title=\"error\",tags=\"tag\" ")) } } ================================================ FILE: rust/xaynet-server/src/metrics/recorders/influxdb/recorder.rs ================================================ use std::borrow::Borrow; use futures::future::poll_fn; use influxdb::Type; use tower::Service; use tracing::{error, warn}; use super::{Dispatcher, Event, InfluxDbService, Measurement, Metric, Request, Tags}; use crate::settings::InfluxSettings; /// An InfluxDB metrics / events recorder. pub struct Recorder { /// A services that dispatches the recorded metrics / events to an InfluxDB instance. service: InfluxDbService, } impl Recorder { /// Creates a new InfluxDB recorder. pub fn new(settings: InfluxSettings) -> Self { let dispatcher = Dispatcher::new(settings.url, settings.db); Self { service: InfluxDbService::new(dispatcher), } } /// Records a new metric and dispatches it to an InfluxDB instance. pub fn metric(&self, measurement: Measurement, value: V, tags: T) where V: Into, T: Into>, I: Into, { let metric = Metric::new(measurement, value).with_tags(tags); self.call(metric.into()); } /// Records a new event and dispatches it to an InfluxDB instance. pub fn event(&self, title: H, description: D, tags: T) where H: Into, D: Into>, S: Into, T: Into>, A: AsRef<[B]>, B: Borrow, { let event = Event::new(title) .with_description(description) .with_tags(tags); self.call(event.into()); } fn call(&self, req: Request) { let mut handle = self.service.0.clone(); tokio::spawn(async move { if let Err(err) = poll_fn(|cx| handle.poll_ready(cx)).await { error!("influx service temporarily unavailable: {}", err) } if let Err(err) = handle.call(req).await { warn!("influx service error: {}", err) } }); } } ================================================ FILE: rust/xaynet-server/src/metrics/recorders/influxdb/service.rs ================================================ use super::{Dispatcher, Request}; use tower::{buffer::Buffer, limit::ConcurrencyLimit, load_shed::LoadShed, ServiceBuilder}; pub(in crate::metrics) struct InfluxDbService( pub LoadShed, Request>>, ); impl InfluxDbService { pub fn new(dispatcher: Dispatcher) -> Self { let service = ServiceBuilder::new() .load_shed() .buffer(4048) .concurrency_limit(50) .service(dispatcher); Self(service) } } ================================================ FILE: rust/xaynet-server/src/metrics/recorders/mod.rs ================================================ pub mod influxdb; ================================================ FILE: rust/xaynet-server/src/rest.rs ================================================ //! A HTTP API for the PET protocol interactions. use std::convert::Infallible; #[cfg(feature = "tls")] use std::path::PathBuf; use bytes::Bytes; use serde::{Deserialize, Serialize}; use thiserror::Error; use tracing::{error, warn}; use warp::{ http::{Response, StatusCode}, reply::Reply, Filter, }; #[cfg(feature = "tls")] use warp::{Server, TlsServer}; use crate::{ services::{fetchers::Fetcher, messages::PetMessageHandler}, settings::ApiSettings, }; use xaynet_core::{crypto::ByteObject, ParticipantPublicKey}; #[derive(Deserialize, Serialize)] struct SeedDictQuery { pk: String, } /// Starts a HTTP server at the given address, listening to GET requests for /// data and POST requests containing PET messages. /// /// * `api_settings`: address of the server and optional certificate and key for TLS server /// authentication as well as trusted anchors for TLS client authentication. /// * `fetcher`: fetcher for responding to data requests. /// * `pet_message_handler`: handler for responding to PET messages. /// /// # Errors /// Fails if the TLS settings are invalid. pub async fn serve( api_settings: ApiSettings, fetcher: F, pet_message_handler: PetMessageHandler, ) -> Result<(), RestError> where F: Fetcher + Sync + Send + 'static + Clone, { let message = warp::path!("message") .and(warp::post()) .and(warp::body::bytes()) .and(with_message_handler(pet_message_handler.clone())) .and_then(handle_message); let sum_dict = warp::path!("sums") .and(warp::get()) .and(with_fetcher(fetcher.clone())) .and_then(handle_sums); let seed_dict = warp::path!("seeds") .and(warp::get()) .and(warp::query::()) .and_then(part_pk) .and(with_fetcher(fetcher.clone())) .and_then(handle_seeds); let round_params = warp::path!("params") .and(warp::get()) .and(with_fetcher(fetcher.clone())) .and_then(handle_params); let model = warp::path!("model") .and(warp::get()) .and(with_fetcher(fetcher.clone())) .and_then(handle_model); let routes = message .or(round_params) .or(sum_dict) .or(seed_dict) .or(model) .recover(handle_reject) .with(warp::log("http")); #[cfg(not(feature = "tls"))] return run_http(routes, api_settings) .await .map_err(RestError::from); #[cfg(feature = "tls")] return run_https(routes, api_settings).await; } /// Handles and responds to a PET message. async fn handle_message( body: Bytes, mut handler: PetMessageHandler, ) -> Result { let _ = handler.handle_message(body.to_vec()).await.map_err(|e| { warn!("failed to handle message: {:?}", e); }); Ok(warp::reply()) } /// Handles and responds to a request for the sum dictionary. async fn handle_sums(mut fetcher: F) -> Result { Ok(match fetcher.sum_dict().await { Err(e) => { warn!("failed to handle sum dict request: {:?}", e); Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(Vec::new()) .unwrap() } Ok(None) => Response::builder() .status(StatusCode::NO_CONTENT) .body(Vec::new()) .unwrap(), Ok(Some(dict)) => { let bytes = bincode::serialize(dict.as_ref()).unwrap(); Response::builder() .header("Content-Type", "application/octet-stream") .status(StatusCode::OK) .body(bytes) .unwrap() } }) } /// Handles and responds to a request for the seed dictionary. async fn handle_seeds( pk: ParticipantPublicKey, mut fetcher: F, ) -> Result { Ok(match fetcher.seed_dict().await { Err(e) => { warn!("failed to handle seed dict request: {:?}", e); Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(Vec::new()) .unwrap() } Ok(Some(dict)) if dict.get(&pk).is_some() => { let bytes = bincode::serialize(dict.as_ref().get(&pk).unwrap()).unwrap(); Response::builder() .header("Content-Type", "application/octet-stream") .status(StatusCode::OK) .body(bytes) .unwrap() } _ => Response::builder() .status(StatusCode::NO_CONTENT) .body(Vec::new()) .unwrap(), }) } /// Handles and responds to a request for the global model. async fn handle_model(mut fetcher: F) -> Result { Ok(match fetcher.model().await { Ok(Some(model)) => Response::builder() .status(StatusCode::OK) .body(bincode::serialize(model.as_ref()).unwrap()) .unwrap(), Ok(None) => Response::builder() .status(StatusCode::NO_CONTENT) .body(Vec::new()) .unwrap(), Err(e) => { warn!("failed to handle model request: {:?}", e); Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(Vec::new()) .unwrap() } }) } /// Handles and responds to a request for the round parameters. async fn handle_params(mut fetcher: F) -> Result { Ok(match fetcher.round_params().await { Ok(params) => Response::builder() .status(StatusCode::OK) .body(bincode::serialize(¶ms).unwrap()) .unwrap(), Err(e) => { warn!("failed to handle round parameters request: {:?}", e); Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(Vec::new()) .unwrap() } }) } /// Converts a PET message handler into a `warp` filter. fn with_message_handler( handler: PetMessageHandler, ) -> impl Filter + Clone { warp::any().map(move || handler.clone()) } /// Converts a data fetcher into a `warp` filter. fn with_fetcher( fetcher: F, ) -> impl Filter + Clone { warp::any().map(move || fetcher.clone()) } /// Extracts a participant public key from the url query string async fn part_pk(query: SeedDictQuery) -> Result { match base64::decode(query.pk.as_bytes()) { Ok(bytes) => { if let Some(pk) = ParticipantPublicKey::from_slice(&bytes[..]) { Ok(pk) } else { Err(warp::reject::custom(InvalidPublicKey)) } } Err(_) => Err(warp::reject::custom(InvalidPublicKey)), } } #[derive(Debug)] struct InvalidPublicKey; impl warp::reject::Reject for InvalidPublicKey {} /// Handles `warp` rejections of bad requests. async fn handle_reject(err: warp::Rejection) -> Result { let code = if err.is_not_found() { StatusCode::NOT_FOUND } else if let Some(InvalidPublicKey) = err.find() { StatusCode::BAD_REQUEST } else { error!("unhandled rejection: {:?}", err); StatusCode::INTERNAL_SERVER_ERROR }; // reply with empty body; the status code is the interesting part Ok(warp::reply::with_status(Vec::new(), code)) } #[derive(Debug, Error)] /// Errors of the rest server. pub enum RestError { #[error("invalid TLS configuration was provided")] InvalidTlsConfig, } impl From for RestError { fn from(infallible: Infallible) -> RestError { match infallible {} } } #[cfg(feature = "tls")] /// Configures a server for TLS server and client authentication. /// /// # Errors /// Fails if the TLS settings are invalid. fn configure_tls( server: Server, tls_certificate: Option, tls_key: Option, tls_client_auth: Option, ) -> Result, RestError> where F: Filter + Clone + Send + Sync + 'static, F::Extract: Reply, { if tls_certificate.is_none() && tls_key.is_none() && tls_client_auth.is_none() { return Err(RestError::InvalidTlsConfig); } let mut server = server.tls(); match (tls_certificate, tls_key) { (Some(cert), Some(key)) => server = server.cert_path(cert).key_path(key), (None, None) => {} _ => return Err(RestError::InvalidTlsConfig), } if let Some(trust_anchor) = tls_client_auth { server = server.client_auth_required_path(trust_anchor); } Ok(server) } #[cfg(not(feature = "tls"))] /// Runs a server with the provided filter routes. async fn run_http(filter: F, api_settings: ApiSettings) -> Result<(), Infallible> where F: Filter + Clone + Send + Sync + 'static, F::Extract: Reply, { warp::serve(filter).run(api_settings.bind_address).await; Ok(()) } #[cfg(feature = "tls")] /// Runs a TLS server with the provided filter routes. /// /// # Errors /// Fails if the TLS settings are invalid. async fn run_https(filter: F, api_settings: ApiSettings) -> Result<(), RestError> where F: Filter + Clone + Send + Sync + 'static, F::Extract: Reply, { configure_tls( warp::serve(filter), api_settings.tls_certificate, api_settings.tls_key, api_settings.tls_client_auth, )? .run(api_settings.bind_address) .await; Ok(()) } ================================================ FILE: rust/xaynet-server/src/services/fetchers/mod.rs ================================================ //! This module provides the services for serving data. //! //! There are multiple such services and the [`Fetcher`] trait //! provides a single unifying interface for all of these. mod model; mod round_parameters; mod seed_dict; mod sum_dict; use std::task::{Context, Poll}; use async_trait::async_trait; use futures::future::poll_fn; use tower::{layer::Layer, Service, ServiceBuilder}; pub use self::{ model::{ModelRequest, ModelResponse, ModelService}, round_parameters::{RoundParamsRequest, RoundParamsResponse, RoundParamsService}, seed_dict::{SeedDictRequest, SeedDictResponse, SeedDictService}, sum_dict::{SumDictRequest, SumDictResponse, SumDictService}, }; use crate::state_machine::events::EventSubscriber; /// A single interface for retrieving data from the coordinator. #[async_trait] pub trait Fetcher { /// Fetch the parameters for the current round async fn round_params(&mut self) -> Result; /// Fetch the latest global model. async fn model(&mut self) -> Result; /// Fetch the global seed dictionary. Each sum2 participant needs a /// different portion of that dictionary. async fn seed_dict(&mut self) -> Result; /// Fetch the sum dictionary. The update participants need this /// dictionary to encrypt their masking seed for each sum /// participant. async fn sum_dict(&mut self) -> Result; } /// An error returned by the [`Fetcher`]'s method. pub type FetchError = anyhow::Error; fn into_fetch_error>>( e: E, ) -> FetchError { anyhow::anyhow!("Fetcher failed: {:?}", e.into()) } #[async_trait] impl Fetcher for Fetchers where Self: Send + Sync + 'static, RoundParams: Service + Send + 'static, >::Future: Send + Sync + 'static, >::Error: Into>, Model: Service + Send + 'static, >::Future: Send + Sync + 'static, >::Error: Into>, SeedDict: Service + Send + 'static, >::Future: Send + Sync + 'static, >::Error: Into>, SumDict: Service + Send + 'static, >::Future: Send + Sync + 'static, >::Error: Into>, { async fn round_params(&mut self) -> Result { poll_fn(|cx| { >::poll_ready(&mut self.round_params, cx) }) .await .map_err(into_fetch_error)?; Ok(>::call( &mut self.round_params, RoundParamsRequest, ) .await .map_err(into_fetch_error)?) } async fn model(&mut self) -> Result { poll_fn(|cx| >::poll_ready(&mut self.model, cx)) .await .map_err(into_fetch_error)?; Ok( >::call(&mut self.model, ModelRequest) .await .map_err(into_fetch_error)?, ) } async fn seed_dict(&mut self) -> Result { poll_fn(|cx| >::poll_ready(&mut self.seed_dict, cx)) .await .map_err(into_fetch_error)?; Ok( >::call(&mut self.seed_dict, SeedDictRequest) .await .map_err(into_fetch_error)?, ) } async fn sum_dict(&mut self) -> Result { poll_fn(|cx| >::poll_ready(&mut self.sum_dict, cx)) .await .map_err(into_fetch_error)?; Ok( >::call(&mut self.sum_dict, SumDictRequest) .await .map_err(into_fetch_error)?, ) } } pub(in crate::services) struct FetcherService(S); impl Service for FetcherService where S: Service, { type Response = S::Response; type Error = S::Error; type Future = S::Future; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.0.poll_ready(cx) } fn call(&mut self, req: R) -> Self::Future { self.0.call(req) } } pub(in crate::services) struct FetcherLayer; impl Layer for FetcherLayer { type Service = FetcherService; fn layer(&self, service: S) -> Self::Service { FetcherService(service) } } #[derive(Debug, Clone)] pub struct Fetchers { round_params: RoundParams, sum_dict: SumDict, seed_dict: SeedDict, model: Model, } impl Fetchers { pub fn new( round_params: RoundParams, sum_dict: SumDict, seed_dict: SeedDict, model: Model, ) -> Self { Self { round_params, sum_dict, seed_dict, model, } } } /// Construct a [`Fetcher`] service pub fn fetcher(event_subscriber: &EventSubscriber) -> impl Fetcher + Sync + Send + Clone + 'static { let round_params = ServiceBuilder::new() .buffer(100) .concurrency_limit(100) .layer(FetcherLayer) .service(RoundParamsService::new(event_subscriber)); let model = ServiceBuilder::new() .buffer(100) .concurrency_limit(100) .layer(FetcherLayer) .service(ModelService::new(event_subscriber)); let sum_dict = ServiceBuilder::new() .buffer(100) .concurrency_limit(100) .layer(FetcherLayer) .service(SumDictService::new(event_subscriber)); let seed_dict = ServiceBuilder::new() .buffer(100) .concurrency_limit(100) .layer(FetcherLayer) .service(SeedDictService::new(event_subscriber)); Fetchers::new(round_params, sum_dict, seed_dict, model) } ================================================ FILE: rust/xaynet-server/src/services/fetchers/model.rs ================================================ use std::{ sync::Arc, task::{Context, Poll}, }; use futures::future::{self, Ready}; use tower::Service; use tracing::error_span; use tracing_futures::{Instrument, Instrumented}; use crate::state_machine::events::{EventListener, EventSubscriber, ModelUpdate}; use xaynet_core::mask::Model; /// [`ModelService`]'s request type #[derive(Default, Clone, Eq, PartialEq, Debug)] pub struct ModelRequest; /// [`ModelService`]'s response type. /// /// The response is `None` when no model is currently available. pub type ModelResponse = Option>; /// A service that serves the latest available global model pub struct ModelService(EventListener); impl ModelService { pub fn new(events: &EventSubscriber) -> Self { Self(events.model_listener()) } } impl Service for ModelService { type Response = ModelResponse; type Error = std::convert::Infallible; type Future = Instrumented>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, _req: ModelRequest) -> Self::Future { future::ready(match self.0.get_latest().event { ModelUpdate::Invalidate => Ok(None), ModelUpdate::New(model) => Ok(Some(model)), }) .instrument(error_span!("model_fetch_request")) } } ================================================ FILE: rust/xaynet-server/src/services/fetchers/round_parameters.rs ================================================ use std::task::{Context, Poll}; use futures::future::{self, Ready}; use tower::Service; use tracing::error_span; use tracing_futures::{Instrument, Instrumented}; use crate::state_machine::events::{EventListener, EventSubscriber}; use xaynet_core::common::RoundParameters; /// [`RoundParamsService`]'s request type #[derive(Default, Clone, Eq, PartialEq, Debug)] pub struct RoundParamsRequest; /// [`RoundParamsService`]'s response type pub type RoundParamsResponse = RoundParameters; /// A service that serves the round parameters for the current round. pub struct RoundParamsService(EventListener); impl RoundParamsService { pub fn new(events: &EventSubscriber) -> Self { Self(events.params_listener()) } } impl Service for RoundParamsService { type Response = RoundParameters; type Error = std::convert::Infallible; type Future = Instrumented>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, _req: RoundParamsRequest) -> Self::Future { future::ready(Ok(self.0.get_latest().event)) .instrument(error_span!("round_params_fetch_request")) } } ================================================ FILE: rust/xaynet-server/src/services/fetchers/seed_dict.rs ================================================ use std::{ sync::Arc, task::{Context, Poll}, }; use futures::future::{self, Ready}; use tower::Service; use tracing::error_span; use tracing_futures::{Instrument, Instrumented}; use crate::state_machine::events::{DictionaryUpdate, EventListener, EventSubscriber}; use xaynet_core::SeedDict; /// A service that serves the seed dictionary for the current round. pub struct SeedDictService(EventListener>); impl SeedDictService { pub fn new(events: &EventSubscriber) -> Self { Self(events.seed_dict_listener()) } } /// [`SeedDictService`]'s request type #[derive(Default, Clone, Eq, PartialEq, Debug)] pub struct SeedDictRequest; /// [`SeedDictService`]'s response type. /// /// The response is `None` when no seed dictionary is currently /// available pub type SeedDictResponse = Option>; impl Service for SeedDictService { type Response = SeedDictResponse; type Error = std::convert::Infallible; type Future = Instrumented>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, _req: SeedDictRequest) -> Self::Future { future::ready(match self.0.get_latest().event { DictionaryUpdate::Invalidate => Ok(None), DictionaryUpdate::New(dict) => Ok(Some(dict)), }) .instrument(error_span!("seed_dict_fetch_request")) } } ================================================ FILE: rust/xaynet-server/src/services/fetchers/sum_dict.rs ================================================ use std::{ sync::Arc, task::{Context, Poll}, }; use futures::future::{self, Ready}; use tower::Service; use tracing::error_span; use tracing_futures::{Instrument, Instrumented}; use crate::state_machine::events::{DictionaryUpdate, EventListener, EventSubscriber}; use xaynet_core::SumDict; /// A service that returns the sum dictionary for the current round. pub struct SumDictService(EventListener>); /// [`SumDictService`]'s request type #[derive(Default, Clone, Eq, PartialEq, Debug)] pub struct SumDictRequest; /// [`SumDictService`]'s response type. /// /// The response is `None` when no sum dictionary is currently /// available pub type SumDictResponse = Option>; impl SumDictService { pub fn new(events: &EventSubscriber) -> Self { Self(events.sum_dict_listener()) } } impl Service for SumDictService { type Response = SumDictResponse; type Error = std::convert::Infallible; type Future = Instrumented>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, _req: SumDictRequest) -> Self::Future { future::ready(match self.0.get_latest().event { DictionaryUpdate::Invalidate => Ok(None), DictionaryUpdate::New(dict) => Ok(Some(dict)), }) .instrument(error_span!("sum_dict_fetch_request")) } } ================================================ FILE: rust/xaynet-server/src/services/messages/decryptor.rs ================================================ use std::{pin::Pin, sync::Arc, task::Poll}; use futures::{future::Future, task::Context}; use rayon::ThreadPool; use tokio::sync::oneshot; use tower::{ limit::concurrency::{future::ResponseFuture, ConcurrencyLimit}, Service, }; use tracing::{debug, info, trace}; use crate::{ services::messages::{BoxedServiceFuture, ServiceError}, state_machine::events::{EventListener, EventSubscriber}, }; use xaynet_core::crypto::EncryptKeyPair; /// A service for decrypting PET messages. /// /// Since this is a CPU-intensive task for large messages, this /// service offloads the processing to a `rayon` thread-pool to avoid /// overloading the tokio thread-pool with blocking tasks. #[derive(Clone)] struct RawDecryptor { /// A listener to retrieve the latest coordinator keys. These are /// necessary for decrypting messages and verifying their /// signature. keys_events: EventListener, /// Thread-pool the CPU-intensive tasks are offloaded to. thread_pool: Arc, } impl Service for RawDecryptor where T: AsRef<[u8]> + Sync + Send + 'static, { type Response = Vec; type Error = ServiceError; #[allow(clippy::type_complexity)] type Future = Pin> + 'static + Send + Sync>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, data: T) -> Self::Future { debug!("retrieving the current keys"); let keys = self.keys_events.get_latest().event; let (tx, rx) = oneshot::channel::>(); trace!("spawning decryption task on threadpool"); self.thread_pool.spawn(move || { info!("decrypting message"); let res = keys .secret .decrypt(data.as_ref(), &keys.public) .map_err(|_| ServiceError::Decrypt); let _ = tx.send(res); }); Box::pin(async move { rx.await.unwrap_or_else(|_| { Err(ServiceError::InternalError( "failed to receive response from thread-pool".to_string(), )) }) }) } } #[derive(Clone)] pub struct Decryptor(ConcurrencyLimit); impl Decryptor { pub fn new(state_machine_events: &EventSubscriber, thread_pool: Arc) -> Self { let limit = thread_pool.current_num_threads(); let keys_events = state_machine_events.keys_listener(); let service = RawDecryptor { keys_events, thread_pool, }; Self(ConcurrencyLimit::new(service, limit)) } } impl Service for Decryptor where T: AsRef<[u8]> + Sync + Send + 'static, { type Response = Vec; type Error = ServiceError; type Future = ResponseFuture>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { as Service>::poll_ready(&mut self.0, cx) } fn call(&mut self, data: T) -> Self::Future { self.0.call(data) } } #[cfg(test)] mod tests { use rayon::ThreadPoolBuilder; use tokio_test::assert_ready; use tower_test::mock::Spawn; use crate::{ services::tests::utils, state_machine::events::{EventPublisher, EventSubscriber}, }; use super::*; fn spawn_svc() -> (EventPublisher, EventSubscriber, Spawn) { let (publisher, subscriber) = utils::new_event_channels(); let thread_pool = Arc::new(ThreadPoolBuilder::new().build().unwrap()); let task = Spawn::new(Decryptor::new(&subscriber, thread_pool)); (publisher, subscriber, task) } #[tokio::test] async fn test_decrypt_fail() { let (_publisher, _subscriber, mut task) = spawn_svc(); assert_ready!(task.poll_ready::>()).unwrap(); let req = vec![0, 1, 2, 3, 4, 5, 6]; match task.call(req).await { Err(ServiceError::Decrypt) => {} _ => panic!("expected decrypt error"), } assert_ready!(task.poll_ready::>()).unwrap(); } #[tokio::test] async fn test_decrypt_ok() { let (_publisher, subscriber, mut task) = spawn_svc(); assert_ready!(task.poll_ready::>()).unwrap(); let round_params = subscriber.params_listener().get_latest().event; let (message, participant_signing_keys) = utils::new_sum_message(&round_params); let serialized_message = utils::serialize_message(&message, &participant_signing_keys); let encrypted_message = utils::encrypt_message(&message, &round_params, &participant_signing_keys); // Call the service let decrypted_message = task.call(encrypted_message).await.unwrap(); assert_eq!(decrypted_message, serialized_message); } } ================================================ FILE: rust/xaynet-server/src/services/messages/error.rs ================================================ use displaydoc::Display; use thiserror::Error; use crate::state_machine::requests::RequestError; use xaynet_core::message::DecodeError; /// Errors for the message parsing service. #[derive(Debug, Display, Error)] pub enum ServiceError { /// Failed to decrypt the message with the coordinator secret key. Decrypt, /// Failed to parse the message: {0}. Parsing(DecodeError), /// Invalid message signature. InvalidMessageSignature, /// Invalid coordinator public key. InvalidCoordinatorPublicKey, /// The message was not expected in the current phase. UnexpectedMessage, // FIXME: we need to refine the state machine errors and the // conversion into a service error /// The state machine failed to process the request: {0}. StateMachine(RequestError), /// Participant is not eligible for sum task. NotSumEligible, /// Participant is not eligible for update task. NotUpdateEligible, /// Internal error: {0}. InternalError(String), } impl From> for ServiceError { fn from(e: Box) -> Self { match e.downcast::() { Ok(e) => *e, Err(e) => ServiceError::InternalError(format!("{}", e)), } } } impl From> for ServiceError { fn from(e: Box) -> Self { ServiceError::from(e as Box) } } ================================================ FILE: rust/xaynet-server/src/services/messages/message_parser.rs ================================================ use std::{convert::TryInto, sync::Arc, task::Poll}; use futures::{future, task::Context}; use rayon::ThreadPool; use tokio::sync::oneshot; use tower::{layer::Layer, limit::concurrency::ConcurrencyLimit, Service, ServiceBuilder}; use tracing::{debug, info, trace, warn}; use crate::{ services::messages::{BoxedServiceFuture, ServiceError}, state_machine::{ events::{EventListener, EventSubscriber}, phases::PhaseName, }, }; use xaynet_core::{ crypto::{EncryptKeyPair, PublicEncryptKey}, message::{FromBytes, Message, MessageBuffer, Tag}, }; /// A type that hold a un-parsed message struct RawMessage { /// The buffer that contains the message to parse buffer: Arc>, } impl Clone for RawMessage { fn clone(&self) -> Self { Self { buffer: self.buffer.clone(), } } } impl From> for RawMessage { fn from(buffer: MessageBuffer) -> Self { RawMessage { buffer: Arc::new(buffer), } } } /// A service that wraps a buffer `T` representing a message into a /// [`RawMessage`] #[derive(Debug, Clone)] struct BufferWrapper(S); impl Service for BufferWrapper where T: AsRef<[u8]> + Send + 'static, S: Service, Response = Message, Error = ServiceError>, S::Future: Sync + Send + 'static, { type Response = Message; type Error = ServiceError; type Future = BoxedServiceFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.0.poll_ready(cx) } fn call(&mut self, req: T) -> Self::Future { debug!("creating a RawMessage request"); match MessageBuffer::new(req) { Ok(buffer) => { let fut = self.0.call(RawMessage::from(buffer)); Box::pin(async move { trace!("calling inner service"); fut.await }) } Err(e) => Box::pin(future::ready(Err(ServiceError::Parsing(e)))), } } } struct BufferWrapperLayer; impl Layer for BufferWrapperLayer { type Service = BufferWrapper; fn layer(&self, service: S) -> BufferWrapper { BufferWrapper(service) } } /// A service that discards messages that are not expected in the current phase #[derive(Debug, Clone)] struct PhaseFilter { /// A listener to retrieve the current phase phase: EventListener, /// Next service to be called next_svc: S, } impl Service> for PhaseFilter where T: AsRef<[u8]> + Send + 'static, S: Service, Response = Message, Error = ServiceError>, S::Future: Sync + Send + 'static, { type Response = Message; type Error = ServiceError; type Future = BoxedServiceFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.next_svc.poll_ready(cx) } fn call(&mut self, req: RawMessage) -> Self::Future { debug!("retrieving the current phase"); let phase = self.phase.get_latest().event; match req.buffer.tag().try_into() { Ok(tag) => match (phase, tag) { (PhaseName::Sum, Tag::Sum) | (PhaseName::Update, Tag::Update) | (PhaseName::Sum2, Tag::Sum2) => { let fut = self.next_svc.call(req); Box::pin(async move { fut.await }) } _ => Box::pin(future::ready(Err(ServiceError::UnexpectedMessage))), }, Err(e) => Box::pin(future::ready(Err(ServiceError::Parsing(e)))), } } } struct PhaseFilterLayer { phase: EventListener, } impl Layer for PhaseFilterLayer { type Service = PhaseFilter; fn layer(&self, service: S) -> PhaseFilter { PhaseFilter { phase: self.phase.clone(), next_svc: service, } } } /// A service for verifying the signature of PET messages /// /// Since this is a CPU-intensive task for large messages, this /// service offloads the processing to a `rayon` thread-pool to avoid /// overloading the tokio thread-pool with blocking tasks. #[derive(Debug, Clone)] struct SignatureVerifier { /// Thread-pool the CPU-intensive tasks are offloaded to. thread_pool: Arc, /// The service to be called after the [`SignatureVerifier`] next_svc: S, } impl Service> for SignatureVerifier where T: AsRef<[u8]> + Sync + Send + 'static, S: Service, Response = Message, Error = ServiceError> + Clone + Sync + Send + 'static, S::Future: Sync + Send + 'static, { type Response = Message; type Error = ServiceError; type Future = BoxedServiceFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.next_svc.poll_ready(cx) } fn call(&mut self, req: RawMessage) -> Self::Future { let (tx, rx) = oneshot::channel::>(); let req_clone = req.clone(); trace!("spawning signature verification task on thread-pool"); self.thread_pool.spawn(move || { let res = match req.buffer.as_ref().as_ref().check_signature() { Ok(()) => { info!("found a valid message signature"); Ok(()) } Err(e) => { warn!("invalid message signature: {:?}", e); Err(ServiceError::InvalidMessageSignature) } }; let _ = tx.send(res); }); let mut next_svc = self.next_svc.clone(); let fut = async move { rx.await.map_err(|_| { ServiceError::InternalError( "failed to receive response from thread-pool".to_string(), ) })??; next_svc.call(req_clone).await }; Box::pin(fut) } } struct SignatureVerifierLayer { thread_pool: Arc, } impl Layer for SignatureVerifierLayer { type Service = ConcurrencyLimit>; fn layer(&self, service: S) -> Self::Service { let limit = self.thread_pool.current_num_threads(); // FIXME: we actually want to limit the concurrency of just // the SignatureVerifier middleware. Right now we're limiting // the whole stack of services. ConcurrencyLimit::new( SignatureVerifier { thread_pool: self.thread_pool.clone(), next_svc: service, }, limit, ) } } /// A service that verifies the coordinator public key embedded in PET /// messsages #[derive(Debug, Clone)] struct CoordinatorPublicKeyValidator { /// A listener to retrieve the latest coordinator keys keys: EventListener, /// Next service to be called next_svc: S, } impl Service> for CoordinatorPublicKeyValidator where T: AsRef<[u8]> + Send + 'static, S: Service, Response = Message, Error = ServiceError>, S::Future: Sync + Send + 'static, { type Response = Message; type Error = ServiceError; type Future = BoxedServiceFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.next_svc.poll_ready(cx) } fn call(&mut self, req: RawMessage) -> Self::Future { debug!("retrieving the current keys"); let coord_pk = self.keys.get_latest().event.public; match PublicEncryptKey::from_byte_slice(&req.buffer.as_ref().as_ref().coordinator_pk()) { Ok(pk) => { if pk != coord_pk { warn!("found an invalid coordinator public key"); Box::pin(future::ready(Err( ServiceError::InvalidCoordinatorPublicKey, ))) } else { info!("found a valid coordinator public key"); let fut = self.next_svc.call(req); Box::pin(async move { fut.await }) } } Err(_) => Box::pin(future::ready(Err( ServiceError::InvalidCoordinatorPublicKey, ))), } } } struct CoordinatorPublicKeyValidatorLayer { keys: EventListener, } impl Layer for CoordinatorPublicKeyValidatorLayer { type Service = CoordinatorPublicKeyValidator; fn layer(&self, service: S) -> CoordinatorPublicKeyValidator { CoordinatorPublicKeyValidator { keys: self.keys.clone(), next_svc: service, } } } #[derive(Debug, Clone)] struct Parser; impl Service> for Parser where T: AsRef<[u8]> + Send + 'static, { type Response = Message; type Error = ServiceError; type Future = future::Ready>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, req: RawMessage) -> Self::Future { let bytes = req.buffer.inner(); future::ready(Message::from_byte_slice(&bytes).map_err(ServiceError::Parsing)) } } type InnerService = BufferWrapper< PhaseFilter>>>, >; #[derive(Debug, Clone)] pub struct MessageParser(InnerService); impl Service for MessageParser where T: AsRef<[u8]> + Sync + Send + 'static, { type Response = Message; type Error = ServiceError; type Future = BoxedServiceFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { >::poll_ready(&mut self.0, cx) } fn call(&mut self, req: T) -> Self::Future { let fut = self.0.call(req); Box::pin(async move { fut.await }) } } impl MessageParser { pub fn new(events: &EventSubscriber, thread_pool: Arc) -> Self { let inner = ServiceBuilder::new() .layer(BufferWrapperLayer) .layer(PhaseFilterLayer { phase: events.phase_listener(), }) .layer(SignatureVerifierLayer { thread_pool }) .layer(CoordinatorPublicKeyValidatorLayer { keys: events.keys_listener(), }) .service(Parser); Self(inner) } } #[cfg(test)] mod tests { use rayon::ThreadPoolBuilder; use tokio_test::assert_ready; use tower_test::mock::Spawn; use super::*; use crate::{ services::tests::utils, state_machine::events::{EventPublisher, EventSubscriber}, }; fn spawn_svc() -> (EventPublisher, EventSubscriber, Spawn) { let (publisher, subscriber) = utils::new_event_channels(); let thread_pool = Arc::new(ThreadPoolBuilder::new().build().unwrap()); let task = Spawn::new(MessageParser::new(&subscriber, thread_pool)); (publisher, subscriber, task) } #[tokio::test] async fn test_valid_request() { let (mut publisher, subscriber, mut task) = spawn_svc(); assert_ready!(task.poll_ready::>()).unwrap(); let round_params = subscriber.params_listener().get_latest().event; let (message, signing_keys) = utils::new_sum_message(&round_params); let serialized_message = utils::serialize_message(&message, &signing_keys); // Simulate the state machine broadcasting the sum phase // (otherwise the request will be rejected by the phase // filter) publisher.broadcast_phase(PhaseName::Sum); // Call the service let mut resp = task.call(serialized_message).await.unwrap(); // The signature should be set. However in `message` it's not been // computed, so we just check that it's there, then set it to // `None` in `resp` assert!(resp.signature.is_some()); resp.signature = None; // Now the comparison should work assert_eq!(resp, message); } #[tokio::test] async fn test_unexpected_message() { let (_publisher, subscriber, mut task) = spawn_svc(); assert_ready!(task.poll_ready::>()).unwrap(); let round_params = subscriber.params_listener().get_latest().event; let (message, signing_keys) = utils::new_sum_message(&round_params); let serialized_message = utils::serialize_message(&message, &signing_keys); let err = task.call(serialized_message).await.unwrap_err(); match err { ServiceError::UnexpectedMessage => {} _ => panic!("expected ServiceError::UnexpectedMessage got {:?}", err), } } } ================================================ FILE: rust/xaynet-server/src/services/messages/mod.rs ================================================ //! This module provides the services for processing PET messages. //! //! There are multiple such services and [`PetMessageHandler`] //! provides a single unifying interface for all of these. mod decryptor; mod error; mod message_parser; mod multipart; mod state_machine; mod task_validator; use std::sync::Arc; use futures::future::poll_fn; use rayon::ThreadPoolBuilder; use tower::Service; use xaynet_core::message::Message; pub use self::error::ServiceError; use self::{ decryptor::Decryptor, message_parser::MessageParser, multipart::MultipartHandler, state_machine::StateMachine, task_validator::TaskValidator, }; use crate::state_machine::{events::EventSubscriber, requests::RequestSender}; impl PetMessageHandler { pub fn new(event_subscriber: &EventSubscriber, requests_tx: RequestSender) -> Self { // TODO: make this configurable. Users should be able to // choose how many threads they want etc. // // TODO: don't unwrap let thread_pool = Arc::new(ThreadPoolBuilder::new().build().unwrap()); let decryptor = Decryptor::new(event_subscriber, thread_pool.clone()); let multipart_handler = MultipartHandler::new(); let message_parser = MessageParser::new(event_subscriber, thread_pool); let task_validator = TaskValidator::new(event_subscriber); let state_machine = StateMachine::new(requests_tx); Self { decryptor, multipart_handler, message_parser, task_validator, state_machine, } } async fn decrypt(&mut self, enc_data: Vec) -> Result, ServiceError> { poll_fn(|cx| >>::poll_ready(&mut self.decryptor, cx)).await?; self.decryptor.call(enc_data).await } async fn parse(&mut self, data: Vec) -> Result { poll_fn(|cx| >>::poll_ready(&mut self.message_parser, cx)) .await?; self.message_parser.call(data).await } async fn handle_multipart( &mut self, message: Message, ) -> Result, ServiceError> { poll_fn(|cx| self.multipart_handler.poll_ready(cx)).await?; self.multipart_handler.call(message).await } async fn validate_task(&mut self, message: Message) -> Result { poll_fn(|cx| self.task_validator.poll_ready(cx)).await?; self.task_validator.call(message).await } async fn process(&mut self, message: Message) -> Result<(), ServiceError> { poll_fn(|cx| self.state_machine.poll_ready(cx)).await?; self.state_machine.call(message).await } pub async fn handle_message(&mut self, enc_data: Vec) -> Result<(), ServiceError> { let raw_message = self.decrypt(enc_data).await?; let message = self.parse(raw_message).await?; match self.handle_multipart(message).await? { Some(message) => { let message = self.validate_task(message).await?; self.process(message).await } None => Ok(()), } } } /// A service that processes requests from the beginning to the /// end. /// /// The processing is divided in three phases: /// /// 1. The raw request (which is just a vector of bytes represented an /// encrypted message) goes through the `MessageParser` service, /// which decrypt the message, validates it, and parses it /// /// 2. The message is passed to the `TaskValidator`, which depending on /// the message type performs some additional checks. The /// `TaskValidator` may also discard the message /// /// 3. Finally, the message is handled by the `StateMachine` service. #[derive(Clone)] pub struct PetMessageHandler { decryptor: Decryptor, multipart_handler: MultipartHandler, message_parser: MessageParser, task_validator: TaskValidator, state_machine: StateMachine, } pub type BoxedServiceFuture = std::pin::Pin< Box> + 'static + Send + Sync>, >; ================================================ FILE: rust/xaynet-server/src/services/messages/multipart/buffer.rs ================================================ use std::{ collections::btree_map::{BTreeMap, IntoIter as BTreeMapIter}, iter::{ExactSizeIterator, Iterator}, vec::IntoIter as VecIter, }; /// A data structure for reading a multipart message pub struct MultipartMessageBuffer { /// message chunks that haven't been read yet remaining_chunks: BTreeMapIter>, /// chunk being read current_chunk: Option>, /// total length of the buffer initial_length: usize, /// number of bytes that have been read consumed: usize, } impl From>> for MultipartMessageBuffer { fn from(map: BTreeMap>) -> Self { let initial_length = map.values().fold(0, |acc, chunk| acc + chunk.len()); Self { remaining_chunks: map.into_iter(), current_chunk: None, initial_length, consumed: 0, } } } // Note that this Iterator implementation could be optimized. We // currently increment a counter for every byte consumed, but we could // exploits the fact that IterVec implements ExactSizeIterator avoid // that. impl Iterator for MultipartMessageBuffer { type Item = u8; fn next(&mut self) -> Option { if self.current_chunk.is_none() { let (_, chunk) = self.remaining_chunks.next()?; self.current_chunk = Some(chunk.into_iter()); return self.next(); } // Per `if` above, `self.current_chunk` is not None match self.current_chunk.as_mut().unwrap().next() { Some(b) => { self.consumed += 1; Some(b) } None => { self.current_chunk = None; self.next() } } } fn size_hint(&self) -> (usize, Option) { let lower_bound = self.initial_length - self.consumed; let upper_bound = self.initial_length - self.consumed; (lower_bound, Some(upper_bound)) } } impl ExactSizeIterator for MultipartMessageBuffer {} #[cfg(test)] mod tests { use super::*; #[test] fn test() { let mut map: BTreeMap> = BTreeMap::new(); map.insert(1, vec![0, 1, 2]); map.insert(2, vec![3]); map.insert(3, vec![4, 5]); let mut iter = MultipartMessageBuffer::from(map); assert_eq!(iter.consumed, 0); assert_eq!(iter.initial_length, 6); assert_eq!(iter.len(), 6); assert!(iter.current_chunk.is_none()); assert_eq!(iter.next(), Some(0)); assert_eq!(iter.consumed, 1); assert_eq!(iter.initial_length, 6); assert_eq!(iter.len(), 5); assert!(iter.current_chunk.is_some()); assert_eq!(iter.next(), Some(1)); assert_eq!(iter.consumed, 2); assert_eq!(iter.initial_length, 6); assert_eq!(iter.len(), 4); assert!(iter.current_chunk.is_some()); assert_eq!(iter.next(), Some(2)); assert_eq!(iter.consumed, 3); assert_eq!(iter.initial_length, 6); assert_eq!(iter.len(), 3); assert!(iter.current_chunk.is_some()); assert_eq!(iter.next(), Some(3)); assert_eq!(iter.consumed, 4); assert_eq!(iter.initial_length, 6); assert_eq!(iter.len(), 2); assert!(iter.current_chunk.is_some()); assert_eq!(iter.next(), Some(4)); assert_eq!(iter.consumed, 5); assert_eq!(iter.initial_length, 6); assert_eq!(iter.len(), 1); assert!(iter.current_chunk.is_some()); assert_eq!(iter.next(), Some(5)); assert_eq!(iter.consumed, 6); assert_eq!(iter.initial_length, 6); assert_eq!(iter.len(), 0); assert!(iter.current_chunk.is_some()); } } ================================================ FILE: rust/xaynet-server/src/services/messages/multipart/mod.rs ================================================ mod buffer; mod service; use std::task::{Context, Poll}; use futures::future::TryFutureExt; use tower::{buffer::Buffer, Service, ServiceBuilder}; use crate::services::messages::ServiceError; use xaynet_core::message::Message; type Inner = Buffer; #[derive(Clone)] pub struct MultipartHandler(Inner); impl Service for MultipartHandler { type Response = Option; type Error = ServiceError; #[allow(clippy::type_complexity)] type Future = futures::future::MapErr< >::Future, fn(>::Error) -> ServiceError, >; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { >::poll_ready(&mut self.0, cx).map_err(ServiceError::from) } fn call(&mut self, req: Message) -> Self::Future { <>::Future>::map_err(self.0.call(req), ServiceError::from) } } impl MultipartHandler { pub fn new() -> Self { Self( ServiceBuilder::new() .buffer(100) .service(service::MultipartHandler::new()), ) } } ================================================ FILE: rust/xaynet-server/src/services/messages/multipart/service.rs ================================================ use std::{ collections::{BTreeMap, HashMap}, task::Poll, }; use futures::{ future::{self, Ready}, task::Context, }; use tower::Service; use tracing::{debug, trace, warn}; use crate::services::messages::{multipart::buffer::MultipartMessageBuffer, ServiceError}; use xaynet_core::{ crypto::{PublicEncryptKey, PublicSigningKey}, message::{Chunk, DecodeError, FromBytes, Message, Payload, Sum, Sum2, Tag, Update}, }; /// A `MessageBuilder` stores chunks of a multipart message. Once it /// has all the chunks, it can be consumed and turned into a /// full-blown [`Message`] (see [`into_message()`]). /// /// [`into_message()`]: MessageBuilder::into_message #[derive(Debug)] #[cfg_attr(test, derive(Clone))] pub struct MessageBuilder { /// Public key of the participant sending the message participant_pk: PublicSigningKey, /// Public key of the coordinator coordinator_pk: PublicEncryptKey, /// Message type tag: Tag, /// The ID of the last chunk is actually the total number of /// chunks this message is made of. last_chunk_id: Option, /// Chunks, ordered by ID data: BTreeMap>, } impl MessageBuilder { /// Create a new [`MessageBuilder`] that contains no chunk. fn new(tag: Tag, participant_pk: PublicSigningKey, coordinator_pk: PublicEncryptKey) -> Self { MessageBuilder { tag, participant_pk, coordinator_pk, data: BTreeMap::new(), last_chunk_id: None, } } /// Return `true` if the message is complete, _i.e._ if the /// builder holds all the chunks. fn has_all_chunks(&self) -> bool { self.last_chunk_id .map(|last_chunk_id| { // The IDs start at 0, hence the + 1 self.data.len() >= (last_chunk_id as usize + 1) }) .unwrap_or(false) } /// Add a chunk. fn add_chunk(&mut self, chunk: Chunk) { let Chunk { id, last, data, .. } = chunk; if last { self.last_chunk_id = Some(id); } self.data.insert(id, data); } /// Aggregate all the chunks. This method should only be called /// when all the chunks are here, otherwise the aggregated message /// will be invalid. fn into_message(self) -> Result { let mut bytes = MultipartMessageBuffer::from(self.data); let payload = match self.tag { Tag::Sum => Sum::from_byte_stream(&mut bytes).map(Into::into)?, Tag::Update => Update::from_byte_stream(&mut bytes).map(Into::into)?, Tag::Sum2 => Sum2::from_byte_stream(&mut bytes).map(Into::into)?, }; let message = Message { signature: None, participant_pk: self.participant_pk, coordinator_pk: self.coordinator_pk, tag: self.tag, is_multipart: false, payload, }; Ok(message) } } /// [`MessageId`] uniquely identifies a multipart message by its ID /// (which uniquely identify a message _for a given participant_), and /// the participant public key. #[derive(Debug, Hash, Eq, PartialEq, Clone)] pub struct MessageId { message_id: u16, participant_pk: PublicSigningKey, } /// A service that handles multipart messages. pub struct MultipartHandler { message_builders: HashMap, } impl MultipartHandler { #[allow(dead_code)] pub fn new() -> Self { Self { message_builders: HashMap::new(), } } } impl Service for MultipartHandler { type Response = Option; type Error = ServiceError; type Future = Ready>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, message: Message) -> Self::Future { // If the message doesn't have the multipart flag, this // service has nothing to do with it. if !message.is_multipart { trace!("message is not multipart, nothing to do"); return ready_ok(Some(message)); } debug!("handling multipart message"); if let Message { tag, participant_pk, coordinator_pk, payload: Payload::Chunk(chunk), .. } = message { let id = MessageId { message_id: chunk.message_id, participant_pk, }; // If we don't have a partial message for this ID, create // an empty one. let mp_message = self.message_builders.entry(id.clone()).or_insert_with(|| { debug!("new multipart message (id = {})", id.message_id); MessageBuilder::new(tag, participant_pk, coordinator_pk) }); // Add the chunk to the partial message mp_message.add_chunk(chunk); // Check if the message is complete, and if so parse it // and return it if mp_message.has_all_chunks() { debug!("received the final message chunk, now parsing the full message"); // This entry exists, because `mp_message` above // refers to it, so it's ok to unwrap. match self.message_builders.remove(&id).unwrap().into_message() { Ok(message) => { debug!("multipart message succesfully parsed"); ready_ok(Some(message)) } Err(e) => { warn!("invalid multipart message: {}", e); ready_err(ServiceError::Parsing(e)) } } } else { ready_ok(None) } } else { // This cannot happen, because parsing have fail panic!("multipart flag is set but payload is not a multipart message"); } } } fn ready_ok(t: T) -> Ready> { future::ready(Ok(t)) } fn ready_err(e: E) -> Ready> { future::ready(Err(e)) } #[cfg(test)] mod tests { use std::iter; use tokio_test::assert_ready; use tower_test::mock::Spawn; use xaynet_core::crypto::{ByteObject, PublicEncryptKey, Signature}; use super::*; fn spawn_svc() -> Spawn { Spawn::new(MultipartHandler::new()) } fn sum() -> (Vec, Sum) { let mut start_byte: u8 = 0xff; let f = move || { start_byte = start_byte.wrapping_add(1) & 0b_0001_1111; Some(start_byte) }; let bytes: Vec = iter::from_fn(f) .take(PublicEncryptKey::LENGTH + Signature::LENGTH) .collect(); let sum = Sum { sum_signature: Signature::from_slice(&bytes[..Signature::LENGTH]).unwrap(), ephm_pk: PublicEncryptKey::from_slice(&bytes[Signature::LENGTH..]).unwrap(), }; (bytes, sum) } fn message_builder() -> MessageBuilder { let participant_pk = PublicSigningKey::zeroed(); let coordinator_pk = PublicEncryptKey::zeroed(); let tag = Tag::Sum; MessageBuilder::new(tag, participant_pk, coordinator_pk) } fn chunks(mut data: Vec) -> (Chunk, Chunk, Chunk, Chunk, Chunk) { // Chunk 1: 1 byte // Chunk 2: 2 bytes // Chunk 3: 3 bytes // Chunk 4: 4 bytes // Chunk 5: 96 - (1 + 2 + 3 + 4) = 86 bytes assert_eq!(data.len(), 96); // 96 - 10 = 86, remains 10 let data5 = data.split_off(10); assert_eq!(data5.len(), 86); assert_eq!(data.len(), 10); // 10 - 6 = 4, remains 6 let data4 = data.split_off(6); assert_eq!(data4.len(), 4); assert_eq!(data.len(), 6); // 6 - 3 = 3, remains 3 let data3 = data.split_off(3); assert_eq!(data3.len(), 3); assert_eq!(data.len(), 3); // 3 - 1 = 2, remains 1 let data2 = data.split_off(1); assert_eq!(data2.len(), 2); assert_eq!(data.len(), 1); let chunk1 = Chunk { id: 0, message_id: 1234, last: false, data, }; let chunk2 = Chunk { id: 1, message_id: 1234, last: false, data: data2, }; let chunk3 = Chunk { id: 2, message_id: 1234, last: false, data: data3, }; let chunk4 = Chunk { id: 3, message_id: 1234, last: false, data: data4, }; let chunk5 = Chunk { id: 4, message_id: 1234, last: true, data: data5, }; (chunk1, chunk2, chunk3, chunk4, chunk5) } #[test] fn test_message_builder_in_order() { let mut msg = message_builder(); let (data, sum) = sum(); let (c1, c2, c3, c4, c5) = chunks(data); assert!(msg.data.is_empty()); assert!(msg.last_chunk_id.is_none()); msg.add_chunk(c1); assert_eq!(msg.data.len(), 1); assert!(msg.last_chunk_id.is_none()); assert!(!msg.has_all_chunks()); msg.add_chunk(c2); assert_eq!(msg.data.len(), 2); assert!(msg.last_chunk_id.is_none()); assert!(!msg.has_all_chunks()); msg.add_chunk(c3); assert_eq!(msg.data.len(), 3); assert!(msg.last_chunk_id.is_none()); assert!(!msg.has_all_chunks()); msg.add_chunk(c4); assert_eq!(msg.data.len(), 4); assert!(msg.last_chunk_id.is_none()); assert!(!msg.has_all_chunks()); msg.add_chunk(c5); assert_eq!(msg.data.len(), 5); assert_eq!(msg.last_chunk_id, Some(4)); assert!(msg.has_all_chunks()); let actual = msg.into_message().unwrap(); let expected = Message::new_sum(PublicSigningKey::zeroed(), PublicEncryptKey::zeroed(), sum); assert_eq!(actual, expected); } #[test] fn test_message_builder_out_of_order() { let mut msg = message_builder(); let (data, sum) = sum(); let (c1, c2, c3, c4, c5) = chunks(data); assert!(msg.data.is_empty()); assert!(msg.last_chunk_id.is_none()); msg.add_chunk(c3); assert_eq!(msg.data.len(), 1); assert!(msg.last_chunk_id.is_none()); assert!(!msg.has_all_chunks()); msg.add_chunk(c1); assert_eq!(msg.data.len(), 2); assert!(msg.last_chunk_id.is_none()); assert!(!msg.has_all_chunks()); msg.add_chunk(c5); assert_eq!(msg.data.len(), 3); assert_eq!(msg.last_chunk_id, Some(4)); assert!(!msg.has_all_chunks()); msg.add_chunk(c2); assert_eq!(msg.data.len(), 4); assert_eq!(msg.last_chunk_id, Some(4)); assert!(!msg.has_all_chunks()); msg.add_chunk(c4); assert_eq!(msg.data.len(), 5); assert_eq!(msg.last_chunk_id, Some(4)); assert!(msg.has_all_chunks()); let actual = msg.into_message().unwrap(); let expected = Message::new_sum(PublicSigningKey::zeroed(), PublicEncryptKey::zeroed(), sum); assert_eq!(actual, expected); } #[tokio::test] async fn message_handler() { let mut task = spawn_svc(); assert_ready!(task.poll_ready()).unwrap(); let coordinator_pk = PublicEncryptKey::from_slice(&[0x00; PublicSigningKey::LENGTH]).unwrap(); // The payload of the message (and therefore the chunks) will // be the same for the two participants. What must differ is // the participant public key in the header. let (data, sum) = sum(); let (c1, c2, c3, c4, c5) = chunks(data.clone()); // A signing key that identifies a first faked participant. let pk1 = PublicSigningKey::from_slice(&[0x11; PublicSigningKey::LENGTH]).unwrap(); // message ID for the message from our fake participant identified by `pk1` let message_id1 = MessageId { message_id: 1234, participant_pk: pk1, }; // function that take a data chunk and create Chunk message // with `pk1` as participant public key in the header let make_message1 = |chunk: &Chunk| Message::new_multipart(pk1, coordinator_pk, chunk.clone(), Tag::Sum); // Do the same thing to fake a second participant: generate a // public key, a message ID, and a function to create messages // originating from that participant let pk2 = PublicSigningKey::from_slice(&[0x22; PublicSigningKey::LENGTH]).unwrap(); let message_id2 = MessageId { message_id: 1234, participant_pk: pk2, }; let make_message2 = |chunk: &Chunk| Message::new_multipart(pk2, coordinator_pk, chunk.clone(), Tag::Sum); // Start of the actual test. Notice that we send the chunks // out of order. assert!(task.call(make_message1(&c3)).await.unwrap().is_none()); assert_eq!(task.get_ref().message_builders.len(), 1); let builder = task.get_ref().message_builders.get(&message_id1).unwrap(); assert_eq!(builder.data.len(), 1); assert!(task.call(make_message2(&c3)).await.unwrap().is_none()); assert_eq!(task.get_ref().message_builders.len(), 2); let builder = task.get_ref().message_builders.get(&message_id2).unwrap(); assert_eq!(builder.data.len(), 1); assert!(task.call(make_message1(&c5)).await.unwrap().is_none()); assert!(task.call(make_message2(&c5)).await.unwrap().is_none()); assert!(task.call(make_message1(&c1)).await.unwrap().is_none()); assert!(task.call(make_message2(&c1)).await.unwrap().is_none()); assert!(task.call(make_message1(&c4)).await.unwrap().is_none()); assert!(task.call(make_message2(&c4)).await.unwrap().is_none()); let builder = task.get_ref().message_builders.get(&message_id1).unwrap(); assert_eq!(builder.data.len(), 4); let builder = task.get_ref().message_builders.get(&message_id2).unwrap(); assert_eq!(builder.data.len(), 4); let res1 = task.call(make_message1(&c2)).await.unwrap().unwrap(); let res2 = task.call(make_message2(&c2)).await.unwrap().unwrap(); assert!(task.get_ref().message_builders.get(&message_id1).is_none()); assert!(task.get_ref().message_builders.get(&message_id2).is_none()); assert_eq!(res1, Message::new_sum(pk1, coordinator_pk, sum.clone())); assert_eq!(res2, Message::new_sum(pk2, coordinator_pk, sum.clone())); } } ================================================ FILE: rust/xaynet-server/src/services/messages/state_machine.rs ================================================ use std::task::Poll; use futures::task::Context; use tower::Service; use xaynet_core::message::Message; use crate::{ services::messages::{BoxedServiceFuture, ServiceError}, state_machine::requests::RequestSender, }; /// A service that hands the requests to the [`StateMachine`] that runs in the background. /// /// [`StateMachine`]: crate::state_machine::StateMachine #[derive(Debug, Clone)] pub struct StateMachine { handle: RequestSender, } impl StateMachine { /// Create a new service with the given handle for forwarding /// requests to the state machine. The handle should be obtained /// via [`init()`]. /// /// [`init()`]: crate::state_machine::initializer::StateMachineInitializer::init pub fn new(handle: RequestSender) -> Self { Self { handle } } } impl Service for StateMachine { type Response = (); type Error = ServiceError; type Future = BoxedServiceFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, req: Message) -> Self::Future { let handle = self.handle.clone(); Box::pin(async move { handle .request(req.into(), tracing::Span::none()) .await .map_err(ServiceError::StateMachine) }) } } ================================================ FILE: rust/xaynet-server/src/services/messages/task_validator.rs ================================================ use std::task::Poll; use futures::{future, task::Context}; use tower::Service; use crate::{ services::messages::ServiceError, state_machine::events::{EventListener, EventSubscriber}, }; use xaynet_core::{ common::RoundParameters, crypto::ByteObject, message::{Message, Payload}, }; /// A service for performing sanity checks and preparing incoming /// requests to be handled by the state machine. #[derive(Clone, Debug)] pub struct TaskValidator { params_listener: EventListener, } impl TaskValidator { pub fn new(subscriber: &EventSubscriber) -> Self { Self { params_listener: subscriber.params_listener(), } } } impl Service for TaskValidator { type Response = Message; type Error = ServiceError; type Future = future::Ready>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, message: Message) -> Self::Future { let (sum_signature, update_signature) = match message.payload { Payload::Sum(ref sum) => (sum.sum_signature, None), Payload::Update(ref update) => (update.sum_signature, Some(update.update_signature)), Payload::Sum2(ref sum2) => (sum2.sum_signature, None), _ => return future::ready(Err(ServiceError::UnexpectedMessage)), }; let params = self.params_listener.get_latest().event; let seed = params.seed.as_slice(); // Check whether the participant is eligible for the sum task let has_valid_sum_signature = message .participant_pk .verify_detached(&sum_signature, &[seed, b"sum"].concat()); let is_summer = has_valid_sum_signature && sum_signature.is_eligible(params.sum); // Check whether the participant is eligible for the update task let has_valid_update_signature = update_signature .map(|sig| { message .participant_pk .verify_detached(&sig, &[seed, b"update"].concat()) }) .unwrap_or(false); let is_updater = !is_summer && has_valid_update_signature && update_signature .map(|sig| sig.is_eligible(params.update)) .unwrap_or(false); match message.payload { Payload::Sum(_) | Payload::Sum2(_) => { if is_summer { future::ready(Ok(message)) } else { future::ready(Err(ServiceError::NotSumEligible)) } } Payload::Update(_) => { if is_updater { future::ready(Ok(message)) } else { future::ready(Err(ServiceError::NotUpdateEligible)) } } _ => future::ready(Err(ServiceError::UnexpectedMessage)), } } } #[cfg(test)] mod tests { use tokio_test::assert_ready; use tower_test::mock::Spawn; use crate::{ services::tests::utils, state_machine::{ events::{EventPublisher, EventSubscriber}, phases::PhaseName, }, }; use super::*; fn spawn_svc() -> (EventPublisher, EventSubscriber, Spawn) { let (publisher, subscriber) = utils::new_event_channels(); let task = Spawn::new(TaskValidator::new(&subscriber)); (publisher, subscriber, task) } #[tokio::test] async fn test_sum_ok() { let (mut publisher, subscriber, mut task) = spawn_svc(); let mut round_params = subscriber.params_listener().get_latest().event; // make sure everyone is eligible round_params.sum = 1.0; publisher.broadcast_params(round_params.clone()); publisher.broadcast_phase(PhaseName::Sum); let (message, _) = utils::new_sum_message(&round_params); assert_ready!(task.poll_ready()).unwrap(); let resp = task.call(message.clone()).await.unwrap(); assert_eq!(resp, message); } #[tokio::test] async fn test_sum_not_eligible() { let (mut publisher, subscriber, mut task) = spawn_svc(); let mut round_params = subscriber.params_listener().get_latest().event; // make sure no-one is eligible round_params.sum = 0.0; publisher.broadcast_params(round_params.clone()); publisher.broadcast_phase(PhaseName::Sum); let (message, _) = utils::new_sum_message(&round_params); assert_ready!(task.poll_ready()).unwrap(); let err = task.call(message).await.unwrap_err(); match err { ServiceError::NotSumEligible => {} _ => panic!("expected ServiceError::NotSumEligible got {:?}", err), } } } ================================================ FILE: rust/xaynet-server/src/services/mod.rs ================================================ //! This module implements the services the PET protocol provides. //! //! There are two main types of services: //! //! - the services for fetching data broadcasted by the state //! machine. These services are implemented in the [`fetchers`] //! module //! - the services for processing PET message are provided by the //! [`messages`] module. pub mod fetchers; pub mod messages; #[cfg(test)] mod tests; ================================================ FILE: rust/xaynet-server/src/services/tests/fetchers.rs ================================================ use std::{collections::HashMap, sync::Arc}; use tokio_test::assert_ready; use tower_test::mock::Spawn; use crate::{ services::{ fetchers::{ ModelRequest, ModelService, RoundParamsRequest, RoundParamsService, SeedDictRequest, SeedDictService, SumDictRequest, SumDictService, }, tests::utils::{mask_config, new_event_channels}, }, state_machine::events::{DictionaryUpdate, ModelUpdate}, }; use xaynet_core::{ common::{RoundParameters, RoundSeed}, crypto::{ByteObject, PublicEncryptKey, PublicSigningKey}, mask::{EncryptedMaskSeed, Model}, SeedDict, SumDict, UpdateSeedDict, }; #[tokio::test] async fn test_model_svc() { let (mut publisher, subscriber) = new_event_channels(); let mut task = Spawn::new(ModelService::new(&subscriber)); assert_ready!(task.poll_ready()).unwrap(); let resp = task.call(ModelRequest).await; assert_eq!(resp, Ok(None)); let model = Arc::new(Model::from(vec![])); publisher.broadcast_model(ModelUpdate::New(model.clone())); assert_ready!(task.poll_ready()).unwrap(); let resp = task.call(ModelRequest).await; assert_eq!(resp, Ok(Some(model))); publisher.broadcast_model(ModelUpdate::Invalidate); assert_ready!(task.poll_ready()).unwrap(); let resp = task.call(ModelRequest).await; assert_eq!(resp, Ok(None)); } #[tokio::test] async fn test_round_params_svc() { let (mut publisher, subscriber) = new_event_channels(); let initial_params = subscriber.params_listener().get_latest().event; let mut task = Spawn::new(RoundParamsService::new(&subscriber)); assert_ready!(task.poll_ready()).unwrap(); let resp = task.call(RoundParamsRequest).await; assert_eq!(resp, Ok(initial_params)); let params = RoundParameters { pk: PublicEncryptKey::fill_with(0x11), sum: 0.42, update: 0.42, seed: RoundSeed::fill_with(0x11), mask_config: mask_config().into(), model_length: 42, }; publisher.broadcast_params(params.clone()); assert_ready!(task.poll_ready()).unwrap(); let resp = task.call(RoundParamsRequest).await; assert_eq!(resp, Ok(params)); } fn dummy_seed_dict() -> SeedDict { let mut dict = HashMap::new(); dict.insert(PublicSigningKey::fill_with(0xaa), dummy_update_dict()); dict.insert(PublicSigningKey::fill_with(0xbb), dummy_update_dict()); dict } fn dummy_update_dict() -> UpdateSeedDict { let mut dict = HashMap::new(); dict.insert( PublicSigningKey::fill_with(0x11), EncryptedMaskSeed::fill_with(0x11), ); dict.insert( PublicSigningKey::fill_with(0x22), EncryptedMaskSeed::fill_with(0x22), ); dict } #[tokio::test] async fn test_seed_dict_svc() { let (mut publisher, subscriber) = new_event_channels(); let mut task = Spawn::new(SeedDictService::new(&subscriber)); assert_ready!(task.poll_ready()).unwrap(); let resp = task.call(SeedDictRequest).await; assert_eq!(resp, Ok(None)); let seed_dict = Arc::new(dummy_seed_dict()); publisher.broadcast_seed_dict(DictionaryUpdate::New(seed_dict.clone())); assert_ready!(task.poll_ready()).unwrap(); let resp = task.call(SeedDictRequest).await; assert_eq!(resp, Ok(Some(seed_dict))); publisher.broadcast_seed_dict(DictionaryUpdate::Invalidate); assert_ready!(task.poll_ready()).unwrap(); let resp = task.call(SeedDictRequest).await; assert_eq!(resp, Ok(None)); } fn dummy_sum_dict() -> SumDict { let mut dict = HashMap::new(); dict.insert( PublicSigningKey::fill_with(0xaa), PublicEncryptKey::fill_with(0xcc), ); dict.insert( PublicSigningKey::fill_with(0xbb), PublicEncryptKey::fill_with(0xdd), ); dict } #[tokio::test] async fn test_sum_dict_svc() { let (mut publisher, subscriber) = new_event_channels(); let mut task = Spawn::new(SumDictService::new(&subscriber)); assert_ready!(task.poll_ready()).unwrap(); let resp = task.call(SumDictRequest).await; assert_eq!(resp, Ok(None)); let sum_dict = Arc::new(dummy_sum_dict()); publisher.broadcast_sum_dict(DictionaryUpdate::New(sum_dict.clone())); assert_ready!(task.poll_ready()).unwrap(); let resp = task.call(SumDictRequest).await; assert_eq!(resp, Ok(Some(sum_dict))); publisher.broadcast_sum_dict(DictionaryUpdate::Invalidate); assert_ready!(task.poll_ready()).unwrap(); let resp = task.call(SumDictRequest).await; assert_eq!(resp, Ok(None)); } ================================================ FILE: rust/xaynet-server/src/services/tests/mod.rs ================================================ mod fetchers; pub mod utils; ================================================ FILE: rust/xaynet-server/src/services/tests/utils.rs ================================================ use crate::state_machine::{ events::{EventPublisher, EventSubscriber, ModelUpdate}, phases::PhaseName, }; use xaynet_core::{ common::{RoundParameters, RoundSeed}, crypto::{ByteObject, EncryptKeyPair, PublicEncryptKey, SigningKeyPair}, mask::{self, MaskConfig}, message::{Message, Sum}, }; pub fn mask_config() -> MaskConfig { MaskConfig { group_type: mask::GroupType::Integer, data_type: mask::DataType::F32, bound_type: mask::BoundType::B0, model_type: mask::ModelType::M3, } } /// Create an [`EventPublisher`]/[`EventSubscriber`] pair with default /// values similar to those produced in practice when instantiating a /// new coordinator. pub fn new_event_channels() -> (EventPublisher, EventSubscriber) { let keys = EncryptKeyPair::generate(); let params = RoundParameters { pk: keys.public, sum: 0.0, update: 0.0, seed: RoundSeed::generate(), mask_config: mask_config().into(), model_length: 0, }; let phase = PhaseName::Idle; let round_id = 0; let model = ModelUpdate::Invalidate; EventPublisher::init(round_id, keys, params, phase, model) } /// Simulate a participant generating keys and crafting a valid sum /// message for the given round parameters. The keys generated by the /// participants are returned along with the message. pub fn new_sum_message(round_params: &RoundParameters) -> (Message, SigningKeyPair) { let signing_keys = SigningKeyPair::generate(); let sum = Sum { sum_signature: signing_keys .secret .sign_detached(&[round_params.seed.as_slice(), b"sum"].concat()), ephm_pk: PublicEncryptKey::generate(), }; let message = Message::new_sum(signing_keys.public, round_params.pk, sum); (message, signing_keys) } /// Sign and encrypt the given message using the given round /// parameters and particpant keys. pub fn encrypt_message( message: &Message, round_params: &RoundParameters, participant_signing_keys: &SigningKeyPair, ) -> Vec { let serialized = serialize_message(message, participant_signing_keys); round_params.pk.encrypt(&serialized[..]) } pub fn serialize_message(message: &Message, participant_signing_keys: &SigningKeyPair) -> Vec { let mut buf = vec![0; message.buffer_length()]; message.to_bytes(&mut buf, &participant_signing_keys.secret); buf } ================================================ FILE: rust/xaynet-server/src/settings/mod.rs ================================================ //! Loading and validation of settings. //! //! Values defined in the configuration file can be overridden by environment variables. Examples of //! configuration files can be found in the `configs/` directory located in the repository root. #[cfg(feature = "tls")] use std::path::PathBuf; use std::{fmt, path::Path}; use config::{Config, ConfigError, Environment, File}; use displaydoc::Display; use redis::{ConnectionInfo, IntoConnectionInfo}; use serde::{ de::{self, Deserializer, Visitor}, Deserialize, }; use thiserror::Error; use tracing_subscriber::filter::EnvFilter; use validator::{Validate, ValidationError, ValidationErrors}; use xaynet_core::{ mask::{BoundType, DataType, GroupType, MaskConfig, ModelType}, message::{SUM_COUNT_MIN, UPDATE_COUNT_MIN}, }; #[cfg(feature = "model-persistence")] #[cfg_attr(docsrs, doc(cfg(feature = "model-persistence")))] pub mod s3; #[cfg(feature = "model-persistence")] pub use self::{s3::RestoreSettings, s3::S3BucketsSettings, s3::S3Settings}; #[derive(Debug, Display, Error)] /// An error related to loading and validation of settings. pub enum SettingsError { /// Configuration loading failed: {0}. Loading(#[from] ConfigError), /// Validation failed: {0}. Validation(#[from] ValidationErrors), } #[derive(Debug, Validate, Deserialize)] /// The combined settings. /// /// Each section in the configuration file corresponds to the identically named settings field. pub struct Settings { pub api: ApiSettings, #[validate] pub pet: PetSettings, pub mask: MaskSettings, pub log: LoggingSettings, pub model: ModelSettings, #[validate] pub metrics: MetricsSettings, pub redis: RedisSettings, #[cfg(feature = "model-persistence")] #[validate] pub s3: S3Settings, #[cfg(feature = "model-persistence")] #[validate] pub restore: RestoreSettings, #[serde(default)] pub trust_anchor: TrustAnchorSettings, } impl Settings { /// Loads and validates the settings via a configuration file. /// /// # Errors /// Fails when the loading of the configuration file or its validation failed. pub fn new(path: impl AsRef) -> Result { let settings: Settings = Self::load(path)?; settings.validate()?; Ok(settings) } fn load(path: impl AsRef) -> Result { Config::builder() .add_source(File::from(path.as_ref())) .add_source(Environment::with_prefix("xaynet").separator("__")) .build()? .try_deserialize() } } /// The PET protocol count settings. #[derive(Debug, Deserialize, Clone, Copy)] #[cfg_attr(test, derive(PartialEq))] pub struct PetSettingsCount { /// The minimal number of participants selected in a phase. pub min: u64, /// The maximal number of participants selected in a phase. pub max: u64, } /// The PET protocol time settings. #[derive(Debug, Deserialize, Clone, Copy)] #[cfg_attr(test, derive(PartialEq))] pub struct PetSettingsTime { /// The minimal amount of time reserved for a phase. pub min: u64, /// The maximal amount of time reserved for a phase. pub max: u64, } /// The PET protocol `sum` phase settings. #[derive(Debug, Deserialize, Clone, Copy)] #[cfg_attr(test, derive(PartialEq))] pub struct PetSettingsSum { /// The probability of participants selected for preparing and computing the aggregated mask. /// The value must be between `0` and `1` (i.e. `0 < sum.prob < 1`). /// /// # Examples /// /// **TOML** /// ```text /// [pet.sum] /// prob = 0.01 /// ``` /// /// **Environment variable** /// ```text /// XAYNET__PET__SUM__PROB=0.01 /// ``` pub prob: f64, /// The minimal and maximal number of participants selected for preparing the unmasking. /// /// The minimal value must be greater or equal to `1` (i.e. `sum.count.min >= 1`) for the PET /// protocol to function correctly. The maximal value must be greater or equal to the minimal /// value (i.e. `sum.count.min <= sum.count.max`). No more than `sum.count.max` messages will be /// processed in the `sum` phase if the `sum.time.min` has not yet elapsed. /// /// # Examples /// /// **TOML** /// ```text /// [pet.sum.count] /// min = 10 /// max = 100 /// ``` /// /// **Environment variable** /// ```text /// XAYNET__PET__SUM__COUNT__MIN=10 /// XAYNET__PET__SUM__COUNT__MAX=100 /// ``` pub count: PetSettingsCount, /// The minimal and maximal amount of time reserved for processing messages in the `sum` phase, /// in seconds. /// /// Once the minimal time has passed, the `sum` phase ends *as soon as* `sum.count.min` messages /// have been processed. Set this higher to allow for the possibility of more than /// `sum.count.min` messages to be processed in the `sum` phase. Set the maximal time lower to /// allow for the processing of `sum.count.min` messages to time-out sooner in the `sum` phase. /// /// # Examples /// /// **TOML** /// ```text /// [pet.sum.time] /// min = 5 /// max = 3600 /// ``` /// /// **Environment variable** /// ```text /// XAYNET__PET__SUM__TIME__MIN=5 /// XAYNET__PET__SUM__TIME__MAX=3600 /// ``` pub time: PetSettingsTime, } /// The PET protocol `update` phase settings. #[derive(Debug, Deserialize, Clone, Copy)] #[cfg_attr(test, derive(PartialEq))] pub struct PetSettingsUpdate { /// The probability of participants selected for submitting an updated local model for /// aggregation. The value must be between `0` and `1` (i.e. `0 < update.prob <= 1`). Here, `1` /// is included to be able to express that every participant who is not a sum participant must be /// an update participant. /// /// # Examples /// /// **TOML** /// ```text /// [pet.update] /// prob = 0.1 /// ``` /// /// **Environment variable** /// ```text /// XAYNET__PET__UPDATE__PROB=0.1 /// ``` pub prob: f64, /// The minimal and maximal number of participants selected for submitting an updated local /// model for aggregation. /// /// The minimal value must be greater or equal to `3` (i.e. `update.count.min >= 3`) for the PET /// protocol to function correctly. The maximal value must be greater or equal to the minimal /// value (i.e. `update.count.min <= update.count.max`). No more than `update.count.max` /// messages will be processed in the `update` phase if the `update.time.min` has not yet /// elapsed. /// /// # Examples /// /// **TOML** /// ```text /// [pet.update.count] /// min = 100 /// max = 10000 /// ``` /// /// **Environment variable** /// ```text /// XAYNET__PET__UPDATE__COUNT__MIN=100 /// XAYNET__PET__UPDATE__COUNT__MAX=10000 /// ``` pub count: PetSettingsCount, /// The minimal and maximal amount of time reserved for processing messages in the `update` /// phase, in seconds. /// /// Once the minimal time has passed, the `update` phase ends *as soon as* `update.count.min` /// messages have been processed. Set this higher to allow for the possibility of more than /// `update.count.min` messages to be processed in the `update` phase. Set the maximal time /// lower to allow for the processing of `update.count.min` messages to time-out sooner in the /// `update` phase. /// /// # Examples /// /// **TOML** /// ```text /// [pet.update.time] /// min = 10 /// max = 3600 /// ``` /// /// **Environment variable** /// ```text /// XAYNET__PET__UPDATE__TIME__MIN=10 /// XAYNET__PET__UPDATE__TIME__MAX=10 /// ``` pub time: PetSettingsTime, } /// The PET protocol `sum2` phase settings. #[derive(Debug, Deserialize, Clone, Copy)] #[cfg_attr(test, derive(PartialEq))] pub struct PetSettingsSum2 { /// The minimal and maximal number of participants selected for submitting the aggregated masks. /// /// The minimal value must be greater or equal to `1` (i.e. `sum2.count.min >= 1`) for the PET /// protocol to function correctly and less or equal to the maximal value of the `sum` phase /// (i.e. `sum2.count.sum <= sum.count.max`). The maximal value must be greater or equal to the /// minimal value (i.e. `sum2.count.min <= sum2.count.max`) and less or equal to the maximal /// value of the `sum` phase (i.e. `sum2.count.max <= sum.count.max`). No more than /// `sum2.count.max` messages will be processed in the `sum2` phase if the `sum2.time.min` has /// not yet elapsed. /// /// # Examples /// /// **TOML** /// ```text /// [pet.sum2.count] /// min = 10 /// max = 100 /// ``` /// /// **Environment variable** /// ```text /// XAYNET__PET__SUM2__COUNT__MIN=10 /// XAYNET__PET__SUM2__COUNT__MAX=100 /// ``` pub count: PetSettingsCount, /// The minimal and maximal amount of time reserved for processing messages in the `sum2` phase, /// in seconds. /// /// Once the minimal time has passed, the `sum2` phase ends *as soon as* `sum2.count.min` /// messages have been processed. Set this higher to allow for the possibility of more than /// `sum2.count.min` messages to be processed in the `sum2` phase. Set the maximal time lower to /// allow for the processing of `sum2.count.min` messages to time-out sooner in the `sum2` /// phase. /// /// # Examples /// /// **TOML** /// ```text /// [pet.sum2.time] /// min = 5 /// max = 3600 /// ``` /// /// **Environment variable** /// ```text /// XAYNET__PET__SUM2__TIME__MIN=5 /// XAYNET__PET__SUM2__TIME__MAX=3600 /// ``` pub time: PetSettingsTime, } /// The PET protocol settings. #[derive(Debug, Validate, Deserialize, Clone, Copy)] #[cfg_attr(test, derive(PartialEq))] #[validate(schema(function = "validate_pet"))] pub struct PetSettings { /// The PET settings for the `sum` phase. pub sum: PetSettingsSum, /// The PET settings for the `update` phase. pub update: PetSettingsUpdate, /// The PET settings for the `sum2` phase. pub sum2: PetSettingsSum2, } impl PetSettings { /// Checks the PET settings. fn validate_pet(&self) -> Result<(), ValidationError> { self.validate_counts()?; self.validate_times()?; self.validate_probabilities() } /// Checks the validity of phase count ranges. fn validate_counts(&self) -> Result<(), ValidationError> { // the validate attribute only accepts literals, therefore we check the invariants here if SUM_COUNT_MIN <= self.sum.count.min && self.sum.count.min <= self.sum.count.max && UPDATE_COUNT_MIN <= self.update.count.min && self.update.count.min <= self.update.count.max && SUM_COUNT_MIN <= self.sum2.count.min && self.sum2.count.min <= self.sum2.count.max && self.sum2.count.min <= self.sum.count.max && self.sum2.count.max <= self.sum.count.max { Ok(()) } else { Err(ValidationError::new("invalid phase count range(s)")) } } /// Checks the validity of phase time ranges. fn validate_times(&self) -> Result<(), ValidationError> { if self.sum.time.min <= self.sum.time.max && self.update.time.min <= self.update.time.max && self.sum2.time.min <= self.sum2.time.max { Ok(()) } else { Err(ValidationError::new("invalid phase time range(s)")) } } /// Checks the validity of fraction ranges including pathological cases of deadlocks. fn validate_probabilities(&self) -> Result<(), ValidationError> { if 0. < self.sum.prob && self.sum.prob < 1. && 0. < self.update.prob && self.update.prob <= 1. && 0. < self.sum.prob + self.update.prob - self.sum.prob * self.update.prob && self.sum.prob + self.update.prob - self.sum.prob * self.update.prob <= 1. { Ok(()) } else { Err(ValidationError::new("starvation")) } } } /// A wrapper for validate derive. fn validate_pet(s: &PetSettings) -> Result<(), ValidationError> { s.validate_pet() } #[derive(Debug, Deserialize, Clone)] #[cfg_attr( feature = "tls", derive(Validate), validate(schema(function = "validate_api")) )] /// REST API settings. /// /// Requires at least one of the following arguments if the `tls` feature is enabled: /// - `tls_certificate` together with `tls_key` for TLS server authentication // - `tls_client_auth` for TLS client authentication pub struct ApiSettings { /// The address to which the REST API should be bound. /// /// # Examples /// /// **TOML** /// ```text /// [api] /// bind_address = "0.0.0.0:8081" /// # or /// bind_address = "127.0.0.1:8081" /// ``` /// /// **Environment variable** /// ```text /// XAYNET__API__BIND_ADDRESS=127.0.0.1:8081 /// ``` pub bind_address: std::net::SocketAddr, #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] /// The path to the server certificate to enable TLS server authentication. Leave this out to /// disable server authentication. If this is present, then `tls_key` must also be present. /// /// Requires the `tls` feature to be enabled. /// /// # Examples /// /// **TOML** /// ```text /// [api] /// tls_certificate = path/to/tls/files/cert.pem /// ``` /// /// **Environment variable** /// ```text /// XAYNET__API__TLS_CERTIFICATE=path/to/tls/files/certificate.pem /// ``` pub tls_certificate: Option, #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] /// The path to the server private key to enable TLS server authentication. Leave this out to /// disable server authentication. If this is present, then `tls_certificate` must also be /// present. /// /// Requires the `tls` feature to be enabled. /// /// # Examples /// /// **TOML** /// ```text /// [api] /// tls_key = path/to/tls/files/key.rsa /// ``` /// /// **Environment variable** /// ```text /// XAYNET__API__TLS_KEY=path/to/tls/files/key.rsa /// ``` pub tls_key: Option, #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] /// The path to the trust anchor to enable TLS client authentication. Leave this out to disable /// client authentication. /// /// Requires the `tls` feature to be enabled. /// /// # Examples /// /// **TOML** /// ```text /// [api] /// tls_client_auth = path/to/tls/files/trust_anchor.pem /// ``` /// /// **Environment variable** /// ```text /// XAYNET__API__TLS_CLIENT_AUTH=path/to/tls/files/trust_anchor.pem /// ``` pub tls_client_auth: Option, } #[cfg(feature = "tls")] impl ApiSettings { /// Checks API settings. fn validate_api(&self) -> Result<(), ValidationError> { match (&self.tls_certificate, &self.tls_key, &self.tls_client_auth) { (Some(_), Some(_), _) | (None, None, Some(_)) => Ok(()), _ => Err(ValidationError::new("invalid tls settings")), } } } /// A wrapper for validate derive. #[cfg(feature = "tls")] fn validate_api(s: &ApiSettings) -> Result<(), ValidationError> { s.validate_api() } #[derive(Debug, Validate, Deserialize, Clone, Copy)] #[cfg_attr(test, derive(PartialEq, Eq))] /// Masking settings. pub struct MaskSettings { /// The order of the finite group. /// /// # Examples /// /// **TOML** /// ```text /// [mask] /// group_type = "Integer" /// ``` /// /// **Environment variable** /// ```text /// XAYNET__MASK__GROUP_TYPE=Integer /// ``` pub group_type: GroupType, /// The data type of the numbers to be masked. /// /// # Examples /// /// **TOML** /// ```text /// [mask] /// data_type = "F32" /// ``` /// /// **Environment variable** /// ```text /// XAYNET__MASK__DATA_TYPE=F32 /// ``` pub data_type: DataType, /// The bounds of the numbers to be masked. /// /// # Examples /// /// **TOML** /// ```text /// [mask] /// bound_type = "B0" /// ``` /// /// **Environment variable** /// ```text /// XAYNET__MASK__BOUND_TYPE=B0 /// ``` pub bound_type: BoundType, /// The maximum number of models to be aggregated. /// /// # Examples /// /// **TOML** /// ```text /// [mask] /// model_type = "M3" /// ``` /// /// **Environment variable** /// ```text /// XAYNET__MASK__MODEL_TYPE=M3 /// ``` pub model_type: ModelType, } impl From for MaskConfig { fn from( MaskSettings { group_type, data_type, bound_type, model_type, }: MaskSettings, ) -> MaskConfig { MaskConfig { group_type, data_type, bound_type, model_type, } } } #[derive(Debug, Deserialize, Clone)] #[cfg_attr(test, derive(PartialEq))] /// Model settings. pub struct ModelSettings { /// The expected length of the model. The model length corresponds to the number of elements. /// This value is used to validate the uniform length of the submitted models/masks. /// /// # Examples /// /// **TOML** /// ```text /// [model] /// length = 100 /// ``` /// /// **Environment variable** /// ```text /// XAYNET__MODEL__LENGTH=100 /// ``` pub length: usize, } #[derive(Debug, Deserialize, Validate)] /// Metrics settings. pub struct MetricsSettings { #[validate] /// Settings for the InfluxDB backend. pub influxdb: InfluxSettings, } #[derive(Debug, Deserialize, Validate)] /// InfluxDB settings. pub struct InfluxSettings { #[validate(url)] /// The URL where InfluxDB is running. /// /// # Examples /// /// **TOML** /// ```text /// [metrics.influxdb] /// url = "http://localhost:8086" /// ``` /// /// **Environment variable** /// ```text /// XAYNET__METRICS__INFLUXDB__URL=http://localhost:8086 /// ``` pub url: String, /// The InfluxDB database name. /// /// # Examples /// /// **TOML** /// ```text /// [metrics.influxdb] /// db = "test" /// ``` /// /// **Environment variable** /// ```text /// XAYNET__METRICS__INFLUXDB__DB=test /// ``` pub db: String, } #[derive(Debug, Deserialize)] /// Redis settings. pub struct RedisSettings { /// The URL where Redis is running. /// /// The format of the URL is `redis://[][:@][:port][/]`. /// /// # Examples /// /// **TOML** /// ```text /// [redis] /// url = "redis://127.0.0.1/" /// ``` /// /// **Environment variable** /// ```text /// XAYNET__REDIS__URL=redis://127.0.0.1/ /// ``` #[serde(deserialize_with = "deserialize_redis_url")] pub url: ConnectionInfo, } fn deserialize_redis_url<'de, D>(deserializer: D) -> Result where D: Deserializer<'de>, { struct ConnectionInfoVisitor; impl<'de> Visitor<'de> for ConnectionInfoVisitor { type Value = ConnectionInfo; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { write!( formatter, "redis://[][:@][:port][/]" ) } fn visit_str(self, value: &str) -> Result where E: de::Error, { value .into_connection_info() .map_err(|_| de::Error::invalid_value(serde::de::Unexpected::Str(value), &self)) } } deserializer.deserialize_str(ConnectionInfoVisitor) } #[derive(Debug, Deserialize, Validate)] /// Trust anchor settings. pub struct TrustAnchorSettings {} // Default value for the global models bucket impl Default for TrustAnchorSettings { fn default() -> Self { Self {} } } #[derive(Debug, Deserialize)] /// Logging settings. pub struct LoggingSettings { /// A comma-separated list of logging directives. More information about logging directives /// can be found [here]. /// /// # Examples /// /// **TOML** /// ```text /// [log] /// filter = "info" /// ``` /// /// **Environment variable** /// ```text /// XAYNET__LOG__FILTER=info /// ``` /// /// [here]: https://docs.rs/tracing-subscriber/0.2.15/tracing_subscriber/filter/struct.EnvFilter.html#directives #[serde(deserialize_with = "deserialize_env_filter")] pub filter: EnvFilter, } fn deserialize_env_filter<'de, D>(deserializer: D) -> Result where D: Deserializer<'de>, { struct EnvFilterVisitor; impl<'de> Visitor<'de> for EnvFilterVisitor { type Value = EnvFilter; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { write!(formatter, "a valid tracing filter directive: https://docs.rs/tracing-subscriber/0.2.6/tracing_subscriber/filter/struct.EnvFilter.html#directives") } fn visit_str(self, value: &str) -> Result where E: de::Error, { EnvFilter::try_new(value) .map_err(|_| de::Error::invalid_value(serde::de::Unexpected::Str(value), &self)) } } deserializer.deserialize_str(EnvFilterVisitor) } #[cfg(test)] mod tests { use super::*; impl Default for PetSettings { fn default() -> Self { Self { sum: PetSettingsSum { prob: 0.01, count: PetSettingsCount { min: 10, max: 100 }, time: PetSettingsTime { min: 0, max: 604800, }, }, update: PetSettingsUpdate { prob: 0.1, count: PetSettingsCount { min: 100, max: 10000, }, time: PetSettingsTime { min: 0, max: 604800, }, }, sum2: PetSettingsSum2 { count: PetSettingsCount { min: 10, max: 100 }, time: PetSettingsTime { min: 0, max: 604800, }, }, } } } impl Default for MaskSettings { fn default() -> Self { Self { group_type: GroupType::Prime, data_type: DataType::F32, bound_type: BoundType::B0, model_type: ModelType::M3, } } } #[test] fn test_settings_new() { assert!(Settings::new("../../configs/config.toml").is_ok()); assert!(Settings::new("").is_err()); } #[test] fn test_validate_pet() { assert!(PetSettings::default().validate_pet().is_ok()); } #[test] fn test_validate_pet_counts() { assert_eq!(SUM_COUNT_MIN, 1); assert_eq!(UPDATE_COUNT_MIN, 3); let mut pet = PetSettings::default(); pet.sum.count.min = 0; assert!(pet.validate().is_err()); let mut pet = PetSettings::default(); pet.sum.count.min = 11; pet.sum.count.max = 10; assert!(pet.validate().is_err()); let mut pet = PetSettings::default(); pet.update.count.min = 2; assert!(pet.validate().is_err()); let mut pet = PetSettings::default(); pet.update.count.min = 11; pet.update.count.max = 10; assert!(pet.validate().is_err()); let mut pet = PetSettings::default(); pet.sum2.count.min = 0; assert!(pet.validate().is_err()); let mut pet = PetSettings::default(); pet.sum2.count.min = 11; pet.sum2.count.max = 10; assert!(pet.validate().is_err()); let mut pet = PetSettings::default(); pet.sum2.count.min = 11; pet.sum.count.max = 10; assert!(pet.validate().is_err()); let mut pet = PetSettings::default(); pet.sum2.count.max = 11; pet.sum.count.max = 10; assert!(pet.validate().is_err()); } #[test] fn test_validate_pet_times() { let mut pet = PetSettings::default(); pet.sum.time.min = 2; pet.sum.time.max = 1; assert!(pet.validate().is_err()); let mut pet = PetSettings::default(); pet.update.time.min = 2; pet.update.time.max = 1; assert!(pet.validate().is_err()); let mut pet = PetSettings::default(); pet.sum2.time.min = 2; pet.sum2.time.max = 1; assert!(pet.validate().is_err()); } #[test] fn test_validate_pet_probabilities() { let mut pet = PetSettings::default(); pet.sum.prob = 0.; assert!(pet.validate().is_err()); let mut pet = PetSettings::default(); pet.sum.prob = 1.; assert!(pet.validate().is_err()); let mut pet = PetSettings::default(); pet.update.prob = 0.; assert!(pet.validate().is_err()); let mut pet = PetSettings::default(); pet.update.prob = 1. + f64::EPSILON; assert!(pet.validate().is_err()); } #[cfg(feature = "tls")] #[test] fn test_validate_api() { let bind_address = ([0, 0, 0, 0], 0).into(); let some_path = Some(std::path::PathBuf::new()); assert!(ApiSettings { bind_address, tls_certificate: some_path.clone(), tls_key: some_path.clone(), tls_client_auth: some_path.clone(), } .validate() .is_ok()); assert!(ApiSettings { bind_address, tls_certificate: some_path.clone(), tls_key: some_path.clone(), tls_client_auth: None, } .validate() .is_ok()); assert!(ApiSettings { bind_address, tls_certificate: None, tls_key: None, tls_client_auth: some_path.clone(), } .validate() .is_ok()); assert!(ApiSettings { bind_address, tls_certificate: some_path.clone(), tls_key: None, tls_client_auth: some_path.clone(), } .validate() .is_err()); assert!(ApiSettings { bind_address, tls_certificate: None, tls_key: some_path.clone(), tls_client_auth: some_path.clone(), } .validate() .is_err()); assert!(ApiSettings { bind_address, tls_certificate: some_path.clone(), tls_key: None, tls_client_auth: None, } .validate() .is_err()); assert!(ApiSettings { bind_address, tls_certificate: None, tls_key: some_path, tls_client_auth: None, } .validate() .is_err()); assert!(ApiSettings { bind_address, tls_certificate: None, tls_key: None, tls_client_auth: None, } .validate() .is_err()); } } ================================================ FILE: rust/xaynet-server/src/settings/s3.rs ================================================ //! S3 settings. use std::fmt; use fancy_regex::Regex; use rusoto_core::Region; use serde::{ de::{self, value, Deserializer, Visitor}, Deserialize, }; use validator::{Validate, ValidationError}; #[derive(Debug, Validate, Deserialize)] /// S3 settings. pub struct S3Settings { /// The [access key ID](https://docs.aws.amazon.com/general/latest/gr/aws-sec-cred-types.html). /// /// # Examples /// /// **TOML** /// ```text /// [s3] /// access_key = "AKIAIOSFODNN7EXAMPLE" /// ``` /// /// **Environment variable** /// ```text /// XAYNET__S3__ACCESS_KEY=AKIAIOSFODNN7EXAMPLE /// ``` pub access_key: String, /// The [secret access key](https://docs.aws.amazon.com/general/latest/gr/aws-sec-cred-types.html). /// /// # Examples /// /// **TOML** /// ```text /// [s3] /// secret_access_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" /// ``` /// /// **Environment variable** /// ```text /// XAYNET__S3__SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY /// ``` pub secret_access_key: String, /// The Regional AWS endpoint. /// /// The region is specified using the [Region code](https://docs.aws.amazon.com/general/latest/gr/rande.html#regional-endpoints) /// /// # Examples /// /// **TOML** /// ```text /// [s3] /// region = ["eu-west-1"] /// ``` /// /// **Environment variable** /// ```text /// XAYNET__S3__REGION="eu-west-1" /// ``` /// /// To connect to AWS-compatible services such as Minio, you need to specify a custom region. /// /// # Examples /// /// **TOML** /// ```text /// [s3] /// region = ["minio", "http://localhost:8000"] /// ``` /// /// **Environment variable** /// ```text /// XAYNET__S3__REGION="minio http://localhost:8000" /// ``` #[serde(deserialize_with = "deserialize_s3_region")] pub region: Region, #[validate] #[serde(default)] pub buckets: S3BucketsSettings, } #[derive(Debug, Validate, Deserialize)] /// S3 buckets settings. pub struct S3BucketsSettings { /// The bucket name in which the global models are stored. /// Defaults to `global-models`. /// /// Please follow the [rules for bucket naming](https://docs.aws.amazon.com/AmazonS3/latest/dev/BucketRestrictions.html) /// when creating the name. /// /// # Examples /// /// **TOML** /// ```text /// [s3.buckets] /// global_models = "global-models" /// ``` /// /// **Environment variable** /// ```text /// XAYNET__S3__BUCKETS__GLOBAL_MODELS="global-models" /// ``` #[validate(custom = "validate_s3_bucket_name")] pub global_models: String, } // Default value for the global models bucket impl Default for S3BucketsSettings { fn default() -> Self { Self { global_models: String::from("global-models"), } } } // Validates the bucket name // [Rules for AWS bucket naming](https://docs.aws.amazon.com/AmazonS3/latest/dev/BucketRestrictions.html) fn validate_s3_bucket_name(bucket_name: &str) -> Result<(), ValidationError> { // https://stackoverflow.com/questions/50480924/regex-for-s3-bucket-name#comment104807676_58248645 // I had to use fancy_regex here because the std regex does not support `look-around` let re = Regex::new(r"(?!^(\d{1,3}\.){3}\d{1,3}$)(^[a-z0-9]([a-z0-9-]*(\.[a-z0-9])?)*$(? Ok(()), Ok(false) => Err(ValidationError::new("invalid bucket name\n See here: https://docs.aws.amazon.com/AmazonS3/latest/dev/BucketRestrictions.html")), // something went wrong with the regex engine Err(_) => Err(ValidationError::new("can not validate bucket name")), } } // A small wrapper to support the list type for environment variable values. // config-rs always converts a environment variable value to a string // https://github.com/mehcode/config-rs/blob/master/src/env.rs#L114 . // Strings however, are not supported by the deserializer of rusoto_core::Region (only sequences). // Therefore we use S3RegionVisitor to implement `visit_str` and thus support // the deserialization of rusoto_core::Region from strings. fn deserialize_s3_region<'de, D>(deserializer: D) -> Result where D: Deserializer<'de>, { struct S3RegionVisitor; impl<'de> Visitor<'de> for S3RegionVisitor { type Value = Region; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("sequence of \"name Optional\"") } // FIXME: a copy of https://rusoto.github.io/rusoto/src/rusoto_core/region.rs.html#185 // I haven't managed to create a sequence and call `self.visit_seq(seq)`. fn visit_str(self, value: &str) -> Result where E: de::Error, { let mut seq = value.split_whitespace(); let name: &str = seq .next() .ok_or_else(|| de::Error::custom("region is missing name"))?; let endpoint: Option<&str> = seq.next(); match (name, endpoint) { (name, Some(endpoint)) => Ok(Region::Custom { name: name.to_string(), endpoint: endpoint.to_string(), }), (name, None) => name.parse().map_err(de::Error::custom), } } // delegate the call for sequences to the deserializer of rusoto_core::Region fn visit_seq(self, seq: A) -> Result where A: de::SeqAccess<'de>, { Deserialize::deserialize(value::SeqAccessDeserializer::new(seq)) } } deserializer.deserialize_any(S3RegionVisitor) } #[derive(Debug, Deserialize, Validate)] /// Restore settings. pub struct RestoreSettings { /// If set to `false`, the restoring of coordinator state is prevented. /// Instead, the state is reset and the coordinator is started with the /// settings of the configuration file. /// /// # Examples /// /// **TOML** /// ```text /// [restore] /// enable = true /// ``` /// /// **Environment variable** /// ```text /// XAYNET__RESTORE__ENABLE=false /// ``` pub enable: bool, } #[cfg(test)] mod tests { use super::*; use crate::settings::Settings; use config::{Config, ConfigError, Environment, File, FileFormat}; use serial_test::serial; impl Settings { fn load_from_str(string: &str) -> Result { Config::builder() .add_source(File::from_str(string, FileFormat::Toml)) .add_source(Environment::with_prefix("xaynet").separator("__")) .build()? .try_deserialize() } } struct ConfigBuilder { config: String, } impl ConfigBuilder { fn new() -> Self { Self { config: String::new(), } } fn build(self) -> String { self.config } fn with_log(mut self) -> Self { let log = r#" [log] filter = "xaynet=debug,http=warn,info" "#; self.config.push_str(log); self } fn with_api(mut self) -> Self { let api = r#" [api] bind_address = "127.0.0.1:8081" tls_certificate = "/app/ssl/tls.pem" tls_key = "/app/ssl/tls.key" "#; self.config.push_str(api); self } fn with_pet(mut self) -> Self { let pet = r#" [pet.sum] prob = 0.5 count = { min = 1, max = 100 } time = { min = 5, max = 3600 } [pet.update] prob = 0.9 count = { min = 3, max = 10000 } time = { min = 10, max = 3600 } [pet.sum2] count = { min = 1, max = 100 } time = { min = 5, max = 3600 } "#; self.config.push_str(pet); self } fn with_mask(mut self) -> Self { let mask = r#" [mask] group_type = "Prime" data_type = "F32" bound_type = "B0" model_type = "M3" "#; self.config.push_str(mask); self } fn with_model(mut self) -> Self { let model = r#" [model] length = 4 "#; self.config.push_str(model); self } fn with_metrics(mut self) -> Self { let metrics = r#" [metrics.influxdb] url = "http://influxdb:8086" db = "metrics" "#; self.config.push_str(metrics); self } fn with_redis(mut self) -> Self { let redis = r#" [redis] url = "redis://127.0.0.1/" "#; self.config.push_str(redis); self } fn with_s3(mut self) -> Self { let s3 = r#" [s3] access_key = "minio" secret_access_key = "minio123" region = ["minio", "http://localhost:9000"] "#; self.config.push_str(s3); self } fn with_s3_buckets(mut self) -> Self { let s3_buckets = r#" [s3.buckets] global_models = "global-models-toml" "#; self.config.push_str(s3_buckets); self } fn with_restore(mut self) -> Self { let restore = r#" [restore] enable = true "#; self.config.push_str(restore); self } fn with_custom(mut self, custom_config: &str) -> Self { self.config.push_str(custom_config); self } } #[test] fn test_validate_s3_bucket_name() { // I took the examples from https://docs.aws.amazon.com/AmazonS3/latest/dev/BucketRestrictions.html // valid names assert!(validate_s3_bucket_name("docexamplebucket").is_ok()); assert!(validate_s3_bucket_name("log-delivery-march-2020").is_ok()); assert!(validate_s3_bucket_name("my-hosted-content").is_ok()); // valid but not recommended names assert!(validate_s3_bucket_name("docexamplewebsite.com").is_ok()); assert!(validate_s3_bucket_name("www.docexamplewebsite.com").is_ok()); assert!(validate_s3_bucket_name("my.example.s3.bucket").is_ok()); // invalid names assert!(validate_s3_bucket_name("doc_example_bucket").is_err()); assert!(validate_s3_bucket_name("DocExampleBucket").is_err()); assert!(validate_s3_bucket_name("doc-example-bucket-").is_err()); } #[test] #[serial] fn test_s3_bucket_name_default() { let config = ConfigBuilder::new() .with_log() .with_api() .with_pet() .with_mask() .with_model() .with_metrics() .with_redis() .with_restore() .with_s3() .build(); let settings = Settings::load_from_str(&config).unwrap(); assert_eq!( settings.s3.buckets.global_models, S3BucketsSettings::default().global_models ) } #[test] #[serial] fn test_s3_bucket_name_toml_overrides_default() { let config = ConfigBuilder::new() .with_log() .with_api() .with_pet() .with_mask() .with_model() .with_metrics() .with_redis() .with_restore() .with_s3() .with_s3_buckets() .build(); let settings = Settings::load_from_str(&config).unwrap(); assert_eq!(settings.s3.buckets.global_models, "global-models-toml") } #[test] #[serial] fn test_s3_bucket_name_env_overrides_toml_and_default() { let config = ConfigBuilder::new() .with_log() .with_api() .with_pet() .with_mask() .with_model() .with_metrics() .with_redis() .with_restore() .with_s3() .with_s3_buckets() .build(); std::env::set_var("XAYNET__S3__BUCKETS__GLOBAL_MODELS", "global-models-env"); let settings = Settings::load_from_str(&config).unwrap(); assert_eq!(settings.s3.buckets.global_models, "global-models-env"); std::env::remove_var("XAYNET__S3__BUCKETS__GLOBAL_MODELS"); } #[test] #[serial] fn test_s3_bucket_name_env_overrides_default() { let config = ConfigBuilder::new() .with_log() .with_api() .with_pet() .with_mask() .with_model() .with_metrics() .with_redis() .with_restore() .with_s3() .build(); std::env::set_var("XAYNET__S3__BUCKETS__GLOBAL_MODELS", "global-models-env"); let settings = Settings::load_from_str(&config).unwrap(); assert_eq!(settings.s3.buckets.global_models, "global-models-env"); std::env::remove_var("XAYNET__S3__BUCKETS__GLOBAL_MODELS"); } #[test] #[serial] fn test_s3_region_toml() { let region = r#" [s3] access_key = "minio" secret_access_key = "minio123" region = ["eu-west-1"] "#; let config = ConfigBuilder::new() .with_log() .with_api() .with_pet() .with_mask() .with_model() .with_metrics() .with_redis() .with_restore() .with_custom(region) .build(); let settings = Settings::load_from_str(&config).unwrap(); assert!(matches!(settings.s3.region, Region::EuWest1)); } #[test] #[serial] fn test_s3_custom_region_toml() { let config = ConfigBuilder::new() .with_log() .with_api() .with_pet() .with_mask() .with_model() .with_metrics() .with_redis() .with_restore() .with_s3() .build(); let settings = Settings::load_from_str(&config).unwrap(); assert!(matches!( settings.s3.region, Region::Custom { name, endpoint } if name == "minio" && endpoint == "http://localhost:9000" )); } #[test] #[serial] fn test_s3_region_env() { let config = ConfigBuilder::new() .with_log() .with_api() .with_pet() .with_mask() .with_model() .with_metrics() .with_redis() .with_restore() .with_s3() .build(); std::env::set_var("XAYNET__S3__REGION", "eu-west-1"); let settings = Settings::load_from_str(&config).unwrap(); assert!(matches!(settings.s3.region, Region::EuWest1)); std::env::remove_var("XAYNET__S3__REGION"); } #[test] #[serial] fn test_restore() { let no_restore = r#" [restore] enable = false "#; let config = ConfigBuilder::new() .with_log() .with_api() .with_pet() .with_mask() .with_model() .with_metrics() .with_redis() .with_s3() .with_custom(no_restore) .build(); let settings = Settings::load_from_str(&config).unwrap(); assert!(!settings.restore.enable); } #[test] #[serial] fn test_s3_custom_region_env() { let config = ConfigBuilder::new() .with_log() .with_api() .with_pet() .with_mask() .with_model() .with_metrics() .with_redis() .with_restore() .with_s3() .build(); std::env::set_var("XAYNET__S3__REGION", "minio-env http://localhost:8000"); let settings = Settings::load_from_str(&config).unwrap(); assert!(matches!( settings.s3.region, Region::Custom { name, endpoint } if name == "minio-env" && endpoint == "http://localhost:8000" )); std::env::remove_var("XAYNET__S3__REGION"); } } ================================================ FILE: rust/xaynet-server/src/state_machine/coordinator.rs ================================================ //! Coordinator state and round parameter types. use serde::{Deserialize, Serialize}; use crate::settings::{ MaskSettings, ModelSettings, PetSettings, PetSettingsCount, PetSettingsSum, PetSettingsSum2, PetSettingsTime, PetSettingsUpdate, }; use xaynet_core::{ common::{RoundParameters, RoundSeed}, crypto::{ByteObject, EncryptKeyPair}, mask::MaskConfig, }; /// The phase count parameters. #[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] pub struct CountParameters { /// The minimal number of required messages. pub min: u64, /// The maximal number of accepted messages. pub max: u64, } impl From for CountParameters { fn from(count: PetSettingsCount) -> Self { let PetSettingsCount { min, max } = count; Self { min, max } } } /// The phase time parameters. #[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] pub struct TimeParameters { /// The minimal amount of time (in seconds) reserved for processing messages. pub min: u64, /// The maximal amount of time (in seconds) permitted for processing messages. pub max: u64, } impl From for TimeParameters { fn from(time: PetSettingsTime) -> Self { let PetSettingsTime { min, max } = time; Self { min, max } } } /// The phase parameters. #[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] pub struct PhaseParameters { /// The number of messages. pub count: CountParameters, /// The amount of time for processing messages. pub time: TimeParameters, } impl From for PhaseParameters { fn from(sum: PetSettingsSum) -> Self { let PetSettingsSum { count, time, .. } = sum; Self { count: count.into(), time: time.into(), } } } impl From for PhaseParameters { fn from(update: PetSettingsUpdate) -> Self { let PetSettingsUpdate { count, time, .. } = update; Self { count: count.into(), time: time.into(), } } } impl From for PhaseParameters { fn from(sum2: PetSettingsSum2) -> Self { let PetSettingsSum2 { count, time } = sum2; Self { count: count.into(), time: time.into(), } } } /// The coordinator state. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct CoordinatorState { /// The credentials of the coordinator. pub keys: EncryptKeyPair, /// Internal ID used to identify a round pub round_id: u64, /// The round parameters. pub round_params: RoundParameters, /// The sum phase parameters. pub sum: PhaseParameters, /// The update phase parameters. pub update: PhaseParameters, /// The sum2 phase parameters. pub sum2: PhaseParameters, } impl CoordinatorState { pub fn new( pet_settings: PetSettings, mask_settings: MaskSettings, model_settings: ModelSettings, ) -> Self { let keys = EncryptKeyPair::generate(); let round_params = RoundParameters { pk: keys.public, sum: pet_settings.sum.prob, update: pet_settings.update.prob, seed: RoundSeed::zeroed(), mask_config: MaskConfig::from(mask_settings).into(), model_length: model_settings.length, }; let round_id = 0; Self { keys, round_params, round_id, sum: pet_settings.sum.into(), update: pet_settings.update.into(), sum2: pet_settings.sum2.into(), } } } ================================================ FILE: rust/xaynet-server/src/state_machine/events.rs ================================================ //! This module provides the `StateMachine`, `Events`, `EventSubscriber` and `EventPublisher` types. use std::sync::Arc; use tokio::sync::watch; use crate::state_machine::phases::PhaseName; use xaynet_core::{ common::RoundParameters, crypto::EncryptKeyPair, mask::Model, SeedDict, SumDict, }; /// An event emitted by the coordinator. #[derive(Debug, Clone, PartialEq, Eq)] pub struct Event { /// Metadata that associates this event to the round in which it is /// emitted. pub round_id: u64, /// The event itself pub event: E, } // FIXME: should we simply use `Option`s here? /// Global model update event. #[derive(Debug, Clone, PartialEq)] pub enum ModelUpdate { Invalidate, New(Arc), } /// Dictionary update event. #[derive(Debug, Clone, Eq, PartialEq)] pub enum DictionaryUpdate { Invalidate, New(Arc), } /// A convenience type to emit any coordinator event. #[derive(Debug)] pub struct EventPublisher { /// Round ID that is attached to all the requests. round_id: u64, keys_tx: EventBroadcaster, params_tx: EventBroadcaster, phase_tx: EventBroadcaster, model_tx: EventBroadcaster, sum_dict_tx: EventBroadcaster>, seed_dict_tx: EventBroadcaster>, } /// The `EventSubscriber` hands out `EventListener`s for any /// coordinator event. #[derive(Debug)] pub struct EventSubscriber { keys_rx: EventListener, params_rx: EventListener, phase_rx: EventListener, model_rx: EventListener, sum_dict_rx: EventListener>, seed_dict_rx: EventListener>, } impl EventPublisher { /// Initialize a new event publisher with the given initial events. pub fn init( round_id: u64, keys: EncryptKeyPair, params: RoundParameters, phase: PhaseName, model: ModelUpdate, ) -> (Self, EventSubscriber) { let (keys_tx, keys_rx) = watch::channel::>(Event { round_id, event: keys, }); let (params_tx, params_rx) = watch::channel::>(Event { round_id, event: params, }); let (phase_tx, phase_rx) = watch::channel::>(Event { round_id, event: phase, }); let (model_tx, model_rx) = watch::channel::>(Event { round_id, event: model, }); let (sum_dict_tx, sum_dict_rx) = watch::channel::>>(Event { round_id, event: DictionaryUpdate::Invalidate, }); let (seed_dict_tx, seed_dict_rx) = watch::channel::>>(Event { round_id, event: DictionaryUpdate::Invalidate, }); let publisher = EventPublisher { round_id, keys_tx: keys_tx.into(), params_tx: params_tx.into(), phase_tx: phase_tx.into(), model_tx: model_tx.into(), sum_dict_tx: sum_dict_tx.into(), seed_dict_tx: seed_dict_tx.into(), }; let subscriber = EventSubscriber { keys_rx: keys_rx.into(), params_rx: params_rx.into(), phase_rx: phase_rx.into(), model_rx: model_rx.into(), sum_dict_rx: sum_dict_rx.into(), seed_dict_rx: seed_dict_rx.into(), }; (publisher, subscriber) } /// Set the round ID that is attached to the events the publisher broadcasts. pub fn set_round_id(&mut self, id: u64) { self.round_id = id; } fn event(&self, event: T) -> Event { Event { round_id: self.round_id, event, } } /// Emit a keys event pub fn broadcast_keys(&mut self, keys: EncryptKeyPair) { let _ = self.keys_tx.broadcast(self.event(keys)); } /// Emit a round parameters event pub fn broadcast_params(&mut self, params: RoundParameters) { let _ = self.params_tx.broadcast(self.event(params)); } /// Emit a phase event pub fn broadcast_phase(&mut self, phase: PhaseName) { let _ = self.phase_tx.broadcast(self.event(phase)); } /// Emit a model event pub fn broadcast_model(&mut self, update: ModelUpdate) { let _ = self.model_tx.broadcast(self.event(update)); } /// Emit a sum dictionary update pub fn broadcast_sum_dict(&mut self, update: DictionaryUpdate) { let _ = self.sum_dict_tx.broadcast(self.event(update)); } /// Emit a seed dictionary update pub fn broadcast_seed_dict(&mut self, update: DictionaryUpdate) { let _ = self.seed_dict_tx.broadcast(self.event(update)); } } impl EventSubscriber { /// Get a listener for keys events. Callers must be careful not to /// leak the secret key they receive, since that would compromise /// the security of the coordinator. pub fn keys_listener(&self) -> EventListener { self.keys_rx.clone() } /// Get a listener for round parameters events pub fn params_listener(&self) -> EventListener { self.params_rx.clone() } /// Get a listener for new phase events pub fn phase_listener(&self) -> EventListener { self.phase_rx.clone() } /// Get a listener for new model events pub fn model_listener(&self) -> EventListener { self.model_rx.clone() } /// Get a listener for sum dictionary updates pub fn sum_dict_listener(&self) -> EventListener> { self.sum_dict_rx.clone() } /// Get a listener for seed dictionary updates pub fn seed_dict_listener(&self) -> EventListener> { self.seed_dict_rx.clone() } } /// A listener for coordinator events. It can be used to either /// retrieve the latest `Event` emitted by the coordinator (with /// `EventListener::get_latest`). #[derive(Debug, Clone)] pub struct EventListener(watch::Receiver>); impl From>> for EventListener { fn from(receiver: watch::Receiver>) -> Self { EventListener(receiver) } } impl EventListener where E: Clone, { pub fn get_latest(&self) -> Event { self.0.borrow().clone() } #[cfg(test)] pub async fn changed(&mut self) -> Result<(), watch::error::RecvError> { self.0.changed().await } } /// A channel to send `Event` to all the `EventListener`. #[derive(Debug)] pub struct EventBroadcaster(watch::Sender>); impl EventBroadcaster { /// Send `event` to all the `EventListener` fn broadcast(&self, event: Event) { // We don't care whether there's a listener or not let _ = self.0.send(event); } } impl From>> for EventBroadcaster { fn from(sender: watch::Sender>) -> Self { Self(sender) } } ================================================ FILE: rust/xaynet-server/src/state_machine/initializer.rs ================================================ //! A state machine initializer. use displaydoc::Display; use thiserror::Error; #[cfg(feature = "model-persistence")] use tracing::{debug, info}; #[cfg(feature = "model-persistence")] use crate::settings::RestoreSettings; use crate::{ settings::{MaskSettings, ModelSettings, PetSettings}, state_machine::{ coordinator::CoordinatorState, events::{EventPublisher, EventSubscriber, ModelUpdate}, phases::{Idle, PhaseName, PhaseState, Shared}, requests::{RequestReceiver, RequestSender}, StateMachine, }, storage::{Storage, StorageError}, }; #[cfg(feature = "model-persistence")] use xaynet_core::mask::Model; type StateMachineInitializationResult = Result; /// Errors which can occur during the initialization of the [`StateMachine`]. #[derive(Debug, Display, Error)] pub enum StateMachineInitializationError { /// Initializing crypto library failed. CryptoInit, /// Fetching coordinator state failed: {0}. FetchCoordinatorState(StorageError), /// Deleting coordinator data failed: {0}. DeleteCoordinatorData(StorageError), /// Fetching latest global model id failed: {0}. FetchLatestGlobalModelId(StorageError), /// Fetching global model failed: {0}. FetchGlobalModel(StorageError), /// Global model is unavailable: {0}. GlobalModelUnavailable(String), /// Global model is invalid: {0}. GlobalModelInvalid(String), } /// The state machine initializer that initializes a new state machine. pub struct StateMachineInitializer { pet_settings: PetSettings, mask_settings: MaskSettings, model_settings: ModelSettings, #[cfg(feature = "model-persistence")] restore_settings: RestoreSettings, store: T, } impl StateMachineInitializer { /// Creates a new [`StateMachineInitializer`]. pub fn new( pet_settings: PetSettings, mask_settings: MaskSettings, model_settings: ModelSettings, #[cfg(feature = "model-persistence")] restore_settings: RestoreSettings, store: T, ) -> Self { Self { pet_settings, mask_settings, model_settings, #[cfg(feature = "model-persistence")] restore_settings, store, } } // Initializes a new [`StateMachine`] with its components. fn init_state_machine( self, coordinator_state: CoordinatorState, global_model: ModelUpdate, ) -> (StateMachine, RequestSender, EventSubscriber) { let (event_publisher, event_subscriber) = EventPublisher::init( coordinator_state.round_id, coordinator_state.keys.clone(), coordinator_state.round_params.clone(), PhaseName::Idle, global_model, ); let (request_rx, request_tx) = RequestReceiver::new(); let shared = Shared::new(coordinator_state, event_publisher, request_rx, self.store); let state_machine = StateMachine::from(PhaseState::::new(shared)); (state_machine, request_tx, event_subscriber) } } impl StateMachineInitializer where T: Storage, { #[cfg(not(feature = "model-persistence"))] /// Initializes a new [`StateMachine`] with the given settings. pub async fn init( mut self, ) -> StateMachineInitializationResult<(StateMachine, RequestSender, EventSubscriber)> { // crucial: init must be called before anything else in this module sodiumoxide::init().or(Err(StateMachineInitializationError::CryptoInit))?; let (coordinator_state, global_model) = { self.from_settings().await? }; Ok(self.init_state_machine(coordinator_state, global_model)) } // Creates a new [`CoordinatorState`] from the given settings and deletes // all coordinator data. Should only be called for the first start // or if we need to perform reset. pub(in crate::state_machine) async fn from_settings( &mut self, ) -> StateMachineInitializationResult<(CoordinatorState, ModelUpdate)> { self.store .delete_coordinator_data() .await .map_err(StateMachineInitializationError::DeleteCoordinatorData)?; Ok(( CoordinatorState::new( self.pet_settings, self.mask_settings, self.model_settings.clone(), ), ModelUpdate::Invalidate, )) } } #[cfg(feature = "model-persistence")] #[cfg_attr(docsrs, doc(cfg(feature = "model-persistence")))] impl StateMachineInitializer where T: Storage, { /// Initializes a new [`StateMachine`] by trying to restore the previous coordinator state /// along with the latest global model. After a successful initialization, the state machine /// always starts from a new round. This means that the round id is increased by one. /// If the state machine is reset during the initialization, the state machine starts /// with the round id `1`. /// /// # Behavior /// ![](https://mermaid.ink/svg/eyJjb2RlIjoic2VxdWVuY2VEaWFncmFtXG4gICAgYWx0IHJlc3RvcmUuZW5hYmxlID0gZmFsc2VcbiAgICAgICAgQ29vcmRpbmF0b3ItPj4rUmVkaXM6IGZsdXNoIGRiXG4gICAgICAgIE5vdGUgb3ZlciBDb29yZGluYXRvcixSZWRpczogc3RhcnQgZnJvbSBzZXR0aW5nc1xuICAgIGVsc2VcbiAgICAgICAgQ29vcmRpbmF0b3ItPj4rUmVkaXM6IGdldCBzdGF0ZVxuICAgICAgICBSZWRpcy0tPj4tQ29vcmRpbmF0b3I6IHN0YXRlXG4gICAgICAgIGFsdCBzdGF0ZSBub24tZXhpc3RlbnRcbiAgICAgICAgICAgIENvb3JkaW5hdG9yLT4-K1JlZGlzOiBmbHVzaCBkYlxuICAgICAgICAgICAgTm90ZSBvdmVyIENvb3JkaW5hdG9yLFJlZGlzOiBzdGFydCBmcm9tIHNldHRpbmdzXG4gICAgICAgIGVsc2Ugc3RhdGUgZXhpc3RcbiAgICAgICAgICAgIENvb3JkaW5hdG9yLT4-K1JlZGlzOiBnZXQgbGF0ZXN0IGdsb2JhbCBtb2RlbCBpZFxuICAgICAgICAgICAgUmVkaXMtLT4-LUNvb3JkaW5hdG9yOiBnbG9iYWwgbW9kZWwgaWRcbiAgICAgICAgICAgIGFsdCBnbG9iYWwgbW9kZWwgaWQgbm9uLWV4aXN0ZW50XG4gICAgICAgICAgICAgICAgTm90ZSBvdmVyIENvb3JkaW5hdG9yLFMzOiByZXN0b3JlIGNvb3JkaW5hdG9yIHdpdGggbGF0ZXN0IHN0YXRlIGJ1dCB3aXRob3V0IGEgZ2xvYmFsIG1vZGVsXG4gICAgICAgICAgICBlbHNlIGdsb2JhbCBtb2RlbCBpZCBleGlzdFxuICAgICAgICAgICAgICBDb29yZGluYXRvci0-PitTMzogZ2V0IGdsb2JhbCBtb2RlbFxuICAgICAgICAgICAgICBTMy0tPj4tQ29vcmRpbmF0b3I6IGdsb2JhbCBtb2RlbFxuICAgICAgICAgICAgICBhbHQgZ2xvYmFsIG1vZGVsIG5vbi1leGlzdGVudFxuICAgICAgICAgICAgICAgIE5vdGUgb3ZlciBDb29yZGluYXRvcixTMzogZXhpdCB3aXRoIGVycm9yXG4gICAgICAgICAgICAgIGVsc2UgZ2xvYmFsIG1vZGVsIGV4aXN0XG4gICAgICAgICAgICAgICAgTm90ZSBvdmVyIENvb3JkaW5hdG9yLFMzOiByZXN0b3JlIGNvb3JkaW5hdG9yIHdpdGggbGF0ZXN0IHN0YXRlIGFuZCBsYXRlc3QgZ2xvYmFsIG1vZGVsXG4gICAgICAgICAgICAgIGVuZFxuICAgICAgICAgICAgZW5kXG4gICAgICAgICAgZW5kXG4gICAgICAgIGVuZCIsIm1lcm1haWQiOnsidGhlbWUiOiJkZWZhdWx0IiwidGhlbWVWYXJpYWJsZXMiOnsiYmFja2dyb3VuZCI6IndoaXRlIiwicHJpbWFyeUNvbG9yIjoiI0VDRUNGRiIsInNlY29uZGFyeUNvbG9yIjoiI2ZmZmZkZSIsInRlcnRpYXJ5Q29sb3IiOiJoc2woODAsIDEwMCUsIDk2LjI3NDUwOTgwMzklKSIsInByaW1hcnlCb3JkZXJDb2xvciI6ImhzbCgyNDAsIDYwJSwgODYuMjc0NTA5ODAzOSUpIiwic2Vjb25kYXJ5Qm9yZGVyQ29sb3IiOiJoc2woNjAsIDYwJSwgODMuNTI5NDExNzY0NyUpIiwidGVydGlhcnlCb3JkZXJDb2xvciI6ImhzbCg4MCwgNjAlLCA4Ni4yNzQ1MDk4MDM5JSkiLCJwcmltYXJ5VGV4dENvbG9yIjoiIzEzMTMwMCIsInNlY29uZGFyeVRleHRDb2xvciI6IiMwMDAwMjEiLCJ0ZXJ0aWFyeVRleHRDb2xvciI6InJnYig5LjUwMDAwMDAwMDEsIDkuNTAwMDAwMDAwMSwgOS41MDAwMDAwMDAxKSIsImxpbmVDb2xvciI6IiMzMzMzMzMiLCJ0ZXh0Q29sb3IiOiIjMzMzIiwibWFpbkJrZyI6IiNFQ0VDRkYiLCJzZWNvbmRCa2ciOiIjZmZmZmRlIiwiYm9yZGVyMSI6IiM5MzcwREIiLCJib3JkZXIyIjoiI2FhYWEzMyIsImFycm93aGVhZENvbG9yIjoiIzMzMzMzMyIsImZvbnRGYW1pbHkiOiJcInRyZWJ1Y2hldCBtc1wiLCB2ZXJkYW5hLCBhcmlhbCIsImZvbnRTaXplIjoiMTZweCIsImxhYmVsQmFja2dyb3VuZCI6IiNlOGU4ZTgiLCJub2RlQmtnIjoiI0VDRUNGRiIsIm5vZGVCb3JkZXIiOiIjOTM3MERCIiwiY2x1c3RlckJrZyI6IiNmZmZmZGUiLCJjbHVzdGVyQm9yZGVyIjoiI2FhYWEzMyIsImRlZmF1bHRMaW5rQ29sb3IiOiIjMzMzMzMzIiwidGl0bGVDb2xvciI6IiMzMzMiLCJlZGdlTGFiZWxCYWNrZ3JvdW5kIjoiI2U4ZThlOCIsImFjdG9yQm9yZGVyIjoiaHNsKDI1OS42MjYxNjgyMjQzLCA1OS43NzY1MzYzMTI4JSwgODcuOTAxOTYwNzg0MyUpIiwiYWN0b3JCa2ciOiIjRUNFQ0ZGIiwiYWN0b3JUZXh0Q29sb3IiOiJibGFjayIsImFjdG9yTGluZUNvbG9yIjoiZ3JleSIsInNpZ25hbENvbG9yIjoiIzMzMyIsInNpZ25hbFRleHRDb2xvciI6IiMzMzMiLCJsYWJlbEJveEJrZ0NvbG9yIjoiI0VDRUNGRiIsImxhYmVsQm94Qm9yZGVyQ29sb3IiOiJoc2woMjU5LjYyNjE2ODIyNDMsIDU5Ljc3NjUzNjMxMjglLCA4Ny45MDE5NjA3ODQzJSkiLCJsYWJlbFRleHRDb2xvciI6ImJsYWNrIiwibG9vcFRleHRDb2xvciI6ImJsYWNrIiwibm90ZUJvcmRlckNvbG9yIjoiI2FhYWEzMyIsIm5vdGVCa2dDb2xvciI6IiNmZmY1YWQiLCJub3RlVGV4dENvbG9yIjoiYmxhY2siLCJhY3RpdmF0aW9uQm9yZGVyQ29sb3IiOiIjNjY2IiwiYWN0aXZhdGlvbkJrZ0NvbG9yIjoiI2Y0ZjRmNCIsInNlcXVlbmNlTnVtYmVyQ29sb3IiOiJ3aGl0ZSIsInNlY3Rpb25Ca2dDb2xvciI6InJnYmEoMTAyLCAxMDIsIDI1NSwgMC40OSkiLCJhbHRTZWN0aW9uQmtnQ29sb3IiOiJ3aGl0ZSIsInNlY3Rpb25Ca2dDb2xvcjIiOiIjZmZmNDAwIiwidGFza0JvcmRlckNvbG9yIjoiIzUzNGZiYyIsInRhc2tCa2dDb2xvciI6IiM4YTkwZGQiLCJ0YXNrVGV4dExpZ2h0Q29sb3IiOiJ3aGl0ZSIsInRhc2tUZXh0Q29sb3IiOiJ3aGl0ZSIsInRhc2tUZXh0RGFya0NvbG9yIjoiYmxhY2siLCJ0YXNrVGV4dE91dHNpZGVDb2xvciI6ImJsYWNrIiwidGFza1RleHRDbGlja2FibGVDb2xvciI6IiMwMDMxNjMiLCJhY3RpdmVUYXNrQm9yZGVyQ29sb3IiOiIjNTM0ZmJjIiwiYWN0aXZlVGFza0JrZ0NvbG9yIjoiI2JmYzdmZiIsImdyaWRDb2xvciI6ImxpZ2h0Z3JleSIsImRvbmVUYXNrQmtnQ29sb3IiOiJsaWdodGdyZXkiLCJkb25lVGFza0JvcmRlckNvbG9yIjoiZ3JleSIsImNyaXRCb3JkZXJDb2xvciI6IiNmZjg4ODgiLCJjcml0QmtnQ29sb3IiOiJyZWQiLCJ0b2RheUxpbmVDb2xvciI6InJlZCIsImxhYmVsQ29sb3IiOiJibGFjayIsImVycm9yQmtnQ29sb3IiOiIjNTUyMjIyIiwiZXJyb3JUZXh0Q29sb3IiOiIjNTUyMjIyIiwiY2xhc3NUZXh0IjoiIzEzMTMwMCIsImZpbGxUeXBlMCI6IiNFQ0VDRkYiLCJmaWxsVHlwZTEiOiIjZmZmZmRlIiwiZmlsbFR5cGUyIjoiaHNsKDMwNCwgMTAwJSwgOTYuMjc0NTA5ODAzOSUpIiwiZmlsbFR5cGUzIjoiaHNsKDEyNCwgMTAwJSwgOTMuNTI5NDExNzY0NyUpIiwiZmlsbFR5cGU0IjoiaHNsKDE3NiwgMTAwJSwgOTYuMjc0NTA5ODAzOSUpIiwiZmlsbFR5cGU1IjoiaHNsKC00LCAxMDAlLCA5My41Mjk0MTE3NjQ3JSkiLCJmaWxsVHlwZTYiOiJoc2woOCwgMTAwJSwgOTYuMjc0NTA5ODAzOSUpIiwiZmlsbFR5cGU3IjoiaHNsKDE4OCwgMTAwJSwgOTMuNTI5NDExNzY0NyUpIn19LCJ1cGRhdGVFZGl0b3IiOmZhbHNlfQ) /// /// - If the [`RestoreSettings.enable`] flag is set to `false`, the current coordinator /// state will be reset and a new [`StateMachine`] is created with the given settings. /// - If no coordinator state exists, the current coordinator state will be reset and a new /// [`StateMachine`] is created with the given settings. /// - If a coordinator state exists but no global model has been created so far, the /// [`StateMachine`] will be restored with the coordinator state but without a global model. /// - If a coordinator state and a global model exists, the [`StateMachine`] will be restored /// with the coordinator state and the global model. /// - If a global model has been created but does not exists, the initialization will fail with /// [`StateMachineInitializationError::GlobalModelUnavailable`]. /// - If a global model exists but its properties do not match the coordinator model settings, /// the initialization will fail with [`StateMachineInitializationError::GlobalModelInvalid`]. /// - Any network error will cause the initialization to fail. pub async fn init( mut self, ) -> StateMachineInitializationResult<(StateMachine, RequestSender, EventSubscriber)> { // crucial: init must be called before anything else in this module sodiumoxide::init().or(Err(StateMachineInitializationError::CryptoInit))?; let (coordinator_state, global_model) = if self.restore_settings.enable { self.from_previous_state().await? } else { info!("restoring coordinator state is disabled"); info!("initialize state machine from settings"); self.from_settings().await? }; Ok(self.init_state_machine(coordinator_state, global_model)) } // see [`StateMachineInitializer::init`] async fn from_previous_state( &mut self, ) -> StateMachineInitializationResult<(CoordinatorState, ModelUpdate)> { let (coordinator_state, global_model) = if let Some(coordinator_state) = self .store .coordinator_state() .await .map_err(StateMachineInitializationError::FetchCoordinatorState)? { self.try_restore_state(coordinator_state).await? } else { // no coordinator state available seems to be a fresh start self.from_settings().await? }; Ok((coordinator_state, global_model)) } // see [`StateMachineInitializer::init`] async fn try_restore_state( &mut self, coordinator_state: CoordinatorState, ) -> StateMachineInitializationResult<(CoordinatorState, ModelUpdate)> { let global_model_id = match self .store .latest_global_model_id() .await .map_err(StateMachineInitializationError::FetchLatestGlobalModelId)? { // the state machine was shut down before completing a round // we cannot use the round_id here because we increment the round_id after each restart // that means even if the round id is larger than one, it doesn't mean that a // round has ever been completed None => { debug!("apparently no round has been completed yet"); debug!("restore coordinator without a global model"); return Ok((coordinator_state, ModelUpdate::Invalidate)); } Some(global_model_id) => global_model_id, }; let global_model = self .load_global_model(&coordinator_state, &global_model_id) .await?; debug!( "restore coordinator with global model id: {}", global_model_id ); Ok(( coordinator_state, ModelUpdate::New(std::sync::Arc::new(global_model)), )) } // Loads a global model and checks its properties for suitability. async fn load_global_model( &mut self, coordinator_state: &CoordinatorState, global_model_id: &str, ) -> StateMachineInitializationResult { match self .store .global_model(global_model_id) .await .map_err(StateMachineInitializationError::FetchGlobalModel)? { Some(global_model) => { if Self::model_properties_matches_settings(coordinator_state, &global_model) { Ok(global_model) } else { let error_msg = format!( "the length of global model with the id {} does not match with the value of the model length setting {} != {}", &global_model_id, global_model.len(), coordinator_state.round_params.model_length); Err(StateMachineInitializationError::GlobalModelInvalid( error_msg, )) } } None => { // the model id exists but we cannot find it in the model store // here we better fail because if we restart a coordinator with an empty model // the clients will throw away their current global model and start from scratch Err(StateMachineInitializationError::GlobalModelUnavailable( format!("cannot find global model {}", &global_model_id), )) } } } // Checks whether the properties of the loaded global model match the current // model settings of the coordinator. fn model_properties_matches_settings( coordinator_state: &CoordinatorState, global_model: &Model, ) -> bool { coordinator_state.round_params.model_length == global_model.len() } } ================================================ FILE: rust/xaynet-server/src/state_machine/mod.rs ================================================ //! The state machine that controls the execution of the PET protocol. //! //! # Overview //! //! ![State Machine](https://mermaid.ink/svg/eyJjb2RlIjoic3RhdGVEaWFncmFtXG5cdFsqXSAtLT4gSWRsZVxuXG4gICAgSWRsZSAtLT4gU3VtXG4gICAgU3VtIC0tPiBVcGRhdGVcbiAgICBVcGRhdGUgLS0-IFN1bTJcbiAgICBTdW0yIC0tPiBVbm1hc2tcbiAgICBVbm1hc2sgLS0-IElkbGVcblxuICAgIFN1bSAtLT4gRmFpbHVyZVxuICAgIFVwZGF0ZSAtLT4gRmFpbHVyZVxuICAgIFN1bTIgLS0-IEZhaWx1cmVcbiAgICBVbm1hc2sgLS0-IEZhaWx1cmVcbiAgICBGYWlsdXJlIC0tPiBJZGxlXG4gICAgRmFpbHVyZSAtLT4gU2h1dGRvd25cblxuICAgIFNodXRkb3duIC0tPiBbKl1cbiIsIm1lcm1haWQiOnsidGhlbWUiOiJuZXV0cmFsIn0sInVwZGF0ZUVkaXRvciI6ZmFsc2V9) //! //! The [`StateMachine`] is responsible for executing the individual tasks of the PET protocol. //! The main tasks include: building the sum and seed dictionaries, aggregating the masked //! models, determining the applicable mask and unmasking the global masked model. //! //! Furthermore, the [`StateMachine`] publishes protocol events and handles protocol errors. //! //! The [`StateMachine`] as well as the PET settings can be configured in the config file. //! See [here][settings] for more details. //! //! # Phase states //! //! **Idle** //! //! Publishes [`PhaseName::Idle`] and increments the `round_id` by `1`. Invalidates the [`SumDict`], //! [`SeedDict`], `scalar` and `mask length`. Updates the [`EncryptKeyPair`], `probabilities` for //! the tasks and the `seed`. Publishes the [`EncryptKeyPair`] and the [`RoundParameters`]. //! //! **Sum** //! //! Publishes [`PhaseName::Sum`], builds and publishes the [`SumDict`], ensures that enough sum //! messages have been submitted and initializes the [`SeedDict`]. //! //! **Update** //! //! Publishes [`PhaseName::Update`], builds and publishes the [`SeedDict`], ensures that enough //! update messages have been submitted and aggregates the masked model. //! //! **Sum2** //! //! Publishes [`PhaseName::Sum2`], builds the mask dictionary, ensures that enough sum2 //! messages have been submitted and determines the applicable mask for unmasking the global //! masked model. //! //! **Unmask** //! //! Publishes [`PhaseName::Unmask`], unmasks the global masked model and publishes the global //! model. //! //! **Failure** //! //! Publishes [`PhaseName::Failure`] and handles [`PhaseError`]s that can occur during the //! execution of the [`StateMachine`]. In most cases, the error is handled by restarting the round. //! However, if a [`PhaseError::RequestChannel`] occurs, the [`StateMachine`] will shut down. //! //! **Shutdown** //! //! Publishes [`PhaseName::Shutdown`] and shuts down the [`StateMachine`]. During the shutdown, //! the [`StateMachine`] performs a clean shutdown of the [Request][requests] channel by //! closing it and consuming all remaining messages. //! //! # Requests //! //! By initiating a new [`StateMachine`] via [`StateMachineInitializer::init()`], a new //! [StateMachineRequest][requests] channel is created, the function of which is to send //! [`StateMachineRequest`]s to the [`StateMachine`]. The sender half of that channel //! ([`RequestSender`]) is returned back to the caller of //! [`StateMachineInitializer::init()`], whereas the receiver half ([`RequestReceiver`]) //! is used by the [`StateMachine`]. //! //! See [here][requests] for more details. //! //! # Events //! //! During the execution of the PET protocol, the [`StateMachine`] will publish various events //! (see Phase states). Everyone who is interested in the events can subscribe to the respective //! events via the [`EventSubscriber`]. An [`EventSubscriber`] is automatically created when a new //! [`StateMachine`] is created through [`StateMachineInitializer::init()`]. //! //! See [here][events] for more details. //! //! [settings]: crate::settings //! [`PhaseName::Idle`]: crate::state_machine::phases::PhaseName::Idle //! [`PhaseName::Sum`]: crate::state_machine::phases::PhaseName::Sum //! [`PhaseName::Update`]: crate::state_machine::phases::PhaseName::Update //! [`PhaseName::Sum2`]: crate::state_machine::phases::PhaseName::Sum2 //! [`PhaseName::Unmask`]: crate::state_machine::phases::PhaseName::Unmask //! [`PhaseName::Failure`]: crate::state_machine::phases::PhaseName::Failure //! [`PhaseName::Shutdown`]: crate::state_machine::phases::PhaseName::Shutdown //! [`PhaseError`]: crate::state_machine::phases::PhaseError //! [`PhaseError::RequestChannel`]: crate::state_machine::phases::PhaseError::RequestChannel //! [`SumDict`]: xaynet_core::SumDict //! [`SeedDict`]: xaynet_core::SeedDict //! [`EncryptKeyPair`]: xaynet_core::crypto::EncryptKeyPair //! [`RoundParameters`]: xaynet_core::common::RoundParameters //! [`StateMachineInitializer::init()`]: crate::state_machine::initializer::StateMachineInitializer::init //! [`StateMachineRequest`]: crate::state_machine::requests::StateMachineRequest //! [requests]: crate::state_machine::requests //! [`RequestSender`]: crate::state_machine::requests::RequestSender //! [`RequestReceiver`]: crate::state_machine::requests::RequestReceiver //! [events]: crate::state_machine::events //! [`EventSubscriber`]: crate::state_machine::events::EventSubscriber pub mod coordinator; pub mod events; pub mod initializer; pub mod phases; pub mod requests; use derive_more::From; use crate::{ state_machine::phases::{ Failure, Idle, Phase, PhaseState, Shutdown, Sum, Sum2, Unmask, Update, }, storage::Storage, }; /// The state machine with all its states. #[derive(From)] pub enum StateMachine { /// The [`Idle`] phase. Idle(PhaseState), /// The [`Sum`] phase. Sum(PhaseState), /// The [`Update`] phase. Update(PhaseState), /// The [`Sum2`] phase. Sum2(PhaseState), /// The [`Unmask`] phase. Unmask(PhaseState), /// The [`Failure`] phase. Failure(PhaseState), /// The [`Shutdown`] phase. Shutdown(PhaseState), } impl StateMachine where T: Storage, PhaseState: Phase, PhaseState: Phase, PhaseState: Phase, PhaseState: Phase, PhaseState: Phase, PhaseState: Phase, PhaseState: Phase, { /// Moves the [`StateMachine`] to the next state and consumes the current one. /// /// Returns the next state or `None` if the [`StateMachine`] reached the state [`Shutdown`]. pub async fn next(self) -> Option { match self { StateMachine::Idle(state) => state.run_phase().await, StateMachine::Sum(state) => state.run_phase().await, StateMachine::Update(state) => state.run_phase().await, StateMachine::Sum2(state) => state.run_phase().await, StateMachine::Unmask(state) => state.run_phase().await, StateMachine::Failure(state) => state.run_phase().await, StateMachine::Shutdown(state) => state.run_phase().await, } } /// Runs the state machine until it shuts down. /// /// The [`StateMachine`] shuts down once all [`RequestSender`] have been dropped. /// /// [`RequestSender`]: crate::state_machine::requests::RequestSender pub async fn run(mut self) -> Option<()> { loop { self = self.next().await?; } } } /// Records a message accepted metric. #[doc(hidden)] #[macro_export] macro_rules! accepted { ($round_id: expr, $phase: expr $(,)?) => { crate::metric!( crate::metrics::Measurement::MessageAccepted, 1, ("round_id", $round_id), ("phase", $phase as u8), ); }; } /// Records a message rejected metric. #[doc(hidden)] #[macro_export] macro_rules! rejected { ($round_id: expr, $phase: expr $(,)?) => { crate::metric!( crate::metrics::Measurement::MessageRejected, 1, ("round_id", $round_id), ("phase", $phase as u8), ); }; } /// Records a message discarded metric. #[doc(hidden)] #[macro_export] macro_rules! discarded { ($round_id: expr, $phase: expr $(,)?) => { crate::metric!( crate::metrics::Measurement::MessageDiscarded, 1, ("round_id", $round_id), ("phase", $phase as u8), ); }; } #[cfg(test)] pub(crate) mod tests; ================================================ FILE: rust/xaynet-server/src/state_machine/phases/failure.rs ================================================ use std::time::Duration; use async_trait::async_trait; use displaydoc::Display; use thiserror::Error; use tokio::time::sleep; use tracing::{error, info}; use crate::{ event, state_machine::{ events::DictionaryUpdate, phases::{ Idle, IdleError, Phase, PhaseName, PhaseState, Shared, Shutdown, SumError, UnmaskError, UpdateError, }, StateMachine, }, storage::Storage, }; /// Errors which can occur during the execution of the [`StateMachine`]. #[derive(Debug, Display, Error)] pub enum PhaseError { /// Request channel error: {0}. RequestChannel(&'static str), /// Phase timeout. PhaseTimeout(#[from] tokio::time::error::Elapsed), /// Idle phase failed: {0}. Idle(#[from] IdleError), /// Sum phase failed: {0}. Sum(#[from] SumError), /// Update phase failed: {0}. Update(#[from] UpdateError), /// Unmask phase failed: {0}. Unmask(#[from] UnmaskError), } /// The failure state. #[derive(Debug)] pub struct Failure { pub(in crate::state_machine) error: PhaseError, } #[async_trait] impl Phase for PhaseState where T: Storage, { const NAME: PhaseName = PhaseName::Failure; async fn process(&mut self) -> Result<(), PhaseError> { error!("phase state error: {}", self.private.error); event!("Phase error", self.private.error.to_string()); Ok(()) } fn broadcast(&mut self) { info!("broadcasting invalidation of sum dictionary"); self.shared .events .broadcast_sum_dict(DictionaryUpdate::Invalidate); info!("broadcasting invalidation of seed dictionary"); self.shared .events .broadcast_seed_dict(DictionaryUpdate::Invalidate); } async fn next(mut self) -> Option> { if let PhaseError::RequestChannel(_) = self.private.error { Some(PhaseState::::new(self.shared).into()) } else { self.wait_for_store_readiness().await; Some(PhaseState::::new(self.shared).into()) } } } impl PhaseState { /// Creates a new error phase. pub fn new(shared: Shared, error: PhaseError) -> Self { Self { private: Failure { error }, shared, } } } impl PhaseState where T: Storage, { /// Waits until the [`Store`] is ready. /// /// [`Store`]: crate::storage::Store async fn wait_for_store_readiness(&mut self) { while let Err(err) = ::is_ready(&mut self.shared.store).await { error!("store not ready: {}", err); info!("try again in 5 sec"); sleep(Duration::from_secs(5)).await; } } } #[cfg(test)] mod tests { use std::sync::Arc; use super::*; use anyhow::anyhow; use tokio::time::{timeout, Duration, Instant}; use xaynet_core::{SeedDict, SumDict}; use crate::{ state_machine::{ coordinator::CoordinatorState, events::{EventPublisher, EventSubscriber, ModelUpdate}, tests::{ utils::{enable_logging, init_shared, EventSnapshot}, CoordinatorStateBuilder, EventBusBuilder, }, }, storage::{ tests::{utils::create_global_model, MockCoordinatorStore, MockModelStore}, Store, }, }; fn state_and_events_from_sum2_phase() -> (CoordinatorState, EventPublisher, EventSubscriber) { let state = CoordinatorStateBuilder::new().build(); let (event_publisher, event_subscriber) = EventBusBuilder::new(&state) .broadcast_phase(PhaseName::Sum2) .broadcast_sum_dict(DictionaryUpdate::New(Arc::new(SumDict::new()))) .broadcast_seed_dict(DictionaryUpdate::New(Arc::new(SeedDict::new()))) .broadcast_model(ModelUpdate::New(Arc::new(create_global_model(1)))) .build(); (state, event_publisher, event_subscriber) } #[tokio::test] async fn error_to_idle_phase() { // No Storage errors // // What should happen: // 1. broadcast Error phase // 2. broadcast invalidation of sum and seed dict // 3. check if store is ready to process requests // 4. move into idle phase // // What should not happen: // - the shared state has been changed // (except for`round_id` when moving into idle phase) // - events have been broadcasted (except phase event and invalidation // event of sum and seed dict) enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_is_ready().return_once(move || Ok(())); let mut ms = MockModelStore::new(); ms.expect_is_ready().return_once(move || Ok(())); let store = Store::new(cs, ms); let (state, event_publisher, event_subscriber) = state_and_events_from_sum2_phase(); let events_before_error = EventSnapshot::from(&event_subscriber); let state_before_error = state.clone(); let (shared, _request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new( shared, PhaseError::Idle(IdleError::DeleteDictionaries(anyhow!(""))), )); assert!(state_machine.is_failure()); let state_machine = state_machine.next().await.unwrap(); let state_after_error = state_machine.as_ref().clone(); // round id is updated in idle phase assert_ne!(state_after_error.round_id, state_before_error.round_id); assert_eq!( state_after_error.round_params, state_before_error.round_params ); assert_eq!(state_after_error.keys, state_before_error.keys); assert_eq!(state_after_error.sum, state_before_error.sum); assert_eq!(state_after_error.update, state_before_error.update); assert_eq!(state_after_error.sum2, state_before_error.sum2); let events_after_error = EventSnapshot::from(&event_subscriber); assert_ne!(events_after_error.phase, events_before_error.phase); assert_eq!(events_after_error.keys, events_before_error.keys); assert_eq!(events_after_error.params, events_before_error.params); assert_eq!( events_after_error.sum_dict.event, DictionaryUpdate::Invalidate ); assert_eq!( events_after_error.seed_dict.event, DictionaryUpdate::Invalidate ); assert_eq!(events_after_error.model, events_before_error.model); assert_eq!(events_after_error.phase.event, PhaseName::Failure); assert!(state_machine.is_idle()); } #[tokio::test] async fn test_error_to_shutdown_phase() { // No Storage errors // // What should happen: // 1. broadcast Error phase // 2. broadcast invalidation of sum and seed dict // 3. previous phase failed with Failure::RequestChannel // which means that the state machine should be shut down // 4. move into shutdown phase // // What should not happen: // - the shared state has been changed // - events have been broadcasted (except phase event and invalidation // event of sum and seed dict) enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_is_ready().return_once(move || Ok(())); let mut ms = MockModelStore::new(); ms.expect_is_ready().return_once(move || Ok(())); let store = Store::new(cs, ms); let (state, event_publisher, event_subscriber) = state_and_events_from_sum2_phase(); let events_before_error = EventSnapshot::from(&event_subscriber); let state_before_error = state.clone(); let (shared, _request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new( shared, PhaseError::RequestChannel(""), )); assert!(state_machine.is_failure()); let state_machine = state_machine.next().await.unwrap(); let state_after_error = state_machine.as_ref().clone(); assert_eq!(state_after_error, state_before_error); let events_after_error = EventSnapshot::from(&event_subscriber); assert_ne!(events_after_error.phase, events_before_error.phase); assert_eq!(events_after_error.keys, events_before_error.keys); assert_eq!(events_after_error.params, events_before_error.params); assert_eq!( events_after_error.sum_dict.event, DictionaryUpdate::Invalidate ); assert_eq!( events_after_error.seed_dict.event, DictionaryUpdate::Invalidate ); assert_eq!(events_after_error.model, events_before_error.model); assert_eq!(events_after_error.phase.event, PhaseName::Failure); assert!(state_machine.is_shutdown()); } #[tokio::test] async fn test_error_to_idle_store_failed() { // Storage error: // - first call on `is_ready` the coordinator store and model store fails // - second call on `is_ready` the coordinator store fails and model store passes // - third call on `is_ready` the coordinator store passes and model store fails // - forth call on `is_ready` the coordinator store and model store passes // // What should happen: // 1. broadcast Error phase // 2. broadcast invalidation of sum and seed dict // 3. check if store is ready to process requests // 4. wait until store is ready again (15 sec) // 5. move into idle phase // // What should not happen: // - the shared state has been changed // (except for`round_id` when moving into idle phase) // - events have been broadcasted (except phase event and invalidation // event of sum and seed dict) enable_logging(); let mut cs = MockCoordinatorStore::new(); let mut cs_counter = 0; cs.expect_is_ready().returning(move || { let res = match cs_counter { 0 => Err(anyhow!("")), 1 => Err(anyhow!("")), 2 => Ok(()), 3 => Ok(()), _ => panic!(""), }; cs_counter += 1; res }); let mut ms = MockModelStore::new(); let mut ms_counter = 0; ms.expect_is_ready().returning(move || { let res = match ms_counter { // we skip step 1 and 2 because Storage::is_ready does not call // MockModelStore::is_ready if MockCoordinatorStore::is_ready // has already failed 0 => Err(anyhow!("")), 1 => Ok(()), _ => panic!(""), }; ms_counter += 1; res }); let store = Store::new(cs, ms); let state = CoordinatorStateBuilder::new().build(); let (event_publisher, _event_subscriber) = EventBusBuilder::new(&state).build(); let (shared, _request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new( shared, PhaseError::Idle(IdleError::DeleteDictionaries(anyhow!(""))), )); assert!(state_machine.is_failure()); let now = Instant::now(); let state_machine = timeout(Duration::from_secs(20), state_machine.next()) .await .unwrap() .unwrap(); assert!(now.elapsed().as_secs() > 14); assert!(state_machine.is_idle()); } #[tokio::test] async fn test_error_to_shutdown_skip_store_readiness_check() { // Storage error: // // What should happen: // 1. broadcast Error phase // 2. broadcast invalidation of sum and seed dict // 3. previous phase failed with Failure::RequestChannel // which means that the state machine should be shut down // 4. skip store readiness check // 5. move into shutdown phase // // What should not happen: // - wait for the store to be ready again // - the shared state has been changed // - events have been broadcasted (except phase event and invalidation // event of sum and seed dict) enable_logging(); let store = Store::new(MockCoordinatorStore::new(), MockModelStore::new()); let state = CoordinatorStateBuilder::new().build(); let (event_publisher, _event_subscriber) = EventBusBuilder::new(&state).build(); let (shared, _request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new( shared, PhaseError::RequestChannel(""), )); assert!(state_machine.is_failure()); let state_machine = timeout(Duration::from_secs(5), state_machine.next()) .await .unwrap() .unwrap(); assert!(state_machine.is_shutdown()); } } ================================================ FILE: rust/xaynet-server/src/state_machine/phases/handler.rs ================================================ use async_trait::async_trait; use tokio::time::{timeout, Duration}; use tracing::{debug, info, Span}; use crate::{ accepted, discarded, rejected, state_machine::{ coordinator::{CountParameters, PhaseParameters}, phases::{Phase, PhaseError, PhaseState}, requests::{RequestError, ResponseSender, StateMachineRequest}, }, storage::Storage, }; /// A trait that must be implemented by a state to handle a request. #[async_trait] pub trait Handler { /// Handles a request. /// /// # Errors /// Fails on PET and storage errors. async fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), RequestError>; } /// A counter to keep track of handled messages. struct Counter { /// The minimal number of successfully processed messages. min: u64, /// The maximal number of successfully processed messages. max: u64, /// The number of messages successfully processed. accepted: u64, /// The number of messages failed to processed. rejected: u64, /// The number of messages discarded without being processed. discarded: u64, } impl AsMut for Counter { fn as_mut(&mut self) -> &mut Self { self } } impl Counter { /// Creates a new message counter. fn new(CountParameters { min, max }: CountParameters) -> Self { Self { min, max, accepted: 0, rejected: 0, discarded: 0, } } /// Checks whether enough requests have been processed successfully wrt the PET settings. fn has_enough_messages(&self) -> bool { self.accepted >= self.min } /// Checks whether too many requests are processed wrt the PET settings. fn has_overmuch_messages(&self) -> bool { self.accepted >= self.max } /// Increments the counter for accepted requests. fn increment_accepted(&mut self) { self.accepted += 1; debug!( "{} messages accepted (min {} and max {} required)", self.accepted, self.min, self.max, ); } /// Increments the counter for rejected requests. fn increment_rejected(&mut self) { self.rejected += 1; debug!("{} messages rejected", self.rejected); } /// Increments the counter for discarded requests. fn increment_discarded(&mut self) { self.discarded += 1; debug!("{} messages discarded", self.discarded); } } impl PhaseState where T: Storage, Self: Phase + Handler, { /// Processes requests wrt the phase parameters. /// /// - Processes at most `count.max` requests during the time interval `[now, now + time.min]`. /// - Processes requests until there are enough (ie `count.min`) for the time interval /// `[now + time.min, now + time.max]`. /// - Aborts if either all connections were dropped or not enough requests were processed until /// timeout. pub(super) async fn process( &mut self, PhaseParameters { count, time }: PhaseParameters, ) -> Result<(), PhaseError> { let mut counter = Counter::new(count); info!("processing requests"); debug!( "processing for min {} and max {} seconds", time.min, time.max ); self.process_during(Duration::from_secs(time.min), counter.as_mut()) .await?; let time_left = time.max - time.min; timeout( Duration::from_secs(time_left), self.process_until_enough(counter.as_mut()), ) .await??; info!( "in total {} messages accepted (min {} and max {} required)", counter.accepted, counter.min, counter.max, ); info!("in total {} messages rejected", counter.rejected); info!( "in total {} messages discarded (purged not included)", counter.discarded, ); Ok(()) } /// Processes requests for as long as the given duration. async fn process_during( &mut self, dur: tokio::time::Duration, counter: &mut Counter, ) -> Result<(), PhaseError> { let deadline = tokio::time::sleep(dur); tokio::pin!(deadline); loop { tokio::select! { biased; _ = &mut deadline => { debug!("duration elapsed"); break Ok(()); } next = self.next_request() => { let (req, span, resp_tx) = next?; self.process_single(req, span, resp_tx, counter).await; } } } } /// Processes requests until there are enough. async fn process_until_enough(&mut self, counter: &mut Counter) -> Result<(), PhaseError> { while !counter.has_enough_messages() { let (req, span, resp_tx) = self.next_request().await?; self.process_single(req, span, resp_tx, counter).await; } Ok(()) } /// Processes a single request. /// /// The request is discarded if the maximum message count is reached, accepted if processed /// successfully and rejected otherwise. async fn process_single( &mut self, req: StateMachineRequest, span: Span, resp_tx: ResponseSender, counter: &mut Counter, ) { let _span_guard = span.enter(); let response = if counter.has_overmuch_messages() { counter.increment_discarded(); discarded!(self.shared.state.round_id, Self::NAME); Err(RequestError::MessageDiscarded) } else { let response = self.handle_request(req).await; if response.is_ok() { counter.increment_accepted(); accepted!(self.shared.state.round_id, Self::NAME); } else { counter.increment_rejected(); rejected!(self.shared.state.round_id, Self::NAME); } response }; // This may error out if the receiver has already been dropped but it doesn't matter for us. let _ = resp_tx.send(response); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_counter() { // 0 accepted let mut counter = Counter::new(CountParameters { min: 1, max: 3 }); assert!(!counter.has_enough_messages()); assert!(!counter.has_overmuch_messages()); // 1 accepted counter.increment_accepted(); assert!(counter.has_enough_messages()); assert!(!counter.has_overmuch_messages()); // 2 accepted counter.increment_accepted(); assert!(counter.has_enough_messages()); assert!(!counter.has_overmuch_messages()); // 3 accepted counter.increment_accepted(); assert!(counter.has_enough_messages()); assert!(counter.has_overmuch_messages()); } } ================================================ FILE: rust/xaynet-server/src/state_machine/phases/idle.rs ================================================ use async_trait::async_trait; use displaydoc::Display; use sodiumoxide::crypto::hash::sha256; use thiserror::Error; use tracing::{debug, info, warn}; use crate::{ metric, metrics::Measurement, state_machine::{ phases::{Phase, PhaseError, PhaseName, PhaseState, Shared, Sum}, StateMachine, }, storage::{Storage, StorageError}, }; use xaynet_core::{ common::RoundSeed, crypto::{ByteObject, EncryptKeyPair, SigningKeySeed}, }; /// Errors which can occur during the idle phase. #[derive(Debug, Display, Error)] pub enum IdleError { /// Setting the coordinator state failed: {0}. SetCoordinatorState(StorageError), /// Deleting the dictionaries failed: {0}. DeleteDictionaries(StorageError), } /// The idle state. #[derive(Debug)] pub struct Idle; #[async_trait] impl Phase for PhaseState where T: Storage, { const NAME: PhaseName = PhaseName::Idle; async fn process(&mut self) -> Result<(), PhaseError> { self.delete_dicts().await?; self.gen_round_keypair(); self.update_round_probabilities(); self.update_round_seed(); self.set_coordinator_state().await?; Ok(()) } fn broadcast(&mut self) { self.broadcast_keys(); self.broadcast_params(); self.broadcast_metrics(); } async fn next(self) -> Option> { Some(PhaseState::::new(self.shared).into()) } } impl PhaseState { /// Creates a new idle state. pub fn new(mut shared: Shared) -> Self { // Since some events are emitted very early, the round id must // be correct when the idle phase starts. Therefore, we update // it here, when instantiating the idle PhaseState. shared.set_round_id(shared.round_id() + 1); debug!("new round ID = {}", shared.round_id()); Self { private: Idle, shared, } } /// Updates the participant probabilities round parameters. fn update_round_probabilities(&mut self) { info!("updating round probabilities"); warn!("round probabilities stay constant, no update strategy implemented yet"); } /// Updates the seed round parameter. fn update_round_seed(&mut self) { info!("updating round seed"); // Safe unwrap: `sk` and `seed` have same number of bytes let (_, sk) = SigningKeySeed::from_slice_unchecked(self.shared.state.keys.secret.as_slice()) .derive_signing_key_pair(); let signature = sk.sign_detached( &[ self.shared.state.round_params.seed.as_slice(), &self.shared.state.round_params.sum.to_le_bytes(), &self.shared.state.round_params.update.to_le_bytes(), ] .concat(), ); // Safe unwrap: the length of the hash is 32 bytes self.shared.state.round_params.seed = RoundSeed::from_slice_unchecked(sha256::hash(signature.as_slice()).as_ref()); } /// Generates fresh round credentials. fn gen_round_keypair(&mut self) { info!("updating the keys"); self.shared.state.keys = EncryptKeyPair::generate(); self.shared.state.round_params.pk = self.shared.state.keys.public; } /// Broadcasts the keys. fn broadcast_keys(&mut self) { info!("broadcasting new keys"); self.shared .events .broadcast_keys(self.shared.state.keys.clone()); } /// Broadcasts the round parameters. fn broadcast_params(&mut self) { info!("broadcasting new round parameters"); self.shared .events .broadcast_params(self.shared.state.round_params.clone()); } } impl PhaseState where T: Storage, { /// Deletes the dicts from the store. async fn delete_dicts(&mut self) -> Result<(), IdleError> { info!("removing phase dictionaries from previous round"); self.shared .store .delete_dicts() .await .map_err(IdleError::DeleteDictionaries) } /// Persists the coordinator state to the store. async fn set_coordinator_state(&mut self) -> Result<(), IdleError> { info!("storing new coordinator state"); self.shared .store .set_coordinator_state(&self.shared.state) .await .map_err(IdleError::SetCoordinatorState) } } impl PhaseState where T: Storage, Self: Phase, { /// Broadcasts idle phase metrics. fn broadcast_metrics(&self) { metric!(Measurement::RoundTotalNumber, self.shared.state.round_id); metric!( Measurement::RoundParamSum, self.shared.state.round_params.sum, ("round_id", self.shared.state.round_id), ("phase", Self::NAME as u8), ); metric!( Measurement::RoundParamUpdate, self.shared.state.round_params.update, ("round_id", self.shared.state.round_id), ("phase", Self::NAME as u8), ); } } #[cfg(test)] mod tests { use super::*; use std::sync::Arc; use anyhow::anyhow; use xaynet_core::common::RoundParameters; use crate::{ state_machine::{ coordinator::CoordinatorState, events::{DictionaryUpdate, EventPublisher, EventSubscriber, ModelUpdate}, tests::{ utils::{assert_event_updated_with_id, enable_logging, init_shared, EventSnapshot}, CoordinatorStateBuilder, EventBusBuilder, }, }, storage::{ tests::{utils::create_global_model, MockCoordinatorStore, MockModelStore}, Store, }, }; fn state_and_events_from_unmask_phase() -> (CoordinatorState, EventPublisher, EventSubscriber) { let state = CoordinatorStateBuilder::new().build(); let (event_publisher, event_subscriber) = EventBusBuilder::new(&state) .broadcast_phase(PhaseName::Unmask) .broadcast_sum_dict(DictionaryUpdate::Invalidate) .broadcast_seed_dict(DictionaryUpdate::Invalidate) .broadcast_model(ModelUpdate::New(Arc::new(create_global_model(1)))) .build(); (state, event_publisher, event_subscriber) } fn assert_params(params1: &RoundParameters, params2: &RoundParameters) { assert_ne!(params1.pk, params2.pk); assert_ne!(params1.seed, params2.seed); assert!((params1.sum - params2.sum).abs() <= f64::EPSILON); assert!((params1.update - params2.update).abs() <= f64::EPSILON); assert_eq!(params1.mask_config, params2.mask_config); assert_eq!(params1.model_length, params2.model_length); } fn assert_after_delete_dict_failure( state_before: &CoordinatorState, events_before: &EventSnapshot, state_after: &CoordinatorState, events_after: &EventSnapshot, ) { assert_eq!(state_after.round_params.pk, state_before.round_params.pk); assert_eq!( state_after.round_params.seed, state_before.round_params.seed ); assert!( (state_after.round_params.sum - state_before.round_params.sum).abs() <= f64::EPSILON ); assert!( (state_after.round_params.update - state_before.round_params.update).abs() <= f64::EPSILON ); assert_eq!( state_after.round_params.mask_config, state_before.round_params.mask_config ); assert_eq!( state_after.round_params.model_length, state_before.round_params.model_length ); assert_ne!(state_after.round_id, state_before.round_id); assert_eq!(state_after.keys, state_before.keys); assert_eq!(state_after.sum, state_before.sum); assert_eq!(state_after.update, state_before.update); assert_eq!(state_after.sum2, state_before.sum2); assert_eq!(state_after.keys.public, state_after.round_params.pk); assert_eq!(state_after.round_id, 1); assert_event_updated_with_id(&events_after.phase, &events_before.phase); assert_eq!(events_after.phase.event, PhaseName::Idle); assert_eq!(&events_after.keys, &events_before.keys); assert_eq!(&events_after.sum_dict, &events_before.sum_dict); assert_eq!(&events_after.seed_dict, &events_before.seed_dict); assert_eq!(events_after.params, events_before.params); assert_eq!(events_after.model, events_before.model); } #[tokio::test] async fn test_idle_to_sum_phase() { // No Storage errors // lets pretend we come from the unmask phase // // What should happen: // 1. increase round id by 1 // 2. broadcast Idle phase // 3. delete the sum/seed/mask dict // 4. update coordinator keys // 5. update round thresholds (not implemented yet) // 6. update round seeds // 7. save the new coordinator state // 8. broadcast updated keys // 9. broadcast new round parameters // 10. move into sum phase // // What should not happen: // - the global model has been invalidated enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_delete_dicts().return_once(move || Ok(())); cs.expect_set_coordinator_state() .return_once(move |_| Ok(())); let store = Store::new(cs, MockModelStore::new()); let (state, event_publisher, event_subscriber) = state_and_events_from_unmask_phase(); let events_before_idle = EventSnapshot::from(&event_subscriber); let state_before_idle = state.clone(); let (shared, _request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new(shared)); assert!(state_machine.is_idle()); let state_machine = state_machine.next().await.unwrap(); let state_after_idle = state_machine.as_ref().clone(); assert_params( &state_after_idle.round_params, &state_before_idle.round_params, ); assert_ne!(state_after_idle.keys, state_before_idle.keys); assert_ne!(state_after_idle.round_id, state_before_idle.round_id); assert_eq!(state_after_idle.sum, state_before_idle.sum); assert_eq!(state_after_idle.update, state_before_idle.update); assert_eq!(state_after_idle.sum2, state_before_idle.sum2); assert_eq!( state_after_idle.keys.public, state_after_idle.round_params.pk ); assert_eq!(state_after_idle.round_id, 1); let events_after_idle = EventSnapshot::from(&event_subscriber); assert_event_updated_with_id(&events_after_idle.keys, &events_before_idle.keys); assert_event_updated_with_id(&events_after_idle.params, &events_before_idle.params); assert_event_updated_with_id(&events_after_idle.phase, &events_before_idle.phase); assert_eq!(events_after_idle.phase.event, PhaseName::Idle); assert_eq!(events_after_idle.sum_dict, events_before_idle.sum_dict); assert_eq!(events_after_idle.seed_dict, events_before_idle.seed_dict); assert_eq!(events_after_idle.model, events_before_idle.model); assert!(state_machine.is_sum()); } #[tokio::test] async fn test_idle_to_sum_delete_dicts_failed() { // Storage: // - delete_dicts fails // // What should happen: // 1. increase round id by 1 // 2. broadcast Idle phase // 3. delete the sum/seed/mask dict (fails) // 4. move into error phase // // What should not happen: // - new keys have been broadcasted // - new round parameters have been broadcasted // - the global model has been invalidated // - the state machine has moved into sum phase enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_delete_dicts() .return_once(move || Err(anyhow!(""))); let store = Store::new(cs, MockModelStore::new()); let (state, event_publisher, event_subscriber) = state_and_events_from_unmask_phase(); let events_before_idle = EventSnapshot::from(&event_subscriber); let state_before_idle = state.clone(); let (shared, _request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new(shared)); assert!(state_machine.is_idle()); let state_machine = state_machine.next().await.unwrap(); let state_after_idle = state_machine.as_ref().clone(); let events_after_idle = EventSnapshot::from(&event_subscriber); assert_after_delete_dict_failure( &state_before_idle, &events_before_idle, &state_after_idle, &events_after_idle, ); assert!(state_machine.is_failure()); assert!(matches!( state_machine.into_failure_phase_state().private.error, PhaseError::Idle(IdleError::DeleteDictionaries(_)) )) } #[tokio::test] async fn test_idle_to_sum_save_state_failed() { // Storage: // - set_coordinator_state fails // // What should happen: // 1. increase round id by 1 // 2. broadcast Idle phase // 3. delete the sum/seed/mask dict // 4. update coordinator keys // 5. update round thresholds (not implemented yet) // 6. update round seeds // 7. save the new coordinator state (fails) // 6. broadcast updated keys // 10. move into error phase // // What should not happen: // - new round parameters have been broadcast // - the global model has been invalidated // - the state machine has moved into sum phase enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_delete_dicts().return_once(move || Ok(())); cs.expect_set_coordinator_state() .return_once(move |_| Err(anyhow!(""))); let store = Store::new(cs, MockModelStore::new()); let (state, event_publisher, event_subscriber) = state_and_events_from_unmask_phase(); let events_before_idle = EventSnapshot::from(&event_subscriber); let state_before_idle = state.clone(); let (shared, _request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new(shared)); assert!(state_machine.is_idle()); let state_machine = state_machine.next().await.unwrap(); let state_after_idle = state_machine.as_ref().clone(); let events_after_idle = EventSnapshot::from(&event_subscriber); assert_params( &state_after_idle.round_params, &state_before_idle.round_params, ); assert_ne!(state_after_idle.keys, state_before_idle.keys); assert_ne!(state_after_idle.round_id, state_before_idle.round_id); assert_eq!(state_after_idle.sum, state_before_idle.sum); assert_eq!(state_after_idle.update, state_before_idle.update); assert_eq!(state_after_idle.sum2, state_before_idle.sum2); assert_eq!( state_after_idle.keys.public, state_after_idle.round_params.pk ); assert_eq!(state_after_idle.round_id, 1); assert_event_updated_with_id(&events_after_idle.phase, &events_before_idle.phase); assert_eq!(events_after_idle.phase.event, PhaseName::Idle); assert_eq!(&events_after_idle.keys, &events_before_idle.keys); assert_eq!(&events_after_idle.sum_dict, &events_before_idle.sum_dict); assert_eq!(&events_after_idle.seed_dict, &events_before_idle.seed_dict); assert_eq!(events_after_idle.params, events_before_idle.params); assert_eq!(events_after_idle.model, events_before_idle.model); assert!(state_machine.is_failure()); assert!(matches!( state_machine.into_failure_phase_state().private.error, PhaseError::Idle(IdleError::SetCoordinatorState(_)) )) } } ================================================ FILE: rust/xaynet-server/src/state_machine/phases/mod.rs ================================================ //! This module provides the states (aka phases) of the [`StateMachine`]. //! //! [`StateMachine`]: crate::state_machine::StateMachine mod failure; mod handler; mod idle; mod phase; mod shutdown; mod sum; mod sum2; mod unmask; mod update; pub use self::{ failure::{Failure, PhaseError}, handler::Handler, idle::{Idle, IdleError}, phase::{Phase, PhaseName, PhaseState, Shared}, shutdown::Shutdown, sum::{Sum, SumError}, sum2::Sum2, unmask::{Unmask, UnmaskError}, update::{Update, UpdateError}, }; ================================================ FILE: rust/xaynet-server/src/state_machine/phases/phase.rs ================================================ use std::fmt; use async_trait::async_trait; use derive_more::Display; use futures::StreamExt; use tracing::{debug, error, error_span, info, warn, Span}; use tracing_futures::Instrument; use crate::{ discarded, metric, metrics::Measurement, state_machine::{ coordinator::CoordinatorState, events::EventPublisher, phases::{Failure, PhaseError}, requests::{RequestError, RequestReceiver, ResponseSender, StateMachineRequest}, StateMachine, }, storage::Storage, }; /// The name of the current phase. #[derive(Clone, Copy, Debug, Display, Eq, PartialEq)] pub enum PhaseName { #[display(fmt = "Idle")] Idle, #[display(fmt = "Sum")] Sum, #[display(fmt = "Update")] Update, #[display(fmt = "Sum2")] Sum2, #[display(fmt = "Unmask")] Unmask, #[display(fmt = "Failure")] Failure, #[display(fmt = "Shutdown")] Shutdown, } /// A trait that must be implemented by a state in order to perform its tasks and to move to a next /// state. /// /// See the [module level documentation] for more details. /// /// [module level documentation]: crate::state_machine #[async_trait] pub trait Phase where T: Storage, { /// The name of the current phase. const NAME: PhaseName; /// Performs the tasks of this phase. async fn process(&mut self) -> Result<(), PhaseError>; // TODO: add a filter service in PetMessageHandler that only passes through messages if // the state machine is in one of the Sum, Update or Sum2 phases. then we can add a Purge // phase here which gets broadcasted when the purge starts to prevent further incomming // messages, which means we can split `purge()` from `process()` and use a no-op default impl // for all phases except Sum, Update and Sum. until then we have to have a purge impl in every // phase, which also means that the metrics can be a bit off. /// Broadcasts data of this phase (nothing by default). fn broadcast(&mut self) {} /// Moves from this phase to the next phase. async fn next(self) -> Option>; } /// The coordinator state and the I/O interfaces that are shared and accessible by all /// [`PhaseState`]s. pub struct Shared { /// The coordinator state. pub(in crate::state_machine) state: CoordinatorState, /// The request receiver half. pub(in crate::state_machine) request_rx: RequestReceiver, /// The event publisher. pub(in crate::state_machine) events: EventPublisher, /// The store for storing coordinator and model data. pub(in crate::state_machine) store: T, } impl fmt::Debug for Shared { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Shared") .field("state", &self.state) .field("request_rx", &self.request_rx) .field("events", &self.events) .finish() } } impl Shared { /// Creates a new shared state. pub fn new( coordinator_state: CoordinatorState, publisher: EventPublisher, request_rx: RequestReceiver, store: T, ) -> Self { Self { state: coordinator_state, request_rx, events: publisher, store, } } /// Sets the round ID to the given value. pub fn set_round_id(&mut self, id: u64) { self.state.round_id = id; self.events.set_round_id(id); } /// Returns the current round ID. pub fn round_id(&self) -> u64 { self.state.round_id } } /// The state corresponding to a phase of the PET protocol. /// /// This contains the state-dependent `private` state and the state-independent `shared` state /// which is shared across state transitions. pub struct PhaseState { /// The private state. pub(in crate::state_machine) private: S, /// The shared coordinator state and I/O interfaces. pub(in crate::state_machine) shared: Shared, } impl PhaseState where S: Send, T: Storage, Self: Phase, { /// Runs the current phase to completion. /// /// 1. Performs the phase tasks. /// 2. Purges outdated phase messages. /// 3. Broadcasts the phase data. /// 4. Transitions to the next phase. pub async fn run_phase(mut self) -> Option> { let phase = Self::NAME; let span = error_span!("run_phase", phase = %phase); async move { info!("starting phase"); self.shared.events.broadcast_phase(phase); metric!(Measurement::Phase, phase as u8); if let Err(err) = self.process().await { warn!("failed to perform the phase tasks"); return Some(self.into_failure_state(err)); } info!("phase ran successfully"); if let Err(err) = self.purge_outdated_requests() { warn!("failed to purge outdated requests"); if let PhaseName::Failure | PhaseName::Shutdown = phase { debug!( "already in {} phase: ignoring error while purging outdated requests", phase, ); } else { return Some(self.into_failure_state(err)); } } self.broadcast(); info!("transitioning to the next phase"); self.next().await } .instrument(span) .await } /// Purges all pending requests that are considered outdated at the end of a successful phase. fn purge_outdated_requests(&mut self) -> Result<(), PhaseError> { info!("discarding outdated requests"); while let Some((_, span, resp_tx)) = self.try_next_request()? { debug!("discarding outdated request"); let _span_guard = span.enter(); discarded!(self.shared.state.round_id, Self::NAME); let _ = resp_tx.send(Err(RequestError::MessageDiscarded)); } Ok(()) } } impl PhaseState { /// Receives the next [`StateMachineRequest`]. /// /// # Errors /// Returns [`PhaseError::RequestChannel`] when all sender halves have been dropped. pub async fn next_request( &mut self, ) -> Result<(StateMachineRequest, Span, ResponseSender), PhaseError> { debug!("waiting for the next incoming request"); self.shared.request_rx.next().await.ok_or_else(|| { error!("request receiver broken: senders have been dropped"); PhaseError::RequestChannel("all message senders have been dropped!") }) } pub fn try_next_request( &mut self, ) -> Result, PhaseError> { match self.shared.request_rx.try_recv() { Some(Some(item)) => Ok(Some(item)), None => { debug!("no pending request"); Ok(None) } Some(None) => { warn!("failed to get next pending request: channel shut down"); Err(PhaseError::RequestChannel( "all message senders have been dropped!", )) } } } fn into_failure_state(self, err: PhaseError) -> StateMachine { PhaseState::::new(self.shared, err).into() } } ================================================ FILE: rust/xaynet-server/src/state_machine/phases/shutdown.rs ================================================ use async_trait::async_trait; use tracing::debug; use crate::{ state_machine::{ phases::{Phase, PhaseError, PhaseName, PhaseState, Shared}, StateMachine, }, storage::Storage, }; /// The shutdown state. #[derive(Debug)] pub struct Shutdown; #[async_trait] impl Phase for PhaseState where T: Storage, { const NAME: PhaseName = PhaseName::Shutdown; async fn process(&mut self) -> Result<(), PhaseError> { debug!("clearing the request channel"); self.shared.request_rx.close(); while self.shared.request_rx.recv().await.is_some() {} Ok(()) } async fn next(self) -> Option> { None } } impl PhaseState { /// Creates a new shutdown state. pub fn new(shared: Shared) -> Self { Self { private: Shutdown, shared, } } } #[cfg(test)] mod tests { use super::*; use crate::{ state_machine::tests::{ utils::{enable_logging, init_shared}, CoordinatorStateBuilder, EventBusBuilder, }, storage::{ tests::{MockCoordinatorStore, MockModelStore}, Store, }, }; #[tokio::test] async fn test_shutdown_to_none() { // No Storage errors // // What should happen: // 1. broadcast Shutdown phase // 2. request channel is closed // 3. state machine is stopped // // What should not happen: // - events have been broadcasted (except phase event) enable_logging(); let store = Store::new(MockCoordinatorStore::new(), MockModelStore::new()); let state = CoordinatorStateBuilder::new().build(); let (event_publisher, _event_subscriber) = EventBusBuilder::new(&state).build(); let (shared, request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new(shared)); assert!(state_machine.is_shutdown()); assert!(!request_tx.is_closed()); let state_machine = state_machine.next().await; assert!(request_tx.is_closed()); assert!(state_machine.is_none()); } } ================================================ FILE: rust/xaynet-server/src/state_machine/phases/sum.rs ================================================ use std::sync::Arc; use async_trait::async_trait; use displaydoc::Display; use thiserror::Error; use tracing::info; use crate::{ state_machine::{ events::DictionaryUpdate, phases::{Handler, Phase, PhaseError, PhaseName, PhaseState, Shared, Update}, requests::{RequestError, StateMachineRequest, SumRequest}, StateMachine, }, storage::{Storage, StorageError}, }; use xaynet_core::{SumDict, SumParticipantEphemeralPublicKey, SumParticipantPublicKey}; /// Errors which can occur during the sum phase. #[derive(Debug, Display, Error)] pub enum SumError { /// Sum dictionary does not exists. NoSumDict, /// Fetching sum dictionary failed: {0}. FetchSumDict(StorageError), } /// The sum state. #[derive(Debug)] pub struct Sum { /// The sum dictionary which gets assembled during the sum phase. sum_dict: Option, } #[async_trait] impl Phase for PhaseState where T: Storage, Self: Handler, { const NAME: PhaseName = PhaseName::Sum; async fn process(&mut self) -> Result<(), PhaseError> { self.process(self.shared.state.sum).await?; self.sum_dict().await?; Ok(()) } fn broadcast(&mut self) { info!("broadcasting sum dictionary"); let sum_dict = self .private .sum_dict .take() .expect("unreachable: never fails when `broadcast()` is called after `process()`"); self.shared .events .broadcast_sum_dict(DictionaryUpdate::New(Arc::new(sum_dict))); } async fn next(self) -> Option> { Some(PhaseState::::new(self.shared).into()) } } #[async_trait] impl Handler for PhaseState where T: Storage, { async fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), RequestError> { if let StateMachineRequest::Sum(SumRequest { participant_pk, ephm_pk, }) = req { self.update_sum_dict(participant_pk, ephm_pk).await } else { Err(RequestError::MessageRejected) } } } impl PhaseState { /// Creates a new sum state. pub fn new(shared: Shared) -> Self { Self { private: Sum { sum_dict: None }, shared, } } } impl PhaseState where T: Storage, { /// Updates the sum dict with a sum participant request. async fn update_sum_dict( &mut self, participant_pk: SumParticipantPublicKey, ephm_pk: SumParticipantEphemeralPublicKey, ) -> Result<(), RequestError> { self.shared .store .add_sum_participant(&participant_pk, &ephm_pk) .await? .into_inner() .map_err(RequestError::from) } /// Gets the sum dict from the store. async fn sum_dict(&mut self) -> Result<(), SumError> { self.private.sum_dict = self .shared .store .sum_dict() .await .map_err(SumError::FetchSumDict)? .ok_or(SumError::NoSumDict)? .into(); Ok(()) } } #[cfg(test)] mod tests { use super::*; use anyhow::anyhow; use tokio::time::{timeout, Duration}; use xaynet_core::SumDict; use crate::{ state_machine::{ coordinator::CoordinatorState, events::{EventPublisher, EventSubscriber, ModelUpdate}, tests::{ utils::{ assert_event_updated, enable_logging, init_shared, send_sum2_messages, send_sum_messages, send_update_messages, EventSnapshot, }, CoordinatorStateBuilder, EventBusBuilder, }, }, storage::{ tests::{utils::create_global_model, MockCoordinatorStore, MockModelStore}, Store, SumPartAdd, SumPartAddError, }, }; fn events_from_idle_phase(state: &CoordinatorState) -> (EventPublisher, EventSubscriber) { EventBusBuilder::new(state) .broadcast_phase(PhaseName::Idle) .broadcast_sum_dict(DictionaryUpdate::Invalidate) .broadcast_seed_dict(DictionaryUpdate::Invalidate) .broadcast_model(ModelUpdate::New(Arc::new(create_global_model(1)))) .build() } fn assert_after_phase_success( state_before: &CoordinatorState, events_before: &EventSnapshot, state_after: &CoordinatorState, events_after: &EventSnapshot, ) { assert_eq!(state_after, state_before); assert_event_updated(&events_after.phase, &events_before.phase); assert_event_updated(&events_after.sum_dict, &events_before.sum_dict); assert_eq!(events_after.keys, events_before.keys); assert_eq!(events_after.params, events_before.params); assert_eq!(events_after.phase.event, PhaseName::Sum); assert_eq!(events_after.seed_dict, events_before.seed_dict); assert_eq!(events_after.model, events_before.model); } fn assert_after_phase_failure( state_before: &CoordinatorState, events_before: &EventSnapshot, state_after: &CoordinatorState, events_after: &EventSnapshot, ) { assert_eq!(state_after, state_before); assert_event_updated(&events_after.phase, &events_before.phase); assert_eq!(events_after.keys, events_before.keys); assert_eq!(events_after.params, events_before.params); assert_eq!(events_after.phase.event, PhaseName::Sum); assert_eq!(events_after.sum_dict, events_before.sum_dict); assert_eq!(events_after.seed_dict, events_before.seed_dict); assert_eq!(events_after.model, events_before.model); } #[tokio::test] async fn test_sum_to_update_phase() { // No Storage errors // lets pretend we come from the sum phase // // What should happen: // 1. broadcast Sum phase // 2. accept 10 sum messages // 3. fetch sum dict // 4. broadcast sum dict // 5. move into update phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_add_sum_participant() .times(10) .returning(move |_, _| Ok(SumPartAdd(Ok(())))); cs.expect_sum_dict() .return_once(move || Ok(Some(SumDict::new()))); let store = Store::new(cs, MockModelStore::new()); let state = CoordinatorStateBuilder::new() .with_round_id(1) .with_sum_count_min(10) .with_sum_count_max(10) .with_sum_time_min(1) .build(); let (event_publisher, event_subscriber) = events_from_idle_phase(&state); let events_before_sum = EventSnapshot::from(&event_subscriber); let state_before_sum = state.clone(); let (shared, request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new(shared)); assert!(state_machine.is_sum()); send_sum_messages(10, request_tx.clone()); let state_machine = state_machine.next().await.unwrap(); let state_after_sum = state_machine.as_ref().clone(); let events_after_sum = EventSnapshot::from(&event_subscriber); assert_after_phase_success( &state_before_sum, &events_before_sum, &state_after_sum, &events_after_sum, ); assert!(state_machine.is_update()); } #[tokio::test] async fn test_sum_phase_timeout() { // No Storage errors // // What should happen: // 1. broadcast Sum phase // 2. phase should timeout // 3. move into error phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated // - the sum dict has been fetched // - the sum dict has been broadcasted enable_logging(); let store = Store::new(MockCoordinatorStore::new(), MockModelStore::new()); let state = CoordinatorStateBuilder::new() .with_round_id(1) .with_sum_time_min(1) .with_sum_time_max(2) .build(); let (event_publisher, event_subscriber) = events_from_idle_phase(&state); let events_before_sum = EventSnapshot::from(&event_subscriber); let state_before_sum = state.clone(); let (shared, _request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new(shared)); assert!(state_machine.is_sum()); let state_machine = timeout(Duration::from_secs(4), state_machine.next()) .await .unwrap() .unwrap(); let state_after_sum = state_machine.as_ref().clone(); let events_after_sum = EventSnapshot::from(&event_subscriber); assert_after_phase_failure( &state_before_sum, &events_before_sum, &state_after_sum, &events_after_sum, ); assert!(state_machine.is_failure()); assert!(matches!( state_machine.into_failure_phase_state().private.error, PhaseError::PhaseTimeout(_) )) } #[tokio::test] async fn test_rejected_messages() { // No Storage errors // // What should happen: // 1. broadcast Sum phase // 2. accept 7 sum messages // 3. reject 3 update and 5 sum2 messages // 4. fetch sum dict // 5. broadcast sum dict // 6. move into update phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_add_sum_participant() .times(7) .returning(move |_, _| Ok(SumPartAdd(Ok(())))); cs.expect_sum_dict() .return_once(move || Ok(Some(SumDict::new()))); let store = Store::new(cs, MockModelStore::new()); let state = CoordinatorStateBuilder::new() .with_round_id(1) .with_sum_count_min(7) .with_sum_count_max(7) .build(); let (event_publisher, event_subscriber) = events_from_idle_phase(&state); let events_before_sum = EventSnapshot::from(&event_subscriber); let state_before_sum = state.clone(); let (shared, request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new(shared)); assert!(state_machine.is_sum()); send_update_messages(3, request_tx.clone()); send_sum2_messages(5, request_tx.clone()); send_sum_messages(7, request_tx.clone()); let state_machine = state_machine.next().await.unwrap(); let state_after_sum = state_machine.as_ref().clone(); let events_after_sum = EventSnapshot::from(&event_subscriber); assert_after_phase_success( &state_before_sum, &events_before_sum, &state_after_sum, &events_after_sum, ); assert!(state_machine.is_update()); } #[tokio::test] async fn test_discarded_messages() { // No Storage errors // // What should happen: // 1. broadcast Sum phase // 2. accept 5 sum messages // 3. discard 5 sum messages // 4. fetch sum dict // 5. broadcast sum dict // 6. move into update phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_add_sum_participant() .times(5) .returning(move |_, _| Ok(SumPartAdd(Ok(())))); cs.expect_sum_dict() .return_once(move || Ok(Some(SumDict::new()))); let store = Store::new(cs, MockModelStore::new()); let state = CoordinatorStateBuilder::new() .with_round_id(1) .with_sum_count_min(5) .with_sum_count_max(5) .with_sum_time_min(5) .with_sum_time_max(10) .build(); let (event_publisher, event_subscriber) = events_from_idle_phase(&state); let events_before_sum = EventSnapshot::from(&event_subscriber); let state_before_sum = state.clone(); let (shared, request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new(shared)); assert!(state_machine.is_sum()); send_sum_messages(10, request_tx.clone()); let state_machine = state_machine.next().await.unwrap(); let state_after_sum = state_machine.as_ref().clone(); let events_after_sum = EventSnapshot::from(&event_subscriber); assert_after_phase_success( &state_before_sum, &events_before_sum, &state_after_sum, &events_after_sum, ); assert!(state_machine.is_update()); } #[tokio::test] async fn test_request_channel_is_dropped() { // No Storage errors // // What should happen: // 1. broadcast Sum phase // 2. request channel is dropped // 3. move into error phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated // - the sum dict has been fetched // - the sum dict has been broadcasted enable_logging(); let store = Store::new(MockCoordinatorStore::new(), MockModelStore::new()); let state = CoordinatorStateBuilder::new() .with_round_id(1) .with_sum_count_min(1) .with_sum_count_max(1) .with_sum_time_min(1) .with_sum_time_max(5) .build(); let (event_publisher, event_subscriber) = events_from_idle_phase(&state); let events_before_sum = EventSnapshot::from(&event_subscriber); let state_before_sum = state.clone(); let (shared, request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new(shared)); assert!(state_machine.is_sum()); drop(request_tx); let state_machine = state_machine.next().await.unwrap(); let state_after_sum = state_machine.as_ref().clone(); let events_after_sum = EventSnapshot::from(&event_subscriber); assert_after_phase_failure( &state_before_sum, &events_before_sum, &state_after_sum, &events_after_sum, ); assert!(state_machine.is_failure()); assert!(matches!( state_machine.into_failure_phase_state().private.error, PhaseError::RequestChannel(_) )) } #[tokio::test] async fn test_sum_to_update_fetch_sum_dict_failed() { // Storage errors // - sum_dict fails // // What should happen: // 1. broadcast Sum phase // 2. accept 1 sum message // 3. fetch sum dict (fails) // 4. move into error phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated // - the sum dict has been broadcasted enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_add_sum_participant() .times(1) .returning(move |_, _| Ok(SumPartAdd(Ok(())))); cs.expect_sum_dict().return_once(move || Err(anyhow!(""))); let store = Store::new(cs, MockModelStore::new()); let state = CoordinatorStateBuilder::new() .with_round_id(1) .with_sum_count_min(1) .with_sum_count_max(1) .with_sum_time_min(1) .with_sum_time_max(5) .build(); let (event_publisher, event_subscriber) = events_from_idle_phase(&state); let events_before_sum = EventSnapshot::from(&event_subscriber); let state_before_sum = state.clone(); let (shared, request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new(shared)); assert!(state_machine.is_sum()); send_sum_messages(1, request_tx.clone()); let state_machine = state_machine.next().await.unwrap(); let state_after_sum = state_machine.as_ref().clone(); let events_after_sum = EventSnapshot::from(&event_subscriber); assert_after_phase_failure( &state_before_sum, &events_before_sum, &state_after_sum, &events_after_sum, ); assert!(state_machine.is_failure()); assert!(matches!( state_machine.into_failure_phase_state().private.error, PhaseError::Sum(SumError::FetchSumDict(_)) )) } #[tokio::test] async fn test_sum_to_update_sum_dict_none() { // No Storage errors // // What should happen: // 1. broadcast Sum phase // 2. accept 1 sum message // 3. fetch sum dict (no storage error but the sum dict is None) // 4. move into error phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated // - the sum dict has been broadcasted enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_add_sum_participant() .times(1) .returning(move |_, _| Ok(SumPartAdd(Ok(())))); cs.expect_sum_dict().return_once(move || Ok(None)); let store = Store::new(cs, MockModelStore::new()); let state = CoordinatorStateBuilder::new() .with_round_id(1) .with_sum_count_min(1) .with_sum_count_max(1) .with_sum_time_min(1) .with_sum_time_max(5) .build(); let (event_publisher, event_subscriber) = events_from_idle_phase(&state); let events_before_sum = EventSnapshot::from(&event_subscriber); let state_before_sum = state.clone(); let (shared, request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new(shared)); assert!(state_machine.is_sum()); send_sum_messages(1, request_tx.clone()); let state_machine = state_machine.next().await.unwrap(); let state_after_sum = state_machine.as_ref().clone(); let events_after_sum = EventSnapshot::from(&event_subscriber); assert_after_phase_failure( &state_before_sum, &events_before_sum, &state_after_sum, &events_after_sum, ); assert!(state_machine.is_failure()); assert!(matches!( state_machine.into_failure_phase_state().private.error, PhaseError::Sum(SumError::NoSumDict) )) } #[tokio::test] async fn test_rejected_messages_pet_error() { // No Storage errors // // What should happen: // 1. broadcast Sum phase // 2. reject 3 sum messages (pet error SumPartAddError::AlreadyExists) // 3. phase should timeout // 4. move into error phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated // - the sum dict has been fetched // - the sum dict has been broadcasted enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_add_sum_participant() .times(3) .returning(move |_, _| Ok(SumPartAdd(Err(SumPartAddError::AlreadyExists)))); let store = Store::new(cs, MockModelStore::new()); let state = CoordinatorStateBuilder::new() .with_round_id(1) .with_sum_count_min(3) .with_sum_count_max(3) .with_sum_time_min(0) .with_sum_time_max(2) .build(); let (event_publisher, event_subscriber) = events_from_idle_phase(&state); let events_before_sum = EventSnapshot::from(&event_subscriber); let state_before_sum = state.clone(); let (shared, request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new(shared)); assert!(state_machine.is_sum()); send_sum_messages(3, request_tx.clone()); let state_machine = state_machine.next().await.unwrap(); let state_after_sum = state_machine.as_ref().clone(); let events_after_sum = EventSnapshot::from(&event_subscriber); assert_after_phase_failure( &state_before_sum, &events_before_sum, &state_after_sum, &events_after_sum, ); assert!(state_machine.is_failure()); assert!(matches!( state_machine.into_failure_phase_state().private.error, PhaseError::PhaseTimeout(_) )) } // #[tokio::test] // async fn test_sum_phase_publish_after_purge() { // // Publish sum dict after purging all remaining messages. // enable_logging(); // let mut cs = MockCoordinatorStore::new(); // cs.expect_add_sum_participant() // .returning(move |_, _| Ok(SumPartAdd(Ok(())))); // cs.expect_sum_dict() // .return_once(move || Ok(Some(SumDict::new()))); // let store = Store::new(cs, MockModelStore::new()); // let state = CoordinatorStateBuilder::new() // .with_round_id(1) // .with_sum_count_min(2) // .with_sum_count_max(500) // .with_sum_time_min(0) // .build(); // let (event_publisher, event_subscriber) = events_from_idle_phase(&state); // let (shared, request_tx) = init_shared(state, store, event_publisher); // let state_machine = StateMachine::from(PhaseState::::new(shared)); // assert!(state_machine.is_sum()); // let (mut ready, latch) = Readiness::new(); // send_sum_messages_with_latch(1000, request_tx.clone(), latch); // let mut sum_dict_listener = event_subscriber.sum_dict_listener(); // sum_dict_listener.changed().await.unwrap(); // tokio::time::sleep(Duration::from_secs(10)).await; // tokio::select! { // // TODO: purge_outdated_requests blocks the current thread (we should fix that) // // and sum_dict_listener.changed() would always be executed after // // state_machine.next(). The test always passes although it shouldn't // // therefore we need to spawn it here to run the state machine on a separate // // thread // // // // Further more we suffer from the https://github.com/tokio-rs/tokio/issues/3350 // // issue in request_tx::try_recv(). We fill the request channel with 1000 // // before we start the machine. Nevertheless, the message purging stops after // // around 134 messages. // _ = state_machine.next() => { // panic!("state did no run successfully") // } // _ = sum_dict_listener.changed() => { // panic!("sum dict was broadcasted before all requests has been purged") // } // _ = ready.is_ready() => { // } // } // } } ================================================ FILE: rust/xaynet-server/src/state_machine/phases/sum2.rs ================================================ use async_trait::async_trait; use tracing::info; use crate::{ state_machine::{ events::DictionaryUpdate, phases::{Handler, Phase, PhaseError, PhaseName, PhaseState, Shared, Unmask}, requests::{RequestError, StateMachineRequest, Sum2Request}, StateMachine, }, storage::Storage, }; use xaynet_core::{ mask::{Aggregation, MaskObject}, SumParticipantPublicKey, }; /// The sum2 state. #[derive(Debug)] pub struct Sum2 { /// The aggregator for masked models. model_agg: Aggregation, } #[async_trait] impl Phase for PhaseState where T: Storage, Self: Handler, { const NAME: PhaseName = PhaseName::Sum2; async fn process(&mut self) -> Result<(), PhaseError> { self.process(self.shared.state.sum2).await } fn broadcast(&mut self) { info!("broadcasting invalidation of sum dictionary"); self.shared .events .broadcast_sum_dict(DictionaryUpdate::Invalidate); info!("broadcasting invalidation of seed dictionary"); self.shared .events .broadcast_seed_dict(DictionaryUpdate::Invalidate); } async fn next(self) -> Option> { Some(PhaseState::::new(self.shared, self.private.model_agg).into()) } } #[async_trait] impl Handler for PhaseState where T: Storage, { async fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), RequestError> { if let StateMachineRequest::Sum2(Sum2Request { participant_pk, model_mask, }) = req { self.update_mask_dict(participant_pk, model_mask).await } else { Err(RequestError::MessageRejected) } } } impl PhaseState { /// Creates a new sum2 state. pub fn new(shared: Shared, model_agg: Aggregation) -> Self { Self { private: Sum2 { model_agg }, shared, } } } impl PhaseState where T: Storage, { /// Updates the mask dict with a sum2 participant request. async fn update_mask_dict( &mut self, participant_pk: SumParticipantPublicKey, model_mask: MaskObject, ) -> Result<(), RequestError> { self.shared .store .incr_mask_score(&participant_pk, &model_mask) .await? .into_inner() .map_err(RequestError::from) } } #[cfg(test)] mod tests { use super::*; use std::sync::Arc; use xaynet_core::{SeedDict, SumDict}; use crate::{ state_machine::{ coordinator::CoordinatorState, events::{DictionaryUpdate, EventPublisher, EventSubscriber, ModelUpdate}, tests::{ utils::{ assert_event_updated, enable_logging, init_shared, send_sum2_messages, EventSnapshot, }, CoordinatorStateBuilder, EventBusBuilder, }, }, storage::{ tests::{utils::create_global_model, MockCoordinatorStore, MockModelStore}, MaskScoreIncr, MaskScoreIncrError, Store, }, }; fn events_from_update_phase(state: &CoordinatorState) -> (EventPublisher, EventSubscriber) { EventBusBuilder::new(state) .broadcast_phase(PhaseName::Update) .broadcast_sum_dict(DictionaryUpdate::New(Arc::new(SumDict::new()))) .broadcast_seed_dict(DictionaryUpdate::New(Arc::new(SeedDict::new()))) .broadcast_model(ModelUpdate::New(Arc::new(create_global_model(1)))) .build() } fn assert_after_phase_success( state_before: &CoordinatorState, events_before: &EventSnapshot, state_after: &CoordinatorState, events_after: &EventSnapshot, ) { assert_eq!(state_after, state_before); assert_event_updated(&events_after.phase, &events_before.phase); assert_event_updated(&events_after.sum_dict, &events_before.sum_dict); assert_event_updated(&events_after.seed_dict, &events_before.seed_dict); assert_eq!(events_after.sum_dict.event, DictionaryUpdate::Invalidate); assert_eq!(events_after.seed_dict.event, DictionaryUpdate::Invalidate); assert_eq!(events_after.keys, events_before.keys); assert_eq!(events_after.params, events_before.params); assert_eq!(events_after.phase.event, PhaseName::Sum2); assert_eq!(events_after.model, events_before.model); } fn assert_after_phase_failure( state_before: &CoordinatorState, events_before: &EventSnapshot, state_after: &CoordinatorState, events_after: &EventSnapshot, ) { assert_eq!(state_after, state_before); assert_event_updated(&events_after.phase, &events_before.phase); assert_eq!(events_after.keys, events_before.keys); assert_eq!(events_after.params, events_before.params); assert_eq!(events_after.phase.event, PhaseName::Sum2); assert_eq!(events_after.sum_dict, events_before.sum_dict); assert_eq!(events_after.seed_dict, events_before.seed_dict); assert_eq!(events_after.model, events_before.model); } #[tokio::test] async fn test_sum2_to_unmask_phase() { // No Storage errors // lets pretend we come from the update phase // // What should happen: // 1. broadcast Sum2 phase // 2. accept 10 sum2 messages // 3. broadcast invalidation of sum and seed dict // 4. move into unmask phase // // What should not happen: // - the shared state has been changed // - events have been broadcasted (except phase event and invalidation // event of sum and seed dict) enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_incr_mask_score() .times(10) .returning(move |_, _| Ok(MaskScoreIncr(Ok(())))); let store = Store::new(cs, MockModelStore::new()); let state = CoordinatorStateBuilder::new() .with_round_id(1) .with_sum2_count_min(10) .with_sum2_count_max(10) .with_sum2_time_min(1) .build(); let (event_publisher, event_subscriber) = events_from_update_phase(&state); let events_before_sum2 = EventSnapshot::from(&event_subscriber); let state_before_sum2 = state.clone(); let (shared, request_tx) = init_shared(state, store, event_publisher); let agg = Aggregation::new( state_before_sum2.round_params.mask_config, state_before_sum2.round_params.model_length, ); let state_machine = StateMachine::from(PhaseState::::new(shared, agg)); assert!(state_machine.is_sum2()); send_sum2_messages(10, request_tx.clone()); let state_machine = state_machine.next().await.unwrap(); let state_after_sum2 = state_machine.as_ref().clone(); let events_after_sum2 = EventSnapshot::from(&event_subscriber); assert_after_phase_success( &state_before_sum2, &events_before_sum2, &state_after_sum2, &events_after_sum2, ); assert!(state_machine.is_unmask()); } #[tokio::test] async fn test_rejected_messages_pet_error() { // No Storage errors // // What should happen: // 1. broadcast Sum2 phase // 2. reject 3 sum2 messages (pet error MaskScoreIncrError::UnknownSumPk) // 3. phase should timeout // 4. move into error phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated // - the sum dict has been invalidated // - the seed dict has been invalidated enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_incr_mask_score() .times(3) .returning(move |_, _| Ok(MaskScoreIncr(Err(MaskScoreIncrError::UnknownSumPk)))); let store = Store::new(cs, MockModelStore::new()); let state = CoordinatorStateBuilder::new() .with_round_id(1) .with_sum2_count_min(3) .with_sum2_count_max(3) .with_sum2_time_min(0) .with_sum2_time_max(2) .build(); let (event_publisher, event_subscriber) = events_from_update_phase(&state); let events_before_sum2 = EventSnapshot::from(&event_subscriber); let state_before_sum2 = state.clone(); let (shared, request_tx) = init_shared(state, store, event_publisher); let agg = Aggregation::new( state_before_sum2.round_params.mask_config, state_before_sum2.round_params.model_length, ); let state_machine = StateMachine::from(PhaseState::::new(shared, agg)); assert!(state_machine.is_sum2()); send_sum2_messages(3, request_tx.clone()); let state_machine = state_machine.next().await.unwrap(); let state_after_sum2 = state_machine.as_ref().clone(); let events_after_sum2 = EventSnapshot::from(&event_subscriber); assert_after_phase_failure( &state_before_sum2, &events_before_sum2, &state_after_sum2, &events_after_sum2, ); assert!(state_machine.is_failure()); assert!(matches!( state_machine.into_failure_phase_state().private.error, PhaseError::PhaseTimeout(_) )) } } ================================================ FILE: rust/xaynet-server/src/state_machine/phases/unmask.rs ================================================ use std::{cmp::Ordering, sync::Arc}; use async_trait::async_trait; use displaydoc::Display; use thiserror::Error; #[cfg(feature = "model-persistence")] use tracing::warn; use tracing::{error, info}; use crate::{ metric, metrics::{GlobalRecorder, Measurement}, state_machine::{ events::ModelUpdate, phases::{Idle, Phase, PhaseError, PhaseName, PhaseState, Shared}, StateMachine, }, storage::{Storage, StorageError}, }; use xaynet_core::mask::{Aggregation, MaskObject, Model, UnmaskingError}; /// Errors which can occur during the unmask phase. #[derive(Debug, Display, Error)] pub enum UnmaskError { /// Ambiguous masks were computed by the sum participants. AmbiguousMasks, /// No mask found. NoMask, /// Unmasking global model failed: {0}. Unmasking(#[from] UnmaskingError), /// Fetching best masks failed: {0}. FetchBestMasks(#[from] StorageError), #[cfg(feature = "model-persistence")] /// Saving the global model failed: {0}. SaveGlobalModel(crate::storage::StorageError), /// Publishing the proof of the global model failed: {0}. PublishProof(crate::storage::StorageError), } /// The unmask state. #[derive(Debug)] pub struct Unmask { /// The aggregator for masked models. model_agg: Option, /// The global model of the current round. global_model: Option>, } #[async_trait] impl Phase for PhaseState where T: Storage, { const NAME: PhaseName = PhaseName::Unmask; async fn process(&mut self) -> Result<(), PhaseError> { self.emit_number_of_unique_masks_metrics(); let best_masks = self.best_masks().await?; self.end_round(best_masks).await?; #[cfg(feature = "model-persistence")] self.save_global_model().await?; self.publish_proof().await?; Ok(()) } fn broadcast(&mut self) { info!("broadcasting the new global model"); let global_model = self.private.global_model.take().expect( "unreachable: never fails when `broadcast()` is called after `end_round()`", ); self.shared .events .broadcast_model(ModelUpdate::New(global_model)); } async fn next(self) -> Option> { Some(PhaseState::::new(self.shared).into()) } } impl PhaseState { /// Creates a new unmask state. pub fn new(shared: Shared, model_agg: Aggregation) -> Self { Self { private: Unmask { model_agg: Some(model_agg), global_model: None, }, shared, } } /// Freezes the mask dictionary. async fn freeze_mask_dict( &mut self, mut best_masks: Vec<(MaskObject, u64)>, ) -> Result { let mask = best_masks .drain(0..) .fold( (None, 0), |(unique_mask, unique_count), (mask, count)| match unique_count.cmp(&count) { Ordering::Less => (Some(mask), count), Ordering::Greater => (unique_mask, unique_count), Ordering::Equal => (None, unique_count), }, ) .0 .ok_or(UnmaskError::AmbiguousMasks)?; Ok(mask) } /// Ends the round by unmasking the global model. async fn end_round(&mut self, best_masks: Vec<(MaskObject, u64)>) -> Result<(), UnmaskError> { let mask = self.freeze_mask_dict(best_masks).await?; // Safe unwrap: State::::new always creates Some(aggregation) let model_agg = self.private.model_agg.take().unwrap(); model_agg .validate_unmasking(&mask) .map_err(UnmaskError::from)?; self.private.global_model = Some(Arc::new(model_agg.unmask(mask))); Ok(()) } } impl PhaseState where T: Storage, { /// Broadcasts mask metrics. fn emit_number_of_unique_masks_metrics(&mut self) { if GlobalRecorder::global().is_none() { return; } let mut store = self.shared.store.clone(); let (round_id, phase_name) = (self.shared.state.round_id, Self::NAME); tokio::spawn(async move { match store.number_of_unique_masks().await { Ok(number_of_masks) => metric!( Measurement::MasksTotalNumber, number_of_masks, ("round_id", round_id), ("phase", phase_name as u8), ), Err(err) => error!("failed to fetch total number of masks: {}", err), }; }); } /// Gets the two masks with the highest score. async fn best_masks(&mut self) -> Result, UnmaskError> { self.shared .store .best_masks() .await .map_err(UnmaskError::FetchBestMasks)? .ok_or(UnmaskError::NoMask) } /// Persists the global model to the store. #[cfg(feature = "model-persistence")] async fn save_global_model(&mut self) -> Result<(), UnmaskError> { info!("saving global model"); let global_model = self .private .global_model .as_ref() .expect( "unreachable: never fails when `save_global_model()` is called after `end_round()`", ) .as_ref(); let global_model_id = self .shared .store .set_global_model( self.shared.state.round_id, &self.shared.state.round_params.seed, global_model, ) .await .map_err(UnmaskError::SaveGlobalModel)?; if let Err(err) = self .shared .store .set_latest_global_model_id(&global_model_id) .await { warn!("failed to update latest global model id: {}", err); } Ok(()) } /// Publishes proof of the global model. async fn publish_proof(&mut self) -> Result<(), UnmaskError> { info!("publishing proof of the new global model"); let global_model = self .private .global_model .as_ref() .expect( "unreachable: never fails when `save_global_model()` is called after `end_round()`", ) .as_ref(); self.shared .store .publish_proof(global_model) .await .map_err(UnmaskError::PublishProof) } } #[cfg(test)] mod tests { use super::*; use std::sync::Arc; use anyhow::anyhow; use crate::{ state_machine::{ coordinator::CoordinatorState, events::{DictionaryUpdate, EventPublisher, EventSubscriber, ModelUpdate}, tests::{ utils::{assert_event_updated, enable_logging, init_shared, EventSnapshot}, CoordinatorStateBuilder, EventBusBuilder, }, }, storage::{ tests::{ utils::{create_global_model, create_mask}, MockCoordinatorStore, MockModelStore, MockTrustAnchor, }, Store, }, }; fn events_from_sum2_phase(state: &CoordinatorState) -> (EventPublisher, EventSubscriber) { EventBusBuilder::new(state) .broadcast_phase(PhaseName::Sum2) .broadcast_sum_dict(DictionaryUpdate::Invalidate) .broadcast_seed_dict(DictionaryUpdate::Invalidate) .broadcast_model(ModelUpdate::New(Arc::new(create_global_model(1)))) .build() } fn assert_after_phase_success( state_before: &CoordinatorState, events_before: &EventSnapshot, state_after: &CoordinatorState, events_after: &EventSnapshot, ) { assert_ne!(state_after.round_id, state_before.round_id); assert_eq!(state_after.round_params, state_before.round_params); assert_eq!(state_after.keys, state_before.keys); assert_eq!(state_after.sum, state_before.sum); assert_eq!(state_after.update, state_before.update); assert_eq!(state_after.sum2, state_before.sum2); assert_event_updated(&events_after.phase, &events_before.phase); assert_event_updated(&events_after.model, &events_before.model); assert_eq!(events_after.keys, events_before.keys); assert_eq!(events_after.params, events_before.params); assert_eq!(events_after.phase.event, PhaseName::Unmask); assert_eq!(events_after.sum_dict, events_before.sum_dict); assert_eq!(events_after.seed_dict, events_before.seed_dict); } fn assert_after_phase_failure( state_before: &CoordinatorState, events_before: &EventSnapshot, state_after: &CoordinatorState, events_after: &EventSnapshot, ) { assert_eq!(state_after, state_before); assert_event_updated(&events_after.phase, &events_before.phase); assert_eq!(events_after.keys, events_before.keys); assert_eq!(events_after.params, events_before.params); assert_eq!(events_after.phase.event, PhaseName::Unmask); assert_eq!(events_after.sum_dict, events_before.sum_dict); assert_eq!(events_after.seed_dict, events_before.seed_dict); assert_eq!(events_after.model, events_before.model); } fn init_aggregator(state: &CoordinatorState) -> Aggregation { let mut aggregator = Aggregation::new( state.round_params.mask_config, state.round_params.model_length, ); aggregator.aggregate(create_mask(state.round_params.model_length, 1)); aggregator } #[tokio::test] async fn test_unmask_to_idle_phase() { // No Storage errors // lets pretend we come from the sum2 phase // // What should happen: // 1. broadcast Unmask phase // 2 fetch best masks (return only one) // 3. unmask the masked global model // 4. publish proof // 5. broadcast unmasked global model // 6. move into idle phase // // What should not happen: // - the shared state has been changed // - events have been broadcasted (except phase event and global model) enable_logging(); let state = CoordinatorStateBuilder::new().with_round_id(1).build(); let model_length = state.round_params.model_length; let mut cs = MockCoordinatorStore::new(); cs.expect_best_masks() .returning(move || Ok(Some(vec![(create_mask(model_length, 1), 1)]))); #[cfg(feature = "model-persistence")] { cs.expect_set_latest_global_model_id() .returning(move |_| Ok(())); } let ms = { #[cfg(not(feature = "model-persistence"))] { MockModelStore::new() } #[cfg(feature = "model-persistence")] { let mut ms = MockModelStore::new(); ms.expect_set_global_model() .returning(move |_, _, _| Ok("id".to_string())); ms } }; let store = Store::new(cs, ms); let (event_publisher, event_subscriber) = events_from_sum2_phase(&state); let events_before_sum2 = EventSnapshot::from(&event_subscriber); let state_before_sum2 = state.clone(); let (shared, _request_tx) = init_shared(state, store, event_publisher); let aggregator = init_aggregator(&state_before_sum2); let state_machine = StateMachine::from(PhaseState::::new(shared, aggregator)); assert!(state_machine.is_unmask()); let state_machine = state_machine.next().await.unwrap(); let state_after_sum2 = state_machine.as_ref().clone(); let events_after_sum2 = EventSnapshot::from(&event_subscriber); assert_after_phase_success( &state_before_sum2, &events_before_sum2, &state_after_sum2, &events_after_sum2, ); assert!(state_machine.is_idle()); } #[tokio::test] async fn test_unmask_to_idle_phase_best_masks_fails() { // Storage: // - best_masks fails // // What should happen: // 1. broadcast Unmask phase // 2. fetch best masks (fails) // 3. move into error phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated/changed // - the sum dict has been invalidated // - the seed dict has been invalidated enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_best_masks().returning(move || Err(anyhow!(""))); let store = Store::new(cs, MockModelStore::new()); let state = CoordinatorStateBuilder::new().with_round_id(1).build(); let (event_publisher, event_subscriber) = events_from_sum2_phase(&state); let events_before_sum2 = EventSnapshot::from(&event_subscriber); let state_before_sum2 = state.clone(); let (shared, _request_tx) = init_shared(state, store, event_publisher); let aggregator = init_aggregator(&state_before_sum2); let state_machine = StateMachine::from(PhaseState::::new(shared, aggregator)); assert!(state_machine.is_unmask()); let state_machine = state_machine.next().await.unwrap(); let state_after_sum2 = state_machine.as_ref().clone(); let events_after_sum2 = EventSnapshot::from(&event_subscriber); assert_after_phase_failure( &state_before_sum2, &events_before_sum2, &state_after_sum2, &events_after_sum2, ); assert!(state_machine.is_failure()); assert!(matches!( state_machine.into_failure_phase_state().private.error, PhaseError::Unmask(UnmaskError::FetchBestMasks(_)) )) } #[tokio::test] async fn test_unmask_to_idle_phase_no_mask() { // No Storage errors // // What should happen: // 1. broadcast Unmask phase // 2. fetch best masks (no storage error but the mask vec is None) // 3. move into error phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated/changed // - the sum dict has been invalidated // - the seed dict has been invalidated enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_best_masks().returning(move || Ok(None)); let store = Store::new(cs, MockModelStore::new()); let state = CoordinatorStateBuilder::new().with_round_id(1).build(); let (event_publisher, event_subscriber) = events_from_sum2_phase(&state); let events_before_sum2 = EventSnapshot::from(&event_subscriber); let state_before_sum2 = state.clone(); let (shared, _request_tx) = init_shared(state, store, event_publisher); let aggregator = init_aggregator(&state_before_sum2); let state_machine = StateMachine::from(PhaseState::::new(shared, aggregator)); assert!(state_machine.is_unmask()); let state_machine = state_machine.next().await.unwrap(); let state_after_sum2 = state_machine.as_ref().clone(); let events_after_sum2 = EventSnapshot::from(&event_subscriber); assert_after_phase_failure( &state_before_sum2, &events_before_sum2, &state_after_sum2, &events_after_sum2, ); assert!(state_machine.is_failure()); assert!(matches!( state_machine.into_failure_phase_state().private.error, PhaseError::Unmask(UnmaskError::NoMask) )) } #[tokio::test] async fn test_unmask_to_idle_phase_ambiguous_masks() { // No Storage errors // // What should happen: // 1. broadcast Unmask phase // 2. fetch best masks // 3. unmask the masked global model (fails because of ambiguous masks) // 4. move into error phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated/changed // - the sum dict has been invalidated // - the seed dict has been invalidated enable_logging(); let state = CoordinatorStateBuilder::new().with_round_id(1).build(); let model_length = state.round_params.model_length; let mut cs = MockCoordinatorStore::new(); cs.expect_best_masks().returning(move || { Ok(Some(vec![ (create_mask(model_length, 1), 1), (create_mask(model_length, 2), 1), ])) }); let store = Store::new(cs, MockModelStore::new()); let (event_publisher, event_subscriber) = events_from_sum2_phase(&state); let events_before_sum2 = EventSnapshot::from(&event_subscriber); let state_before_sum2 = state.clone(); let (shared, _request_tx) = init_shared(state, store, event_publisher); let aggregator = init_aggregator(&state_before_sum2); let state_machine = StateMachine::from(PhaseState::::new(shared, aggregator)); assert!(state_machine.is_unmask()); let state_machine = state_machine.next().await.unwrap(); let state_after_sum2 = state_machine.as_ref().clone(); let events_after_sum2 = EventSnapshot::from(&event_subscriber); assert_after_phase_failure( &state_before_sum2, &events_before_sum2, &state_after_sum2, &events_after_sum2, ); assert!(state_machine.is_failure()); assert!(matches!( state_machine.into_failure_phase_state().private.error, PhaseError::Unmask(UnmaskError::AmbiguousMasks) )) } #[tokio::test] async fn test_unmask_to_idle_phase_validate_unmasking_fails() { // No Storage errors // // What should happen: // 1. broadcast Unmask phase // 2. fetch best masks // 3. unmask the masked global model (fails because of validate unmasking error) // 4. move into error phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated/changed // - the sum dict has been invalidated // - the seed dict has been invalidated enable_logging(); let state = CoordinatorStateBuilder::new().with_round_id(1).build(); let model_length = state.round_params.model_length; let mut cs = MockCoordinatorStore::new(); cs.expect_best_masks() .returning(move || Ok(Some(vec![(create_mask(model_length, 1), 1)]))); let store = Store::new(cs, MockModelStore::new()); let (event_publisher, event_subscriber) = events_from_sum2_phase(&state); let events_before_sum2 = EventSnapshot::from(&event_subscriber); let state_before_sum2 = state.clone(); let (shared, _request_tx) = init_shared(state, store, event_publisher); let aggregator = Aggregation::new( state_before_sum2.round_params.mask_config, state_before_sum2.round_params.model_length, ); let state_machine = StateMachine::from(PhaseState::::new(shared, aggregator)); assert!(state_machine.is_unmask()); let state_machine = state_machine.next().await.unwrap(); let state_after_sum2 = state_machine.as_ref().clone(); let events_after_sum2 = EventSnapshot::from(&event_subscriber); assert_after_phase_failure( &state_before_sum2, &events_before_sum2, &state_after_sum2, &events_after_sum2, ); assert!(state_machine.is_failure()); assert!(matches!( state_machine.into_failure_phase_state().private.error, PhaseError::Unmask(UnmaskError::Unmasking(UnmaskingError::NoModel)) )) } #[tokio::test] async fn test_unmask_to_idle_phase_publish_proof_fails() { // TODO: we should set the latest_global_model_id only if the // the proof was successfully published // // Why? If the coordinator were to restart after this phase, they would // be using a model that has no evidence and therefore cannot be validated // by the user. // Storage: // - publish_proof fails // // What should happen: // 1. broadcast Unmask phase // 2. fetch best masks // 3. unmask the masked global model // 4. save global model and model id (model-persistence feature) // 5. publish proof (fails) // 6. move into error phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated/changed // - the sum dict has been invalidated // - the seed dict has been invalidated enable_logging(); let state = CoordinatorStateBuilder::new().with_round_id(1).build(); let model_length = state.round_params.model_length; let mut cs = MockCoordinatorStore::new(); cs.expect_best_masks() .returning(move || Ok(Some(vec![(create_mask(model_length, 1), 1)]))); #[cfg(feature = "model-persistence")] { cs.expect_set_latest_global_model_id() .returning(move |_| Ok(())); } let ms = { #[cfg(not(feature = "model-persistence"))] { MockModelStore::new() } #[cfg(feature = "model-persistence")] { let mut ms = MockModelStore::new(); ms.expect_set_global_model() .returning(move |_, _, _| Ok("id".to_string())); ms } }; let mut ta = MockTrustAnchor::new(); ta.expect_publish_proof() .returning(move |_| Err(anyhow!(""))); let store = Store::new_with_trust_anchor(cs, ms, ta); let (event_publisher, event_subscriber) = events_from_sum2_phase(&state); let events_before_sum2 = EventSnapshot::from(&event_subscriber); let state_before_sum2 = state.clone(); let (shared, _request_tx) = init_shared(state, store, event_publisher); let aggregator = init_aggregator(&state_before_sum2); let state_machine = StateMachine::from(PhaseState::::new(shared, aggregator)); assert!(state_machine.is_unmask()); let state_machine = state_machine.next().await.unwrap(); let state_after_sum2 = state_machine.as_ref().clone(); let events_after_sum2 = EventSnapshot::from(&event_subscriber); assert_after_phase_failure( &state_before_sum2, &events_before_sum2, &state_after_sum2, &events_after_sum2, ); assert!(state_machine.is_failure()); assert!(matches!( state_machine.into_failure_phase_state().private.error, PhaseError::Unmask(UnmaskError::PublishProof(_)) )) } #[cfg(feature = "model-persistence")] #[tokio::test] async fn test_unmask_to_idle_phase_set_global_model_fails() { // Storage: // - set_global_model fails // // What should happen: // 1. broadcast Unmask phase // 2. fetch best masks // 3. unmask the masked global model // 4. save global model (fails) // 5. move into error phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated/changed // - the sum dict has been invalidated // - the seed dict has been invalidated enable_logging(); let state = CoordinatorStateBuilder::new().with_round_id(1).build(); let model_length = state.round_params.model_length; let mut cs = MockCoordinatorStore::new(); cs.expect_best_masks() .returning(move || Ok(Some(vec![(create_mask(model_length, 1), 1)]))); cs.expect_set_latest_global_model_id() .returning(move |_| Ok(())); let mut ms = MockModelStore::new(); ms.expect_set_global_model() .returning(move |_, _, _| Err(anyhow!(""))); let store = Store::new(cs, ms); let (event_publisher, event_subscriber) = events_from_sum2_phase(&state); let events_before_sum2 = EventSnapshot::from(&event_subscriber); let state_before_sum2 = state.clone(); let (shared, _request_tx) = init_shared(state, store, event_publisher); let aggregator = init_aggregator(&state_before_sum2); let state_machine = StateMachine::from(PhaseState::::new(shared, aggregator)); assert!(state_machine.is_unmask()); let state_machine = state_machine.next().await.unwrap(); let state_after_sum2 = state_machine.as_ref().clone(); let events_after_sum2 = EventSnapshot::from(&event_subscriber); assert_after_phase_failure( &state_before_sum2, &events_before_sum2, &state_after_sum2, &events_after_sum2, ); assert!(state_machine.is_failure()); assert!(matches!( state_machine.into_failure_phase_state().private.error, PhaseError::Unmask(UnmaskError::SaveGlobalModel(_)) )) } #[cfg(feature = "model-persistence")] #[tokio::test] async fn test_unmask_to_idle_phase_set_global_model_id_fails() { // Storage: // - set_latest_global_model_id fails // // What should happen: // 1. broadcast Unmask phase // 2. fetch best masks // 3. unmask the masked global model // 4. save global model and model id (fails) // 5. publish proof // 6. broadcast unmasked global model // 7. move into idle phase // // What should not happen: // - the shared state has been changed // - events have been broadcasted (except phase event and global model) enable_logging(); let state = CoordinatorStateBuilder::new().with_round_id(1).build(); let model_length = state.round_params.model_length; let mut cs = MockCoordinatorStore::new(); cs.expect_best_masks() .returning(move || Ok(Some(vec![(create_mask(model_length, 1), 1)]))); cs.expect_set_latest_global_model_id() .returning(move |_| Err(anyhow!(""))); let mut ms = MockModelStore::new(); ms.expect_set_global_model() .returning(move |_, _, _| Ok("id".to_string())); let store = Store::new(cs, ms); let (event_publisher, event_subscriber) = events_from_sum2_phase(&state); let events_before_sum2 = EventSnapshot::from(&event_subscriber); let state_before_sum2 = state.clone(); let (shared, _request_tx) = init_shared(state, store, event_publisher); let aggregator = init_aggregator(&state_before_sum2); let state_machine = StateMachine::from(PhaseState::::new(shared, aggregator)); assert!(state_machine.is_unmask()); let state_machine = state_machine.next().await.unwrap(); let state_after_sum2 = state_machine.as_ref().clone(); let events_after_sum2 = EventSnapshot::from(&event_subscriber); assert_after_phase_success( &state_before_sum2, &events_before_sum2, &state_after_sum2, &events_after_sum2, ); assert!(state_machine.is_idle()); } } ================================================ FILE: rust/xaynet-server/src/state_machine/phases/update.rs ================================================ use std::sync::Arc; use async_trait::async_trait; use displaydoc::Display; use thiserror::Error; use tracing::{debug, info, warn}; use crate::{ state_machine::{ events::DictionaryUpdate, phases::{Handler, Phase, PhaseError, PhaseName, PhaseState, Shared, Sum2}, requests::{RequestError, StateMachineRequest, UpdateRequest}, StateMachine, }, storage::{Storage, StorageError}, }; use xaynet_core::{ mask::{Aggregation, MaskObject}, LocalSeedDict, SeedDict, UpdateParticipantPublicKey, }; /// Errors which can occur during the update phase. #[derive(Debug, Display, Error)] pub enum UpdateError { /// Seed dictionary does not exists. NoSeedDict, /// Fetching seed dictionary failed: {0}. FetchSeedDict(StorageError), } /// The update state. #[derive(Debug)] pub struct Update { /// The aggregator for masked models. model_agg: Aggregation, /// The seed dictionary which gets assembled during the update phase. seed_dict: Option, } #[async_trait] impl Phase for PhaseState where T: Storage, Self: Handler, { const NAME: PhaseName = PhaseName::Update; async fn process(&mut self) -> Result<(), PhaseError> { self.process(self.shared.state.update).await?; self.seed_dict().await?; Ok(()) } fn broadcast(&mut self) { info!("broadcasting the global seed dictionary"); let seed_dict = self .private .seed_dict .take() .expect("unreachable: never fails when `broadcast()` is called after `process()`"); self.shared .events .broadcast_seed_dict(DictionaryUpdate::New(Arc::new(seed_dict))); } async fn next(self) -> Option> { Some(PhaseState::::new(self.shared, self.private.model_agg).into()) } } #[async_trait] impl Handler for PhaseState where T: Storage, { async fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), RequestError> { if let StateMachineRequest::Update(UpdateRequest { participant_pk, local_seed_dict, masked_model, }) = req { self.update_seed_dict_and_aggregate_mask( &participant_pk, &local_seed_dict, masked_model, ) .await } else { Err(RequestError::MessageRejected) } } } impl PhaseState { /// Creates a new update state. pub fn new(shared: Shared) -> Self { let model_agg = Aggregation::new( shared.state.round_params.mask_config, shared.state.round_params.model_length, ); Self { private: Update { model_agg, seed_dict: None, }, shared, } } } impl PhaseState where T: Storage, { /// Updates the local seed dict and aggregates the masked model. async fn update_seed_dict_and_aggregate_mask( &mut self, pk: &UpdateParticipantPublicKey, local_seed_dict: &LocalSeedDict, mask_object: MaskObject, ) -> Result<(), RequestError> { // Check if aggregation can be performed. It is important to // do that _before_ updating the seed dictionary, because we // don't want to add the local seed dict if the corresponding // masked model is invalid debug!("checking whether the masked model can be aggregated"); self.private .model_agg .validate_aggregation(&mask_object) .map_err(|e| { warn!("model aggregation error: {}", e); RequestError::AggregationFailed })?; // Try to update local seed dict first. If this fail, we do // not want to aggregate the model. info!("updating the global seed dictionary"); self.add_local_seed_dict(pk, local_seed_dict) .await .map_err(|err| { warn!("invalid local seed dictionary, ignoring update message"); err })?; info!("aggregating the masked model and scalar"); self.private.model_agg.aggregate(mask_object); Ok(()) } /// Adds a local seed dictionary to the global seed dictionary. /// /// # Error /// /// Fails if the local seed dict cannot be added due to a PET or [`StorageError`]. async fn add_local_seed_dict( &mut self, pk: &UpdateParticipantPublicKey, local_seed_dict: &LocalSeedDict, ) -> Result<(), RequestError> { self.shared .store .add_local_seed_dict(pk, local_seed_dict) .await? .into_inner() .map_err(RequestError::from) } /// Gets the global seed dict from the store. async fn seed_dict(&mut self) -> Result<(), UpdateError> { self.private.seed_dict = self .shared .store .seed_dict() .await .map_err(UpdateError::FetchSeedDict)? .ok_or(UpdateError::NoSeedDict)? .into(); Ok(()) } } #[cfg(test)] mod tests { use super::*; use anyhow::anyhow; use xaynet_core::{SeedDict, SumDict}; use crate::{ state_machine::{ coordinator::CoordinatorState, events::{EventPublisher, EventSubscriber, ModelUpdate}, tests::{ utils::{ assert_event_updated, enable_logging, init_shared, send_update_messages, send_update_messages_with_model, EventSnapshot, }, CoordinatorStateBuilder, EventBusBuilder, }, }, storage::{ tests::{ utils::{create_global_model, create_mask}, MockCoordinatorStore, MockModelStore, }, LocalSeedDictAdd, LocalSeedDictAddError, Store, }, }; fn events_from_sum_phase(state: &CoordinatorState) -> (EventPublisher, EventSubscriber) { EventBusBuilder::new(state) .broadcast_phase(PhaseName::Sum) .broadcast_sum_dict(DictionaryUpdate::New(Arc::new(SumDict::new()))) .broadcast_seed_dict(DictionaryUpdate::Invalidate) .broadcast_model(ModelUpdate::New(Arc::new(create_global_model(1)))) .build() } fn assert_after_phase_success( state_before: &CoordinatorState, events_before: &EventSnapshot, state_after: &CoordinatorState, events_after: &EventSnapshot, ) { assert_eq!(state_after, state_before); assert_event_updated(&events_after.phase, &events_before.phase); assert_event_updated(&events_after.seed_dict, &events_before.seed_dict); assert_eq!(events_after.keys, events_before.keys); assert_eq!(events_after.params, events_before.params); assert_eq!(events_after.phase.event, PhaseName::Update); assert_eq!(events_after.sum_dict, events_before.sum_dict); assert_eq!(events_after.model, events_before.model); } fn assert_after_phase_failure( state_before: &CoordinatorState, events_before: &EventSnapshot, state_after: &CoordinatorState, events_after: &EventSnapshot, ) { assert_eq!(state_after, state_before); assert_event_updated(&events_after.phase, &events_before.phase); assert_eq!(events_after.keys, events_before.keys); assert_eq!(events_after.params, events_before.params); assert_eq!(events_after.phase.event, PhaseName::Update); assert_eq!(events_after.sum_dict, events_before.sum_dict); assert_eq!(events_after.seed_dict, events_before.seed_dict); assert_eq!(events_after.model, events_before.model); } #[tokio::test] async fn test_update_to_sum2_phase() { // No Storage errors // lets pretend we come from the sum phase // // What should happen: // 1. broadcast Update phase // 2. accept 10 update messages // 3. fetch seed dict // 4. broadcast seed dict // 5. move into sum2 phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated // - the sum dict has been invalidated enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_add_local_seed_dict() .times(10) .returning(move |_, _| Ok(LocalSeedDictAdd(Ok(())))); cs.expect_seed_dict() .return_once(move || Ok(Some(SeedDict::new()))); let store = Store::new(cs, MockModelStore::new()); let state = CoordinatorStateBuilder::new() .with_round_id(1) .with_update_count_min(10) .with_update_count_max(10) .with_update_time_min(1) .build(); let (event_publisher, event_subscriber) = events_from_sum_phase(&state); let events_before_update = EventSnapshot::from(&event_subscriber); let state_before_update = state.clone(); let (shared, request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new(shared)); assert!(state_machine.is_update()); send_update_messages(10, request_tx.clone()); let state_machine = state_machine.next().await.unwrap(); let state_after_update = state_machine.as_ref().clone(); let events_after_update = EventSnapshot::from(&event_subscriber); assert_after_phase_success( &state_before_update, &events_before_update, &state_after_update, &events_after_update, ); assert!(state_machine.is_sum2()); } #[tokio::test] async fn test_update_to_sum2_fetch_seed_dict_failed() { // Storage errors // - seed_dict fails // // What should happen: // 1. broadcast Update phase // 2. accept 1 update message // 3. fetch seed dict (fails) // 4. move into error phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated // - the sum dict has been invalidated // - the seed dict has been broadcasted enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_add_local_seed_dict() .times(1) .returning(move |_, _| Ok(LocalSeedDictAdd(Ok(())))); cs.expect_seed_dict().return_once(move || Err(anyhow!(""))); let store = Store::new(cs, MockModelStore::new()); let state = CoordinatorStateBuilder::new() .with_round_id(1) .with_update_count_min(1) .with_update_count_max(1) .with_update_time_min(1) .with_update_time_max(5) .build(); let (event_publisher, event_subscriber) = events_from_sum_phase(&state); let events_before_update = EventSnapshot::from(&event_subscriber); let state_before_update = state.clone(); let (shared, request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new(shared)); assert!(state_machine.is_update()); send_update_messages(1, request_tx.clone()); let state_machine = state_machine.next().await.unwrap(); let state_after_update = state_machine.as_ref().clone(); let events_after_update = EventSnapshot::from(&event_subscriber); assert_after_phase_failure( &state_before_update, &events_before_update, &state_after_update, &events_after_update, ); assert!(state_machine.is_failure()); assert!(matches!( state_machine.into_failure_phase_state().private.error, PhaseError::Update(UpdateError::FetchSeedDict(_)) )) } #[tokio::test] async fn test_update_to_sum2_seed_dict_none() { // No Storage errors // // What should happen: // 1. broadcast Update phase // 2. accept 1 update message // 3. fetch seed dict (no storage error but the seed dict is None) // 4. move into error phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated // - the sum dict has been invalidated // - the seed dict has been broadcasted enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_add_local_seed_dict() .times(1) .returning(move |_, _| Ok(LocalSeedDictAdd(Ok(())))); cs.expect_seed_dict().return_once(move || Ok(None)); let store = Store::new(cs, MockModelStore::new()); let state = CoordinatorStateBuilder::new() .with_round_id(1) .with_update_count_min(1) .with_update_count_max(1) .with_update_time_min(1) .with_update_time_max(5) .build(); let (event_publisher, event_subscriber) = events_from_sum_phase(&state); let events_before_update = EventSnapshot::from(&event_subscriber); let state_before_update = state.clone(); let (shared, request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new(shared)); assert!(state_machine.is_update()); send_update_messages(1, request_tx.clone()); let state_machine = state_machine.next().await.unwrap(); let state_after_update = state_machine.as_ref().clone(); let events_after_update = EventSnapshot::from(&event_subscriber); assert_after_phase_failure( &state_before_update, &events_before_update, &state_after_update, &events_after_update, ); assert!(state_machine.is_failure()); assert!(matches!( state_machine.into_failure_phase_state().private.error, PhaseError::Update(UpdateError::NoSeedDict) )) } #[tokio::test] async fn test_aggregation_error() { // No Storage errors // // What should happen: // 1. broadcast Update phase // 2. reject 3 update messages (validation of the models fail due to an invalid length) // 3. accept 3 update messages // 4. fetch seed dict // 5. broadcast seed dict // 6. move into sum2 phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated // - the sum dict has been invalidated enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_add_local_seed_dict() .times(3) .returning(move |_, _| Ok(LocalSeedDictAdd(Ok(())))); cs.expect_seed_dict() .return_once(move || Ok(Some(SeedDict::new()))); let store = Store::new(cs, MockModelStore::new()); let state = CoordinatorStateBuilder::new() .with_round_id(1) .with_update_count_min(3) .with_update_count_max(3) .with_update_time_min(1) .build(); let (event_publisher, event_subscriber) = events_from_sum_phase(&state); let events_before_update = EventSnapshot::from(&event_subscriber); let state_before_update = state.clone(); let (shared, request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new(shared)); assert!(state_machine.is_update()); send_update_messages_with_model(3, request_tx.clone(), create_mask(2, 1)); send_update_messages(3, request_tx.clone()); let state_machine = state_machine.next().await.unwrap(); let state_after_update = state_machine.as_ref().clone(); let events_after_update = EventSnapshot::from(&event_subscriber); assert_after_phase_success( &state_before_update, &events_before_update, &state_after_update, &events_after_update, ); assert!(state_machine.is_sum2()); } #[tokio::test] async fn test_rejected_messages_pet_error() { // No Storage errors // // What should happen: // 1. broadcast Update phase // 2. reject 3 update messages (pet error LocalSeedDictAddError::LengthMisMatch) // 3. phase should timeout // 4. move into error phase // // What should not happen: // - the shared state has been changed // - the global model has been invalidated // - the sum dict has been invalidated // - the seed dict has been broadcasted enable_logging(); let mut cs = MockCoordinatorStore::new(); cs.expect_add_local_seed_dict() .times(3) .returning(move |_, _| { Ok(LocalSeedDictAdd(Err(LocalSeedDictAddError::LengthMisMatch))) }); let store = Store::new(cs, MockModelStore::new()); let state = CoordinatorStateBuilder::new() .with_round_id(1) .with_update_count_min(3) .with_update_count_max(3) .with_update_time_min(0) .with_update_time_max(2) .build(); let (event_publisher, event_subscriber) = events_from_sum_phase(&state); let events_before_update = EventSnapshot::from(&event_subscriber); let state_before_update = state.clone(); let (shared, request_tx) = init_shared(state, store, event_publisher); let state_machine = StateMachine::from(PhaseState::::new(shared)); assert!(state_machine.is_update()); send_update_messages(3, request_tx.clone()); let state_machine = state_machine.next().await.unwrap(); let state_after_update = state_machine.as_ref().clone(); let events_after_update = EventSnapshot::from(&event_subscriber); assert_after_phase_failure( &state_before_update, &events_before_update, &state_after_update, &events_after_update, ); assert!(state_machine.is_failure()); assert!(matches!( state_machine.into_failure_phase_state().private.error, PhaseError::PhaseTimeout(_) )) } } ================================================ FILE: rust/xaynet-server/src/state_machine/requests.rs ================================================ //! This module provides the the `StateMachine`, `Request`, `RequestSender` and `RequestReceiver` //! types. use std::{ pin::Pin, task::{Context, Poll}, }; use derive_more::From; use displaydoc::Display; use futures::{future::FutureExt, Stream}; use thiserror::Error; use tokio::sync::{mpsc, oneshot}; use tracing::{trace, Span}; use crate::storage::{LocalSeedDictAddError, MaskScoreIncrError, StorageError, SumPartAddError}; use xaynet_core::{ mask::MaskObject, message::{Message, Payload, Update}, LocalSeedDict, ParticipantPublicKey, SumParticipantEphemeralPublicKey, SumParticipantPublicKey, UpdateParticipantPublicKey, }; /// Errors which can occur while the state machine handles a request. #[derive(Debug, Display, Error)] pub enum RequestError { /// The message was rejected. MessageRejected, /// The message was discarded. MessageDiscarded, /// Invalid update: the model or scalar sent by the participant could not be aggregated. AggregationFailed, /// The request could not be processed due to an internal error: {0}. InternalError(&'static str), /// Storage request failed: {0}. CoordinatorStorage(#[from] StorageError), /// Adding a local seed dict to the seed dictionary failed: {0}. LocalSeedDictAdd(#[from] LocalSeedDictAddError), /// Adding a sum participant to the sum dictionary failed: {0}. SumPartAdd(#[from] SumPartAddError), /// Incrementing a mask score failed: {0}. MaskScoreIncr(#[from] MaskScoreIncrError), } /// A sum request. #[derive(Debug)] pub struct SumRequest { /// The public key of the participant. pub participant_pk: SumParticipantPublicKey, /// The ephemeral public key of the participant. pub ephm_pk: SumParticipantEphemeralPublicKey, } /// An update request. #[derive(Debug)] pub struct UpdateRequest { /// The public key of the participant. pub participant_pk: UpdateParticipantPublicKey, /// The local seed dict that contains the seed used to mask `masked_model`. pub local_seed_dict: LocalSeedDict, /// The masked model trained by the participant. pub masked_model: MaskObject, } /// A sum2 request. #[derive(Debug)] pub struct Sum2Request { /// The public key of the participant. pub participant_pk: ParticipantPublicKey, /// The model mask computed by the participant. pub model_mask: MaskObject, } /// A [`StateMachine`] request. /// /// [`StateMachine`]: crate::state_machine #[derive(Debug, From)] pub enum StateMachineRequest { Sum(SumRequest), Update(UpdateRequest), Sum2(Sum2Request), } impl From for StateMachineRequest { fn from(message: Message) -> Self { let participant_pk = message.participant_pk; match message.payload { Payload::Sum(sum) => StateMachineRequest::Sum(SumRequest { participant_pk, ephm_pk: sum.ephm_pk, }), Payload::Update(update) => { let Update { local_seed_dict, masked_model, .. } = update; StateMachineRequest::Update(UpdateRequest { participant_pk, local_seed_dict, masked_model, }) } Payload::Sum2(sum2) => StateMachineRequest::Sum2(Sum2Request { participant_pk, model_mask: sum2.model_mask, }), Payload::Chunk(_) => unimplemented!(), } } } /// A handle to send requests to the [`StateMachine`]. /// /// [`StateMachine`]: crate::state_machine #[derive(Clone, From, Debug)] pub struct RequestSender(mpsc::UnboundedSender<(StateMachineRequest, Span, ResponseSender)>); impl RequestSender { /// Sends a request to the [`StateMachine`]. /// /// # Errors /// Fails if the [`StateMachine`] has already shut down and the `Request` channel has been /// closed as a result. /// /// [`StateMachine`]: crate::state_machine pub async fn request(&self, req: StateMachineRequest, span: Span) -> Result<(), RequestError> { let (resp_tx, resp_rx) = oneshot::channel::>(); self.0.send((req, span, resp_tx)).map_err(|_| { RequestError::InternalError( "failed to send request to the state machine: state machine is shutting down", ) })?; resp_rx.await.map_err(|_| { RequestError::InternalError("failed to receive response from the state machine") })? } #[cfg(test)] pub fn is_closed(&self) -> bool { self.0.is_closed() } } /// A channel for sending the state machine to send the response to a /// [`StateMachineRequest`]. pub(in crate::state_machine) type ResponseSender = oneshot::Sender>; /// The receiver half of the `Request` channel that is used by the [`StateMachine`] to receive /// requests. /// /// [`StateMachine`]: crate::state_machine #[derive(From, Debug)] pub struct RequestReceiver(mpsc::UnboundedReceiver<(StateMachineRequest, Span, ResponseSender)>); impl Stream for RequestReceiver { type Item = (StateMachineRequest, Span, ResponseSender); fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { trace!("RequestReceiver: polling"); Pin::new(&mut self.get_mut().0).poll_recv(cx) } } impl RequestReceiver { /// Creates a new `Request` channel and returns the [`RequestReceiver`] as well as the /// [`RequestSender`] half. pub fn new() -> (Self, RequestSender) { let (tx, rx) = mpsc::unbounded_channel::<(StateMachineRequest, Span, ResponseSender)>(); let receiver = RequestReceiver::from(rx); let handle = RequestSender::from(tx); (receiver, handle) } /// Closes the `Request` channel. /// See [the `tokio` documentation][close] for more information. /// /// [close]: https://docs.rs/tokio/1.1.0/tokio/sync/mpsc/struct.UnboundedReceiver.html#method.close pub fn close(&mut self) { self.0.close() } /// Receives the next request. /// See [the `tokio` documentation][receive] for more information. /// /// [receive]: https://docs.rs/tokio/1.1.0/tokio/sync/mpsc/struct.UnboundedReceiver.html#method.recv pub async fn recv(&mut self) -> Option<(StateMachineRequest, Span, ResponseSender)> { self.0.recv().await } /// Try to retrieve the next request without blocking pub fn try_recv(&mut self) -> Option> { // Note `try_recv` (tokio 0.2.x) or `recv().now_or_never()` (tokio 1.x) // has an implementation bug where previously sent messages may not be // available immediately. // Related issue: https://github.com/tokio-rs/tokio/issues/3350 // At the moment it behaves like `try_recv`, but we should check if this // bug is a problem for us. But first we should replace the unbounded channel canal with // a bounded channel (XN-1162) self.0.recv().now_or_never() } } ================================================ FILE: rust/xaynet-server/src/state_machine/tests/coordinator_state.rs ================================================ use xaynet_core::{common::RoundSeed, crypto::EncryptKeyPair, mask::MaskConfig}; use crate::state_machine::coordinator::CoordinatorState; use super::utils::{mask_settings, model_settings, pet_settings}; pub struct CoordinatorStateBuilder { state: CoordinatorState, } #[allow(dead_code)] impl CoordinatorStateBuilder { pub fn new() -> Self { Self { state: CoordinatorState::new(pet_settings(), mask_settings(), model_settings()), } } pub fn build(self) -> CoordinatorState { self.state } pub fn with_keys(mut self, keys: EncryptKeyPair) -> Self { self.state.round_params.pk = keys.public; self.state.keys = keys; self } pub fn with_round_id(mut self, id: u64) -> Self { self.state.round_id = id; self } pub fn with_sum_probability(mut self, prob: f64) -> Self { self.state.round_params.sum = prob; self } pub fn with_update_probability(mut self, prob: f64) -> Self { self.state.round_params.update = prob; self } pub fn with_seed(mut self, seed: RoundSeed) -> Self { self.state.round_params.seed = seed; self } pub fn with_sum_count_min(mut self, min: u64) -> Self { self.state.sum.count.min = min; self } pub fn with_sum_count_max(mut self, max: u64) -> Self { self.state.sum.count.max = max; self } pub fn with_mask_config(mut self, mask_config: MaskConfig) -> Self { self.state.round_params.mask_config = mask_config.into(); self } pub fn with_update_count_min(mut self, min: u64) -> Self { self.state.update.count.min = min; self } pub fn with_update_count_max(mut self, max: u64) -> Self { self.state.update.count.max = max; self } pub fn with_sum2_count_min(mut self, min: u64) -> Self { self.state.sum2.count.min = min; self } pub fn with_sum2_count_max(mut self, max: u64) -> Self { self.state.sum2.count.max = max; self } pub fn with_model_length(mut self, model_length: usize) -> Self { self.state.round_params.model_length = model_length; self } pub fn with_sum_time_min(mut self, min: u64) -> Self { self.state.sum.time.min = min; self } pub fn with_sum_time_max(mut self, max: u64) -> Self { self.state.sum.time.max = max; self } pub fn with_update_time_min(mut self, min: u64) -> Self { self.state.update.time.min = min; self } pub fn with_update_time_max(mut self, max: u64) -> Self { self.state.update.time.max = max; self } pub fn with_sum2_time_min(mut self, min: u64) -> Self { self.state.sum2.time.min = min; self } pub fn with_sum2_time_max(mut self, max: u64) -> Self { self.state.sum2.time.max = max; self } } ================================================ FILE: rust/xaynet-server/src/state_machine/tests/event_bus.rs ================================================ use xaynet_core::{SeedDict, SumDict}; use crate::state_machine::{ coordinator::CoordinatorState, events::{DictionaryUpdate, EventPublisher, EventSubscriber, ModelUpdate}, phases::PhaseName, }; use super::{utils::EventSnapshot, CoordinatorStateBuilder, WARNING}; pub struct EventBusBuilder { event_publisher: EventPublisher, event_subscriber: EventSubscriber, } impl EventBusBuilder { pub fn new(state: &CoordinatorState) -> Self { let (event_publisher, event_subscriber) = EventPublisher::init( state.round_id, state.keys.clone(), state.round_params.clone(), PhaseName::Idle, ModelUpdate::Invalidate, ); Self { event_publisher, event_subscriber, } } pub fn broadcast_phase(mut self, phase: PhaseName) -> Self { self.event_publisher.broadcast_phase(phase); self } pub fn broadcast_model(mut self, update: ModelUpdate) -> Self { self.event_publisher.broadcast_model(update); self } pub fn broadcast_sum_dict(mut self, update: DictionaryUpdate) -> Self { self.event_publisher.broadcast_sum_dict(update); self } pub fn broadcast_seed_dict(mut self, update: DictionaryUpdate) -> Self { self.event_publisher.broadcast_seed_dict(update); self } pub fn build(self) -> (EventPublisher, EventSubscriber) { (self.event_publisher, self.event_subscriber) } } #[test] fn test_initial_events() { const PANIC_MESSAGE: &str = "the initial events have been changed."; let state = CoordinatorStateBuilder::new().build(); let (_, subscriber) = EventBusBuilder::new(&state).build(); let events = EventSnapshot::from(&subscriber); assert_eq!( events.phase.event, PhaseName::Idle, "{} {}", PANIC_MESSAGE, WARNING ); assert_eq!( events.model.event, ModelUpdate::Invalidate, "{} {}", PANIC_MESSAGE, WARNING ); assert_eq!( events.sum_dict.event, DictionaryUpdate::Invalidate, "{} {}", PANIC_MESSAGE, WARNING ); assert_eq!( events.seed_dict.event, DictionaryUpdate::Invalidate, "{} {}", PANIC_MESSAGE, WARNING ); } ================================================ FILE: rust/xaynet-server/src/state_machine/tests/impls.rs ================================================ use tracing::Span; use xaynet_core::message::Message; use crate::state_machine::{ coordinator::CoordinatorState, events::DictionaryUpdate, phases::{Failure, Idle, PhaseState, Shutdown, Sum, Sum2, Unmask, Update}, requests::{RequestError, RequestSender}, StateMachine, }; impl RequestSender { pub async fn msg(&self, msg: &Message) -> Result<(), RequestError> { self.request(msg.clone().into(), Span::none()).await } } impl StateMachine { pub fn is_idle(&self) -> bool { matches!(self, StateMachine::Idle(_)) } pub fn into_idle_phase_state(self) -> PhaseState { match self { StateMachine::Idle(state) => state, _ => panic!("not in idle state"), } } pub fn is_sum(&self) -> bool { matches!(self, StateMachine::Sum(_)) } pub fn into_sum_phase_state(self) -> PhaseState { match self { StateMachine::Sum(state) => state, _ => panic!("not in sum state"), } } pub fn is_update(&self) -> bool { matches!(self, StateMachine::Update(_)) } pub fn into_update_phase_state(self) -> PhaseState { match self { StateMachine::Update(state) => state, _ => panic!("not in update state"), } } pub fn is_sum2(&self) -> bool { matches!(self, StateMachine::Sum2(_)) } pub fn into_sum2_phase_state(self) -> PhaseState { match self { StateMachine::Sum2(state) => state, _ => panic!("not in sum2 state"), } } pub fn is_unmask(&self) -> bool { matches!(self, StateMachine::Unmask(_)) } pub fn into_unmask_phase_state(self) -> PhaseState { match self { StateMachine::Unmask(state) => state, _ => panic!("not in unmask state"), } } pub fn is_failure(&self) -> bool { matches!(self, StateMachine::Failure(_)) } pub fn into_failure_phase_state(self) -> PhaseState { match self { StateMachine::Failure(state) => state, _ => panic!("not in error state"), } } pub fn is_shutdown(&self) -> bool { matches!(self, StateMachine::Shutdown(_)) } pub fn into_shutdown_phase_state(self) -> PhaseState { match self { StateMachine::Shutdown(state) => state, _ => panic!("not in shutdown state"), } } } impl AsRef for StateMachine { fn as_ref(&self) -> &CoordinatorState { match self { StateMachine::Idle(state) => &state.shared.state, StateMachine::Sum(state) => &state.shared.state, StateMachine::Update(state) => &state.shared.state, StateMachine::Sum2(state) => &state.shared.state, StateMachine::Unmask(state) => &state.shared.state, StateMachine::Failure(state) => &state.shared.state, StateMachine::Shutdown(state) => &state.shared.state, } } } impl DictionaryUpdate { pub fn unwrap(self) -> std::sync::Arc { if let DictionaryUpdate::New(inner) = self { inner } else { panic!("DictionaryUpdate::Invalidate"); } } } ================================================ FILE: rust/xaynet-server/src/state_machine/tests/initializer.rs ================================================ //! State machine initialization test utilities. use serial_test::serial; #[cfg(feature = "model-persistence")] use crate::{ settings::RestoreSettings, state_machine::{ events::{DictionaryUpdate, ModelUpdate}, initializer::StateMachineInitializationError, phases::PhaseName, }, storage::tests::utils::create_global_model, storage::ModelStorage, }; use crate::{ state_machine::{ coordinator::CoordinatorState, initializer::StateMachineInitializer, tests::utils::{mask_settings, model_settings, pet_settings}, }, storage::{tests::init_store, CoordinatorStorage}, }; #[cfg(feature = "model-persistence")] #[tokio::test] #[serial] #[ignore] async fn integration_state_machine_initializer_no_restore() { let store = init_store().await; let smi = StateMachineInitializer::new( pet_settings(), mask_settings(), model_settings(), RestoreSettings { enable: false }, store, ); let (state_machine, _request_sender, event_subscriber) = smi.init().await.unwrap(); assert!(state_machine.is_idle()); let phase = event_subscriber.phase_listener().get_latest().event; assert!(matches!(phase, PhaseName::Idle)); let sum_dict = event_subscriber.sum_dict_listener().get_latest().event; assert!(matches!(sum_dict, DictionaryUpdate::Invalidate)); let seed_dict = event_subscriber.seed_dict_listener().get_latest().event; assert!(matches!(seed_dict, DictionaryUpdate::Invalidate)); let global_model = event_subscriber.model_listener().get_latest().event; assert!(matches!(global_model, ModelUpdate::Invalidate)); let round_id = event_subscriber.params_listener().get_latest().round_id; assert_eq!(round_id, 0); } #[cfg(feature = "model-persistence")] #[tokio::test] #[serial] #[ignore] async fn integration_state_machine_initializer_no_state() { let store = init_store().await; let smi = StateMachineInitializer::new( pet_settings(), mask_settings(), model_settings(), RestoreSettings { enable: true }, store, ); let (state_machine, _request_sender, event_subscriber) = smi.init().await.unwrap(); assert!(state_machine.is_idle()); let phase = event_subscriber.phase_listener().get_latest().event; assert!(matches!(phase, PhaseName::Idle)); let sum_dict = event_subscriber.sum_dict_listener().get_latest().event; assert!(matches!(sum_dict, DictionaryUpdate::Invalidate)); let seed_dict = event_subscriber.seed_dict_listener().get_latest().event; assert!(matches!(seed_dict, DictionaryUpdate::Invalidate)); let global_model = event_subscriber.model_listener().get_latest().event; assert!(matches!(global_model, ModelUpdate::Invalidate)); let round_id = event_subscriber.params_listener().get_latest().round_id; assert_eq!(round_id, 0); } #[cfg(feature = "model-persistence")] #[tokio::test] #[serial] #[ignore] async fn integration_state_machine_initializer_without_global_model() { let pet_settings = pet_settings(); let mask_settings = mask_settings(); let model_settings = model_settings(); // we change the round id to ensure that the state machine is // initialized with the coordinator state in the store // if we don't update the round_id we can't check if the state in the store was used or if the state was reset // because in both cases the round id will be 0 let mut store = init_store().await; let mut state = CoordinatorState::new(pet_settings, mask_settings, model_settings.clone()); let new_round_id = 5; state.round_id = new_round_id; store.set_coordinator_state(&state).await.unwrap(); let smi = StateMachineInitializer::new( pet_settings, mask_settings, model_settings, RestoreSettings { enable: true }, store, ); let (state_machine, _request_sender, event_subscriber) = smi.init().await.unwrap(); assert!(state_machine.is_idle()); let phase = event_subscriber.phase_listener().get_latest().event; assert!(matches!(phase, PhaseName::Idle)); let sum_dict = event_subscriber.sum_dict_listener().get_latest().event; assert!(matches!(sum_dict, DictionaryUpdate::Invalidate)); let seed_dict = event_subscriber.seed_dict_listener().get_latest().event; assert!(matches!(seed_dict, DictionaryUpdate::Invalidate)); let global_model = event_subscriber.model_listener().get_latest().event; assert!(matches!(global_model, ModelUpdate::Invalidate)); let round_id = event_subscriber.params_listener().get_latest().round_id; assert_eq!(round_id, new_round_id); } #[cfg(feature = "model-persistence")] #[tokio::test] #[serial] #[ignore] async fn integration_state_machine_initializer_with_global_model() { let pet_settings = pet_settings(); let mask_settings = mask_settings(); let model_settings = model_settings(); let mut store = init_store().await; let mut state = CoordinatorState::new(pet_settings, mask_settings, model_settings.clone()); let new_round_id = 7; state.round_id = new_round_id; store.set_coordinator_state(&state).await.unwrap(); // upload a global model and set the id let uploaded_global_model = create_global_model(state.round_params.model_length); let global_model_id = store .set_global_model( state.round_id, &state.round_params.seed, &uploaded_global_model, ) .await .unwrap(); store .set_latest_global_model_id(&global_model_id) .await .unwrap(); let smi = StateMachineInitializer::new( pet_settings, mask_settings, model_settings, RestoreSettings { enable: true }, store, ); let (state_machine, _request_sender, event_subscriber) = smi.init().await.unwrap(); assert!(state_machine.is_idle()); let phase = event_subscriber.phase_listener().get_latest().event; assert!(matches!(phase, PhaseName::Idle)); let sum_dict = event_subscriber.sum_dict_listener().get_latest().event; assert!(matches!(sum_dict, DictionaryUpdate::Invalidate)); let seed_dict = event_subscriber.seed_dict_listener().get_latest().event; assert!(matches!(seed_dict, DictionaryUpdate::Invalidate)); let global_model = event_subscriber.model_listener().get_latest().event; assert!( matches!(global_model, ModelUpdate::New(broadcasted_model) if uploaded_global_model == *broadcasted_model) ); let round_id = event_subscriber.params_listener().get_latest().round_id; assert_eq!(round_id, new_round_id); } #[cfg(feature = "model-persistence")] #[tokio::test] #[serial] #[ignore] async fn integration_state_machine_initializer_failed_because_of_wrong_size() { let pet_settings = pet_settings(); let mask_settings = mask_settings(); let model_settings = model_settings(); let mut store = init_store().await; let mut state = CoordinatorState::new(pet_settings, mask_settings, model_settings.clone()); let new_round_id = 9; state.round_id = new_round_id; store.set_coordinator_state(&state).await.unwrap(); // upload a global model with a wrong model length and set the id let uploaded_global_model = create_global_model(state.round_params.model_length + 10); let global_model_id = store .set_global_model( state.round_id, &state.round_params.seed, &uploaded_global_model, ) .await .unwrap(); store .set_latest_global_model_id(&global_model_id) .await .unwrap(); let smi = StateMachineInitializer::new( pet_settings, mask_settings, model_settings, RestoreSettings { enable: true }, store, ); let result = smi.init().await; assert!(matches!( result, Err(StateMachineInitializationError::GlobalModelInvalid(_)) )); } #[cfg(feature = "model-persistence")] #[tokio::test] #[serial] #[ignore] async fn integration_state_machine_initializer_failed_to_find_global_model() { let pet_settings = pet_settings(); let mask_settings = mask_settings(); let model_settings = model_settings(); let mut store = init_store().await; let mut state = CoordinatorState::new(pet_settings, mask_settings, model_settings.clone()); let new_round_id = 11; state.round_id = new_round_id; store.set_coordinator_state(&state).await.unwrap(); // set a model id but don't store a model let global_model_id = "1_412957050209fcfa733b1fb4ad51f321"; store .set_latest_global_model_id(global_model_id) .await .unwrap(); let smi = StateMachineInitializer::new( pet_settings, mask_settings, model_settings, RestoreSettings { enable: true }, store, ); let result = smi.init().await; assert!(matches!( result, Err(StateMachineInitializationError::GlobalModelUnavailable(_)) )); } #[tokio::test] #[serial] #[ignore] async fn integration_state_machine_initializer_reset_state() { let pet_settings = pet_settings(); let mask_settings = mask_settings(); let model_settings = model_settings(); let mut store = init_store().await; let state = CoordinatorState::new(pet_settings, mask_settings, model_settings.clone()); store.set_coordinator_state(&state).await.unwrap(); let mut smi = StateMachineInitializer::new( pet_settings, mask_settings, model_settings, #[cfg(feature = "model-persistence")] RestoreSettings { enable: true }, store.clone(), ); smi.from_settings().await.unwrap(); assert!(store.coordinator_state().await.unwrap().is_none()); assert!(store.sum_dict().await.unwrap().is_none()); assert!(store.seed_dict().await.unwrap().is_none()); assert!(store.best_masks().await.unwrap().is_none()); assert!(store.latest_global_model_id().await.unwrap().is_none()); assert_eq!(store.number_of_unique_masks().await.unwrap(), 0); } ================================================ FILE: rust/xaynet-server/src/state_machine/tests/mod.rs ================================================ //! State machine test utilities. pub mod coordinator_state; pub mod event_bus; pub mod impls; pub mod initializer; pub mod utils; pub use coordinator_state::CoordinatorStateBuilder; pub use event_bus::EventBusBuilder; const WARNING: &str = "All state machine tests were written assuming these initial values. First, carefully check the correctness of the state machine test before finally changing these values."; ================================================ FILE: rust/xaynet-server/src/state_machine/tests/utils.rs ================================================ //! State machine misc test utilities. use std::fmt::Debug; use tokio::sync::mpsc; use tracing_subscriber::{EnvFilter, FmtSubscriber}; use xaynet_core::{ common::RoundParameters, crypto::{ByteObject, EncryptKeyPair, PublicEncryptKey, PublicSigningKey}, mask::{BoundType, DataType, GroupType, MaskObject, ModelType}, message::{Message, Sum, Sum2, Update}, LocalSeedDict, ParticipantTaskSignature, SeedDict, SumDict, }; use crate::{ settings::{ MaskSettings, ModelSettings, PetSettings, PetSettingsCount, PetSettingsSum, PetSettingsSum2, PetSettingsTime, PetSettingsUpdate, }, state_machine::{ coordinator::CoordinatorState, events::{DictionaryUpdate, Event, EventPublisher, EventSubscriber, ModelUpdate}, phases::{PhaseName, Shared}, requests::{RequestReceiver, RequestSender}, }, storage::tests::utils::create_mask, }; use super::WARNING; pub fn enable_logging() { let _fmt_subscriber = FmtSubscriber::builder() .with_env_filter(EnvFilter::from_default_env()) .with_ansi(true) .try_init(); } pub fn pet_settings() -> PetSettings { PetSettings { sum: PetSettingsSum { prob: 0.4, count: PetSettingsCount { min: 1, max: 100 }, time: PetSettingsTime { min: 1, max: 2 }, }, update: PetSettingsUpdate { prob: 0.5, count: PetSettingsCount { min: 3, max: 1000 }, time: PetSettingsTime { min: 1, max: 2 }, }, sum2: PetSettingsSum2 { count: PetSettingsCount { min: 1, max: 100 }, time: PetSettingsTime { min: 1, max: 2 }, }, } } pub fn mask_settings() -> MaskSettings { MaskSettings { group_type: GroupType::Prime, data_type: DataType::F32, bound_type: BoundType::B0, model_type: ModelType::M3, } } pub fn model_settings() -> ModelSettings { ModelSettings { length: 1 } } pub fn init_shared( coordinator_state: CoordinatorState, store: T, event_publisher: EventPublisher, ) -> (Shared, RequestSender) { let (request_rx, request_tx) = RequestReceiver::new(); ( Shared::new(coordinator_state, event_publisher, request_rx, store), request_tx, ) } #[derive(Debug, Clone, PartialEq)] pub struct EventSnapshot { pub keys: Event, pub params: Event, pub phase: Event, pub model: Event, pub sum_dict: Event>, pub seed_dict: Event>, } impl From<&EventSubscriber> for EventSnapshot { fn from(event_subscriber: &EventSubscriber) -> Self { Self { keys: event_subscriber.keys_listener().get_latest(), params: event_subscriber.params_listener().get_latest(), phase: event_subscriber.phase_listener().get_latest(), model: event_subscriber.model_listener().get_latest(), sum_dict: event_subscriber.sum_dict_listener().get_latest(), seed_dict: event_subscriber.seed_dict_listener().get_latest(), } } } pub fn assert_event_updated_with_id(event1: &Event, event2: &Event) { assert_ne!(event1.round_id, event2.round_id); assert_ne!(event1.event, event2.event); } pub fn assert_event_updated(event1: &Event, event2: &Event) { assert_eq!(event1.round_id, event2.round_id); assert_ne!(event1.event, event2.event); } pub fn compose_sum_message() -> Message { let payload = Sum { sum_signature: ParticipantTaskSignature::zeroed(), ephm_pk: PublicEncryptKey::zeroed(), }; Message::new_sum( PublicSigningKey::zeroed(), PublicEncryptKey::zeroed(), payload, ) } pub fn compose_update_message(masked_model: MaskObject) -> Message { let payload = Update { sum_signature: ParticipantTaskSignature::zeroed(), update_signature: ParticipantTaskSignature::zeroed(), masked_model, local_seed_dict: LocalSeedDict::new(), }; Message::new_update( PublicSigningKey::zeroed(), PublicEncryptKey::zeroed(), payload, ) } pub fn compose_sum2_message() -> Message { let payload = Sum2 { sum_signature: ParticipantTaskSignature::zeroed(), model_mask: create_mask(1, 1), }; Message::new_sum2( PublicSigningKey::zeroed(), PublicEncryptKey::zeroed(), payload, ) } pub fn send_sum_messages(n: u32, request_tx: RequestSender) { for _ in 0..n { let request = request_tx.clone(); tokio::spawn(async move { request.msg(&compose_sum_message()).await }); } } #[allow(dead_code)] pub fn send_sum_messages_with_latch(n: u32, request_tx: RequestSender, latch: Latch) { for _ in 0..n { let request = request_tx.clone(); let l = latch.clone(); tokio::spawn(async move { let _ = request.msg(&compose_sum_message()).await; l.release(); }); } } pub fn send_sum2_messages(n: u32, request_tx: RequestSender) { for _ in 0..n { let request = request_tx.clone(); tokio::spawn(async move { request.msg(&compose_sum2_message()).await }); } } pub fn send_update_messages(n: u32, request_tx: RequestSender) { let default_model = create_mask(1, 1); for _ in 0..n { let request = request_tx.clone(); let masked_model = default_model.clone(); tokio::spawn(async move { request.msg(&compose_update_message(masked_model)).await }); } } pub fn send_update_messages_with_model( n: u32, request_tx: RequestSender, masked_model: MaskObject, ) { for _ in 0..n { let request = request_tx.clone(); let moved_masked_model = masked_model.clone(); tokio::spawn(async move { request .msg(&compose_update_message(moved_masked_model)) .await }); } } #[allow(dead_code)] pub struct Readiness(mpsc::Receiver<()>); #[allow(dead_code)] #[derive(Clone)] pub struct Latch(mpsc::Sender<()>); #[allow(dead_code)] impl Readiness { pub fn new() -> (Readiness, Latch) { let (sender, receiver) = mpsc::channel(1); (Readiness(receiver), Latch(sender)) } pub async fn is_ready(&mut self) { let _ = self.0.recv().await; } } impl Latch { /// Releases this readiness latch. pub fn release(self) { drop(self); } } #[test] fn test_initial_settings() { let pet = PetSettings { sum: PetSettingsSum { prob: 0.4, count: PetSettingsCount { min: 1, max: 100 }, time: PetSettingsTime { min: 1, max: 2 }, }, update: PetSettingsUpdate { prob: 0.5, count: PetSettingsCount { min: 3, max: 1000 }, time: PetSettingsTime { min: 1, max: 2 }, }, sum2: PetSettingsSum2 { count: PetSettingsCount { min: 1, max: 100 }, time: PetSettingsTime { min: 1, max: 2 }, }, }; assert_eq!( pet, pet_settings(), "the initial PetSettings have been changed. {}", WARNING ); let mask = MaskSettings { group_type: GroupType::Prime, data_type: DataType::F32, bound_type: BoundType::B0, model_type: ModelType::M3, }; assert_eq!( mask, mask_settings(), "the initial MaskSettings have been changed. {}", WARNING ); let model = ModelSettings { length: 1 }; assert_eq!( model, model_settings(), "the initial ModelSettings have been changed. {}", WARNING ); } ================================================ FILE: rust/xaynet-server/src/storage/coordinator_storage/mod.rs ================================================ //! Storage backends to manage the coordinator state. pub mod redis; ================================================ FILE: rust/xaynet-server/src/storage/coordinator_storage/redis/impls.rs ================================================ use std::convert::TryFrom; use derive_more::{From, Into}; use paste::paste; use redis::{ErrorKind, FromRedisValue, RedisError, RedisResult, RedisWrite, ToRedisArgs, Value}; use serde::{Deserialize, Serialize}; use crate::{ state_machine::coordinator::CoordinatorState, storage::{ LocalSeedDictAdd, LocalSeedDictAddError, MaskScoreIncr, MaskScoreIncrError, SumPartAdd, SumPartAddError, }, }; use xaynet_core::{ crypto::{ByteObject, PublicEncryptKey, PublicSigningKey}, mask::{EncryptedMaskSeed, MaskObject}, LocalSeedDict, }; pub fn redis_type_error(desc: &'static str, details: Option) -> RedisError { if let Some(details) = details { RedisError::from((ErrorKind::TypeError, desc, details)) } else { RedisError::from((ErrorKind::TypeError, desc)) } } fn error_code_type_error(response: &Value) -> RedisError { redis_type_error( "Response status not valid integer", Some(format!("Response was {:?}", response)), ) } /// Implements ['FromRedisValue'] and ['ToRedisArgs'] for types that implement ['ByteObject']. /// The Redis traits as well as the crypto types are both defined in foreign crates. /// To bypass the restrictions of orphan rule, we use `Newtypes` for the crypto types. /// /// Each crypto type has two `Newtypes`, one for reading and one for writing. /// The difference between `Read` and `Write` is that the write `Newtype` does not take the /// ownership of the value but only a reference. This allows us to use references in the /// [`Client`] methods. The `Read` Newtype also implements [`ToRedisArgs`] to reduce the /// conversion overhead that you would get if you wanted to reuse a `Read` value for another /// Redis query. /// /// Example: /// /// ```compile_fail /// let sum_pks: Vec = self.connection.hkeys("sum_dict").await?; /// for sum_pk in sum_pks { /// let sum_pk_seed_dict: HashMap /// = self.connection.hgetall(&sum_pk).await?; // no need to convert sum_pk from PublicSigningKeyRead to PublicSigningKeyWrite /// } /// ``` /// /// [`Client`]: crate::storage::redis::Client macro_rules! impl_byte_object_redis_traits { ($ty: ty) => { paste! { #[derive(Into, Hash, Eq, PartialEq)] pub(crate) struct [<$ty Read>]($ty); impl FromRedisValue for [<$ty Read>] { fn from_redis_value(v: &Value) -> RedisResult<[<$ty Read>]> { match *v { Value::Data(ref bytes) => { let inner = <$ty>::from_slice(bytes).ok_or_else(|| { redis_type_error(concat!("Invalid ", stringify!($ty)), None) })?; Ok([<$ty Read>](inner)) } _ => Err(redis_type_error( concat!("Response not ", stringify!($ty), " compatible"), None, )), } } } impl ToRedisArgs for [<$ty Read>] { fn write_redis_args(&self, out: &mut W) where W: ?Sized + RedisWrite, { self.0.as_slice().write_redis_args(out) } } #[derive(From)] pub(crate) struct [<$ty Write>]<'a>(&'a $ty); impl ToRedisArgs for [<$ty Write>]<'_> { fn write_redis_args(&self, out: &mut W) where W: ?Sized + RedisWrite, { self.0.as_slice().write_redis_args(out) } } } }; } impl_byte_object_redis_traits!(PublicEncryptKey); impl_byte_object_redis_traits!(PublicSigningKey); impl_byte_object_redis_traits!(EncryptedMaskSeed); /// Implements ['FromRedisValue'] and ['ToRedisArgs'] for types that implement /// ['Serialize`] and [`Deserialize']. The data is de/serialized via bincode. /// /// # Panics /// /// `write_redis_args` will panic if the data cannot be serialized with `bincode` /// /// More information about what can cause a panic in bincode: /// - https://github.com/servo/bincode/issues/293 /// - https://github.com/servo/bincode/issues/255 /// - https://github.com/servo/bincode/issues/130#issuecomment-284641263 macro_rules! impl_bincode_redis_traits { ($ty: ty) => { impl FromRedisValue for $ty { fn from_redis_value(v: &Value) -> RedisResult<$ty> { match *v { Value::Data(ref bytes) => bincode::deserialize(bytes) .map_err(|e| redis_type_error("Invalid data", Some(e.to_string()))), _ => Err(redis_type_error("Response not bincode compatible", None)), } } } impl ToRedisArgs for $ty { fn write_redis_args(&self, out: &mut W) where W: ?Sized + RedisWrite, { let data = bincode::serialize(self).unwrap(); data.write_redis_args(out) } } }; } // CoordinatorState is pretty straightforward: // - all the sequences have known length ( // - no untagged enum // so bincode will not panic. impl_bincode_redis_traits!(CoordinatorState); #[derive(From, Into, Serialize, Deserialize)] pub(crate) struct MaskObjectRead(MaskObject); impl_bincode_redis_traits!(MaskObjectRead); #[derive(From, Serialize)] pub(crate) struct MaskObjectWrite<'a>(&'a MaskObject); impl ToRedisArgs for MaskObjectWrite<'_> { fn write_redis_args(&self, out: &mut W) where W: ?Sized + RedisWrite, { let data = bincode::serialize(self).unwrap(); data.write_redis_args(out) } } #[derive(From)] pub(crate) struct LocalSeedDictWrite<'a>(&'a LocalSeedDict); impl ToRedisArgs for LocalSeedDictWrite<'_> { fn write_redis_args(&self, out: &mut W) where W: ?Sized + RedisWrite, { let args: Vec<(PublicSigningKeyWrite, EncryptedMaskSeedWrite)> = self .0 .iter() .map(|(pk, seed)| { ( PublicSigningKeyWrite::from(pk), EncryptedMaskSeedWrite::from(seed), ) }) .collect(); args.write_redis_args(out) } } impl FromRedisValue for LocalSeedDictAdd { fn from_redis_value(v: &Value) -> RedisResult { match *v { Value::Int(0) => Ok(LocalSeedDictAdd(Ok(()))), Value::Int(error_code) => match LocalSeedDictAddError::try_from(error_code) { Ok(error_variant) => Ok(LocalSeedDictAdd(Err(error_variant))), Err(_) => Err(error_code_type_error(v)), }, _ => Err(error_code_type_error(v)), } } } impl FromRedisValue for SumPartAdd { fn from_redis_value(v: &Value) -> RedisResult { match *v { Value::Int(1) => Ok(SumPartAdd(Ok(()))), Value::Int(error_code) => match SumPartAddError::try_from(error_code) { Ok(error_variant) => Ok(SumPartAdd(Err(error_variant))), Err(_) => Err(error_code_type_error(v)), }, _ => Err(error_code_type_error(v)), } } } impl FromRedisValue for MaskScoreIncr { fn from_redis_value(v: &Value) -> RedisResult { match *v { Value::Int(0) => Ok(MaskScoreIncr(Ok(()))), Value::Int(error_code) => match MaskScoreIncrError::try_from(error_code) { Ok(error_variant) => Ok(MaskScoreIncr(Err(error_variant))), Err(_) => Err(error_code_type_error(v)), }, _ => Err(error_code_type_error(v)), } } } #[cfg(test)] #[derive(derive_more::Deref)] pub struct SumDictDelete(Result<(), SumDictDeleteError>); #[cfg(test)] impl SumDictDelete { pub fn into_inner(self) -> Result<(), SumDictDeleteError> { self.0 } } #[cfg(test)] #[derive(thiserror::Error, Debug, num_enum::TryFromPrimitive)] #[repr(i64)] pub enum SumDictDeleteError { #[error("sum participant does not exist")] DoesNotExist = 0, } #[cfg(test)] impl FromRedisValue for SumDictDelete { fn from_redis_value(v: &Value) -> RedisResult { match *v { Value::Int(1) => Ok(SumDictDelete(Ok(()))), Value::Int(error_code) => match SumDictDeleteError::try_from(error_code) { Ok(error_variant) => Ok(SumDictDelete(Err(error_variant))), Err(_) => Err(error_code_type_error(v)), }, _ => Err(error_code_type_error(v)), } } } ================================================ FILE: rust/xaynet-server/src/storage/coordinator_storage/redis/mod.rs ================================================ //! A Redis [`CoordinatorStorage`] backend. //! //! # Redis Data Model //! //!```text //! { //! // Coordinator state //! "coordinator_state": "...", // bincode encoded string //! // Sum dict //! "sum_dict": { // hash //! "SumParticipantPublicKey_1": SumParticipantEphemeralPublicKey_1, //! "SumParticipantPublicKey_2": SumParticipantEphemeralPublicKey_2 //! }, //! // Seed dict //! "update_participants": [ // set //! UpdateParticipantPublicKey_1, //! UpdateParticipantPublicKey_2 //! ], //! "SumParticipantPublicKey_1": { // hash //! "UpdateParticipantPublicKey_1": EncryptedMaskSeed, //! "UpdateParticipantPublicKey_2": EncryptedMaskSeed //! }, //! "SumParticipantPublicKey_2": { //! "UpdateParticipantPublicKey_1": EncryptedMaskSeed, //! "UpdateParticipantPublicKey_2": EncryptedMaskSeed //! }, //! // Mask dict //! "mask_submitted": [ // set //! SumParticipantPublicKey_1, //! SumParticipantPublicKey_2 //! ], //! "mask_dict": [ // sorted set //! (mask_object_1, 2), // (mask: bincode encoded string, score/counter: number) //! (mask_object_2, 1) //! ], //! "latest_global_model_id": global_model_id //! } //! ``` pub(in crate::storage) mod impls; use std::collections::HashMap; use async_trait::async_trait; use redis::{aio::ConnectionManager, AsyncCommands, IntoConnectionInfo, Pipeline, Script}; pub use redis::{RedisError, RedisResult}; use tracing::debug; use self::impls::{ EncryptedMaskSeedRead, LocalSeedDictWrite, MaskObjectRead, MaskObjectWrite, PublicEncryptKeyRead, PublicEncryptKeyWrite, PublicSigningKeyRead, PublicSigningKeyWrite, }; use crate::{ state_machine::coordinator::CoordinatorState, storage::{ CoordinatorStorage, LocalSeedDictAdd, MaskScoreIncr, StorageError, StorageResult, SumPartAdd, }, }; use xaynet_core::{ mask::MaskObject, LocalSeedDict, SeedDict, SumDict, SumParticipantEphemeralPublicKey, SumParticipantPublicKey, UpdateParticipantPublicKey, }; /// Redis client. #[derive(Clone)] pub struct Client { connection: ConnectionManager, } fn to_storage_err(e: RedisError) -> StorageError { anyhow::anyhow!(e) } impl Client { /// Creates a new Redis client. /// /// `url` to which Redis instance the client should connect to. /// The URL format is `redis://[][:@][:port][/]`. /// /// The [`Client`] uses a [`ConnectionManager`] that automatically reconnects /// if the connection is dropped. pub async fn new(url: T) -> Result { let client = redis::Client::open(url)?; let connection = client.get_tokio_connection_manager().await?; Ok(Self { connection }) } async fn create_flush_dicts_pipeline(&mut self) -> RedisResult { // https://redis.io/commands/hkeys // > Return value: // Array reply: list of fields in the hash, or an empty list when key does not exist. let sum_pks: Vec = self.connection.hkeys("sum_dict").await?; let mut pipe = redis::pipe(); // https://redis.io/commands/del // > Return value: // The number of keys that were removed. // // Returns `0` if the key does not exist. // We ignore the return value because we are not interested in it. // delete sum dict pipe.del("sum_dict").ignore(); // delete seed dict pipe.del("update_participants").ignore(); for sum_pk in sum_pks { pipe.del(sum_pk).ignore(); } // delete mask dict pipe.del("mask_submitted").ignore(); pipe.del("mask_dict").ignore(); Ok(pipe) } } #[async_trait] impl CoordinatorStorage for Client { async fn set_coordinator_state(&mut self, state: &CoordinatorState) -> StorageResult<()> { debug!("set coordinator state"); // https://redis.io/commands/set // > Set key to hold the string value. If key already holds a value, // it is overwritten, regardless of its type. // Possible return value in our case: // > Simple string reply: OK if SET was executed correctly. self.connection .set("coordinator_state", state) .await .map_err(to_storage_err) } async fn coordinator_state(&mut self) -> StorageResult> { // https://redis.io/commands/get // > Get the value of key. If the key does not exist the special value nil is returned. // An error is returned if the value stored at key is not a string, because GET only // handles string values. // > Return value // Bulk string reply: the value of key, or nil when key does not exist. self.connection .get("coordinator_state") .await .map_err(to_storage_err) } async fn add_sum_participant( &mut self, pk: &SumParticipantPublicKey, ephm_pk: &SumParticipantEphemeralPublicKey, ) -> StorageResult { debug!("add sum participant with pk {:?}", pk); // https://redis.io/commands/hsetnx // > If field already exists, this operation has no effect. // > Return value // Integer reply, specifically: // 1 if field is a new field in the hash and value was set. // 0 if field already exists in the hash and no operation was performed. self.connection .hset_nx( "sum_dict", PublicSigningKeyWrite::from(pk), PublicEncryptKeyWrite::from(ephm_pk), ) .await .map_err(to_storage_err) } async fn sum_dict(&mut self) -> StorageResult> { debug!("get sum dictionary"); // https://redis.io/commands/hgetall // > Return value // Array reply: list of fields and their values stored in the hash, or an empty // list when key does not exist. let reply: Vec<(PublicSigningKeyRead, PublicEncryptKeyRead)> = self .connection .hgetall("sum_dict") .await .map_err(to_storage_err)?; if reply.is_empty() { return Ok(None); }; let sum_dict = reply .into_iter() .map(|(pk, ephm_pk)| (pk.into(), ephm_pk.into())) .collect(); Ok(Some(sum_dict)) } async fn add_local_seed_dict( &mut self, update_pk: &UpdateParticipantPublicKey, local_seed_dict: &LocalSeedDict, ) -> StorageResult { debug!( "update seed dictionary for update participant with pk {:?}", update_pk ); let script = Script::new( r#" -- lua lists (tables) start at 1 local update_pk = ARGV[1] -- check if the local seed dict has the same length as the sum_dict -- KEYS is a list (table) of key value pairs ([sum_pk_1, seed_1, sum_pk_2, seed_2, ...]) local seed_dict_len = #KEYS / 2 local sum_dict_len = redis.call("HLEN", "sum_dict") if seed_dict_len ~= sum_dict_len then return -1 end -- check if all pks of the local seed dict exists in sum_dict for i = 1, #KEYS, 2 do local exist_in_sum_dict = redis.call("HEXISTS", "sum_dict", KEYS[i]) if exist_in_sum_dict == 0 then return -2 end end -- check if the update pk already exists (i.e. the local seed dict has already been submitted) local exist_in_seed_dict = redis.call("SADD", "update_participants", update_pk) -- SADD returns 0 if the key already exists if exist_in_seed_dict == 0 then return -3 end -- update the seed dict for i = 1, #KEYS, 2 do local exist_in_update_seed_dict = redis.call("HSETNX", KEYS[i], update_pk, KEYS[i + 1]) -- HSETNX returns 0 if the update pk already exists if exist_in_update_seed_dict == 0 then -- This condition should never apply. -- If this condition is true, it is an indication that the data in redis is corrupted. return -4 end end return 0 "#, ); script .key(LocalSeedDictWrite::from(local_seed_dict)) .arg(PublicSigningKeyWrite::from(update_pk)) .invoke_async(&mut self.connection) .await .map_err(to_storage_err) } /// # Note /// This method is **not** an atomic operation. async fn seed_dict(&mut self) -> StorageResult> { debug!("get seed dictionary"); // https://redis.io/commands/hkeys // > Return value: // Array reply: list of fields in the hash, or an empty list when key does not exist. let sum_pks: Vec = self.connection.hkeys("sum_dict").await?; if sum_pks.is_empty() { return Ok(None); }; let mut seed_dict: SeedDict = SeedDict::new(); for sum_pk in sum_pks { // https://redis.io/commands/hgetall // > Return value // Array reply: list of fields and their values stored in the hash, or an empty // list when key does not exist. let sum_pk_seed_dict: HashMap = self.connection.hgetall(&sum_pk).await?; seed_dict.insert( sum_pk.into(), sum_pk_seed_dict .into_iter() .map(|(pk, seed)| (pk.into(), seed.into())) .collect(), ); } Ok(Some(seed_dict)) } /// The maximum length of a serialized mask is 512 Megabytes. async fn incr_mask_score( &mut self, sum_pk: &SumParticipantPublicKey, mask: &MaskObject, ) -> StorageResult { debug!("increment mask count"); let script = Script::new( r#" -- lua lists (tables) start at 1 local sum_pk = ARGV[1] -- check if the client participated in sum phase -- -- Note: we cannot delete the sum_pk in the sum_dict because we -- need the sum_dict later to delete the seed_dict local sum_pk_exist = redis.call("HEXISTS", "sum_dict", sum_pk) if sum_pk_exist == 0 then return -1 end -- check if sum participant has not already submitted a mask local mask_already_submitted = redis.call("SADD", "mask_submitted", sum_pk) -- SADD returns 0 if the key already exists if mask_already_submitted == 0 then return -2 end redis.call("ZINCRBY", "mask_dict", 1, KEYS[1]) return 0 "#, ); script .key(MaskObjectWrite::from(mask)) .arg(PublicSigningKeyWrite::from(sum_pk)) .invoke_async(&mut self.connection) .await .map_err(to_storage_err) } async fn best_masks(&mut self) -> StorageResult>> { debug!("get best masks"); // https://redis.io/commands/zrevrangebyscore // > Return value: // Array reply: list of elements in the specified range (optionally with their scores, // in case the WITHSCORES option is given). let reply: Vec<(MaskObjectRead, u64)> = self .connection .zrevrange_withscores("mask_dict", 0, 1) .await?; let result = match reply.is_empty() { true => None, _ => { let masks = reply .into_iter() .map(|(mask, count)| (mask.into(), count)) .collect(); Some(masks) } }; Ok(result) } async fn number_of_unique_masks(&mut self) -> StorageResult { debug!("get number of unique masks"); // https://redis.io/commands/zcount // > Return value: // Integer reply: the number of elements in the specified score range. self.connection .zcount("mask_dict", "-inf", "+inf") .await .map_err(to_storage_err) } /// # Note /// This method is **not** an atomic operation. async fn delete_coordinator_data(&mut self) -> StorageResult<()> { debug!("flush coordinator data"); let mut pipe = self.create_flush_dicts_pipeline().await?; pipe.del("coordinator_state").ignore(); pipe.del("latest_global_model_id").ignore(); pipe.atomic() .query_async(&mut self.connection) .await .map_err(to_storage_err) } /// # Note /// This method is **not** an atomic operation. async fn delete_dicts(&mut self) -> StorageResult<()> { debug!("flush all dictionaries"); let mut pipe = self.create_flush_dicts_pipeline().await?; pipe.atomic() .query_async(&mut self.connection) .await .map_err(to_storage_err) } async fn set_latest_global_model_id(&mut self, global_model_id: &str) -> StorageResult<()> { debug!("set latest global model with id {}", global_model_id); // https://redis.io/commands/set // > Set key to hold the string value. If key already holds a value, // it is overwritten, regardless of its type. // Possible return value in our case: // > Simple string reply: OK if SET was executed correctly. self.connection .set("latest_global_model_id", global_model_id) .await .map_err(to_storage_err) } async fn latest_global_model_id(&mut self) -> StorageResult> { debug!("get latest global model id"); // https://redis.io/commands/get // > Get the value of key. If the key does not exist the special value nil is returned. // An error is returned if the value stored at key is not a string, because GET only // handles string values. // > Return value // Bulk string reply: the value of key, or nil when key does not exist. self.connection .get("latest_global_model_id") .await .map_err(to_storage_err) } async fn is_ready(&mut self) -> StorageResult<()> { // https://redis.io/commands/ping redis::cmd("PING") .query_async(&mut self.connection) .await .map_err(to_storage_err) } } #[cfg(test)] // Functions that are not needed in the state machine but handy for testing. impl Client { // Removes an entry in the [`SumDict`]. // // Returns [`SumDictDelete(Ok(()))`] if field was deleted or // [`SumDictDelete(Err(SumDictDeleteError::DoesNotExist)`] if field does not exist. pub async fn remove_sum_dict_entry( &mut self, pk: &SumParticipantPublicKey, ) -> RedisResult { // https://redis.io/commands/hdel // > Return value // Integer reply: the number of fields that were removed from the hash, // not including specified but non existing fields. self.connection .hdel("sum_dict", PublicSigningKeyWrite::from(pk)) .await } // Returns the length of the [`SumDict`]. pub async fn sum_dict_len(&mut self) -> RedisResult { // https://redis.io/commands/hlen // > Return value // Integer reply: number of fields in the hash, or 0 when key does not exist. self.connection.hlen("sum_dict").await } // Returns the [`SumParticipantPublicKey`] of the [`SumDict`] or an empty list when the // [`SumDict`] does not exist. pub async fn sum_pks( &mut self, ) -> RedisResult> { // https://redis.io/commands/hkeys // > Return value: // Array reply: list of fields in the hash, or an empty list when key does not exist. let result: std::collections::HashSet = self.connection.hkeys("sum_dict").await?; let sum_pks = result.into_iter().map(|pk| pk.into()).collect(); Ok(sum_pks) } // Removes an update pk from the the `update_participants` set. pub async fn remove_update_participant( &mut self, update_pk: &UpdateParticipantPublicKey, ) -> RedisResult { self.connection .srem( "update_participants", PublicSigningKeyWrite::from(update_pk), ) .await } pub async fn mask_submitted_set(&mut self) -> RedisResult> { let result: Vec = self.connection.smembers("update_submitted").await?; let sum_pks = result.into_iter().map(|pk| pk.into()).collect(); Ok(sum_pks) } // Returns all keys in the current database pub async fn keys(&mut self) -> RedisResult> { self.connection.keys("*").await } /// Returns the [`SeedDict`] entry for the given ['SumParticipantPublicKey'] or an empty map /// when a [`SeedDict`] entry does not exist. pub async fn seed_dict_for_sum_pk( &mut self, sum_pk: &SumParticipantPublicKey, ) -> RedisResult> { debug!( "get seed dictionary for sum participant with pk {:?}", sum_pk ); // https://redis.io/commands/hgetall // > Return value // Array reply: list of fields and their values stored in the hash, or an empty // list when key does not exist. let result: Vec<(PublicSigningKeyRead, EncryptedMaskSeedRead)> = self .connection .hgetall(PublicSigningKeyWrite::from(sum_pk)) .await?; let seed_dict = result .into_iter() .map(|(pk, seed)| (pk.into(), seed.into())) .collect(); Ok(seed_dict) } /// Deletes all data in the current database. pub async fn flush_db(&mut self) -> RedisResult<()> { debug!("flush current database"); // https://redis.io/commands/flushdb // > This command never fails. redis::cmd("FLUSHDB") .arg("ASYNC") .query_async(&mut self.connection) .await } } #[cfg(test)] pub(in crate) mod tests { use self::impls::SumDictDeleteError; use super::*; use crate::{ state_machine::tests::utils::{mask_settings, model_settings, pet_settings}, storage::{tests::utils::*, LocalSeedDictAddError, MaskScoreIncrError, SumPartAddError}, }; use serial_test::serial; async fn create_redis_client() -> Client { Client::new("redis://127.0.0.1/").await.unwrap() } pub async fn init_client() -> Client { let mut client = create_redis_client().await; client.flush_db().await.unwrap(); client } #[tokio::test] #[serial] #[ignore] async fn integration_set_and_get_coordinator_state() { // test the writing and reading of the coordinator state let mut client = init_client().await; let set_state = CoordinatorState::new(pet_settings(), mask_settings(), model_settings()); client.set_coordinator_state(&set_state).await.unwrap(); let get_state = client.coordinator_state().await.unwrap().unwrap(); assert_eq!(set_state, get_state) } #[tokio::test] #[serial] #[ignore] async fn integration_get_coordinator_empty() { // test the reading of a non existing coordinator state let mut client = init_client().await; let get_state = client.coordinator_state().await.unwrap(); assert_eq!(None, get_state) } #[tokio::test] #[serial] #[ignore] async fn integration_incr_mask_score() { // test the increment of the mask counter let mut client = init_client().await; let should_be_none = client.best_masks().await.unwrap(); assert!(should_be_none.is_none()); let sum_pks = create_and_add_sum_participant_entries(&mut client, 3).await; let mask = create_mask_zeroed(10); for sum_pk in sum_pks { let res = client.incr_mask_score(&sum_pk, &mask).await; assert!(res.is_ok()) } let best_masks = client.best_masks().await.unwrap().unwrap(); assert!(best_masks.len() == 1); let (best_mask, count) = best_masks.into_iter().next().unwrap(); assert_eq!(best_mask, mask); assert_eq!(count, 3); } #[tokio::test] #[serial] #[ignore] async fn integration_get_incr_mask_count_unknown_sum_pk() { // test the writing and reading of one mask let mut client = init_client().await; let should_be_none = client.best_masks().await.unwrap(); assert!(should_be_none.is_none()); let (sum_pk, _) = create_sum_participant_entry(); let mask = create_mask_zeroed(10); let unknown_sum_pk = client.incr_mask_score(&sum_pk, &mask).await.unwrap(); assert!(matches!( unknown_sum_pk.into_inner().unwrap_err(), MaskScoreIncrError::UnknownSumPk )); } #[tokio::test] #[serial] #[ignore] async fn integration_get_incr_mask_score_sum_pk_already_submitted() { // test the writing and reading of one mask let mut client = init_client().await; let should_be_none = client.best_masks().await.unwrap(); assert!(should_be_none.is_none()); let mut sum_pks = create_and_add_sum_participant_entries(&mut client, 1).await; let sum_pk = sum_pks.pop().unwrap(); let mask = create_mask_zeroed(10); let result = client.incr_mask_score(&sum_pk, &mask).await.unwrap(); assert!(result.is_ok()); let already_submitted = client.incr_mask_score(&sum_pk, &mask).await.unwrap(); assert!(matches!( already_submitted.into_inner().unwrap_err(), MaskScoreIncrError::MaskAlreadySubmitted )); } #[tokio::test] #[serial] #[ignore] async fn integration_get_best_masks_only_one_mask() { // test the writing and reading of one mask let mut client = init_client().await; let should_be_none = client.best_masks().await.unwrap(); assert!(should_be_none.is_none()); let sum_pks = create_and_add_sum_participant_entries(&mut client, 1).await; let mask = create_mask_zeroed(10); let res = client.incr_mask_score(sum_pks.get(0).unwrap(), &mask).await; assert!(res.is_ok()); let best_masks = client.best_masks().await.unwrap().unwrap(); assert!(best_masks.len() == 1); let (best_mask, count) = best_masks.into_iter().next().unwrap(); assert_eq!(best_mask, mask); assert_eq!(count, 1); } #[tokio::test] #[serial] #[ignore] async fn integration_get_best_masks_two_masks() { // test the writing and reading of two masks // the first mask is incremented twice let mut client = init_client().await; let should_be_none = client.best_masks().await.unwrap(); assert!(should_be_none.is_none()); let sum_pks = create_and_add_sum_participant_entries(&mut client, 2).await; let mask_1 = create_mask_zeroed(10); for sum_pk in sum_pks { let res = client.incr_mask_score(&sum_pk, &mask_1).await; assert!(res.is_ok()) } let sum_pks = create_and_add_sum_participant_entries(&mut client, 1).await; let mask_2 = create_mask_zeroed(100); for sum_pk in sum_pks { let res = client.incr_mask_score(&sum_pk, &mask_2).await; assert!(res.is_ok()) } let best_masks = client.best_masks().await.unwrap().unwrap(); assert!(best_masks.len() == 2); let mut best_masks_iter = best_masks.into_iter(); let (first_mask, count) = best_masks_iter.next().unwrap(); assert_eq!(first_mask, mask_1); assert_eq!(count, 2); let (second_mask, count) = best_masks_iter.next().unwrap(); assert_eq!(second_mask, mask_2); assert_eq!(count, 1); } #[tokio::test] #[serial] #[ignore] async fn integration_get_best_masks_no_mask() { // ensure that get_best_masks returns an empty vec if no mask exist let mut client = init_client().await; let best_masks = client.best_masks().await.unwrap(); assert!(best_masks.is_none()) } #[tokio::test] #[serial] #[ignore] async fn integration_get_number_of_unique_masks_empty() { // ensure that get_best_masks returns an empty vec if no mask exist let mut client = init_client().await; let number_of_unique_masks = client.number_of_unique_masks().await.unwrap(); assert_eq!(number_of_unique_masks, 0) } #[tokio::test] #[serial] #[ignore] async fn integration_get_number_of_unique_masks() { // ensure that get_best_masks returns an empty vec if no mask exist let mut client = init_client().await; let should_be_none = client.best_masks().await.unwrap(); assert!(should_be_none.is_none()); let sum_pks = create_and_add_sum_participant_entries(&mut client, 4).await; for (number, sum_pk) in sum_pks.iter().enumerate() { let mask_1 = create_mask(10, number as u32); let res = client.incr_mask_score(sum_pk, &mask_1).await; assert!(res.is_ok()) } let number_of_unique_masks = client.number_of_unique_masks().await.unwrap(); assert_eq!(number_of_unique_masks, 4) } #[tokio::test] #[serial] #[ignore] async fn integration_sum_dict() { // test multiple sum dict related methods let mut client = init_client().await; // create two entries and write them into redis let mut entries = vec![]; for _ in 0..2 { let (pk, epk) = create_sum_participant_entry(); let add_new_key = client.add_sum_participant(&pk, &epk).await.unwrap(); assert!(add_new_key.is_ok()); entries.push((pk, epk)); } // ensure that add_sum_participant returns SumPartAddError::AlreadyExists if the key already exist let (pk, epk) = entries.get(0).unwrap(); let key_already_exist = client.add_sum_participant(pk, epk).await.unwrap(); assert!(matches!( key_already_exist.into_inner().unwrap_err(), SumPartAddError::AlreadyExists )); // ensure that get_sum_dict_len returns 2 let len_of_sum_dict = client.sum_dict_len().await.unwrap(); assert_eq!(len_of_sum_dict, 2); // read the written sum keys // ensure they are equal let sum_pks = client.sum_pks().await.unwrap(); for (sum_pk, _) in entries.iter() { assert!(sum_pks.contains(sum_pk)); } // remove both sum entries for (sum_pk, _) in entries.iter() { let remove_sum_pk = client.remove_sum_dict_entry(sum_pk).await.unwrap(); assert!(remove_sum_pk.is_ok()); } // ensure that add_sum_participant returns SumDictDeleteError::DoesNotExist if the key does not exist let (sum_pk, _) = entries.get(0).unwrap(); let key_does_not_exist = client.remove_sum_dict_entry(sum_pk).await.unwrap(); assert!(matches!( key_does_not_exist.into_inner().unwrap_err(), SumDictDeleteError::DoesNotExist )); // ensure that get_sum_dict an empty sum dict let sum_dict = client.sum_dict().await.unwrap(); assert!(sum_dict.is_none()); } #[tokio::test] #[serial] #[ignore] async fn integration_seed_dict() { let mut client = init_client().await; let sum_pks = create_and_add_sum_participant_entries(&mut client, 2).await; let local_seed_dicts = create_local_seed_entries(&sum_pks); let update_result = add_local_seed_entries(&mut client, &local_seed_dicts).await; update_result.iter().for_each(|res| assert!(res.is_ok())); let redis_sum_dict = client.sum_dict().await.unwrap().unwrap(); let seed_dict = create_seed_dict(redis_sum_dict, &local_seed_dicts); let redis_seed_dict = client.seed_dict().await.unwrap().unwrap(); assert_eq!(seed_dict, redis_seed_dict) } #[tokio::test] #[serial] #[ignore] async fn integration_seed_dict_len_mis_match() { let mut client = init_client().await; let mut sum_pks = create_and_add_sum_participant_entries(&mut client, 2).await; // remove one sum pk to create invalid local seed dicts sum_pks.pop(); let local_seed_dicts = create_local_seed_entries(&sum_pks); let update_result = add_local_seed_entries(&mut client, &local_seed_dicts).await; update_result.into_iter().for_each(|res| { assert!(matches!( res.into_inner().unwrap_err(), LocalSeedDictAddError::LengthMisMatch )) }); } #[tokio::test] #[serial] #[ignore] async fn integration_seed_dict_unknown_sum_participant() { let mut client = init_client().await; let mut sum_pks = create_and_add_sum_participant_entries(&mut client, 2).await; // replace a known sum_pk with an unknown one sum_pks.pop(); let (pk, _) = create_sum_participant_entry(); sum_pks.push(pk); let local_seed_dicts = create_local_seed_entries(&sum_pks); let update_result = add_local_seed_entries(&mut client, &local_seed_dicts).await; update_result.into_iter().for_each(|res| { assert!(matches!( res.into_inner().unwrap_err(), LocalSeedDictAddError::UnknownSumParticipant )) }); } #[tokio::test] #[serial] #[ignore] async fn integration_seed_dict_update_pk_already_submitted() { let mut client = init_client().await; let sum_pks = create_and_add_sum_participant_entries(&mut client, 2).await; let local_seed_dicts = create_local_seed_entries(&sum_pks); let update_result = add_local_seed_entries(&mut client, &local_seed_dicts).await; update_result.iter().for_each(|res| assert!(res.is_ok())); let update_result = add_local_seed_entries(&mut client, &local_seed_dicts).await; update_result.into_iter().for_each(|res| { assert!(matches!( res.into_inner().unwrap_err(), LocalSeedDictAddError::UpdatePkAlreadySubmitted )) }); } #[tokio::test] #[serial] #[ignore] async fn integration_seed_dict_update_pk_already_exists_in_update_seed_dict() { let mut client = init_client().await; let sum_pks = create_and_add_sum_participant_entries(&mut client, 2).await; let local_seed_dicts = create_local_seed_entries(&sum_pks); let update_result = add_local_seed_entries(&mut client, &local_seed_dicts).await; update_result.iter().for_each(|res| assert!(res.is_ok())); let (update_participant, local_seed_dict) = local_seed_dicts.get(0).unwrap().clone(); let remove_result = client .remove_update_participant(&update_participant) .await .unwrap(); assert_eq!(remove_result, 1); let update_result = add_local_seed_entries(&mut client, &[(update_participant, local_seed_dict)]).await; update_result.into_iter().for_each(|res| { assert!(matches!( res.into_inner().unwrap_err(), LocalSeedDictAddError::UpdatePkAlreadyExistsInUpdateSeedDict )) }); } #[tokio::test] #[serial] #[ignore] async fn integration_seed_dict_get_seed_dict_for_sum_pk() { let mut client = init_client().await; let mut sum_pks = create_and_add_sum_participant_entries(&mut client, 2).await; let local_seed_dicts = create_local_seed_entries(&sum_pks); let update_result = add_local_seed_entries(&mut client, &local_seed_dicts).await; update_result.iter().for_each(|res| assert!(res.is_ok())); let redis_sum_dict = client.sum_dict().await.unwrap().unwrap(); let seed_dict = create_seed_dict(redis_sum_dict, &local_seed_dicts); let sum_pk = sum_pks.pop().unwrap(); let redis_sum_seed_dict = client.seed_dict_for_sum_pk(&sum_pk).await.unwrap(); assert_eq!(&redis_sum_seed_dict, seed_dict.get(&sum_pk).unwrap()) } #[tokio::test] #[serial] #[ignore] async fn integration_seed_dict_get_seed_dict_for_sum_pk_empty() { let mut client = init_client().await; let (sum_pk, _) = create_sum_participant_entry(); let result = client.seed_dict_for_sum_pk(&sum_pk).await.unwrap(); assert!(result.is_empty()) } #[tokio::test] #[serial] #[ignore] async fn integration_flush_dicts() { let mut client = init_client().await; // write some data into redis let set_state = CoordinatorState::new(pet_settings(), mask_settings(), model_settings()); let res = client.set_coordinator_state(&set_state).await; assert!(res.is_ok()); let res = client.set_latest_global_model_id("global_model_id").await; assert!(res.is_ok()); let sum_pks = create_and_add_sum_participant_entries(&mut client, 2).await; let local_seed_dicts = create_local_seed_entries(&sum_pks); let update_result = add_local_seed_entries(&mut client, &local_seed_dicts).await; update_result.iter().for_each(|res| assert!(res.is_ok())); let mask = create_mask_zeroed(10); client .incr_mask_score(sum_pks.get(0).unwrap(), &mask) .await .unwrap(); // remove dicts let res = client.delete_dicts().await; assert!(res.is_ok()); // ensure that only the coordinator state and latest global model id exists let res = client.coordinator_state().await; assert!(res.unwrap().is_some()); let res = client.latest_global_model_id().await; assert!(res.unwrap().is_some()); let res = client.sum_dict().await; assert!(res.unwrap().is_none()); let res = client.seed_dict().await; assert!(res.unwrap().is_none()); let res = client.mask_submitted_set().await; assert!(res.unwrap().is_empty()); let res = client.best_masks().await; assert!(res.unwrap().is_none()); } #[tokio::test] #[serial] #[ignore] async fn integration_flush_coordinator_data() { let mut client = init_client().await; // write some data into redis let set_state = CoordinatorState::new(pet_settings(), mask_settings(), model_settings()); let res = client.set_coordinator_state(&set_state).await; assert!(res.is_ok()); let res = client.set_latest_global_model_id("global_model_id").await; assert!(res.is_ok()); let sum_pks = create_and_add_sum_participant_entries(&mut client, 2).await; let local_seed_dicts = create_local_seed_entries(&sum_pks); let update_result = add_local_seed_entries(&mut client, &local_seed_dicts).await; update_result.iter().for_each(|res| assert!(res.is_ok())); let mask = create_mask_zeroed(10); client .incr_mask_score(sum_pks.get(0).unwrap(), &mask) .await .unwrap(); // remove all coordinator data let res = client.delete_coordinator_data().await; assert!(res.is_ok()); let keys = client.keys().await.unwrap(); assert!(keys.is_empty()); } #[tokio::test] #[serial] #[ignore] async fn integration_set_and_get_latest_global_model_id() { // test the writing and reading of the global model id let mut client = init_client().await; let set_id = "global_model_id"; client.set_latest_global_model_id(set_id).await.unwrap(); let get_id = client.latest_global_model_id().await.unwrap().unwrap(); assert_eq!(set_id, get_id) } #[tokio::test] #[serial] #[ignore] async fn integration_is_ready_ok() { // test is_ready command let mut client = init_client().await; let res = client.is_ready().await; assert!(res.is_ok()) } #[tokio::test] #[serial] #[ignore] async fn integration_get_latest_global_model_id_empty() { // test the reading of a non existing global model id let mut client = init_client().await; let get_id = client.latest_global_model_id().await.unwrap(); assert_eq!(None, get_id) } } ================================================ FILE: rust/xaynet-server/src/storage/mod.rs ================================================ //! Storage backends for the coordinator. pub mod coordinator_storage; pub mod model_storage; pub mod store; #[cfg(test)] pub(crate) mod tests; pub mod traits; pub mod trust_anchor; pub use self::{ store::Store, traits::{ CoordinatorStorage, LocalSeedDictAdd, LocalSeedDictAddError, MaskScoreIncr, MaskScoreIncrError, ModelStorage, Storage, StorageError, StorageResult, SumPartAdd, SumPartAddError, TrustAnchor, }, }; ================================================ FILE: rust/xaynet-server/src/storage/model_storage/mod.rs ================================================ //! Storage backends to manage global models. pub mod noop; #[cfg(feature = "model-persistence")] #[cfg_attr(docsrs, doc(cfg(feature = "model-persistence")))] pub mod s3; ================================================ FILE: rust/xaynet-server/src/storage/model_storage/noop.rs ================================================ //! A NoOp [`ModelStorage`] backend. use crate::storage::{ModelStorage, StorageResult}; use async_trait::async_trait; use xaynet_core::{common::RoundSeed, mask::Model}; #[derive(Clone)] pub struct NoOp; #[async_trait] impl ModelStorage for NoOp { async fn set_global_model( &mut self, round_id: u64, round_seed: &RoundSeed, _global_model: &Model, ) -> StorageResult { Ok(Self::create_global_model_id(round_id, round_seed)) } async fn global_model(&mut self, _id: &str) -> StorageResult> { Err(anyhow::anyhow!("No-op model store")) } async fn is_ready(&mut self) -> StorageResult<()> { Ok(()) } } ================================================ FILE: rust/xaynet-server/src/storage/model_storage/s3.rs ================================================ //! A S3 [`ModelStorage`] backend. use std::sync::Arc; use async_trait::async_trait; use displaydoc::Display; use http::StatusCode; use rusoto_core::{credential::StaticProvider, request::TlsError, HttpClient, RusotoError}; use rusoto_s3::{ CreateBucketError, CreateBucketOutput, CreateBucketRequest, DeleteObjectsError, GetObjectError, GetObjectOutput, GetObjectRequest, HeadBucketError, HeadBucketRequest, ListObjectsV2Error, PutObjectError, PutObjectOutput, PutObjectRequest, S3Client, StreamingBody, S3, }; use thiserror::Error; use tokio::io::AsyncReadExt; use tracing::debug; use crate::{ settings::{S3BucketsSettings, S3Settings}, storage::{ModelStorage, StorageResult}, }; use xaynet_core::{common::RoundSeed, mask::Model}; type ClientResult = Result; #[derive(Debug, Display, Error)] pub enum ClientError { /// Failed to create bucket: {0}. CreateBucket(#[from] RusotoError), /// Failed to get object: {0}. GetObject(#[from] RusotoError), /// Failed to put object: {0}. PutObject(#[from] RusotoError), /// Failed to list objects: {0}. ListObjects(#[from] RusotoError), /// Failed to delete objects: {0}. DeleteObjects(#[from] RusotoError), /// Failed to dispatch: {0}. Dispatcher(#[from] TlsError), /// Failed to serialize: {0}. Serialization(bincode::Error), /// Failed to deserialize: {0}. Deserialization(bincode::Error), /// Response contains no body. NoBody, /// Failed to download body: {0}. DownloadBody(std::io::Error), /// Object {0} already exists. ObjectAlreadyExists(String), /// Storage not ready: {0}. NotReady(RusotoError), } #[derive(Clone)] pub struct Client { buckets: Arc, client: S3Client, } impl Client { /// Creates a new S3 client. The client creates and maintains one bucket for storing global models. /// /// To connect to AWS-compatible services such as Minio, you need to specify a custom region. /// ``` /// use rusoto_core::Region; /// use xaynet_server::{ /// settings::{S3BucketsSettings, S3Settings}, /// storage::model_storage::s3::Client, /// }; /// /// let region = Region::Custom { /// name: String::from("minio"), /// endpoint: String::from("http://127.0.0.1:9000"), // URL of minio /// }; /// /// let s3_settings = S3Settings { /// region, /// access_key: String::from("minio"), /// secret_access_key: String::from("minio123"), /// buckets: S3BucketsSettings { /// global_models: String::from("global-models"), /// }, /// }; /// /// let store = Client::new(s3_settings).unwrap(); /// ``` pub fn new(settings: S3Settings) -> ClientResult { let credentials_provider = StaticProvider::new_minimal(settings.access_key, settings.secret_access_key); let dispatcher = HttpClient::new()?; Ok(Self { buckets: Arc::new(settings.buckets), client: S3Client::new_with(dispatcher, credentials_provider, settings.region), }) } /// Creates the `global models` bucket. /// This method does not fail if the bucket already exists or is already owned by you. pub async fn create_global_models_bucket(&self) -> ClientResult<()> { debug!("create {} bucket", &self.buckets.global_models); match self.create_bucket(&self.buckets.global_models).await { Ok(_) | Err(RusotoError::Service(CreateBucketError::BucketAlreadyExists(_))) | Err(RusotoError::Service(CreateBucketError::BucketAlreadyOwnedByYou(_))) => Ok(()), Err(err) => Err(ClientError::from(err)), } } // Downloads the content of the given object. async fn download_object_body(object: GetObjectOutput) -> ClientResult> { let mut body = Vec::new(); object .body .ok_or(ClientError::NoBody)? .into_async_read() .read_to_end(&mut body) .await .map_err(ClientError::DownloadBody)?; Ok(body) } // Fetches the metadata of the object with the given key from the given bucket. async fn fetch_object_meta( &self, bucket: &str, key: &str, ) -> Result> { // If an object does not exist, S3 / Minio will return an error let req = GetObjectRequest { bucket: bucket.to_string(), key: key.to_string(), ..Default::default() }; self.client.get_object(req).await } // Uploads an object with the given key to the given bucket. async fn upload_object( &self, bucket: &str, key: &str, data: Vec, ) -> Result> { let req = PutObjectRequest { bucket: bucket.to_string(), key: key.to_string(), body: Some(StreamingBody::from(data)), ..Default::default() }; self.client.put_object(req).await } // Creates a new bucket with the given bucket name. async fn create_bucket( &self, bucket: &str, ) -> Result> { let req = CreateBucketRequest { bucket: bucket.to_string(), ..Default::default() }; self.client.create_bucket(req).await } } #[async_trait] impl ModelStorage for Client { async fn set_global_model( &mut self, round_id: u64, round_seed: &RoundSeed, global_model: &Model, ) -> StorageResult { let id = Self::create_global_model_id(round_id, round_seed); debug!("upload global model: {}", id); let output = self .fetch_object_meta(&self.buckets.global_models, &id) .await; if output.is_ok() { return Err(anyhow::anyhow!(ClientError::ObjectAlreadyExists( id.to_string() ))); }; let data = bincode::serialize(global_model).map_err(ClientError::Serialization)?; self.upload_object(&self.buckets.global_models, &id, data) .await .map(|_| Ok(id))? } async fn global_model(&mut self, id: &str) -> StorageResult> { debug!("download global model {}", id); let output = self .fetch_object_meta(&self.buckets.global_models, id) .await; let object_meta = match output { Err(RusotoError::Service(GetObjectError::NoSuchKey(_))) => return Ok(None), Err(err) => return Err(anyhow::anyhow!(err)), Ok(object) => object, }; let body = Self::download_object_body(object_meta).await?; let model = bincode::deserialize(&body).map_err(ClientError::Deserialization)?; Ok(Some(model)) } async fn is_ready(&mut self) -> StorageResult<()> { let req = HeadBucketRequest { // we can't use an empty string because S3/Minio would return BAD_REQUEST bucket: self.buckets.global_models.clone(), ..Default::default() }; let res = self.client.head_bucket(req).await; match res { // rusoto doesn't return NoSuchBucket if the bucket doesn't exist // https://github.com/rusoto/rusoto/issues/1099 // // a workaround is to check if the StatusCode is NOT_FOUND Err(RusotoError::Service(HeadBucketError::NoSuchBucket(_))) | Ok(_) => Ok(()), Err(RusotoError::Unknown(resp)) => match resp.status { // https://github.com/timberio/vector/blob/803c68c031e5872876e1167c428cd41358123d64/src/sinks/aws_s3.rs#L229 StatusCode::NOT_FOUND => Ok(()), _ => Err(anyhow::anyhow!(ClientError::NotReady( RusotoError::Unknown(resp) ))), }, Err(e) => Err(anyhow::anyhow!(ClientError::NotReady(e))), } } } #[cfg(test)] pub(in crate) mod tests { use super::*; use crate::storage::tests::utils::create_global_model; use rusoto_core::Region; use rusoto_s3::{ Delete, DeleteBucketError, DeleteBucketRequest, DeleteObjectsOutput, DeleteObjectsRequest, ListObjectsV2Output, ListObjectsV2Request, ObjectIdentifier, }; use serial_test::serial; use xaynet_core::{common::RoundSeed, crypto::ByteObject}; impl Client { // Deletes all objects in a bucket. pub async fn clear_bucket(&self, bucket: &str) -> ClientResult<()> { let mut continuation_token: Option = None; loop { let list_obj_resp = self.list_objects(bucket, continuation_token).await?; if let Some(identifiers) = Self::unpack_object_identifier(&list_obj_resp) { self.delete_objects(bucket, identifiers).await?; } else { break; } // check if more objects exist continuation_token = Self::unpack_next_continuation_token(&list_obj_resp); if continuation_token.is_none() { break; } } Ok(()) } // Unpacks the object identifier/keys of a [`ListObjectsV2Output`] response. fn unpack_object_identifier( list_obj_resp: &ListObjectsV2Output, ) -> Option> { if let Some(objects) = &list_obj_resp.contents { let keys = objects .iter() .filter_map(|obj| obj.key.clone()) .map(|key| ObjectIdentifier { key, ..Default::default() }) .collect(); Some(keys) } else { None } } // Deletes the objects of the given bucket. async fn delete_objects( &self, bucket: &str, identifiers: Vec, ) -> Result> { let req = DeleteObjectsRequest { bucket: bucket.to_string(), delete: Delete { objects: identifiers, ..Default::default() }, ..Default::default() }; self.client.delete_objects(req).await.map_err(From::from) } // Returns all object keys for the given bucket. async fn list_objects( &self, bucket: &str, continuation_token: Option, ) -> Result> { let req = ListObjectsV2Request { bucket: bucket.to_string(), continuation_token, // the S3 response is limited to 1000 keys max. // https://docs.aws.amazon.com/AWSJavaScriptSDK/latest/AWS/S3.html#listObjectsV2-property // However, Minio could return more. max_keys: Some(1000), ..Default::default() }; self.client.list_objects_v2(req).await.map_err(From::from) } // Unpacks the next_continuation_token of the [`ListObjectsV2Output`] response. fn unpack_next_continuation_token(list_obj_resp: &ListObjectsV2Output) -> Option { // https://docs.aws.amazon.com/AmazonS3/latest/dev/ListingObjectKeysUsingJava.html if let Some(is_truncated) = list_obj_resp.is_truncated { if is_truncated { list_obj_resp.next_continuation_token.clone() } else { None } } else { None } } async fn delete_bucket(&self, bucket: &str) -> Result<(), RusotoError> { let req = DeleteBucketRequest { bucket: bucket.to_string(), ..Default::default() }; self.client.delete_bucket(req).await } } fn create_minio_setup(url: &str) -> S3Settings { let region = Region::Custom { name: String::from("minio"), endpoint: String::from(url), }; S3Settings { region, access_key: String::from("minio"), secret_access_key: String::from("minio123"), buckets: S3BucketsSettings::default(), } } pub async fn init_client() -> Client { let settings = create_minio_setup("http://localhost:9000"); let client = Client::new(settings).unwrap(); client.create_global_models_bucket().await.unwrap(); client.clear_bucket("global-models").await.unwrap(); client } async fn init_disconnected_client() -> Client { let settings = create_minio_setup("http://localhost:11000"); Client::new(settings).unwrap() } #[tokio::test] #[serial] #[ignore] async fn integration_test_set_and_get_global_model() { let mut client = init_client().await; let global_model = create_global_model(10); let id = client .set_global_model(1, &RoundSeed::generate(), &global_model) .await .unwrap(); let downloaded_global_model = client.global_model(&id).await.unwrap().unwrap(); assert_eq!(global_model, downloaded_global_model) } #[tokio::test] #[serial] #[ignore] async fn integration_test_get_global_model_non_existent() { let mut client = init_client().await; let id = Client::create_global_model_id(1, &RoundSeed::generate()); let res = client.global_model(&id).await.unwrap(); assert!(res.is_none()) } #[tokio::test] #[serial] #[ignore] async fn integration_test_global_model_already_exists() { let mut client = init_client().await; let global_model = create_global_model(10); let round_seed = RoundSeed::generate(); let id = client .set_global_model(1, &round_seed, &global_model) .await .unwrap(); let global_model_2 = create_global_model(20); let res = client .set_global_model(1, &round_seed, &global_model_2) .await .unwrap_err(); assert!(matches!( res.downcast_ref::().unwrap(), ClientError::ObjectAlreadyExists(_) )); let downloaded_global_model = client.global_model(&id).await.unwrap().unwrap(); assert_eq!(global_model, downloaded_global_model) } #[tokio::test] #[serial] #[ignore] async fn integration_test_is_ready_ok() { let mut client = init_client().await; let res = client.is_ready().await; assert!(res.is_ok()) } #[tokio::test] #[serial] #[ignore] async fn integration_test_is_ready_ok_no_such_bucket() { // test that is_ready returns Ok even if the bucket doesn't exist let mut client = init_client().await; client .delete_bucket(&S3BucketsSettings::default().global_models) .await .unwrap(); let res = client.is_ready().await; assert!(res.is_ok()) } #[tokio::test] #[serial] #[ignore] async fn integration_test_is_ready_err() { let mut client = init_disconnected_client().await; let res = client.is_ready().await; assert!(res.is_err()) } } ================================================ FILE: rust/xaynet-server/src/storage/store.rs ================================================ //! A generic store. use async_trait::async_trait; use crate::{ state_machine::coordinator::CoordinatorState, storage::{ trust_anchor::noop::NoOp, CoordinatorStorage, LocalSeedDictAdd, MaskScoreIncr, ModelStorage, Storage, StorageResult, SumPartAdd, TrustAnchor, }, }; use xaynet_core::{ common::RoundSeed, mask::{MaskObject, Model}, LocalSeedDict, SeedDict, SumDict, SumParticipantEphemeralPublicKey, SumParticipantPublicKey, UpdateParticipantPublicKey, }; #[derive(Clone)] /// A generic store. pub struct Store where C: CoordinatorStorage, M: ModelStorage, T: TrustAnchor, { /// A coordinator store. coordinator: C, /// A model store. model: M, /// A trust anchor. trust_anchor: T, } impl Store where C: CoordinatorStorage, M: ModelStorage, T: TrustAnchor, { pub fn new_with_trust_anchor(coordinator: C, model: M, trust_anchor: T) -> Self { Self { coordinator, model, trust_anchor, } } } impl Store where C: CoordinatorStorage, M: ModelStorage, { /// Creates a new [`Store`]. pub fn new(coordinator: C, model: M) -> Self { Self { coordinator, model, trust_anchor: NoOp, } } } #[async_trait] impl CoordinatorStorage for Store where C: CoordinatorStorage, M: ModelStorage, T: TrustAnchor, { async fn set_coordinator_state(&mut self, state: &CoordinatorState) -> StorageResult<()> { self.coordinator.set_coordinator_state(state).await } async fn coordinator_state(&mut self) -> StorageResult> { self.coordinator.coordinator_state().await } async fn add_sum_participant( &mut self, pk: &SumParticipantPublicKey, ephm_pk: &SumParticipantEphemeralPublicKey, ) -> StorageResult { self.coordinator.add_sum_participant(pk, ephm_pk).await } async fn sum_dict(&mut self) -> StorageResult> { self.coordinator.sum_dict().await } async fn add_local_seed_dict( &mut self, update_pk: &UpdateParticipantPublicKey, local_seed_dict: &LocalSeedDict, ) -> StorageResult { self.coordinator .add_local_seed_dict(update_pk, local_seed_dict) .await } async fn seed_dict(&mut self) -> StorageResult> { self.coordinator.seed_dict().await } async fn incr_mask_score( &mut self, pk: &SumParticipantPublicKey, mask: &MaskObject, ) -> StorageResult { self.coordinator.incr_mask_score(pk, mask).await } async fn best_masks(&mut self) -> StorageResult>> { self.coordinator.best_masks().await } async fn number_of_unique_masks(&mut self) -> StorageResult { self.coordinator.number_of_unique_masks().await } async fn delete_coordinator_data(&mut self) -> StorageResult<()> { self.coordinator.delete_coordinator_data().await } async fn delete_dicts(&mut self) -> StorageResult<()> { self.coordinator.delete_dicts().await } async fn set_latest_global_model_id(&mut self, id: &str) -> StorageResult<()> { self.coordinator.set_latest_global_model_id(id).await } async fn latest_global_model_id(&mut self) -> StorageResult> { self.coordinator.latest_global_model_id().await } async fn is_ready(&mut self) -> StorageResult<()> { self.coordinator.is_ready().await } } #[async_trait] impl ModelStorage for Store where C: CoordinatorStorage, M: ModelStorage, T: TrustAnchor, { async fn set_global_model( &mut self, round_id: u64, round_seed: &RoundSeed, global_model: &Model, ) -> StorageResult { self.model .set_global_model(round_id, round_seed, global_model) .await } async fn global_model(&mut self, id: &str) -> StorageResult> { self.model.global_model(id).await } async fn is_ready(&mut self) -> StorageResult<()> { self.model.is_ready().await } } #[async_trait] impl TrustAnchor for Store where C: CoordinatorStorage, M: ModelStorage, T: TrustAnchor, { async fn publish_proof(&mut self, global_model: &Model) -> StorageResult<()> { self.trust_anchor.publish_proof(global_model).await } async fn is_ready(&mut self) -> StorageResult<()> { self.trust_anchor.is_ready().await } } #[async_trait] impl Storage for Store where C: CoordinatorStorage, M: ModelStorage, T: TrustAnchor, { async fn is_ready(&mut self) -> StorageResult<()> { tokio::try_join!( self.coordinator.is_ready(), self.model.is_ready(), self.trust_anchor.is_ready() ) .map(|_| ()) } } ================================================ FILE: rust/xaynet-server/src/storage/tests/mod.rs ================================================ use crate::{ state_machine::coordinator::CoordinatorState, storage::{ coordinator_storage::redis, model_storage, CoordinatorStorage, LocalSeedDictAdd, MaskScoreIncr, ModelStorage, Storage, StorageResult, Store, SumPartAdd, TrustAnchor, }, }; use async_trait::async_trait; use mockall::*; use xaynet_core::{ common::RoundSeed, mask::{MaskObject, Model}, LocalSeedDict, SeedDict, SumDict, SumParticipantEphemeralPublicKey, SumParticipantPublicKey, UpdateParticipantPublicKey, }; pub mod utils; pub async fn init_store() -> impl Storage { let coordinator_store = redis::tests::init_client().await; let model_store = { #[cfg(not(feature = "model-persistence"))] { model_storage::noop::NoOp } #[cfg(feature = "model-persistence")] { model_storage::s3::tests::init_client().await } }; Store::new(coordinator_store, model_store) } mock! { pub CoordinatorStore {} #[async_trait] impl CoordinatorStorage for CoordinatorStore { async fn set_coordinator_state(&mut self, state: &CoordinatorState) -> StorageResult<()>; async fn coordinator_state(&mut self) -> StorageResult>; async fn add_sum_participant( &mut self, pk: &SumParticipantPublicKey, ephm_pk: &SumParticipantEphemeralPublicKey, ) -> StorageResult; async fn sum_dict(&mut self) -> StorageResult>; async fn add_local_seed_dict( &mut self, update_pk: &UpdateParticipantPublicKey, local_seed_dict: &LocalSeedDict, ) -> StorageResult; async fn seed_dict(&mut self) -> StorageResult>; async fn incr_mask_score( &mut self, pk: &SumParticipantPublicKey, mask: &MaskObject, ) -> StorageResult; async fn best_masks(&mut self) -> StorageResult>>; async fn number_of_unique_masks(&mut self) -> StorageResult; async fn delete_coordinator_data(&mut self) -> StorageResult<()>; async fn delete_dicts(&mut self) -> StorageResult<()>; async fn set_latest_global_model_id(&mut self, id: &str) -> StorageResult<()>; async fn latest_global_model_id(&mut self) -> StorageResult>; async fn is_ready(&mut self) -> StorageResult<()>; } impl Clone for CoordinatorStore { fn clone(&self) -> Self; } } mock! { pub ModelStore {} #[async_trait] impl ModelStorage for ModelStore { async fn set_global_model( &mut self, round_id: u64, round_seed: &RoundSeed, global_model: &Model, ) -> StorageResult; async fn global_model(&mut self, id: &str) -> StorageResult>; async fn is_ready(&mut self) -> StorageResult<()>; } impl Clone for ModelStore { fn clone(&self) -> Self; } } mock! { pub TrustAnchor {} #[async_trait] impl TrustAnchor for TrustAnchor { async fn publish_proof(&mut self, global_model: &Model) -> StorageResult<()>; async fn is_ready(&mut self) -> StorageResult<()>; } impl Clone for TrustAnchor { fn clone(&self) -> Self; } } ================================================ FILE: rust/xaynet-server/src/storage/tests/utils.rs ================================================ use num::{bigint::BigUint, traits::identities::Zero}; use crate::{ state_machine::tests::utils::mask_settings, storage::{CoordinatorStorage, LocalSeedDictAdd}, }; use xaynet_core::{ crypto::{ByteObject, EncryptKeyPair, SigningKeyPair}, mask::{EncryptedMaskSeed, MaskConfig, MaskObject}, LocalSeedDict, SeedDict, SumDict, SumParticipantEphemeralPublicKey, SumParticipantPublicKey, UpdateParticipantPublicKey, }; pub fn create_sum_participant_entry() -> (SumParticipantPublicKey, SumParticipantEphemeralPublicKey) { let SigningKeyPair { public: pk, .. } = SigningKeyPair::generate(); let EncryptKeyPair { public: ephm_pk, .. } = EncryptKeyPair::generate(); (pk, ephm_pk) } pub fn create_local_seed_entries( sum_pks: &[SumParticipantPublicKey], ) -> Vec<(UpdateParticipantPublicKey, LocalSeedDict)> { let mut entries = Vec::new(); for _ in 0..sum_pks.len() { let SigningKeyPair { public: update_pk, .. } = SigningKeyPair::generate(); let mut local_seed_dict = LocalSeedDict::new(); for sum_pk in sum_pks { let seed = EncryptedMaskSeed::zeroed(); local_seed_dict.insert(*sum_pk, seed); } entries.push((update_pk, local_seed_dict)) } entries } pub fn create_mask_zeroed(model_length: usize) -> MaskObject { MaskObject::new( MaskConfig::from(mask_settings()).into(), vec![BigUint::zero(); model_length], BigUint::zero(), ) .unwrap() } pub fn create_mask(model_length: usize, number: u32) -> MaskObject { MaskObject::new( MaskConfig::from(mask_settings()).into(), vec![BigUint::from(number); model_length], BigUint::zero(), ) .unwrap() } pub fn create_seed_dict( sum_dict: SumDict, seed_updates: &[(UpdateParticipantPublicKey, LocalSeedDict)], ) -> SeedDict { let mut seed_dict: SeedDict = sum_dict .keys() .map(|pk| (*pk, LocalSeedDict::new())) .collect(); for (pk, local_seed_dict) in seed_updates { for (sum_pk, seed) in local_seed_dict { seed_dict.get_mut(sum_pk).unwrap().insert(*pk, seed.clone()); } } seed_dict } pub async fn create_and_add_sum_participant_entries( client: &mut impl CoordinatorStorage, n: u32, ) -> Vec { let mut sum_pks = Vec::new(); for _ in 0..n { let (pk, ephm_pk) = create_sum_participant_entry(); let _ = client.add_sum_participant(&pk, &ephm_pk).await.unwrap(); sum_pks.push(pk); } sum_pks } pub async fn add_local_seed_entries( client: &mut impl CoordinatorStorage, local_seed_entries: &[(UpdateParticipantPublicKey, LocalSeedDict)], ) -> Vec { let mut update_result = Vec::new(); for (update_pk, local_seed_dict) in local_seed_entries { let res = client.add_local_seed_dict(update_pk, local_seed_dict).await; assert!(res.is_ok()); update_result.push(res.unwrap()) } update_result } use xaynet_core::mask::{FromPrimitives, Model}; pub fn create_global_model(model_length: usize) -> Model { Model::from_primitives(vec![0; model_length].into_iter()).unwrap() } ================================================ FILE: rust/xaynet-server/src/storage/traits.rs ================================================ //! Storage API. use async_trait::async_trait; use derive_more::Deref; use displaydoc::Display; use num_enum::TryFromPrimitive; use thiserror::Error; use crate::state_machine::coordinator::CoordinatorState; use xaynet_core::{ common::RoundSeed, crypto::ByteObject, mask::{MaskObject, Model}, LocalSeedDict, SeedDict, SumDict, SumParticipantEphemeralPublicKey, SumParticipantPublicKey, UpdateParticipantPublicKey, }; /// The error type for storage operations that are not directly related to application domain. /// These include, for example IO errors like broken pipe, file not found, out-of-memory, etc. pub type StorageError = anyhow::Error; /// The result of the storage operation. pub type StorageResult = Result; #[async_trait] /// An abstract coordinator storage. pub trait CoordinatorStorage where Self: Clone + Send + Sync + 'static, { /// Sets a [`CoordinatorState`]. /// /// # Behavior /// /// - If no state has been set yet, set the state and return `StorageResult::Ok(())`. /// - If a state already exists, override the state and return `StorageResult::Ok(())`. async fn set_coordinator_state(&mut self, state: &CoordinatorState) -> StorageResult<()>; /// Returns a [`CoordinatorState`]. /// /// # Behavior /// /// - If no state has been set yet, return `StorageResult::Ok(Option::None)`. /// - If a state exists, return `StorageResult::Ok(Some(CoordinatorState))`. async fn coordinator_state(&mut self) -> StorageResult>; /// Adds a sum participant entry to the [`SumDict`]. /// /// # Behavior /// /// - If a sum participant has been successfully added, return `StorageResult::Ok(SumPartAdd)` /// containing a `Result::Ok(())`. /// - If the participant could not be added due to a PET protocol error, return /// the corresponding `StorageResult::Ok(SumPartAdd)` containing a /// `Result::Err(SumPartAddError)`. async fn add_sum_participant( &mut self, pk: &SumParticipantPublicKey, ephm_pk: &SumParticipantEphemeralPublicKey, ) -> StorageResult; /// Returns the [`SumDict`]. /// /// # Behavior /// /// - If the sum dict does not exist, return `StorageResult::Ok(Option::None)`. /// - If the sum dict exists, return `StorageResult::Ok(Option::Some(SumDict))`. async fn sum_dict(&mut self) -> StorageResult>; /// Adds a local [`LocalSeedDict`] of the given [`UpdateParticipantPublicKey`] to the [`SeedDict`]. /// /// # Behavior /// /// - If the local seed dict has been successfully added, return /// `StorageResult::Ok(LocalSeedDictAdd)` containing a `Result::Ok(())`. /// - If the local seed dict could not be added due to a PET protocol error, return /// the corresponding `StorageResult::Ok(LocalSeedDictAdd)` containing a /// `Result::Err(LocalSeedDictAddError)`. async fn add_local_seed_dict( &mut self, update_pk: &UpdateParticipantPublicKey, local_seed_dict: &LocalSeedDict, ) -> StorageResult; /// Returns the [`SeedDict`]. /// /// # Behavior /// /// - If the seed dict does not exist, return `StorageResult::Ok(Option::None)`. /// - If the seed dict exists, return `StorageResult::Ok(Option::Some(SeedDict))`. async fn seed_dict(&mut self) -> StorageResult>; /// Increments the mask score with the given [`MaskObject`]b by one. /// /// # Behavior /// /// - If the mask score has been successfully incremented, return /// `StorageResult::Ok(MaskScoreIncr)` containing a `Result::Ok(())`. /// - If the mask score could not be incremented due to a PET protocol error, /// return the corresponding `Result::Ok(MaskScoreIncr)` containing a /// `Result::Err(MaskScoreIncrError)`. async fn incr_mask_score( &mut self, pk: &SumParticipantPublicKey, mask: &MaskObject, ) -> StorageResult; /// Returns the two masks with the highest score. /// /// # Behavior /// /// - If no masks exist, return `Result::Ok(Option::None)`. /// - If only one mask exists, return this mask /// `StorageResult::Ok(Option::Some(Vec<(MaskObject, u64)>))`. /// - If two masks exist with the same score, return both /// `StorageResult::Ok(Option::Some(Vec<(MaskObject, u64)>))`. /// - If two masks exist with the different score, return /// both in descending order `StorageResult::Ok(Option::Some(Vec<(MaskObject, u64)>))`. async fn best_masks(&mut self) -> StorageResult>>; /// Returns the number of unique masks. async fn number_of_unique_masks(&mut self) -> StorageResult; /// Deletes all coordinator data. This includes the coordinator /// state as well as the [`SumDict`], [`SeedDict`] and `mask` dictionary. async fn delete_coordinator_data(&mut self) -> StorageResult<()>; /// Deletes the [`SumDict`], [`SeedDict`] and `mask` dictionary. async fn delete_dicts(&mut self) -> StorageResult<()>; /// Sets the latest global model id. /// /// # Behavior /// /// - If no global model id has been set yet, set the new id and return `StorageResult::Ok(())`. /// - If the global model id already exists, override with the new id and /// return `StorageResult::Ok(())`. async fn set_latest_global_model_id(&mut self, id: &str) -> StorageResult<()>; /// Returns the latest global model id. /// /// # Behavior /// /// - If the global model id does not exist, return `StorageResult::Ok(None)`. /// - If the global model id exists, return `StorageResult::Ok(Some(String)))`. async fn latest_global_model_id(&mut self) -> StorageResult>; /// Checks if the [`CoordinatorStorage`] is ready to process requests. /// /// # Behavior /// /// If the [`CoordinatorStorage`] is ready to process requests, return `StorageResult::Ok(())`. /// If the [`CoordinatorStorage`] cannot process requests because of a connection error, /// for example, return `StorageResult::Err(error)`. async fn is_ready(&mut self) -> StorageResult<()>; } #[async_trait] /// An abstract model storage. pub trait ModelStorage where Self: Clone + Send + Sync + 'static, { /// Sets a global model. /// /// # Behavior /// /// - If the global model already exists (has the same model id), return /// `StorageResult::Err(StorageError))`. /// - If the global model does not exist, set the model and return `StorageResult::Ok(String)` async fn set_global_model( &mut self, round_id: u64, round_seed: &RoundSeed, global_model: &Model, ) -> StorageResult; /// Returns a global model. /// /// # Behavior /// /// - If the global model does not exist, return `StorageResult::Ok(Option::None)`. /// - If the global model exists, return `StorageResult::Ok(Option::Some(Model))`. async fn global_model(&mut self, id: &str) -> StorageResult>; /// Creates a unique global model id by using the round id and the round seed in which /// the global model was created. /// /// The format of the default implementation is `roundid_roundseed`, /// where the [`RoundSeed`] is encoded in hexadecimal. fn create_global_model_id(round_id: u64, round_seed: &RoundSeed) -> String { let round_seed = hex::encode(round_seed.as_slice()); format!("{}_{}", round_id, round_seed) } /// Checks if the [`ModelStorage`] is ready to process requests. /// /// # Behavior /// /// If the [`ModelStorage`] is ready to process requests, return `StorageResult::Ok(())`. /// If the [`ModelStorage`] cannot process requests because of a connection error, /// for example, return `StorageResult::Err(error)`. async fn is_ready(&mut self) -> StorageResult<()>; } #[async_trait] /// An abstract trust anchor provider. pub trait TrustAnchor where Self: Clone + Send + Sync + 'static, { /// Publishes a proof of the global model. /// /// # Behavior /// /// Return `StorageResult::Ok(())` if the proof was published successfully, /// otherwise return `StorageResult::Err(error)`. async fn publish_proof(&mut self, global_model: &Model) -> StorageResult<()>; /// Checks if the [`TrustAnchor`] is ready to process requests. /// /// # Behavior /// /// If the [`TrustAnchor`] is ready to process requests, return `StorageResult::Ok(())`. /// If the [`TrustAnchor`] cannot process requests because of a connection error, /// for example, return `StorageResult::Err(error)`. async fn is_ready(&mut self) -> StorageResult<()>; } #[async_trait] pub trait Storage: CoordinatorStorage + ModelStorage + TrustAnchor { /// Checks if the [`CoordinatorStorage`], [`ModelStorage`] and [`TrustAnchor`] /// are ready to process requests. /// /// # Behavior /// /// If all inner services are ready to process requests, /// return `StorageResult::Ok(())`. /// If any inner service cannot process requests because of a connection error, /// for example, return `StorageResult::Err(error)`. async fn is_ready(&mut self) -> StorageResult<()>; } /// A wrapper that contains the result of the "add sum participant" operation. #[derive(Deref)] pub struct SumPartAdd(pub(crate) Result<(), SumPartAddError>); impl SumPartAdd { /// Unwraps this wrapper, returning the underlying result. pub fn into_inner(self) -> Result<(), SumPartAddError> { self.0 } } /// Error that can occur when adding a sum participant to the [`SumDict`]. #[derive(Display, Error, Debug, TryFromPrimitive)] #[repr(i64)] pub enum SumPartAddError { /// sum participant already exists AlreadyExists = 0, } /// A wrapper that contains the result of the "add local seed dict" operation. #[derive(Deref)] pub struct LocalSeedDictAdd(pub(crate) Result<(), LocalSeedDictAddError>); impl LocalSeedDictAdd { /// Unwraps this wrapper, returning the underlying result. pub fn into_inner(self) -> Result<(), LocalSeedDictAddError> { self.0 } } /// Error that can occur when adding a local seed dict to the [`SeedDict`]. #[derive(Display, Error, Debug, TryFromPrimitive)] #[repr(i64)] pub enum LocalSeedDictAddError { /// the length of the local seed dict and the length of sum dict are not equal LengthMisMatch = -1, /// local dict contains an unknown sum participant UnknownSumParticipant = -2, /// update participant already submitted an update UpdatePkAlreadySubmitted = -3, /// update participant already exists in the inner update seed dict UpdatePkAlreadyExistsInUpdateSeedDict = -4, } /// A wrapper that contains the result of the "increment mask score" operation. #[derive(Deref)] pub struct MaskScoreIncr(pub(crate) Result<(), MaskScoreIncrError>); impl MaskScoreIncr { /// Unwraps this wrapper, returning the underlying result. pub fn into_inner(self) -> Result<(), MaskScoreIncrError> { self.0 } } /// Error that can occur when incrementing a mask score. #[derive(Display, Error, Debug, TryFromPrimitive)] #[repr(i64)] pub enum MaskScoreIncrError { /// unknown sum participant UnknownSumPk = -1, /// sum participant submitted a mask already MaskAlreadySubmitted = -2, } ================================================ FILE: rust/xaynet-server/src/storage/trust_anchor/mod.rs ================================================ pub mod noop; ================================================ FILE: rust/xaynet-server/src/storage/trust_anchor/noop.rs ================================================ use crate::storage::traits::{StorageResult, TrustAnchor}; use async_trait::async_trait; use xaynet_core::mask::Model; #[derive(Clone)] pub struct NoOp; #[async_trait] impl TrustAnchor for NoOp { async fn publish_proof(&mut self, _global_model: &Model) -> StorageResult<()> { Ok(()) } async fn is_ready(&mut self) -> StorageResult<()> { Ok(()) } } ================================================ FILE: scripts/bump_version.sh ================================================ #!/usr/bin/env bash WORKDIR="$(git rev-parse --show-toplevel)" # Save the git HEAD before running the script HEAD= # Latest tag PREV_TAG= # Commit that corresponds to the latest tag PREV_TAGGED_COMMIT= # Latest version numbers PREV_MAJOR= PREV_MINOR= PREV_PATCH= # New version numbers MAJOR= MINOR= PATCH= # Return the new version number version() { echo "${MAJOR}.${MINOR}.${PATCH}" } # Return the previous version number prev_version() { echo "${PREV_MAJOR}.${PREV_MINOR}.${PREV_PATCH}" } # Find and parse the latest tag, and populate the global variables fetch_latest_version() { local tag_regex='^v[0-9]\.[0-9]\.[0-9]$' PREV_TAG=$(git describe --tags --abbrev=0) PREV_TAGGED_COMMIT=$(git rev-list -n 1 "${PREV_TAG}") echo "latest tag found: ${PREV_TAG} (commit ${PREV_TAGGED_COMMIT})" if ! [[ ${PREV_TAG} =~ ${tag_regex} ]] ; then echo "error: invalid tag ${PREV_TAG}" >&2 exit 1 fi PREV_MAJOR=${PREV_TAG:1:1} PREV_MINOR=${PREV_TAG:3:1} PREV_PATCH=${PREV_TAG:5:1} MAJOR=${PREV_MAJOR} MINOR=${PREV_MINOR} PATCH=${PREV_PATCH} } # Check that the working directory doesn't have un-committed changes. If it # does, error out. check_workdir_is_clean() { if [ -z "$(git status --untracked-files=no --porcelain)" ]; then echo "git working directory is clean, continuing" else echo "git working directory is dirty, aborting" 2>&1 exit 1 fi } # A helper function for interactively asking the users whether the script # should continue or not ask_yes_or_no() { select yn in "Yes" "No"; do case $yn in Yes ) echo "continuing" break ;; No ) echo "aborting" 2>&1 exit 1 ;; esac done } # Print a message explaining what the script does, and how to undo the changes # if necessary disclaimer() { cat << EOF *********************************** IMPORTANT *********************************** This script modifies the git commit history. If anything goes wrong, or if you have a doubt, you can always rollback to where this script start by running: git reset --hard ${HEAD} This script will: 1. Find the latest tag on the current branch 2. Make sure that the CHANGELOG.md file was updated since this tag was pushed 3. Update the version number in various files in the repository, and commit these changes 4. Create a new annotated tag EOF } # Print a help message usage() { cat << EOF ./bump_version.sh [-h|--help] [-M|--major] [-m|--minor] [-p|--patch] bump_version.sh is used for bumping the previous version number and creating a new tag. OPTIONS: -h|--help: print this help message -M|--major: bump the major version number -m|--minor: bump the minor version number -p|--patch: bump the patch version number EOF } # Make sure the CHANGELOG was updated, and ask the user to double check the # changes check_changelog_was_updated() { diff() { git --no-pager diff "${PREV_TAGGED_COMMIT}" HEAD CHANGELOG.md } if [ "$(diff | wc -l)" -eq 0 ] ; then echo "error: the CHANGELOG has not been updated since ${PREV_TAG}" 2>&1 echo "Do you want to continue anyway?" ask_yes_or_no else echo "The CHANGELOG has been updated since ${PREV_TAG}" diff echo "Does the change above look correct for v$(version)" ask_yes_or_no fi } # Small helper to update the version number in a file, using sed set_version_in_file() { local sed_expr=${1} local file=${2} echo "Setting version to $(version) in ${file}" sed -i "${sed_expr}" "${file}" } # Update the version numbers in various files, and ask confirmation from the # user before committing these changes. update_versions() { set_version_in_file 's/^version = ".*"$/version = "'"$(version)"'"/g' rust/Cargo.toml (cd rust && cargo update -v) if [ "$(git --no-pager diff | wc -l)" -eq 0 ] ; then echo "No changes were made, it seems that the version files were already updated to $(version)" echo "Do you want to continue?" ask_yes_or_no else git --no-pager diff echo "Do you want to commit the changes above?" ask_yes_or_no git add rust/Cargo.toml rust/Cargo.lock git commit -m "bump version $(prev_version) -> $(version)" fi } cargo_publish_dry_run() { echo "Checking that the Rust package is ready to be published" cargo publish --dry-run echo "The Rust package is ready to be published" } main() { local bump_major=false local bump_minor=false local bump_patch=false if [ "$#" -eq 0 ]; then usage exit 0 fi cd "${WORKDIR}" check_workdir_is_clean while (( "$#" )); do case "$1" in -M|--major) bump_major=true shift ;; -m|--minor) bump_minor=true shift ;; -p|--patch) bump_patch=true shift ;; -h|--help|help) usage exit 0 ;; *) echo "error: unsupported argument \"$1\"" 2>&1 usage exit 1 ;; esac done HEAD=$(git rev-parse HEAD) disclaimer fetch_latest_version if [ "$bump_major" = true ] ; then MAJOR=$((PREV_MAJOR + 1)) fi if [ "$bump_minor" = true ] ; then MINOR=$((PREV_MINOR + 1)) fi if [ "$bump_patch" = true ] ; then PATCH=$((PREV_PATCH + 1)) fi if [ "$(prev_version)" = "$(version)" ] ; then echo "error: new version is the same than previous version" 2>&1 exit 1 fi echo "Bumping version from $(prev_version) to $(version)" ask_yes_or_no update_versions check_changelog_was_updated (cd rust && cargo_publish_dry_run) echo "Tagging ${HEAD} as \"v$(version)\"" git tag -a "v$(version)" -m "release v$(version)" echo "Done!" cat << EOF You can now publish the Rust package: (cd rust && cargo publish) Finally: push the new tag to Github: git push master --tags EOF } set -e main "$@"