[
  {
    "path": ".dockerignore",
    "content": "**/.ignore\n**/shell.nix\n**/.envrc\n\n.git\n.github\nassets\nbindings\nconfigs\ndocker\nk8s\nrust/target\nscripts\n"
  },
  {
    "path": ".github/codecov.yml",
    "content": "coverage:\n  status:\n    patch: off\n"
  },
  {
    "path": ".github/dependabot.yml",
    "content": "version: 2\nupdates:\n  - package-ecosystem: cargo\n    directory: \"/rust\"\n    schedule:\n      interval: daily\n      time: \"09:00\"\n      timezone: \"Europe/Berlin\"\n\n  - package-ecosystem: cargo\n    directory: \"/bindings/python\"\n    schedule:\n      interval: weekly\n      day: \"monday\"\n\n  - package-ecosystem: pip\n    directory: \"/bindings/python/examples/keras_house_prices\"\n    schedule:\n      interval: weekly\n      day: \"monday\"\n\n  - package-ecosystem: \"github-actions\"\n    directory: \"/\"\n    schedule:\n      interval: \"weekly\"\n      day: \"monday\"\n"
  },
  {
    "path": ".github/workflows/dockercompose-validation.yml",
    "content": "name: docker-compose validation\n\non:\n  push:\n    paths:\n      - 'docker/docker-compose*yml'\n\njobs:\n  check-docker-compose:\n    name: docker-compose validation\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n\n      - name: Verify docker-compose\n        working-directory: ./docker\n        run: docker-compose -f docker-compose.yml config -q\n"
  },
  {
    "path": ".github/workflows/dockerfile-validation.yml",
    "content": "name: Dockerfiles linting\n\non:\n  push:\n    paths:\n      - 'docker/Dockerfile**'\n\njobs:\n  lint:\n    name: Dockerfiles linting\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n\n      - name: Lint file\n        run: docker run -v $GITHUB_WORKSPACE/docker/Dockerfile:/Dockerfile replicated/dockerfilelint /Dockerfile\n"
  },
  {
    "path": ".github/workflows/dockerhub-cleanup.yml",
    "content": "name: DockerHub Scheduled Cleanup\n\non:\n  schedule:\n    - cron: '00 00 * * sun'\n  workflow_dispatch:\n\njobs:\n  dockerhub-cleanup-inactive:\n    name: Cleanup inactive xaynet tags on Dockerhub\n    runs-on: ubuntu-latest\n    steps:\n      - name: Setup hub-tool\n        env:\n          DHUSER: ${{ secrets.DOCKER_USERNAME }}\n          DHTOKEN: ${{ secrets.DOCKER_PASSWORD }}\n        run: |\n          export DEBIAN_FRONTEND=\"noninteractive\"\n          sudo apt update\n          sudo apt install -y jq\n          LATEST=$(curl -s \"https://api.github.com/repos/docker/hub-tool/releases/latest\" | grep '\"tag_name\":' | sed -E 's/.*\"([^\"]+)\".*/\\1/')\n          wget https://github.com/docker/hub-tool/releases/download/${LATEST}/hub-tool-linux-amd64.tar.gz -O /tmp/hub-tool-linux-amd64.tar.gz\n          tar xzvf /tmp/hub-tool-linux-amd64.tar.gz --strip-components 1 -C /tmp hub-tool/hub-tool\n          mkdir -pv -m 700 ~/.docker\n          chmod -v 600 ~/.docker/config.json\n          echo -ne \"ewogICJ1c2VybmFtZSI6ICJESFVTRVIiLAogICJwYXNzd29yZCI6ICJESFRPS0VOIgp9Cg==\" | base64 -d > /tmp/auth.json\n          echo -ne \"ewogICJhdXRocyI6IHsKICAgICJodWItdG9vbCI6IHsKICAgICAgImF1dGgiOiAiREhVU0VSVE9LRU4iCiAgICB9LAogICAgImh1Yi10b29sLXJlZnJlc2gtdG9rZW4iOiB7CiAgICAgICJhdXRoIjogIkRIVVNFUiIKICAgIH0sCiAgICAiaHViLXRvb2wtdG9rZW4iOiB7CiAgICAgICJhdXRoIjogIkRIVVNFUiIsCiAgICAgICJpZGVudGl0eXRva2VuIjogIkpXVFRPS0VOIgogICAgfQogIH0KfQoK\" | base64 -d > ~/.docker/config.json\n          RUSERTOKEN=$(echo -ne \"${DHUSER}:${DHTOKEN}\" | base64 -w0)\n          RUSER=$(echo -ne \"${DHUSER}:\" | base64 -w0)\n          RTOKEN=$(echo -ne \"${DHTOKEN}\" | base64 -w0)\n          sed -i -e \"s,DHUSERTOKEN,${RUSERTOKEN},g\" -e \"s,DHUSER,${RUSER},g\" -e \"s,DHTOKEN,${RTOKEN},g\" /tmp/auth.json ~/.docker/config.json\n          JWT=$(curl -s -XPOST \"https://hub.docker.com/v2/users/login\" -H \"Content-Type:application/json\" -d \"@/tmp/auth.json\" | jq -r .token)\n          sed -i -e \"s,JWTTOKEN,${JWT},g\" ~/.docker/config.json\n      - name: Delete target tags\n        run: |\n          echo -e \"Inactive tags:\"\n          /tmp/hub-tool tag ls xaynetwork/xaynet | grep -e STATUS -e inactive\n          TAGS=$(/tmp/hub-tool tag ls xaynetwork/xaynet | grep inactive | grep -v -e \"v[0-9]\\+\\.[0-9]\\+\\.[0-9]\\+\" | awk '{ print $1 }')\n          if [[ ! -z ${TAGS} ]]\n            then\n              echo -e \"\\n\\n\"\n              for tag in ${TAGS}\n                do\n                  /tmp/hub-tool tag rm -f ${tag}\n              done\n          fi\n"
  },
  {
    "path": ".github/workflows/dockerhub-master.yml",
    "content": "name: DockerHub (master)\n\non:\n  push:\n    branches:\n      - master\n\njobs:\n  build-tag-push-master:\n    name: build-tag-push-master\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v3\n\n      - name: Login to DockerHub\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKER_USERNAME }}\n          password: ${{ secrets.DOCKER_PASSWORD }}\n\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v2\n\n      - name: build-tag-push\n        uses: docker/build-push-action@v3\n        id: docker\n        with:\n          context: .\n          file: docker/Dockerfile\n          tags: xaynetwork/xaynet:development\n          push: true\n          build-args: COORDINATOR_FEATURES=metrics\n\n      - name: Notify on Slack\n        uses: 8398a7/action-slack@v3\n        if: always()\n        with:\n          status: custom\n          fields: workflow,job,repo,ref\n          custom_payload: |\n            {\n              username: 'GitHub Actions',\n              icon_emoji: ':octocat:',\n              attachments: [{\n                color: '${{ steps.docker.outcome }}' === 'success' ? 'good' : '${{ steps.docker.outcome }}' === 'failure' ? 'danger' : 'warning',\n                text: `${process.env.AS_WORKFLOW}\\nRepository: :xaynet: ${process.env.AS_REPO}\\nRef: ${process.env.AS_REF}\\nTags: development`,\n              }]\n            }\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n          SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }}\n"
  },
  {
    "path": ".github/workflows/dockerhub-pr-with-parameters.yml",
    "content": "name: DockerHub (PR) with parameters\n\non:\n  issue_comment:\n    types: [created]\n\njobs:\n  check_comments:\n    name: Check comments for /deploy\n    runs-on: ubuntu-latest\n    steps:\n      - name: Check for Command\n        id: command\n        uses: xt0rted/slash-command-action@v1\n        with:\n          repo-token: ${{ secrets.GITHUB_TOKEN }}\n          command: deploy\n          reaction: \"true\"\n          reaction-type: \"eyes\"\n          allow-edits: \"false\"\n          permission-level: write\n\n      - uses: jungwinter/split@v2\n        id: split\n        with:\n          msg: '${{ steps.command.outputs.command-arguments }}'\n          maxsplit: 1\n\n      - uses: xt0rted/pull-request-comment-branch@v1\n        id: comment-branch\n        with:\n          repo_token: ${{ secrets.GITHUB_TOKEN }}\n\n      - uses: actions/checkout@v3\n        if: success()\n        with:\n          ref: ${{ steps.comment-branch.outputs.head_ref }}\n\n      - name: Find and Replace\n        uses: jacobtomlinson/gha-find-replace@master\n        with:\n          find: \"newTag: development\"\n          replace: \"newTag: ${{ steps.comment-branch.outputs.head_ref }}\"\n          include: \"kustomization.yaml\"\n\n      - name: Login to DockerHub\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKER_USERNAME }}\n          password: ${{ secrets.DOCKER_PASSWORD }}\n\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v2\n\n      - name: build-tag-push\n        uses: docker/build-push-action@v3\n        id: docker\n        with:\n          context: .\n          file: docker/Dockerfile\n          tags: xaynetwork/xaynet:${{ steps.comment-branch.outputs.head_ref }}\n          push: true\n          build-args: |\n            ${{ steps.split.outputs._0 }}\n            ${{ steps.split.outputs._1 }}\n\n      - name: Notify on Slack\n        uses: 8398a7/action-slack@v3\n        if: ${{ success() }}\n        with:\n          status: custom\n          fields: workflow,job,repo,ref\n          custom_payload: |\n            {\n              username: 'GitHub Actions',\n              icon_emoji: ':octocat:',\n              attachments: [{\n                color: '${{ steps.docker.outcome }}' === 'success' ? 'good' : '${{ steps.docker.outcome }}' === 'failure' ? 'danger' : 'warning',\n                text: `${process.env.AS_WORKFLOW}\\nRepository: :xaynet: ${process.env.AS_REPO}\\nRef: ${process.env.AS_REF}\\nTags: ${{ steps.comment-branch.outputs.head_ref }}`,\n              }]\n            }\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n          SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }}\n"
  },
  {
    "path": ".github/workflows/dockerhub-release.yml",
    "content": "name: DockerHub (Release)\n\non:\n  push:\n    tags:\n      - v[0-9]+.[0-9]+.[0-9]+\n\njobs:\n  build-tag-push-release:\n    name: build-tag-push-release\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v3\n\n      - name: Login to DockerHub\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKER_USERNAME }}\n          password: ${{ secrets.DOCKER_PASSWORD }}\n\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v2\n\n      - name: build-tag-push\n        uses: docker/build-push-action@v3\n        id: docker\n        with:\n          context: .\n          file: docker/Dockerfile\n          tags: xaynetwork/xaynet:latest\n          push: true\n          build-args: RELEASE_BUILD=1\n\n      - name: Notify on Slack\n        uses: 8398a7/action-slack@v3\n        if: always()\n        with:\n          status: custom\n          fields: workflow,job,repo,ref\n          custom_payload: |\n            {\n              username: 'GitHub Actions',\n              icon_emoji: ':octocat:',\n              attachments: [{\n                color: '${{ steps.docker.outcome }}' === 'success' ? 'good' : '${{ steps.docker.outcome }}' === 'failure' ? 'danger' : 'warning',\n                text: `${process.env.AS_WORKFLOW}\\nRepository: :xaynet: ${process.env.AS_REPO}\\nRef: ${process.env.AS_REF} :heavy_check_mark:`,\n              }]\n            }\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n          SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }}"
  },
  {
    "path": ".github/workflows/kubernetes-manifests.yml",
    "content": "name: Kubernetes manifests validation\n\non:\n  push:\n    paths:\n      - 'k8s/**'\n\njobs:\n  k8s-kustomize-validation:\n    name: Kubernetes manifests validation\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n\n      - name: Verify Kubernetes manifests\n        run: kubectl kustomize $GITHUB_WORKSPACE/k8s/coordinator/development > /dev/null # Print only errors, if any\n"
  },
  {
    "path": ".github/workflows/rust-audit-cron.yml",
    "content": "name: Rust Audit for Security Vulnerabilities (master)\n\non:\n  schedule:\n    - cron: '00 08 * * mon-fri'\n\njobs:\n  audit:\n    name: Rust Audit\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n        with:\n          ref: master\n\n      - name: Run rust-audit\n        id: rust-audit\n        run: |\n          cargo audit --deny-warnings -f rust/Cargo.lock\n\n      - name: Notify on Slack\n        uses: 8398a7/action-slack@v3\n        if: ${{ failure() }}\n        with:\n          status: custom\n          fields: workflow,job,repo\n          custom_payload: |\n            {\n              username: 'GitHub Actions',\n              icon_emoji: ':octocat:',\n              attachments: [{\n                color: '${{ steps.rust-audit.outcome }}' === 'success' ? 'good' : '${{ steps.rust-audit.outcome }}' === 'failure' ? 'danger' : 'warning',\n                text: `${process.env.AS_WORKFLOW}\\nRepository: ${process.env.AS_REPO}\\nRef: master :warning:`,\n              }]\n            }\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n          SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }}\n"
  },
  {
    "path": ".github/workflows/rust-next.yml",
    "content": "name: Rust-CI Next\n\non:\n  schedule:\n    - cron: '00 04 10,20 * *'\n\njobs:\n  registry-cache:\n    name: cargo-fetch\n    timeout-minutes: 5\n    runs-on: ubuntu-latest\n    outputs:\n      cache-key: ${{ steps.cache-key.outputs.key }}\n      cache-date: ${{ steps.get-date.outputs.date }}\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n\n      - name: Install stable toolchain\n        id: rust-toolchain\n        uses: actions-rs/toolchain@v1\n        with:\n          profile: minimal\n          toolchain: stable\n          default: true\n\n      # We want to create a new cache after a week. Otherwise, the cache will\n      # take up too much space by caching old dependencies\n      - name: Year + ISO week number\n        id: get-date\n        run: echo \"::set-output name=date::$(/bin/date -u \"+%Y-%V\")\"\n        shell: bash\n\n      # We can use the registry cache of the normal rust ci\n      - name: Cache key\n        id: cache-key\n        run: echo \"::set-output name=key::$(echo ${{ runner.os }}-cargo-registry-${{ steps.get-date.outputs.date }}-${{ hashFiles('**/Cargo.lock') }})\"\n        shell: bash\n\n      - name: Cache cargo registry\n        uses: actions/cache@v3.0.8\n        with:\n          path: |\n            ~/.cargo/registry\n            ~/.cargo/git\n          key: ${{ steps.cache-key.outputs.key }}\n          restore-keys: ${{ runner.os }}-cargo-registry-${{ steps.get-date.outputs.date }}-\n\n      - name: cargo fetch\n        working-directory: ./rust\n        run: cargo fetch\n\n  format:\n    name: cargo-fmt\n    needs: registry-cache\n    timeout-minutes: 10\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        cargo_manifest: [rust, bindings/python]\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n\n      - name: Install nightly toolchain\n        id: rust-toolchain\n        uses: actions-rs/toolchain@v1\n        with:\n          profile: minimal\n          toolchain: nightly\n          components: rustfmt\n          default: true\n\n      - name: Use cached cargo registry\n        uses: actions/cache@v3.0.8\n        with:\n          path: |\n            ~/.cargo/registry\n            ~/.cargo/git\n          key: ${{ needs.registry-cache.outputs.cache-key }}\n\n      # cargo fmt does not create any artifacts, therefore we don't need to cache the target folder\n\n      - name: cargo fmt\n        working-directory: ${{ matrix.cargo_manifest }}\n        run: cargo fmt --all -- --check\n\n  check:\n    name: cargo-check\n    needs: registry-cache\n    timeout-minutes: 20\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        rust_version: [stable, beta]\n        cargo_manifest: [rust, bindings/python]\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n\n      - name: Install toolchain\n        id: rust-toolchain\n        uses: actions-rs/toolchain@v1\n        with:\n          profile: minimal\n          toolchain: ${{ matrix.rust_version }}\n          default: true\n\n      - name: Use cached cargo registry\n        uses: actions/cache@v3.0.8\n        with:\n          path: |\n            ~/.cargo/registry\n            ~/.cargo/git\n          key: ${{ needs.registry-cache.outputs.cache-key }}\n\n      - name: cargo check\n        working-directory: ${{ matrix.cargo_manifest }}\n        env:\n          RUSTFLAGS: \"-D warnings\"\n        run: |\n          cargo check --all-targets\n          cargo check --all-targets --all-features\n\n  clippy:\n    name: cargo-clippy\n    needs: [registry-cache, check]\n    timeout-minutes: 20\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        rust_version: [stable, beta]\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n\n      - name: Install toolchain\n        id: rust-toolchain\n        uses: actions-rs/toolchain@v1\n        with:\n          profile: minimal\n          toolchain: ${{ matrix.rust_version }}\n          default: true\n          components: clippy\n\n      - name: Use cached cargo registry\n        uses: actions/cache@v3.0.8\n        with:\n          path: |\n            ~/.cargo/registry\n            ~/.cargo/git\n          key: ${{ needs.registry-cache.outputs.cache-key }}\n\n      - name: cargo clippy\n        working-directory: rust\n        run: |\n          cargo clippy --all-targets -- --deny warnings --deny clippy::cargo\n          cargo clippy --all-targets --all-features -- --deny warnings --deny clippy::cargo\n\n  docs:\n    name: cargo-doc\n    needs: [registry-cache, check]\n    timeout-minutes: 20\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        rust_version: [stable, beta]\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n\n      - name: Install toolchain\n        id: rust-toolchain\n        uses: actions-rs/toolchain@v1\n        with:\n          profile: minimal\n          toolchain: ${{ matrix.rust_version }}\n          default: true\n\n      - name: Use cached cargo registry\n        uses: actions/cache@v3.0.8\n        with:\n          path: |\n            ~/.cargo/registry\n            ~/.cargo/git\n          key: ${{ needs.registry-cache.outputs.cache-key }}\n\n      - name: Check the building of docs\n        working-directory: ./rust\n        run: cargo doc --all-features --document-private-items --no-deps --color always\n\n  notify:\n    name: notify\n    if: failure()\n    needs: [format, check, clippy, docs]\n    timeout-minutes: 20\n    runs-on: ubuntu-latest\n    steps:\n      - name: Notify on Slack\n        uses: 8398a7/action-slack@v3\n        with:\n          status: custom\n          fields: workflow,repo\n          custom_payload: |\n            {\n              username: 'GitHub Actions',\n              icon_emoji: ':octocat:',\n              attachments: [{\n                color: 'danger',\n                text: `${process.env.AS_WORKFLOW} :warning:\\nRepository: ${process.env.AS_REPO}`,\n              }]\n            }\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n          SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }}\n"
  },
  {
    "path": ".github/workflows/rust.yml",
    "content": "name: Rust-CI\n\non:\n  push:\n    paths:\n      - 'rust/**'\n      - 'bindings/python/**'\n      - '.github/workflows/rust.yml'\n      - 'README.md'\n      - 'README.tpl'\n\nenv:\n  RUST_STABLE: 1.55.0\n  RUST_NIGHTLY: nightly-2021-09-09\n\njobs:\n  registry-cache:\n    name: cargo-fetch\n    timeout-minutes: 5\n    runs-on: ubuntu-latest\n    outputs:\n      cache-key: ${{ steps.cache-key.outputs.key }}\n      cache-date: ${{ steps.get-date.outputs.date }}\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n\n      - name: Install stable toolchain\n        id: rust-toolchain\n        uses: actions-rs/toolchain@v1\n        with:\n          profile: minimal\n          toolchain: ${{ env.RUST_STABLE }}\n          default: true\n\n      # We want to create a new cache after a week. Otherwise, the cache will\n      # take up too much space by caching old dependencies\n      - name: Year + ISO week number\n        id: get-date\n        run: echo \"::set-output name=date::$(/bin/date -u \"+%Y-%V\")\"\n        shell: bash\n\n      - name: Cache key\n        id: cache-key\n        run: echo \"::set-output name=key::$(echo ${{ runner.os }}-cargo-registry-${{ steps.get-date.outputs.date }}-${{ hashFiles('**/Cargo.lock') }})\"\n        shell: bash\n\n      - name: Cache cargo registry\n        uses: actions/cache@v3.0.8\n        with:\n          path: |\n            ~/.cargo/registry\n            ~/.cargo/git\n          key: ${{ steps.cache-key.outputs.key }}\n          restore-keys: ${{ runner.os }}-cargo-registry-${{ steps.get-date.outputs.date }}-\n\n      - name: cargo fetch\n        working-directory: ./rust\n        run: cargo fetch\n\n  format:\n    name: cargo-fmt\n    needs: registry-cache\n    timeout-minutes: 10\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        cargo_manifest: [rust, bindings/python]\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n\n      - name: Install nightly toolchain\n        id: rust-toolchain\n        uses: actions-rs/toolchain@v1\n        with:\n          profile: minimal\n          toolchain: ${{ env.RUST_NIGHTLY }}\n          components: rustfmt\n          default: true\n\n      - name: Use cached cargo registry\n        uses: actions/cache@v3.0.8\n        with:\n          path: |\n            ~/.cargo/registry\n            ~/.cargo/git\n          key: ${{ needs.registry-cache.outputs.cache-key }}\n\n      # cargo fmt does not create any artifacts, therefore we don't need to cache the target folder\n\n      - name: cargo fmt\n        working-directory: ${{ matrix.cargo_manifest }}\n        run: cargo fmt --all -- --check\n\n  check:\n    name: cargo-check\n    needs: registry-cache\n    timeout-minutes: 20\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        cargo_manifest: [rust, bindings/python]\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n\n      - name: Install stable toolchain\n        id: rust-toolchain\n        uses: actions-rs/toolchain@v1\n        with:\n          profile: minimal\n          toolchain: ${{ env.RUST_STABLE }}\n          default: true\n\n      - name: Use cached cargo registry\n        uses: actions/cache@v3.0.8\n        with:\n          path: |\n            ~/.cargo/registry\n            ~/.cargo/git\n          key: ${{ needs.registry-cache.outputs.cache-key }}\n\n      - name: Cache build artifacts\n        uses: actions/cache@v3.0.8\n        with:\n          path: ${{ matrix.cargo_manifest }}/target\n          key: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-check-${{ matrix.cargo_manifest }}-${{ needs.registry-cache.outputs.cache-date }}-${{ hashFiles('**/Cargo.lock') }}\n          restore-keys: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-${{ matrix.cargo_manifest }}-check-${{ needs.registry-cache.outputs.cache-date }}-\n\n      - name: cargo check\n        working-directory: ${{ matrix.cargo_manifest }}\n        env:\n          RUSTFLAGS: \"-D warnings\"\n        run: |\n          cargo check --all-targets\n          cargo check --all-targets --all-features\n\n  clippy:\n    name: cargo-clippy\n    needs: [registry-cache, check]\n    timeout-minutes: 20\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n\n      - name: Install stable toolchain\n        id: rust-toolchain\n        uses: actions-rs/toolchain@v1\n        with:\n          profile: minimal\n          toolchain: ${{ env.RUST_STABLE }}\n          default: true\n          components: clippy\n\n      - name: Use cached cargo registry\n        uses: actions/cache@v3.0.8\n        with:\n          path: |\n            ~/.cargo/registry\n            ~/.cargo/git\n          key: ${{ needs.registry-cache.outputs.cache-key }}\n\n      - name: Cache build artifacts\n        uses: actions/cache@v3.0.8\n        with:\n          path: ${{ github.workspace }}/rust/target\n          key: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-clippy-${{ needs.registry-cache.outputs.cache-date }}-${{ hashFiles('**/Cargo.lock') }}\n          restore-keys: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-clippy-${{ needs.registry-cache.outputs.cache-date }}-\n\n      - name: cargo clippy\n        working-directory: rust\n        run: |\n          cargo clippy --all-targets -- --deny warnings --deny clippy::cargo\n          cargo clippy --all-targets --all-features -- --deny warnings --deny clippy::cargo\n\n  test:\n    name: cargo-test\n    needs: [registry-cache, check]\n    timeout-minutes: 20\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n\n      - name: Install stable toolchain\n        id: rust-toolchain\n        uses: actions-rs/toolchain@v1\n        with:\n          profile: minimal\n          toolchain: ${{ env.RUST_STABLE }}\n          default: true\n\n      - name: Use cached cargo registry\n        uses: actions/cache@v3.0.8\n        with:\n          path: |\n            ~/.cargo/registry\n            ~/.cargo/git\n          key: ${{ needs.registry-cache.outputs.cache-key }}\n\n      - name: Cache build artifacts\n        uses: actions/cache@v3.0.8\n        with:\n          path: ${{ github.workspace }}/rust/target\n          key: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-tests-${{ needs.registry-cache.outputs.cache-date }}-${{ hashFiles('**/Cargo.lock') }}\n          restore-keys: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-tests-${{ needs.registry-cache.outputs.cache-date }}-\n\n      - name: Start docker-compose\n        working-directory: ./docker\n        run: docker-compose up -d influxdb minio redis\n\n      - name: Run tests (unit & integration & doc)\n        working-directory: ./rust\n        env:\n          RUSTFLAGS: \"-D warnings\"\n        run: |\n          cargo test --lib --bins --examples --tests -- -Z unstable-options --include-ignored\n          cargo test --lib --bins --examples --tests --all-features -- -Z unstable-options --include-ignored\n          cargo test --doc --all-features\n\n      - name: Stop docker-compose\n        working-directory: ./docker\n        run: docker-compose down\n\n  bench:\n    name: cargo-bench\n    needs: [registry-cache, check]\n    timeout-minutes: 20\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n\n      - name: Install stable toolchain\n        id: rust-toolchain\n        uses: actions-rs/toolchain@v1\n        with:\n          profile: minimal\n          toolchain: ${{ env.RUST_STABLE }}\n          default: true\n\n      - name: Use cached cargo registry\n        uses: actions/cache@v3.0.8\n        with:\n          path: |\n            ~/.cargo/registry\n            ~/.cargo/git\n          key: ${{ needs.registry-cache.outputs.cache-key }}\n\n      - name: Cache build artifacts\n        uses: actions/cache@v3.0.8\n        with:\n          path: ${{ github.workspace }}/rust/target\n          key: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-bench-${{ needs.registry-cache.outputs.cache-date }}-${{ hashFiles('**/Cargo.lock') }}\n          restore-keys: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-bench-${{ needs.registry-cache.outputs.cache-date }}-\n\n      - name: Run Bench\n        working-directory: ./rust/benches\n        run: cargo bench\n\n      - name: Upload bench artifacts\n        uses: actions/upload-artifact@v3\n        with:\n          name: bench_${{ github.sha }}\n          path: ${{ github.workspace }}/rust/benches/target/criterion\n\n  docs:\n    name: cargo-doc\n    needs: [registry-cache, check]\n    timeout-minutes: 20\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n\n      - name: Install stable toolchain\n        id: rust-toolchain\n        uses: actions-rs/toolchain@v1\n        with:\n          profile: minimal\n          toolchain: ${{ env.RUST_STABLE }}\n          default: true\n\n      - name: Use cached cargo registry\n        uses: actions/cache@v3.0.8\n        with:\n          path: |\n            ~/.cargo/registry\n            ~/.cargo/git\n          key: ${{ needs.registry-cache.outputs.cache-key }}\n\n      - name: Cache build artifacts\n        uses: actions/cache@v3.0.8\n        with:\n          path: ${{ github.workspace }}/rust/target\n          key: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-doc-${{ needs.registry-cache.outputs.cache-date }}-${{ hashFiles('**/Cargo.lock') }}\n          restore-keys: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-doc-${{ needs.registry-cache.outputs.cache-date }}-\n\n      - name: Check the building of docs\n        working-directory: ./rust\n        run: cargo doc --all-features --document-private-items --no-deps --color always\n\n  coverage:\n    name: cargo-tarpaulin\n    needs: [registry-cache, check]\n    timeout-minutes: 20\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n\n      - name: Install stable toolchain\n        id: rust-toolchain\n        uses: actions-rs/toolchain@v1\n        with:\n          toolchain: ${{ env.RUST_STABLE }}\n          default: true\n          profile: minimal\n\n      - name: Use cached cargo registry\n        uses: actions/cache@v3.0.8\n        with:\n          path: |\n            ~/.cargo/registry\n            ~/.cargo/git\n          key: ${{ needs.registry-cache.outputs.cache-key }}\n\n      - name: Cache build artifacts\n        uses: actions/cache@v3.0.8\n        with:\n          path: ${{ github.workspace }}/rust/target\n          key: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-tarpaulin-${{ needs.registry-cache.outputs.cache-date }}-${{ hashFiles('**/Cargo.lock') }}\n          restore-keys: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-tarpaulin-${{ needs.registry-cache.outputs.cache-date }}-\n\n      - name: Start docker-compose\n        working-directory: ./docker\n        run: docker-compose up -d influxdb minio redis\n\n      - name: Run cargo-tarpaulin\n        uses: actions-rs/tarpaulin@v0.1\n        with:\n          version: '0.16.0'\n          args: '--manifest-path rust/Cargo.toml --all-features --force-clean --lib --ignore-tests --ignored --workspace --exclude xaynet-analytics'\n\n      - name: Stop docker-compose\n        working-directory: ./docker\n        run: docker-compose down\n\n      - name: Upload to codecov.io\n        uses: codecov/codecov-action@v3.1.0\n        with:\n          token: ${{ secrets.CODECOV_TOKEN }}\n\n  python_sdk:\n    name: python sdk\n    needs: [registry-cache, format, check]\n    timeout-minutes: 20\n    runs-on: ubuntu-latest\n    env:\n        working-directory: ./bindings/python\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n\n      - name: Install Rust\n        id: rust-toolchain\n        uses: actions-rs/toolchain@v1\n        with:\n          profile: minimal\n          toolchain: ${{ env.RUST_STABLE }}\n          default: true\n\n      - name: Cache cargo registry\n        uses: actions/cache@v3.0.8\n        with:\n          path: |\n            ~/.cargo/registry\n            ~/.cargo/git\n          key: ${{ needs.registry-cache.outputs.cache-key }}\n\n      - name: Cache cargo target\n        uses: actions/cache@v3.0.8\n        with:\n          path: ${{ env.working-directory }}/target\n          key: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-python-bindings-${{ needs.registry-cache.outputs.cache-date }}-${{ hashFiles('**/Cargo.lock') }}\n          restore-keys: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.rustc }}-python-bindings-${{ needs.registry-cache.outputs.cache-date }}-\n\n      - name: Setup Python 3.6\n        uses: actions/setup-python@v4\n        with:\n          python-version: 3.6\n          architecture: \"x64\"\n\n      - name: Get pip cache dir\n        id: pip-cache\n        run: echo \"::set-output name=dir::$(pip cache dir)\"\n\n      - name: Cache pip packages\n        uses: actions/cache@v3.0.8\n        with:\n          path: ${{ steps.pip-cache.outputs.dir }}\n          key: ${{ runner.os }}-pip-${{ hashFiles('./bindings/python/setup.py') }}\n\n      - name: Install dependencies and build sdk\n        run: |\n          pip install --upgrade pip\n          pip install --upgrade setuptools\n          pip install maturin==0.9.1 black==20.8b1 isort==5.7.0\n          maturin build\n        working-directory: ${{ env.working-directory }}\n\n      - name: black\n        working-directory: ${{ env.working-directory }}\n        run: black --check .\n\n      - name: isort\n        working-directory: ${{ env.working-directory }}\n        run: isort --check-only --diff .\n\n  readme:\n    name: cargo-readme\n    timeout-minutes: 20\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n\n      - name: Install stable toolchain\n        id: rust-toolchain\n        uses: actions-rs/toolchain@v1\n        with:\n          profile: minimal\n          toolchain: ${{ env.RUST_STABLE }}\n          default: true\n\n      - name: Cache cargo readme\n        uses: actions/cache@v3.0.8\n        with:\n          path: |\n            ~/.cargo/registry\n            ~/.cargo/git\n            ~/.cargo/bin/cargo-readme\n          key: ${{ runner.os }}-cargo-readme-bin\n\n      - name: Install cargo readme\n        run: cargo install cargo-readme || true\n\n      - name: Check that readme matches docs\n        working-directory: ./\n        run: |\n          cargo readme --project-root rust/xaynet/ --template ../../README.tpl --output ../../CARGO_README.md\n          git diff --exit-code --no-index README.md CARGO_README.md\n"
  },
  {
    "path": ".gitignore",
    "content": "**/.ignore/\n\n# https://github.com/github/gitignore/blob/master/Global/macOS.gitignore\n# General\n**.DS_Store\n**.swp\n\n# vscode workspace settings\n.vscode/*\n\nCARGO_README.md\n"
  },
  {
    "path": "CHANGELOG.md",
    "content": "# Changelog\n\nAll notable changes to this project will be documented in this file.\n\nThe format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to the [Semantic Versioning](http://semver.org/spec/v2.0.0.html).\n\n## [unreleased]\n\n### Changed\n\n#### `xaynet-sdk`\n\n- Update to `tokio` `v1.x`\n- Update to `reqwest` `v0.11.x`\n- Update to `bytes` `v1.x`\n\n#### `xaynet-mobile`\n\n- Update to `tokio` `v1.x`\n- Update to `reqwest` `v0.11.x`\n\n#### `examples`\n\n- Update to `tokio` `v1.x`\n- Update to `reqwest` `v0.11.x`\n\n#### `xaynet-server`\n\n- Update to `tokio` `v1.x`\n- Update to `warp` `v0.3.x`\n- Update to `bytes` `v1.x`\n- Update to `rusoto_core` `v0.46.x`\n- Update to `rusoto_s3` `v0.46.x`\n- Update to `tower` `v0.4.x`\n- Update to `redis` `v0.19.x`\n- Enable optional server side client authentication via tls\n- Environment variable prefixes respect the `__` separator now, i.e. all envs have changed from\n`XAYNET_*` to `XAYNET__*`.\n\n## [0.11.0] - 2021-01-18\n\n### Added\n\n#### Rust SDK `xaynet-sdk`\n\n`xaynet-sdk` contains the basic building blocks required to run the _Privacy-Enhancing Technology_\n(PET) Protocol. It consists of a state machine and two I/O interfaces with which specific Xaynet\nparticipants can be developed that are adapted to the respective environments/requirements.\n\nIf you are interested in building your own Xaynet participant, you can take a look at\n`xaynet-sdk`, our [Rust participant](https://github.com/xaynetwork/xaynet/blob/master/rust/examples/test-drive/participant.rs)\nwhich we use primarily for testing or at\n[`xaynet-mobile`](https://github.com/xaynetwork/xaynet/blob/master/rust/xaynet-mobile/src/participant.rs)\nour mobile friendly participant.\n\n#### A Mobile friendly Xaynet participant `xaynet-mobile`\n\n`xaynet-mobile` provides a mobile friendly implementation of a Xaynet participant. It gives the user\na lot of control on how to drive the participant execution. You can regularly pause the execution of\nthe participant, save it, and later restore it and continue the execution. When running on a device\nthat is low on battery or does not have access to Wi-Fi for instance, it can be useful to be able to\npause the participant.\n\n**C API**\n\nFurthermore, `xaynet-mobile` offers `C` bindings that allow `xaynet-mobile` to be used in other\nprogramming languages ​​such as `Dart`.\n\n#### Python participant SDK `xaynet-sdk-python`\n\nWe are happy to announce that we finally released `xaynet-sdk-python` a Python SDK that\nconsists of two experimental Xaynet participants (`ParticipantABC` and `AsyncParticipant`).\n\nThe `ParticipantABC` API is similar to the old one which we introduced in `v0.8.0`. Aside from some\nchanges to the method signature, the biggest change is that the participant now runs in its own\nthread. To migrate from `v0.8.0` to `v0.11.0` please follow the\n[migration guide](https://github.com/xaynetwork/xaynet/blob/master/bindings/python/migration_guide.md).\n\nHowever, we noticed that our Participant API may be difficult to integrate with existing\napplications, considering the code for the training has to be moved into the `train_round` method,\nwhich can lead to significant changes to the existing code. Therefore, we offer a second API\n(`AsyncParticipant`) in which the training of the model is no longer part of the participant.\n\nA more in-depth explanation of the differences between the Participant APIs\nand examples of how to use them can be found\n[here](https://github.com/xaynetwork/xaynet/blob/master/bindings/python/README.md).\n\n#### Multi-part messages\n\nParticipant messages can get large, possibly too large to be sent successfully in one go. On mobile\ndevices in particular, the internet connection may not be as reliable. In order to make the\ntransmission of messages more robust, we implemented multi-part messages to break a large message\ninto parts and send them sequentially to the coordinator. If the transmission of part of\na message fails, only that part will be resent and not the entire message.\n\n#### Coordinator state managed in Redis\n\nIn order to be able to restore the state of the coordinator after a failure or shutdown,\nthe state is managed in Redis and no longer in memory.\n\nThe Redis client can be configured via the `[redis]` setting:\n\n```toml\n[redis]\nurl = \"redis://127.0.0.1/\"\n```\n\n#### Support for storing global models in S3/Minio\n\nThe coordinator is able to save a global model in S3/Minio after a successful round.\n\nThe S3 client can be configured via the `[s3]` setting:\n\n```toml\n[s3]\naccess_key = \"minio\"\nsecret_access_key = \"minio123\"\nregion = [\"minio\", \"http://localhost:9000\"]\n\n[s3.buckets]\nglobal_models = \"global-models\"\n```\n\n`xaynet-server` must be compiled with the feature flag `model-persistence` in order to enable\nthis feature.\n\n#### Restore coordinator state\n\nThe state of the coordinator can be restored after a failure or shutdown.\n\nRestoring the coordinator be configured via the `[restore]` setting:\n\n```toml\n[restore]\nenable = true\n```\n\n`xaynet-server` must be compiled with the feature flag `model-persistence` in order to enable\nthis feature.\n\n#### Improved collection of state machine metrics\n\nIn `v0.10.0` we introduced the collection of metrics that are emitted in the state machine of\n`xaynet-server` and sent to an InfluxDB instance. In `v0.11.0` we have revised the implementation\nand improved it further. Metrics are now sent much faster and adding metrics to the code has\nbecome much easier.\n\n### Removed\n\n  - `xaynet_client` (was split into `xaynet_sdk` and `xaynet_mobile`)\n  - `xaynet_ffi` (is now part of `xaynet_mobile`)\n  - `xaynet_macro`\n\n## [0.10.0] - 2020-09-22\n\n### Added\n\n- Preparation for redis support: prepare for `xaynet_server` to store PET data in redis [#416](https://github.com/xaynetwork/xaynet/pull/416), [#515](https://github.com/xaynetwork/xaynet/pull/515)\n- Add support for multipart messages in the message structure [#508](https://github.com/xaynetwork/xaynet/pull/508), [#513](https://github.com/xaynetwork/xaynet/pull/513), [#514](https://github.com/xaynetwork/xaynet/pull/514)\n- Generalised scalar extension [#496](https://github.com/xaynetwork/xaynet/pull/496), [#507](https://github.com/xaynetwork/xaynet/pull/507)\n- Add server metrics [#487](https://github.com/xaynetwork/xaynet/pull/487), [#488](https://github.com/xaynetwork/xaynet/pull/488), [#489](https://github.com/xaynetwork/xaynet/pull/489), [#493](https://github.com/xaynetwork/xaynet/pull/493)\n- Refactor the client into a state machine, and add a client tailored for mobile devices [#471](https://github.com/xaynetwork/xaynet/pull/471), [#497](https://github.com/xaynetwork/xaynet/pull/497), [#506](https://github.com/xaynetwork/xaynet/pull/506)\n\n### Changed\n\n- Split the xaynet crate into several sub-crates:\n  - `xaynet_core` (0.1.0 released), re-exported as `xaynet::core`\n  - `xaynet_client` (0.1.0 released), re-exported as `xaynet::client` when compiled with `--features client`\n  - `xaynet_server` (0.1.0 released), re-exported as `xaynet::server` when compiled with `--features server`\n  - `xaynet_macro` (0.1.0 released)\n  - `xaynet_ffi` (not released)\n\n## [0.9.0] - 2020-07-24\n\n`xain/xain-fl` repository was renamed to `xaynetwork/xaynet`.\n\nThe new crate will be published as `xaynet` under `v0.9.0`.\n\n### Added\n\nThis release introduces the integration of the [PET protocol](https://uploads-ssl.webflow.com/5f0c5c0bb18a279f0a62919e/5f157004da6585f299fa542b_XayNet%20Whitepaper%202.1.pdf) into the platform.\n\n**Note:**\nThe integration of the PET protocol required a complete rewrite of the codebase and is therefore not compatible with the previous release.\n\n## [0.8.0] - 2020-04-08\n\n### Added\n\n- New tutorial for the Python SDK [#355](https://github.com/xaynetwork/xaynet/pull/355)\n- Swagger description of the REST API [#345](https://github.com/xaynetwork/xaynet/pull/345), and is published at https://xain-fl.readthedocs.io/en/latest/ [#358](https://github.com/xaynetwork/xaynet/pull/358)\n- The Python examples now accepts additional parameters (model size, heartbeat period, verbosity, etc.) [#351](https://github.com/xaynetwork/xaynet/pull/351)\n- Publish docker images to dockerhub\n\n### Security\n\n- Stop using `pickle` for messages serialization\n  [#355](https://github.com/xaynetwork/xaynet/pull/355). `pickle` is insecure\n  and can lead to remote code execution. Instead, the default\n  aggregator uses `numpy.save()`.\n\n### Fixed\n\n- The documentation has been updated at https://xain-fl.readthedocs.io/en/latest/ [#358](https://github.com/xaynetwork/xaynet/pull/358)\n- Document aggregator error on Darwin platform [#365](https://github.com/xaynetwork/xaynet/pull/365/files)\n\n### Changed\n\n- Simplified the Python SDK API [#355](https://github.com/xaynetwork/xaynet/pull/355)\n- Added unit tests for the coordinator and aggregator [#353](https://github.com/xaynetwork/xaynet/pull/353), [#352](https://github.com/xaynetwork/xaynet/pull/352)\n- Refactor the metrics store [#340](https://github.com/xaynetwork/xaynet/pull/340)\n- Speed up the docker builds [#348](https://github.com/xaynetwork/xaynet/pull/348)\n\n## [0.7.0] - 2020-03-25\n\nOn this release we archived the Python code under the `legacy` folder and shifted the development to Rust.\nThis release has many breaking changes from the previous versions.\nMore details will be made available through the updated README.md of the repository.\n\n## [0.6.0] - 2020-02-26\n\n- HOTFIX add disclaimer (#309) [janpetschexain]\n- PB-314: document the new weight exchange mechanism (#308) [Corentin Henry]\n- PB-407 add more debug level logging (#303) [janpetschexain]\n- PB-44 add heartbeat time and timeout to config (#305) [Robert Steiner]\n- PB-423 lock round access (#304) [kwok]\n- PB-439 Make thread pool workers configurable (#302) [Robert Steiner]\n- PB-159: update xain-{proto,sdk} dependencies to the right branch (#301) [Corentin Henry]\n- PB-159: remove weights from gRPC messages (#298) [Corentin Henry]\n- PB-431 send participant state to influxdb (#300) [Robert Steiner]\n- PB-434 separate metrics (#296) [Robert Steiner]\n- PB-406 :snowflake: Configure mypy (#297) [Anastasiia Tymoshchuk]\n- PB-428 send coordinator states (#292) [Robert Steiner]\n- PB-425 split weight init from training (#295) [janpetschexain]\n- PB-398 Round resumption in Coordinator (#285) [kwok]\n- Merge pull request #294 from xainag/master. [Daniel Kravetz]\n- Hotfix: PB-432 :pencil: :books: Update test badge and CI to reflect changes. [Daniel Kravetz]\n- PB-417 Start new development cycle (#291) [Anastasiia Tymoshchuk, kwok]\n\n## [0.5.0] - 2020-02-12\n\nFix minor issues, update documentation.\n\n- PB-402 Add more logs (#281) [Robert Steiner]\n- DO-76 :whale: non alpine image (#287) [Daniel Kravetz]\n- PB-401 Add console renderer (#280) [Robert Steiner]\n- DO-80 :ambulance: Update dev Dockerfile to build gRPC (#286) [Daniel Kravetz]\n- DO-78 :sparkles: add grafana (#284) [Daniel Kravetz]\n- DO-66 :sparkles: Add keycloak (#283) [Daniel Kravetz]\n- PB-400 increment epoch base (#282) [janpetschexain]\n- PB-397 Simplify write metrics function (#279) [Robert Steiner]\n- PB-385 Fix xain-sdk test (#278) [Robert Steiner]\n- PB-352 Add sdk config (#272) [Robert Steiner]\n- Merge pull request #277 from xainag/master. [Daniel Kravetz]\n- Hotfix: update ci. [Daniel Kravetz]\n- DO-72 :art: Make CI name and feature consistent with other repos. [Daniel Kravetz]\n- DO-47 :newspaper: Build test package on release branch. [Daniel Kravetz]\n- PB-269: enable reading participants weights from S3 (#254) [Corentin Henry]\n- PB-363 Start new development cycle (#271) [Anastasiia Tymoshchuk]\n- PB-119 enable isort diff (#262) [janpetschexain]\n- PB-363 :gem: Release v0.4.0. [Daniel Kravetz]\n- DO-73 :green_heart: Disable continue_on_failure for CI jobs. Fix mypy. [Daniel Kravetz]\n\n## [0.4.0] - 2020-02-04\n\nFlatten model weights instead of using lists.\nFix minor issues, update documentation.\n\n- PB-116: pin docutils version (#259) [Corentin Henry]\n- PB-119 update isort config and calls (#260) [janpetschexain]\n- PB-351 Store participant metrics (#244) [Robert Steiner]\n- Adjust isort config (#258) [Robert Steiner]\n- PB-366 flatten weights (#253) [janpetschexain]\n- PB-379 Update black setup (#255) [Anastasiia Tymoshchuk]\n- PB-387 simplify serve module (#251) [Corentin Henry]\n- PB-104: make the tests fast again (#252) [Corentin Henry]\n- PB-122: handle sigint properly (#250) [Corentin Henry]\n- PB-383 write aggregated weights after each round (#246) [Corentin Henry]\n- PB-104: Fix exception in monitor_hearbeats() (#248) [Corentin Henry]\n- DO-57 Update docker-compose files for provisioning InfluxDB (#249) [Ricardo Saffi Marques]\n- DO-59 Provision Redis 5.x for persisting states for the Coordinator (#247) [Ricardo Saffi Marques]\n- PB-381: make the log level configurable (#243) [Corentin Henry]\n- PB-382: cleanup storage (#245) [Corentin Henry]\n- PB-380: split get_logger() (#242) [Corentin Henry]\n- XP-332: grpc resource exhausted (#238) [Robert Steiner]\n- XP-456: fix coordinator command (#241) [Corentin Henry]\n- XP-485 Document revised state machine (#240) [kwok]\n- XP-456: replace CLI argument with a config file (#221) [Corentin Henry]\n- DO-48 :snowflake: :rocket: Build stable package on git tag with SemVer (#234) [Daniel Kravetz]\n- XP-407 update documentation (#239) [janpetschexain]\n- XP-406 remove numpy file cli (#237) [janpetschexain]\n- XP-544 fix aggregate module (#235) [janpetschexain]\n- DO-58: cache xain-fl dependencies in Docker (#232) [Corentin Henry]\n- XP-479 Start training rounds from 0 (#226) [kwok]\n\n## [0.3.0] - 2020-01-21\n\n- XP-505 cleanup docstrings in xain_fl.coordinator (#228)\n- XP-498 more generic shebangs (#229)\n- XP-510 allow for zero epochs on cli (#227)\n- XP-508 Replace circleci badge (#225)\n- XP-505 docstrings cleanup (#224)\n- XP-333 Replace numproto with xain-proto (#220)\n- XP-499 Remove conftest, exclude tests folder (#223)\n- XP-480 revise message names (#222)\n- XP-436 Reinstate FINISHED heartbeat from Coordinator (#219)\n- XP-308 store aggregated weights in S3 buckets (#215)\n- XP-308 store aggregated weights in S3 buckets (#215)\n- XP-422 ai metrics (#216)\n- XP-119 Fix gRPC testing setup so that it can run on macOS (#217)\n- XP-433 Fix docker headings (#218)\n- Xp 373 add sdk as dependency in fl (#214)\n- DO-49  Create initial buckets (#213)\n- XP-424 Remove unused packages (#212)\n- XP-271 fix pylint issues (#210)\n- XP-374 Clean up docs (#211)\n- DO-43  docker compose minio (#208)\n- XP-384 remove unused files (#209)\n- XP-357 make controller parametrisable (#201)\n- XP 273 scripts cleanup (#206)\n- XP-385 Fix docs badge (#204)\n- XP-354 Remove proto files (#200)\n- DO-17  Add Dockerfiles, dockerignore and docs (#202)\n- XP-241 remove legacy participant and sdk dir (#199)\n- XP-168 update setup.py (#191)\n- XP-261 move tests to own dir (#197)\n- XP-257 cleanup cproto dir (#198)\n- XP-265 move benchmarks to separate repo (#193)\n- XP-255 update codeowners and authors in setup (#195)\n- XP-255 update codeowners and authors in setup (#195)\n- XP-229 Update Readme.md (#189)\n- XP-337 Clean up docs before generation (#188)\n- XP-264 put coordinator as own package (#183)\n- XP-272 Archive rust code (#186)\n- Xp 238 add participant selection (#179)\n- XP-229 Update readme (#185)\n- XP-334 Add make docs into docs make file (#184)\n- XP-291 harmonize docs styles (#181)\n- XP-300 Update docs makefile (#180)\n- XP-228 Update readme (#178)\n- XP-248 use structlog (#173)\n- XP-207 model framework agnostic (#166)\n- XAIN-284 rename package name (#176)\n- XP-251 Add ability to pass params per cmd args to coordinator (#174)\n- XP-167 Add gitter badge (#171)\n- Hotfix badge versions and style (#170)\n- Integrate docs with readthedocs (#169)\n- add pull request template (#168)\n\n## [0.2.0] - 2019-12-02\n\n### Changed\n\n- Renamed package from xain to xain-fl\n\n## [0.1.0] - 2019-09-25\n\nThe first public release of **XAIN**\n\n### Added\n\n- FedML implementation on well known\n  [benchmarks](https://github.com/xaynetwork/xaynet/tree/v0.1.0/xain/benchmark) using\n  a realistic deep learning model structure.\n\n[Unreleased]: https://github.com/xaynetwork/xaynet/compare/v0.11.0...HEAD\n[0.11.0]: https://github.com/xaynetwork/xaynet/compare/v0.10.0...v0.11.0\n[0.10.0]: https://github.com/xaynetwork/xaynet/compare/v0.9.0...v0.10.0\n[0.9.0]: https://github.com/xaynetwork/xaynet/compare/v0.8.0...v0.9.0\n[0.8.0]: https://github.com/xaynetwork/xaynet/compare/v0.7.0...v0.8.0\n[0.7.0]: https://github.com/xaynetwork/xaynet/compare/v0.6.0...v0.7.0\n[0.6.0]: https://github.com/xaynetwork/xaynet/compare/v0.5.0...v0.6.0\n[0.5.0]: https://github.com/xaynetwork/xaynet/compare/v0.4.0...v0.5.0\n[0.4.0]: https://github.com/xaynetwork/xaynet/compare/v0.3.0...v0.4.0\n[0.3.0]: https://github.com/xaynetwork/xaynet/compare/v0.2.0...v0.3.0\n[0.2.1]: https://github.com/xaynetwork/xaynet/compare/v0.2.0...v0.2.1\n[0.2.0]: https://github.com/xaynetwork/xaynet/compare/v0.1.0...v0.2.0\n[0.1.0]: https://github.com/xaynetwork/xaynet/tree/v0.1.0\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "[![crates.io badge](https://img.shields.io/crates/v/xaynet.svg)](https://crates.io/crates/xaynet) [![docs.rs badge](https://docs.rs/xaynet/badge.svg)](https://docs.rs/xaynet) [![rustc badge](https://img.shields.io/badge/rustc-1.51.0+-lightgray.svg)](https://www.rust-lang.org/learn/get-started) [![Coverage Status](https://codecov.io/gh/xaynetwork/xaynet/branch/master/graph/badge.svg)](https://codecov.io/gh/xaynetwork/xaynet)\n![Maintenance](https://img.shields.io/badge/maintenance-activly--developed-brightgreen.svg) [![roadmap badge](https://img.shields.io/badge/Roadmap-2021-blue)](./ROADMAP.md)\n\n![Xaynet banner](./assets/xaynet_banner.png)\n\n# xaynet\n\n## Xaynet: Train on the Edge with Federated Learning\n\nWant a framework that supports federated learning on the edge, in\ndesktop browsers, integrates well with mobile apps, is performant, and\npreserves privacy? Welcome to XayNet, written entirely in Rust!\n\n### Making federated learning easy for developers\n\nFrameworks for machine learning - including those expressly for\nfederated learning - exist already. These frameworks typically\nfacilitate federated learning of cross-silo use cases - for example in\ncollaborative learning across a limited number of hospitals or for\ninstance across multiple banks working on a common use case without\nthe need to share valuable and sensitive data.\n\nThis repository focusses on masked cross-device federated learning to\nenable the orchestration of machine learning in millions of low-power\nedge devices, such as smartphones or even cars. By doing this, we hope\nto also increase the pace and scope of adoption of federated learning\nin practice and especially allow the protection of end user data. All\ndata remains in private local premises, whereby only encrypted AI\nmodels get automatically and asynchronously aggregated. Thus, we\nprovide a solution to the AI privacy dilemma and bridge the\noften-existing gap between privacy and convenience. Imagine, for\nexample, a voice assistant to learn new words directly on device level\nand sharing this knowledge with all other instances, without recording\nand collecting your voice input centrally. Or, think about search\nengine that learns to personalise search results without collecting\nyour often sensitive search queries centrally… There are thousands of\nsuch use cases that right today still trade privacy for\nconvenience. We think this shouldn’t be the case and we want to\nprovide an alternative to overcome this dilemma.\n\nConcretely, we provide developers with:\n\n- **App dev tools**: An SDK to integrate federated learning into\n  apps written in Dart or other languages of choice for mobile development,\n  as well as frameworks like Flutter.\n- **Privacy via cross-device federated learning**: Train your AI\n  models locally on edge devices such as mobile phones, browsers,\n  or even in cars. Federated learning automatically aggregates the\n  local models into a global model. Thus, all insights inherent in\n  the local models are captured, while the user data stays\n  private on end devices.\n- **Security Privacy via homomorphic encryption**: Aggregate\n  models with the highest security and trust. Xayn’s masking\n  protocol encrypts all models homomorphically. This enables you\n  to aggregate encrypted local models into a global one – without\n  having to decrypt local models at all. This protects private and\n  even the most sensitive data.\n\n### The case for writing this framework in Rust\n\nOur framework for federated learning is not only a framework for\nmachine learning as such. Rather, it supports the federation of\nmachine learning that takes place on possibly heterogeneous devices\nand where use cases involve many such devices.\n\nThe programming language in which this framework is written should\ntherefore give us strong support for the following:\n\n- **Runs \"everywhere\"**: the language should not require its own\n  runtime and code should compile on a wide range of devices.\n- **Memory and concurrency safety**: code that compiles should be both\n  memory safe and free of data races.\n- **Secure communication**: state of the art cryptography should be\n  available in vetted implementations.\n- **Asynchronous communication**: abstractions for asynchronous\n  communication should exist that make federated learning scale.\n- **Fast and functional**: the language should offer functional\n  abstractions but also compile code into fast executables.\n\nRust is one of the very few choices of modern programming languages\nthat meets these requirements:\n\n- its concepts of Ownership and Borrowing make it both memory and\n  thread-safe (hence avoiding many common concurrency issues).\n- it has a strong and static type discipline and traits, which\n  describe shareable functionality of a type.\n- it is a modern systems programming language, with some functional\n  style features such as pattern matching, closures and iterators.\n- its idiomatic code compares favourably to idiomatic C in performance.\n- it compiles to WASM and can therefore be applied natively in browser\n  settings.\n- it is widely deployable and doesn't necessarily depend on a runtime,\n  unlike languages such as Java and their need for a virtual machine\n  to run its code. Foreign Function Interfaces support calls from\n  other languages/frameworks, including Dart, Python and Flutter.\n- it compiles into LLVM, and so it can draw from the abundant tool\n  suites for LLVM.\n\n---\n\n# Getting Started\n\n## Minimum supported rust version\n\nrustc 1.51.0\n\n## Running the platform\n\nThere are a few different ways to run the backend: via docker, or by deploying it to\na Kubernetes cluster or by compiling the code and running the binary manually.\n\n1. Everything described below assumes your shell's working directory to be the root\nof the repository.\n2. The following instructions assume you have pre-existing knowledge on some\nof the referenced software (like `docker` and `docker-compose`) and/or a working\nsetup (if you decide to compile the Rust code and run the binary manually).\n3. In case you need help with setting up your system accordingly, we recommend you\nrefer to the official documentation of each tool, as supporting them here would be\nbeyond the scope of this project:\n   * [Rust](https://www.rust-lang.org/tools/install)\n   * [Docker](https://docs.docker.com/) and [Docker Compose](https://docs.docker.com/compose/)\n   * [Kubernetes](https://kubernetes.io/docs/home/)\n\n**Note:**\n\nWith Xaynet `v0.11` the coordinator needs a connection to a redis instance in order to save its state.\n\n**Don't connect the coordinator to a Redis instance that is used in production!**\n\nWe recommend connecting the coordinator to its own Redis instance. We have invested a lot of\ntime to make sure that the coordinator only deletes its own data but in the current state of\ndevelopment, we cannot guarantee that this will always be the case.\n\n### Using Docker\n\nThe convenience of using the docker setup is that there's no need to setup a working Rust\nenvironment on your system, as everything is done inside the container.\n\n#### Run an image from Docker Hub\n\nDocker images of the latest releases are provided on\n[Docker Hub](https://hub.docker.com/r/xaynetwork/xaynet).\n\nYou can try them out with the default `configs/docker-dev.toml` by running:\n\n**Xaynet below v0.11**\n\n```bash\ndocker run -v ${PWD}/configs/docker-dev.toml:/app/config.toml -p 8081:8081 xaynetwork/xaynet:v0.10.0 /app/coordinator -c /app/config.toml\n```\n\n**Xaynet v0.11+**\n\n```bash\n# don't forget to adjust the Redis url in configs/docker-dev.toml\ndocker run -v ${PWD}/configs/docker-dev.toml:/app/config.toml -p 8081:8081 xaynetwork/xaynet:v0.11.0\n```\n\nThe docker image contains a release build of the coordinator without optional features.\n\n#### Run a coordinator with additional infrastructure\n\nStart the coordinator by pointing to the `docker/docker-compose.yml` file. It spins up all\ninfrastructure that is essential to run the coordinator with default or optional features.\nKeep in mind that this file is used for development only.\n\n```bash\ndocker-compose -f docker/docker-compose.yml up --build\n```\n\n#### Create a release build\n\nIf you would like, you can create an optimized release build of the coordinator,\nbut keep in mind that the compilation will be slower.\n\n```bash\ndocker build --build-arg RELEASE_BUILD=1 -f ./docker/Dockerfile .\n```\n\n#### Build a coordinator with optional features\n\nOptional features can be specified via the build argument `COORDINATOR_FEATURES`.\n\n```bash\ndocker build --build-arg COORDINATOR_FEATURES=tls,metrics -f ./docker/Dockerfile .\n```\n\n### Using Kubernetes\n\nTo deploy an instance of the coordinator to your Kubernetes cluster, use the manifests that are\nlocated inside the `k8s/coordinator` folder. The manifests rely on `kustomize` to be generated\n(`kustomize` is officially supported by `kubectl` since v1.14). We recommend you thoroughly go\nthrough the manifests and adjust them according to your own setup (namespace, ingress, etc.).\n\nRemember to also check (and adjust if necessary) the default configuration for the coordinator, available\nat `k8s/coordinator/development/config.toml`.\n\nPlease adjust the domain used in the `k8s/coordinator/development/ingress.yaml` file so it matches\nyour needs (you can also skip `ingress` altogether, just make sure you remove its reference from\n`k8s/coordinator/development/kustomization.yaml`).\n\nKeep in mind that the `ingress` configuration that is shown on `k8s/coordinator/development/ingress.yaml`\nrelies on resources that aren't available in this repository, due to their sensitive nature\n(TLS key and certificate, for instance).\n\nTo verify the generated manifests, run:\n\n```bash\nkubectl kustomize k8s/coordinator/development\n```\n\nTo apply them:\n\n```bash\nkubectl apply -k k8s/coordinator/development\n```\n\nIn case you are not exposing your coordinator via `ingress`, you can still reach it using a port-forward.\nThe example below creates a port-forward at port `8081` assuming the coordinator pod is still using the\n`app=coordinator` label:\n\n```bash\nkubectl port-forward $(kubectl get pods -l \"app=coordinator\" -o jsonpath=\"{.items[0].metadata.name}\") 8081\n```\n\n### Building the project manually\n\nThe coordinator without optional features can be built and started with:\n\n```bash\ncd rust\ncargo run --bin coordinator -- -c ../configs/config.toml\n```\n\n## Running the example\n\nThe example can be found under [rust/examples/](./rust/examples/). It uses a dummy model\nbut is network-capable, so it's a good starting point for checking connectivity with\nthe coordinator.\n\n### `test-drive`\n\nMake sure you have a running instance of the coordinator and that the clients\nyou will spawn with the command below are able to reach it through the network.\n\nHere is an example on how to start `20` participants that will connect to a coordinator\nrunning on `127.0.0.1:8081`:\n\n```bash\ncd rust\nRUST_LOG=info cargo run --example test-drive -- -n 20 -u http://127.0.0.1:8081\n```\n\nFor more in-depth details on how to run examples, see the accompanying Getting\nStarted guide under [rust/xaynet-server/src/examples.rs](./rust/xaynet-server/src/examples.rs).\n\n## Troubleshooting\n\nIf you have any difficulties running the project, please reach out to us by\n[opening an issue](https://github.com/xaynetwork/xaynet/issues/new) and describing your setup\nand the problems you're facing.\n"
  },
  {
    "path": "README.tpl",
    "content": "[![crates.io badge](https://img.shields.io/crates/v/xaynet.svg)](https://crates.io/crates/xaynet) [![docs.rs badge](https://docs.rs/xaynet/badge.svg)](https://docs.rs/xaynet) [![rustc badge](https://img.shields.io/badge/rustc-1.51.0+-lightgray.svg)](https://www.rust-lang.org/learn/get-started) {{badges}} [![roadmap badge](https://img.shields.io/badge/Roadmap-2021-blue)](./ROADMAP.md)\n\n![Xaynet banner](./assets/xaynet_banner.png)\n\n# {{crate}}\n\n{{readme}}\n\n---\n\n# Getting Started\n\n## Minimum supported rust version\n\nrustc 1.51.0\n\n## Running the platform\n\nThere are a few different ways to run the backend: via docker, or by deploying it to\na Kubernetes cluster or by compiling the code and running the binary manually.\n\n1. Everything described below assumes your shell's working directory to be the root\nof the repository.\n2. The following instructions assume you have pre-existing knowledge on some\nof the referenced software (like `docker` and `docker-compose`) and/or a working\nsetup (if you decide to compile the Rust code and run the binary manually).\n3. In case you need help with setting up your system accordingly, we recommend you\nrefer to the official documentation of each tool, as supporting them here would be\nbeyond the scope of this project:\n   * [Rust](https://www.rust-lang.org/tools/install)\n   * [Docker](https://docs.docker.com/) and [Docker Compose](https://docs.docker.com/compose/)\n   * [Kubernetes](https://kubernetes.io/docs/home/)\n\n**Note:**\n\nWith Xaynet `v0.11` the coordinator needs a connection to a redis instance in order to save its state.\n\n**Don't connect the coordinator to a Redis instance that is used in production!**\n\nWe recommend connecting the coordinator to its own Redis instance. We have invested a lot of\ntime to make sure that the coordinator only deletes its own data but in the current state of\ndevelopment, we cannot guarantee that this will always be the case.\n\n### Using Docker\n\nThe convenience of using the docker setup is that there's no need to setup a working Rust\nenvironment on your system, as everything is done inside the container.\n\n#### Run an image from Docker Hub\n\nDocker images of the latest releases are provided on\n[Docker Hub](https://hub.docker.com/r/xaynetwork/xaynet).\n\nYou can try them out with the default `configs/docker-dev.toml` by running:\n\n**Xaynet below v0.11**\n\n```bash\ndocker run -v ${PWD}/configs/docker-dev.toml:/app/config.toml -p 8081:8081 xaynetwork/xaynet:v0.10.0 /app/coordinator -c /app/config.toml\n```\n\n**Xaynet v0.11+**\n\n```bash\n# don't forget to adjust the Redis url in configs/docker-dev.toml\ndocker run -v ${PWD}/configs/docker-dev.toml:/app/config.toml -p 8081:8081 xaynetwork/xaynet:v0.11.0\n```\n\nThe docker image contains a release build of the coordinator without optional features.\n\n#### Run a coordinator with additional infrastructure\n\nStart the coordinator by pointing to the `docker/docker-compose.yml` file. It spins up all\ninfrastructure that is essential to run the coordinator with default or optional features.\nKeep in mind that this file is used for development only.\n\n```bash\ndocker-compose -f docker/docker-compose.yml up --build\n```\n\n#### Create a release build\n\nIf you would like, you can create an optimized release build of the coordinator,\nbut keep in mind that the compilation will be slower.\n\n```bash\ndocker build --build-arg RELEASE_BUILD=1 -f ./docker/Dockerfile .\n```\n\n#### Build a coordinator with optional features\n\nOptional features can be specified via the build argument `COORDINATOR_FEATURES`.\n\n```bash\ndocker build --build-arg COORDINATOR_FEATURES=tls,metrics -f ./docker/Dockerfile .\n```\n\n### Using Kubernetes\n\nTo deploy an instance of the coordinator to your Kubernetes cluster, use the manifests that are\nlocated inside the `k8s/coordinator` folder. The manifests rely on `kustomize` to be generated\n(`kustomize` is officially supported by `kubectl` since v1.14). We recommend you thoroughly go\nthrough the manifests and adjust them according to your own setup (namespace, ingress, etc.).\n\nRemember to also check (and adjust if necessary) the default configuration for the coordinator, available\nat `k8s/coordinator/development/config.toml`.\n\nPlease adjust the domain used in the `k8s/coordinator/development/ingress.yaml` file so it matches\nyour needs (you can also skip `ingress` altogether, just make sure you remove its reference from\n`k8s/coordinator/development/kustomization.yaml`).\n\nKeep in mind that the `ingress` configuration that is shown on `k8s/coordinator/development/ingress.yaml`\nrelies on resources that aren't available in this repository, due to their sensitive nature\n(TLS key and certificate, for instance).\n\nTo verify the generated manifests, run:\n\n```bash\nkubectl kustomize k8s/coordinator/development\n```\n\nTo apply them:\n\n```bash\nkubectl apply -k k8s/coordinator/development\n```\n\nIn case you are not exposing your coordinator via `ingress`, you can still reach it using a port-forward.\nThe example below creates a port-forward at port `8081` assuming the coordinator pod is still using the\n`app=coordinator` label:\n\n```bash\nkubectl port-forward $(kubectl get pods -l \"app=coordinator\" -o jsonpath=\"{.items[0].metadata.name}\") 8081\n```\n\n### Building the project manually\n\nThe coordinator without optional features can be built and started with:\n\n```bash\ncd rust\ncargo run --bin coordinator -- -c ../configs/config.toml\n```\n\n## Running the example\n\nThe example can be found under [rust/examples/](./rust/examples/). It uses a dummy model\nbut is network-capable, so it's a good starting point for checking connectivity with\nthe coordinator.\n\n### `test-drive`\n\nMake sure you have a running instance of the coordinator and that the clients\nyou will spawn with the command below are able to reach it through the network.\n\nHere is an example on how to start `20` participants that will connect to a coordinator\nrunning on `127.0.0.1:8081`:\n\n```bash\ncd rust\nRUST_LOG=info cargo run --example test-drive -- -n 20 -u http://127.0.0.1:8081\n```\n\nFor more in-depth details on how to run examples, see the accompanying Getting\nStarted guide under [rust/xaynet-server/src/examples.rs](./rust/xaynet-server/src/examples.rs).\n\n## Troubleshooting\n\nIf you have any difficulties running the project, please reach out to us by\n[opening an issue](https://github.com/xaynetwork/xaynet/issues/new) and describing your setup\nand the problems you're facing.\n"
  },
  {
    "path": "ROADMAP.md",
    "content": "# Roadmap 2021\n\n![Roadmap Q1](./assets/roadmap_q1.png)\n\nIn Q1 we focus entirely on using XayNet for the [Xayn app] in terms of federated learning and\nfirst simple analytics, such as gathering relevant AI performance data like [NDCG metrics]\nbecause we want to know how our AI models perform without violating the privacy of our users.\nAs you know, our framework originated with the aim to aggregate machine learning models securely\nand privately between edge devices. Thereby, the models are transformed into one-dimensional lists\nso that at the end we only aggregate a list of numbers, so why not also aggregate other numerical\nanalytics data, like AI performance metrics or user behaviour, such as screen times in our app,\nall of course with the privacy guarantees of XayNet. As such, we focus predominantly on mobile\ncross-device learning but also extend our framework to cover such use cases. In Q1 we take however\nmostly care about the internal mobile case and testing so we set the basis to further\ngeneralisation to external cases in the community during the rest of the year.\n\n![Roadmap Q2](./assets/roadmap_q2.png)\n\nIn Q2 we have three main focus points: Extending XayNet to support also web applications, since\nalso our [Xayn app] will be provided as a web version via [WASM]; integrating our product analytics\nextensions in our [Xayn app] and optimising the client for higher performance, which is one the\nmajor bottlenecks.\n\n![Roadmap Q3](./assets/roadmap_q3.png)\n\nIn Q3, we can imagine to opening up the analytics layer also to more general use cases outside of\nXayn itself. Until then our core focus is predominantly internally, yet, of course we hope to get\ncommunity and external feature suggestions and reviews. Also we want to make the coordinator more\nobservable as a foundation for further optimisations.\n\n[Xayn app]: https://www.xayn.com/\n[NDCG metrics]: https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG\n[WASM]: https://webassembly.org/\n"
  },
  {
    "path": "bindings/python/.gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\nglobal_model.bin\nstate.bin"
  },
  {
    "path": "bindings/python/.isort.cfg",
    "content": "[settings]\ncombine_as_imports=True\nforce_grid_wrap=0\nforce_sort_within_sections=True\ninclude_trailing_comma=True\nindent=4\nline_length=88\nmulti_line_output=3\nuse_parentheses=True\n"
  },
  {
    "path": "bindings/python/.pylintrc",
    "content": "[MASTER]\n\n# A comma-separated list of package or module names from where C extensions may\n# be loaded. Extensions are loading into the active Python interpreter and may\n# run arbitrary code.\nextension-pkg-whitelist=\n\n# Add files or directories to the blacklist. They should be base names, not\n# paths.\nignore=grpc\n\n# Add files or directories matching the regex patterns to the blacklist. The\n# regex matches against base names, not paths.\n# ignore-patterns=\n\n# Python code to execute, usually for sys.path manipulation such as\n# pygtk.require().\n#init-hook=\n\n# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the\n# number of processors available to use.\n# Use only 1 because of https://github.com/PyCQA/pylint/issues/374\njobs=1\n\n# Control the amount of potential inferred values when inferring a single\n# object. This can help the performance when dealing with large functions or\n# complex, nested conditions.\nlimit-inference-results=100\n\n# Pickle collected data for later comparisons.\npersistent=yes\n\n# Specify a configuration file.\n#rcfile=\n\n# When enabled, pylint would attempt to guess common misconfiguration and emit\n# user-friendly hints instead of false-positive error messages.\nsuggestion-mode=yes\n\n# Allow loading of arbitrary C extensions. Extensions are imported into the\n# active Python interpreter and may run arbitrary code.\nunsafe-load-any-extension=no\n\n\n[MESSAGES CONTROL]\n\n# Only show warnings with the listed confidence levels. Leave empty to show\n# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.\nconfidence=\n\n# Disable the message, report, category or checker with the given id(s). You\n# can either give multiple identifiers separated by comma (,) or put this\n# option multiple times (only on the command line, not in the configuration\n# file where it should appear only once). You can also use \"--disable=all\" to\n# disable everything first and then reenable specific checks. For example, if\n# you want to run only the similarities checker, you can use \"--disable=all\n# --enable=similarities\". If you want to run only the classes checker, but have\n# no Warning level messages displayed, use \"--disable=all --enable=classes\n# --disable=W\".\ndisable=print-statement,\n        old-raise-syntax,\n        backtick,\n        long-suffix,\n        old-ne-operator,\n        old-octal-literal,\n        import-star-module-level,\n        non-ascii-bytes-literal,\n        raw-checker-failed,\n        bad-inline-option,\n        locally-disabled,\n        file-ignored,\n        suppressed-message,\n        useless-suppression,\n        deprecated-pragma,\n        apply-builtin,\n        basestring-builtin,\n        buffer-builtin,\n        cmp-builtin,\n        coerce-builtin,\n        execfile-builtin,\n        file-builtin,\n        long-builtin,\n        raw_input-builtin,\n        reduce-builtin,\n        standarderror-builtin,\n        unicode-builtin,\n        xrange-builtin,\n        coerce-method,\n        delslice-method,\n        getslice-method,\n        setslice-method,\n        no-absolute-import,\n        old-division,\n        dict-iter-method,\n        dict-view-method,\n        next-method-called,\n        metaclass-assignment,\n        indexing-exception,\n        raising-string,\n        reload-builtin,\n        oct-method,\n        hex-method,\n        nonzero-method,\n        cmp-method,\n        input-builtin,\n        round-builtin,\n        intern-builtin,\n        unichr-builtin,\n        map-builtin-not-iterating,\n        zip-builtin-not-iterating,\n        range-builtin-not-iterating,\n        filter-builtin-not-iterating,\n        using-cmp-argument,\n        eq-without-hash,\n        div-method,\n        idiv-method,\n        rdiv-method,\n        exception-message-attribute,\n        invalid-str-codec,\n        sys-max-int,\n        bad-python3-import,\n        deprecated-string-function,\n        deprecated-str-translate-call,\n        deprecated-itertools-function,\n        deprecated-types-field,\n        next-method-defined,\n        dict-items-not-iterating,\n        dict-keys-not-iterating,\n        dict-values-not-iterating,\n        deprecated-operator-function,\n        deprecated-urllib-function,\n        xreadlines-attribute,\n        deprecated-sys-function,\n        exception-escape,\n        comprehension-escape,\n        c-extension-no-member,\n        duplicate-code,\n        bad-continuation,\n        fixme,\n        redefined-builtin,\n        missing-docstring,\n\n# Enable the message, report, category or checker with the given id(s). You can\n# either give multiple identifier separated by comma (,) or put this option\n# multiple time (only on the command line, not in the configuration file where\n# it should appear only once). See also the \"--disable\" option for examples.\nenable=\n\n\n[REFACTORING]\n\n# Maximum number of nested blocks for function / method body\nmax-nested-blocks=5\n\n# Complete name of functions that never returns. When checking for\n# inconsistent-return-statements if a never returning function is called then\n# it will be considered as an explicit return statement and no message will be\n# printed.\nnever-returning-functions=sys.exit\n\n\n[LOGGING]\n\n# Format style used to check logging format string. `old` means using %\n# formatting, while `new` is for `{}` formatting.\nlogging-format-style=old\n\n# Logging modules to check that the string format arguments are in logging\n# function parameter format.\nlogging-modules=logging\n\n\n[SPELLING]\n\n# Limits count of emitted suggestions for spelling mistakes.\nmax-spelling-suggestions=4\n\n# Spelling dictionary name. Available dictionaries: none. To make it working\n# install python-enchant package..\nspelling-dict=\n\n# List of comma separated words that should not be checked.\nspelling-ignore-words=\n\n# A path to a file that contains private dictionary; one word per line.\nspelling-private-dict-file=\n\n# Tells whether to store unknown words to indicated private dictionary in\n# --spelling-private-dict-file option instead of raising a message.\nspelling-store-unknown-words=no\n\n\n[MISCELLANEOUS]\n\n# List of note tags to take in consideration, separated by a comma.\nnotes=FIXME,\n      XXX,\n      TODO\n\n\n[TYPECHECK]\n\n# List of decorators that produce context managers, such as\n# contextlib.contextmanager. Add to this list to register other decorators that\n# produce valid context managers.\ncontextmanager-decorators=contextlib.contextmanager\n\n# List of members which are set dynamically and missed by pylint inference\n# system, and so shouldn't trigger E1101 when accessed. Python regular\n# expressions are accepted.\n\n# regular expressions currently don't work https://github.com/PyCQA/pylint/issues/2498.\n\ngenerated-members=\n\n\n# Tells whether missing members accessed in mixin class should be ignored. A\n# mixin class is detected if its name ends with \"mixin\" (case insensitive).\nignore-mixin-members=yes\n\n# Tells whether to warn about missing members when the owner of the attribute\n# is inferred to be None.\nignore-none=yes\n\n# This flag controls whether pylint should warn about no-member and similar\n# checks whenever an opaque object is returned when inferring. The inference\n# can return multiple potential results while evaluating a Python object, but\n# some branches might not be evaluated, which results in partial inference. In\n# that case, it might be useful to still emit no-member and other checks for\n# the rest of the inferred objects.\nignore-on-opaque-inference=yes\n\n# List of class names for which member attributes should not be checked (useful\n# for classes with dynamically set attributes). This supports the use of\n# qualified names.\nignored-classes=optparse.Values,thread._local,_thread._local,SQLAlchemy\n\n# List of module names for which member attributes should not be checked\n# (useful for modules/projects where namespaces are manipulated during runtime\n# and thus existing member attributes cannot be deduced by static analysis. It\n# supports qualified module names, as well as Unix pattern matching.\nignored-modules=\n\n# Show a hint with possible names when a member name was not found. The aspect\n# of finding the hint is based on edit distance.\nmissing-member-hint=yes\n\n# The minimum edit distance a name should have in order to be considered a\n# similar match for a missing member name.\nmissing-member-hint-distance=1\n\n# The total number of similar names that should be taken in consideration when\n# showing a hint for a missing member.\nmissing-member-max-choices=1\n\n\n[VARIABLES]\n\n# List of additional names supposed to be defined in builtins. Remember that\n# you should avoid defining new builtins when possible.\nadditional-builtins=\n\n# Tells whether unused global variables should be treated as a violation.\nallow-global-unused-variables=yes\n\n# List of strings which can identify a callback function by name. A callback\n# name must start or end with one of those strings.\ncallbacks=cb_,\n          _cb\n\n# A regular expression matching the name of dummy variables (i.e. expected to\n# not be used).\ndummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_\n\n# Argument names that match this expression will be ignored. Default to name\n# with leading underscore.\nignored-argument-names=_.*|^ignored_|^unused_\n\n# Tells whether we should check for unused import in __init__ files.\ninit-import=no\n\n# List of qualified module names which can have objects that can redefine\n# builtins.\nredefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io\n\n\n[FORMAT]\n\n# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.\nexpected-line-ending-format=\n\n# Regexp for a line that is allowed to be longer than the limit.\nignore-long-lines=^\\s*(# )?<?https?://\\S+>?$\n\n# Number of spaces of indent required inside a hanging or continued line.\nindent-after-paren=4\n\n# String used as indentation unit. This is usually \"    \" (4 spaces) or \"\\t\" (1\n# tab).\nindent-string='    '\n\n# Maximum number of characters on a single line.\nmax-line-length=100\n\n# Maximum number of lines in a module.\nmax-module-lines=2000\n\n# List of optional constructs for which whitespace checking is disabled. `dict-\n# separator` is used to allow tabulation in dicts, etc.: {1  : 1,\\n222: 2}.\n# `trailing-comma` allows a space between comma and closing bracket: (a, ).\n# `empty-line` allows space-only lines.\nno-space-check=trailing-comma,\n               dict-separator\n\n# Allow the body of a class to be on the same line as the declaration if body\n# contains single statement.\nsingle-line-class-stmt=no\n\n# Allow the body of an if to be on the same line as the test if there is no\n# else.\nsingle-line-if-stmt=no\n\n\n[SIMILARITIES]\n\n# Ignore comments when computing similarities.\nignore-comments=yes\n\n# Ignore docstrings when computing similarities.\nignore-docstrings=yes\n\n# Ignore imports when computing similarities.\nignore-imports=no\n\n# Minimum lines number of a similarity.\nmin-similarity-lines=10\n\n\n[BASIC]\n\n# Naming style matching correct argument names.\nargument-naming-style=snake_case\n\n# Regular expression matching correct argument names. Overrides argument-\n# naming-style.\nargument-rgx=[a-z_][a-z0-9_]{2,30}$\n\n# Naming style matching correct attribute names.\nattr-naming-style=snake_case\n\n# Regular expression matching correct attribute names. Overrides attr-naming-\n# style.\nattr-rgx=[a-z_][a-z0-9_]{2,}$\n\n# Bad variable names which should always be refused, separated by a comma.\nbad-names=foo,\n          bar,\n          baz,\n          toto,\n          tutu,\n          tata\n\n# Naming style matching correct class attribute names.\nclass-attribute-naming-style=any\n\n# Regular expression matching correct class attribute names. Overrides class-\n# attribute-naming-style.\nclass-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$\n\n# Naming style matching correct class names.\nclass-naming-style=PascalCase\n\n# Regular expression matching correct class names. Overrides class-naming-\n# style.\nclass-rgx=[A-Z_][a-zA-Z0-9]+$\n\n# Naming style matching correct constant names.\nconst-naming-style=UPPER_CASE\n\n# Regular expression matching correct constant names. Overrides const-naming-\n# style.\nconst-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$\n\n# Minimum line length for functions/classes that require docstrings, shorter\n# ones are exempt.\ndocstring-min-length=1\n\n# Naming style matching correct function names.\nfunction-naming-style=snake_case\n\n# Regular expression matching correct function names. Overrides function-\n# naming-style.\nfunction-rgx=[a-z_][a-z0-9_]{2,}$\n\n# Good variable names which should always be accepted, separated by a comma.\ngood-names=i,\n           j,\n           k,\n           ex,\n           Run,\n           _,\n           logger,\n           _y\n\n# Include a hint for the correct naming format with invalid-name.\ninclude-naming-hint=no\n\n# Naming style matching correct inline iteration names.\ninlinevar-naming-style=any\n\n# Regular expression matching correct inline iteration names. Overrides\n# inlinevar-naming-style.\ninlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$\n\n# Naming style matching correct method names.\nmethod-naming-style=snake_case\n\n# Regular expression matching correct method names. Overrides method-naming-\n# style.\nmethod-rgx=[a-z_][a-z0-9_]{2,}$\n\n# Naming style matching correct module names.\nmodule-naming-style=snake_case\n\n# Regular expression matching correct module names. Overrides module-naming-\n# style.\nmodule-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$\n\n# Colon-delimited sets of names that determine each other's naming style when\n# the name regexes allow several styles.\nname-group=\n\n# Regular expression which should only match function or class names that do\n# not require a docstring.\nno-docstring-rgx=^_\n\n# List of decorators that produce properties, such as abc.abstractproperty. Add\n# to this list to register other decorators that produce valid properties.\n# These decorators are taken in consideration only for invalid-name.\nproperty-classes=abc.abstractproperty\n\n# Naming style matching correct variable names.\nvariable-naming-style=snake_case\n\n# Regular expression matching correct variable names. Overrides variable-\n# naming-style.\nvariable-rgx=[a-z_][a-z0-9_]{2,30}$\n\n\n[STRING]\n\n# This flag controls whether the implicit-str-concat-in-sequence should\n# generate a warning on implicit string concatenation in sequences defined over\n# several lines.\ncheck-str-concat-over-line-jumps=no\n\n\n[STRING_QUOTES]\n\n# The quote character for triple-quoted docstrings.\ndocstring-quote=double\n\n# The quote character for string literals.\nstring-quote=double-avoid-escape\n\n# The quote character for triple-quoted strings (non-docstring).\ntriple-quote=double\n\n\n[IMPORTS]\n\n# Allow wildcard imports from modules that define __all__.\nallow-wildcard-with-all=no\n\n# Analyse import fallback blocks. This can be used to support both Python 2 and\n# 3 compatible code, which means that the block might have code that exists\n# only in one or another interpreter, leading to false positives when analysed.\nanalyse-fallback-blocks=no\n\n# Deprecated modules which should not be used, separated by a comma.\ndeprecated-modules=optparse,tkinter.tix\n\n# Create a graph of external dependencies in the given file (report RP0402 must\n# not be disabled).\next-import-graph=\n\n# Create a graph of every (i.e. internal and external) dependencies in the\n# given file (report RP0402 must not be disabled).\nimport-graph=\n\n# Create a graph of internal dependencies in the given file (report RP0402 must\n# not be disabled).\nint-import-graph=\n\n# Force import order to recognize a module as part of the standard\n# compatibility libraries.\nknown-standard-library=\n\n# Force import order to recognize a module as part of a third party library.\nknown-third-party=enchant\n\n\n[CLASSES]\n\n# List of method names used to declare (i.e. assign) instance attributes.\ndefining-attr-methods=__init__,\n                      __new__,\n                      setUp\n\n# List of member names, which should be excluded from the protected access\n# warning.\nexclude-protected=_asdict,\n                  _fields,\n                  _replace,\n                  _source,\n                  _make\n\n# List of valid names for the first argument in a class method.\nvalid-classmethod-first-arg=cls\n\n# List of valid names for the first argument in a metaclass class method.\nvalid-metaclass-classmethod-first-arg=mcs\n\n\n[DESIGN]\n\n# Maximum number of arguments for function / method.\nmax-args=10\n\n# Maximum number of attributes for a class (see R0902).\nmax-attributes=11\n\n# Maximum number of boolean expressions in an if statement.\nmax-bool-expr=5\n\n# Maximum number of branch for function / method body.\nmax-branches=26\n\n# Maximum number of locals for function / method body.\nmax-locals=25\n\n# Maximum number of parents for a class (see R0901).\nmax-parents=7\n\n# Maximum number of public methods for a class (see R0904).\nmax-public-methods=25\n\n# Maximum number of return / yield for function / method body.\nmax-returns=6\n\n# Maximum number of statements in function / method body.\nmax-statements=100\n\n# Minimum number of public methods for a class (see R0903).\nmin-public-methods=0\n\n\n[EXCEPTIONS]\n\n# Exceptions that will emit a warning when being caught. Defaults to\n# \"BaseException, Exception\".\novergeneral-exceptions=Exception\n"
  },
  {
    "path": "bindings/python/Cargo.toml",
    "content": "[package]\nname = \"xaynet-sdk-python\"\nversion = \"0.1.0\"\nauthors = [\"Xayn Engineering <engineering@xaynet.dev>\"]\nedition = \"2018\"\ndescription = \"The Xayn Network project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant.\"\nreadme = \"README.md\"\nhomepage = \"https://xaynet.dev/\"\nrepository = \"https://github.com/xaynetwork/xaynet/\"\nlicense = \"Apache-2.0\"\nkeywords = [\"federated-learning\", \"fl\", \"ai\", \"machine-learning\"]\ncategories = [\"science\", \"cryptography\"]\n\n[package.metadata.maturin]\nclassifiers = [\n        \"Development Status :: 3 - Alpha\",\n        \"Intended Audience :: Developers\",\n        \"Intended Audience :: Information Technology\",\n        \"Intended Audience :: Science/Research\",\n        \"Topic :: Scientific/Engineering\",\n        \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n        \"Topic :: Software Development\",\n        \"Topic :: Software Development :: Libraries\",\n        \"Topic :: Software Development :: Libraries :: Application Frameworks\",\n        \"Topic :: Software Development :: Libraries :: Python Modules\",\n        \"License :: OSI Approved :: Apache Software License\",\n        \"Programming Language :: Python :: 3 :: Only\",\n        \"Programming Language :: Python :: 3.6\",\n        \"Programming Language :: Python :: 3.7\",\n        \"Programming Language :: Python :: 3.8\",\n        \"Programming Language :: Python :: 3.9\",\n        \"Operating System :: MacOS :: MacOS X\",\n        \"Operating System :: POSIX :: Linux\",\n]\nrequires-python = \">=3.6\"\nrequires-dist = [\n    \"justbackoff (==0.6.0)\",\n]\n\n[package.metadata]\n# minimum supported rust version\nmsrv = \"1.51.0\"\n\n[dependencies]\nsodiumoxide = \"0.2.7\"\ntracing = \"0.1.36\"\ntracing-subscriber = { version = \"0.3.15\", features = [\"env-filter\"] }\npyo3 = {version = \"=0.13.2\", features = [\"abi3-py36\", \"extension-module\"]}\nxaynet-core = { path = \"../../rust/xaynet-core\", version = \"0.2.0\"}\nxaynet-mobile = { path = \"../../rust/xaynet-mobile\", version = \"0.1.0\"}\nxaynet-sdk = { path = \"../../rust/xaynet-sdk\", version = \"0.1.0\"}\n\n[lib]\nname = \"xaynet_sdk\"\ncrate-type = [\"cdylib\"]\n"
  },
  {
    "path": "bindings/python/README.md",
    "content": "![Xaynet banner](../../assets/xaynet_banner.png)\n\n## Installation\n\n**Prerequisites**\n\n- Python 3.6 or higher\n\n**1. Install it via `pip`**\n\n```bash\n# create and activate a virtual environment e.g.\npyenv virtualenv xaynet\npyenv activate xaynet\n\npip install xaynet-sdk-python\n```\n\n**2. Build it from source**\n\n```bash\n# first install rust via https://rustup.rs/\ncurl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh\n\n# clone the xaynet repository\ngit clone https://github.com/xaynetwork/xaynet.git\ncd xaynet/bindings/python\n\n# create and activate a virtual environment e.g.\npyenv virtualenv xaynet\npyenv activate xaynet\n\n# install maturin\npip install maturin==0.9.1\npip install justbackoff\n\n# install xaynet-sdk\nmaturin develop\n```\n\n## Participant API(s)\n\nThe Python SDK that consists of two experimental Xaynet participants `ParticipantABC`\nand `AsyncParticipant`.\n\nThe word `Async` does not refer to either `asyncio` or asynchronous federated learning.\nIt refers to the property when a local model can be set. In `ParticipantABC`\nthe local model can only be set if the participant was selected an update participant\nwhile in `AsyncParticipant` the model can be set at any time.\n\n### `ParticipantABC`\n\nThe `ParticipantABC` API is similar to the old one which we introduced in\n[`v0.8.0`](https://github.com/xaynetwork/xaynet/blob/v0.8.0/python/sdk/xain_sdk/participant.py#L24).\nAside from some changes to the method signature, the biggest change is that the participant\nnow runs in its own thread.\n\nTo migrate from `v0.8.0` to `v0.11.0` please follow the [migration guide](./migration_guide.md).\n\n![ParticipantABC](../../assets/python_participant.svg)\n\n**Public API of `ParticipantABC`  and `InternalParticipant`**\n\n```python\ndef spawn_participant(\n    coordinator_url: str,\n    participant: ParticipantABC,\n    args: Tuple = (),\n    kwargs: dict = {},\n    state: Optional[List[int]] = None,\n    scalar: float = 1.0,\n):\n    \"\"\"\n    Spawns a `InternalParticipant` in a separate thread and returns a participant handle.\n    If a `state` is passed, this state is restored, otherwise a new `InternalParticipant`\n    is created.\n\n    Args:\n        coordinator_url: The url of the coordinator.\n        participant: A class that implements `ParticipantABC`.\n        args: The args that get passed to the constructor of the `participant` class.\n        kwargs: The kwargs that get passed to the constructor of the `participant` class.\n        state: A serialized participant state. Defaults to `None`.\n        scalar: The scalar used for masking. Defaults to `1.0`.\n\n    Note:\n        The `scalar` is used later when the models are aggregated in order to scale their weights.\n        It can be used when you want to weight the participants updates differently.\n\n        For example:\n        If not all participant updates should be weighted equally but proportionally to their\n        training samples, the scalar would be set to `scalar = 1 / number_of_samples`.\n\n    Returns:\n        The `InternalParticipant`.\n\n    Raises:\n        CryptoInit: If the initialization of the underling crypto library has failed.\n        ParticipantInit: If the participant cannot be initialized. This is most\n            likely caused by an invalid `coordinator_url`.\n        ParticipantRestore: If the participant cannot be restored due to invalid\n            serialized state. This exception can never be thrown if the `state` is `None`.\n        Exception: Any exception that can be thrown during the instantiation of `participant`.\n    \"\"\"\n\nclass ParticipantABC(ABC):\n    def train_round(self, training_input: Optional[TrainingInput]) -> TrainingResult:\n        \"\"\"\n        Trains a model. `training_input` is the deserialized global model\n        (see `deserialize_training_input`). If no global model exists\n        (usually in the first round), `training_input` will be `None`.\n        In this case the weights of the model should be initialized and returned.\n\n        Args:\n            self: The participant.\n            training_input: The deserialized global model (weights of the global model) or None.\n\n        Returns:\n            The updated model weights (the local model).\n        \"\"\"\n\n    def serialize_training_result(self, training_result: TrainingResult) -> list:\n        \"\"\"\n        Serializes the `training_result` into a `list`. The data type of the\n        elements must match the data type defined in the coordinator configuration.\n\n        Args:\n            self: The participant.\n            training_result: The `TrainingResult` of `train_round`.\n\n        Returns:\n            The `training_result` as a `list`.\n        \"\"\"\n\n    def deserialize_training_input(self, global_model: list) -> TrainingInput:\n        \"\"\"\n        Deserializes the `global_model` from a `list` to the type of `TrainingInput`.\n        The data type of the elements matches the data type defined in the coordinator\n        configuration. If no global model exists (usually in the first round), the method will\n        not be called by the `InternalParticipant`.\n\n        Args:\n            self: The participant.\n            global_model: The global model.\n\n        Returns:\n            The `TrainingInput` for `train_round`.\n        \"\"\"\n\n    def participate_in_update_task(self) -> bool:\n        \"\"\"\n        A callback used by the `InternalParticipant` to determine whether the\n        `train_round` method should be called. This callback is only called\n        if the participant is selected as an update participant. If `participate_in_update_task`\n        returns `False`, `train_round` will not be called by the `InternalParticipant`.\n\n        If the method is not overridden, it returns `True` by default.\n\n        Returns:\n            Whether the `train_round` method should be called when the participant\n            is an update participant.\n        \"\"\"\n\n    def on_new_global_model(self, global_model: Optional[TrainingInput]) -> None:\n        \"\"\"\n        A callback that is called by the `InternalParticipant` once a new global model is\n        available. If no global model exists (usually in the first round), `global_model` will\n        be `None`. If a global model exists, `global_model` is already the deserialized\n        global model. (See `deserialize_training_input`)\n\n        If the method is not overridden, it does nothing by default.\n\n        Args:\n            self: The participant.\n            global_model: The deserialized global model or `None`.\n        \"\"\"\n\n    def on_stop(self) -> None:\n        \"\"\"\n        A callback that is called by the `InternalParticipant` before the `InternalParticipant`\n        thread is stopped.\n\n        This callback can be used, for example, to show performance values ​​that have been\n        collected in the participant over the course of the training rounds.\n\n        If the method is not overridden, it does nothing by default.\n\n        Args:\n            self: The participant.\n        \"\"\"\n\nclass InternalParticipant:\n    def stop(self) -> List[int]:\n        \"\"\"\n        Stops the execution of the participant and returns its serialized state.\n        The serialized state can be passed to the `spawn_participant` function\n        to restore a participant.\n\n        After calling `stop`, the participant is consumed. Every further method\n        call on the handle of `InternalParticipant` leads to an `UninitializedParticipant`\n        exception.\n\n        Note:\n            The serialized state contains unencrypted **private key(s)**. If used\n            in production, it is important that the serialized state is securely saved.\n\n        Returns:\n            The serialized state of the participant.\n        \"\"\"\n```\n\n### `AsyncParticipant`\n\nWe noticed that the API of `ParticipantABC`/`InternalParticipant` reduces a fair amount of\ncode on the user side, however, it may not be flexible enough to cover some of the following\nuse cases:\n\n1. The user wants to use the global/local model in a different thread.\n\n    It is possible to provide methods for this on the `InternalParticipant` but they are not\n    straight forward to implement. To make them thread-safe, it is probably necessary to use\n    synchronization primitives but this would make the `InternalParticipant` more complicated.\n    In addition, questions arise such as: Would the user want to be able to get\n    the current local model at any time or would they like to be notified as soon as a new\n    local model is available.\n\n2. Train a model without the participant\n\n    Since the training of the model is embedded in the `ParticipantABC`, this will probably lead to\n    code duplication if the user wants to perform the training without the participant. Furthermore,\n    the embedding of the training in the `ParticipantABC` can also be a problem once the participant\n    is integrated into an existing application, considering the code for the training has to be\n    moved into the `train_round` method, which can lead to significant changes to the existing code.\n\n3. Custom exception handling\n\n    Last but not least, the question arises how we can inform the user that an exception has been\n    thrown. We do not want the participant to be terminated with every exception but we want to\n    give the user the opportunity to respond appropriately.\n\nThe main issue we saw is that the participant is responsible for training the model\nand to run the PET protocol. Therefore, we offer a second API in which the training\nof the model is no longer part of the participant. This results in a simpler and more flexible API,\nbut it comes with the tradeoff that the user needs to perform the de/serialization of the\nglobal/local on their side.\n\n![AsyncParticipant](../../assets/python_async_participant.svg)\n\n**Public API of `AsyncParticipant`**\n\n```python\ndef spawn_async_participant(coordinator_url: str, state: Optional[List[int]] = None, scalar: float = 1.0)\n    -> (AsyncParticipant, threading.Event):\n    \"\"\"\n    Spawns a `AsyncParticipant` in a separate thread and returns a participant handle\n    together with a global model notifier. If a `state` is passed, this state is restored,\n    otherwise a new participant is created.\n\n    The global model notifier sets the flag once a new global model is available.\n    The flag is also set when the global model is `None` (usually in the first round).\n    The flag is reset once the method `get_global_model` has been called but it is also possible\n    to reset the flag manually by calling\n    [`clear()`](https://docs.python.org/3/library/threading.html#threading.Event.clear).\n\n    Args:\n        coordinator_url: The url of the coordinator.\n        state: A serialized participant state. Defaults to `None`.\n        scalar: The scalar used for masking. Defaults to `1.0`.\n\n    Note:\n        The `scalar` is used later when the models are aggregated in order to scale their weights.\n        It can be used when you want to weight the participants updates differently.\n\n        For example:\n        If not all participant updates should be weighted equally but proportionally to their\n        training samples, the scalar would be set to `scalar = 1 / number_of_samples`.\n\n    Returns:\n        A tuple which consists of an `AsyncParticipant` and a global model notifier.\n\n    Raises:\n        CryptoInit: If the initialization of the underling crypto library has failed.\n        ParticipantInit: If the participant cannot be initialized. This is most\n            likely caused by an invalid `coordinator_url`.\n        ParticipantRestore: If the participant cannot be restored due to invalid\n            serialized state. This exception can never be thrown if the `state` is `None`.\n    \"\"\"\n\nclass AsyncParticipant:\n    def get_global_model(self) -> Optional[list]:\n        \"\"\"\n        Fetches the current global model. This method can be called at any time. If no global\n        model exists (usually in the first round), the method returns `None`.\n\n        Returns:\n            The current global model or `None`. The data type of the elements matches the data\n            type defined in the coordinator configuration.\n\n        Raises:\n            GlobalModelUnavailable: If the participant cannot connect to the coordinator to get\n                the global model.\n            GlobalModelDataTypeMisMatch: If the data type of the global model does not match\n                the data type defined in the coordinator configuration.\n        \"\"\"\n\n    def set_local_model(self, local_model: list):\n        \"\"\"\n        Sets a local model. This method can be called at any time. Internally the\n        participant first caches the local model. As soon as the participant is selected as an\n        update participant, the currently cached local model is used. This means that the cache\n        is empty after this operation.\n\n        If a local model is already in the cache and `set_local_model` is called with a new local\n        model, the current cached local model will be replaced by the new one.\n        If the participant is an update participant and there is no local model in the cache,\n        the participant waits until a local model is set or until a new round has been started.\n\n        Args:\n            local_model: The local model. The data type of the elements must match the data\n            type defined in the coordinator configuration.\n\n        Raises:\n            LocalModelLengthMisMatch: If the length of the local model does not match the\n                length defined in the coordinator configuration.\n            LocalModelDataTypeMisMatch: If the data type of the local model does not match\n                the data type defined in the coordinator configuration.\n        \"\"\"\n\n    def stop(self) -> List[int]:\n        \"\"\"\n        Stops the execution of the participant and returns its serialized state.\n        The serialized state can be passed to the `spawn_async_participant` function\n        to restore a participant.\n\n        After calling `stop`, the participant is consumed. Every further method\n        call on the handle of `AsyncParticipant` leads to an `UninitializedParticipant`\n        exception.\n\n        Note:\n            The serialized state contains unencrypted **private key(s)**. If used\n            in production, it is important that the serialized state is securely saved.\n\n        Returns:\n            The serialized state of the participant.\n        \"\"\"\n```\n\n## Enable logging of `xaynet-mobile`\n\nIf you are interested in what `xaynet-mobile` is doing under the hood,\nyou can turn on the logging via the environment variable `XAYNET__CLIENT`.\n\nFor example:\n\n`XAYNET__CLIENT=info python examples/participate_in_update.py`\n\n## How can I ... ?\n\nWe have created a few [examples](./examples/README.md) that show the basic methods in action.\nBut if something is missing, not very clear or not working properly, please let us know\nby opening an issue.\n\nWe are happy to help and open to ideas or feedback :)\n"
  },
  {
    "path": "bindings/python/examples/README.md",
    "content": "# Examples\n\nSome examples that show how the `ParticipantABC` or `AsyncParticipant` can be used.\n\n## Getting Started\n\nAll examples in this section work without changing the coordinator\n[config.toml](../../../configs/config.toml) or [docker-dev.toml](../../../configs/docker-dev.toml).\n\n- [`hello_world.py`](./hello_world.py) A basic `ParticipantABC` example\n- [`hello_world_async.py`](./hello_world_async.py) A basic `AsyncParticipant` example\n- [`download_global_model.py`](./download_global_model.py) A `ParticipantABC` that only downloads the latest global model\n- [`download_global_model_async.py`](./download_global_model_async.py) An `AsyncParticipant` that only downloads the latest global model\n- [`multiple_participants.py`](./download_global_model_async.py) Spawn multiple `ParticipantABC`s in a single process\n- [`participate_in_update.py`](./participate_in_update.py) Only train a model when there is enough battery left\n- [`restore.py`](./restore.py) Save and restore the state of an `AsyncParticipant`\n\n## Keras House Prices\n\n- [`keras_house_prices`](./keras_house_prices/) A full machine learning example\n"
  },
  {
    "path": "bindings/python/examples/download_global_model.py",
    "content": "\"\"\"A `ParticipantABC` that only downloads the latest global model\"\"\"\n\nimport json\nimport logging\nfrom typing import Optional\n\nimport xaynet_sdk\n\nLOG = logging.getLogger(__name__)\n\n\nclass Participant(xaynet_sdk.ParticipantABC):\n    def __init__(self, model: list) -> None:\n        self.model = model\n        super().__init__()\n\n    def deserialize_training_input(self, global_model: list) -> list:\n        return global_model\n\n    def train_round(self, training_input: Optional[list]) -> list:\n        pass\n\n    def serialize_training_result(self, training_result: list) -> list:\n        pass\n\n    def participate_in_update_task(self) -> bool:\n        return False\n\n    def on_new_global_model(self, global_model: Optional[list]) -> None:\n        LOG.info(\"new global model\")\n        if global_model is not None:\n            with open(\"global_model.bin\", \"w\") as filehandle:\n                filehandle.write(json.dumps(global_model))\n\n\ndef main() -> None:\n    logging.basicConfig(\n        format=\"%(asctime)s.%(msecs)03d %(levelname)8s %(message)s\",\n        level=logging.DEBUG,\n        datefmt=\"%b %d %H:%M:%S\",\n    )\n\n    participant = xaynet_sdk.spawn_participant(\n        \"http://127.0.0.1:8081\", Participant, args=([0.1, 0.2, 0.345, 0.3],)\n    )\n\n    try:\n        participant.join()\n    except KeyboardInterrupt:\n        participant.stop()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "bindings/python/examples/download_global_model_async.py",
    "content": "\"\"\"An `AsyncParticipant` that only downloads the latest global model\"\"\"\n\nimport json\nimport logging\n\nimport xaynet_sdk\n\nLOG = logging.getLogger(__name__)\n\n\ndef main() -> None:\n    logging.basicConfig(\n        format=\"%(asctime)s.%(msecs)03d %(levelname)8s %(message)s\",\n        level=logging.DEBUG,\n        datefmt=\"%b %d %H:%M:%S\",\n    )\n\n    (participant, global_model_notifier) = xaynet_sdk.spawn_async_participant(\n        \"http://127.0.0.1:8081\"\n    )\n\n    try:\n        while global_model_notifier.wait():\n            LOG.info(\"a new global model\")\n            global_model = participant.get_global_model()\n            if global_model is not None:\n                with open(\"global_model.bin\", \"w\") as filehandle:\n                    filehandle.write(json.dumps(global_model))\n\n    except KeyboardInterrupt:\n        participant.stop()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "bindings/python/examples/hello_world.py",
    "content": "\"\"\"A basic `ParticipantABC` example\"\"\"\n\nimport json\nimport logging\nimport time\nfrom typing import Optional\n\nimport xaynet_sdk\n\nLOG = logging.getLogger(__name__)\n\n\nclass Participant(xaynet_sdk.ParticipantABC):\n    def __init__(self, model: list) -> None:\n        self.model = model\n        super().__init__()\n\n    def deserialize_training_input(self, global_model: list) -> list:\n        return global_model\n\n    def train_round(self, training_input: Optional[list]) -> list:\n        LOG.info(\"training\")\n        time.sleep(3.0)\n        LOG.info(\"training done\")\n        return self.model\n\n    def serialize_training_result(self, training_result: list) -> list:\n        return training_result\n\n    def participate_in_update_task(self) -> bool:\n        return True\n\n    def on_new_global_model(self, global_model: Optional[list]) -> None:\n        if global_model is not None:\n            with open(\"global_model.bin\", \"w\") as filehandle:\n                filehandle.write(json.dumps(global_model))\n\n\ndef main() -> None:\n    logging.basicConfig(\n        format=\"%(asctime)s.%(msecs)03d %(levelname)8s %(message)s\",\n        level=logging.DEBUG,\n        datefmt=\"%b %d %H:%M:%S\",\n    )\n\n    participant = xaynet_sdk.spawn_participant(\n        \"http://127.0.0.1:8081\", Participant, args=([0.1, 0.2, 0.345, 0.3],)\n    )\n\n    try:\n        participant.join()\n    except KeyboardInterrupt:\n        participant.stop()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "bindings/python/examples/hello_world_async.py",
    "content": "\"\"\"A basic `AsyncParticipant` example\"\"\"\n\nimport logging\nimport time\n\nimport xaynet_sdk\n\nLOG = logging.getLogger(__name__)\n\n\ndef training():\n    LOG.info(\"training\")\n    time.sleep(10.0)\n    LOG.info(\"training done\")\n\n\ndef main() -> None:\n    logging.basicConfig(\n        format=\"%(asctime)s.%(msecs)03d %(levelname)8s %(message)s\",\n        level=logging.DEBUG,\n        datefmt=\"%b %d %H:%M:%S\",\n    )\n\n    (participant, global_model_notifier) = xaynet_sdk.spawn_async_participant(\n        \"http://127.0.0.1:8081\"\n    )\n\n    try:\n        while global_model_notifier.wait():\n            LOG.info(\"a new global model\")\n            participant.get_global_model()\n            training()\n            participant.set_local_model([0.1, 0.2, 0.345, 0.3])\n\n    except KeyboardInterrupt:\n        participant.stop()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "bindings/python/examples/keras_house_prices/.gitignore",
    "content": "data/\n"
  },
  {
    "path": "bindings/python/examples/keras_house_prices/README.md",
    "content": "# `keras_house_prices` Example\n\n**Prerequisites**\n\n- Python >=3.7.1 <=3.8\n\n1. Adjust the coordinator settings\n\nChange the model length to `55117` and the `bound_type` to `B2`\nin [`docker-dev.toml`](../../../../configs/docker-dev.toml).\n\n```toml\n[model]\nlength = 55117\n\n[mask]\nbound_type = \"B2\"\n```\n\nCurious what the `bond_type` is? You can find an explanation [here](https://docs.rs/xaynet-core/0.2.0/xaynet_core/mask/index.html#bound-type).\n\n2. Start the coordinator\n\n```shell\n# in the root of the repository\ndocker-compose -f docker/docker-compose.yml up --build\n```\n\n**All the commands in this section are run from the\n`bindings/python/examples/keras_house_prices` directory.**\n\n3. Install the SDK:\n\nFollow the installation steps described in [bindings/python/README.md](../../README.md).\n\n4. Install the example:\n\n```shell\npip install -e .\n```\n\n5. Download the dataset from Kaggle:\n   https://www.kaggle.com/c/house-prices-advanced-regression-techniques/data\n\n6. Extract the data (into\n   `python/examples/keras_house_prices/data/` here, but the\n   location doesn't matter):\n\n```shell\n(cd ./data ; unzip house-prices-advanced-regression-techniques.zip)\n```\n\n7. Prepare the data:\n\n```shell\nsplit-data --data-directory data --number-of-participants 10\n```\n\n8.  Run one participant:\n\n```shell\nXAYNET__CLIENT=info run-participant --data-directory data --coordinator-url http://127.0.0.1:8081\n```\n\n9. Repeat the previous step to run more participants\n"
  },
  {
    "path": "bindings/python/examples/keras_house_prices/keras_house_prices/__init__.py",
    "content": ""
  },
  {
    "path": "bindings/python/examples/keras_house_prices/keras_house_prices/data_handlers/__init__.py",
    "content": ""
  },
  {
    "path": "bindings/python/examples/keras_house_prices/keras_house_prices/data_handlers/data_handler.py",
    "content": "\"\"\"DataHandler base class to read, preprocess and split data for each example.\"\"\"\n\nfrom abc import ABC, abstractmethod\nimport logging\nimport os\nfrom typing import Dict, List, Optional\n\nimport numpy as np\nimport pandas as pd\n\nLOG = logging.getLogger(__name__)\n\n\nclass DataHandler(ABC):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Base class to handle data preparation\n\n    Args:\n\n         data_directory: path to the directory where the data is\n            stored\n\n         homogeneity: The level of homogeneity in the assignment\n            of training samples to each participants. It can take\n            three values:\n\n                - `iid`: meaning samples are randomly assigned to\n                  participants.\n                - `intermediate`: half of the samples are randomly\n                    assigned to participants, half of the samples\n                    follow the 'total_split' logic.\n                - `total_split`: if there are more participants than\n                  labels, samples are split among participants so that\n                  each participant has samples from only one class.\n                  if there are more classes than participants, samples\n                  are split so that no class is repeated between\n                  participants.\n\n         n_participants: The number of participants into which the\n            dataset will be split.\n\n    NOTE: the random seed is set in the initialisation and will make\n    the results reproducible.\n\n    \"\"\"\n\n    TEST_RATIO: float = 0.1\n    MINIMUM_PARTICIPANT_N_SAMPLES: int = 20\n\n    def __init__(\n        self,\n        data_directory: str,\n        homogeneity: str = \"iid\",\n        n_participants: int = 10,\n    ) -> None:\n        self.homogeneity: str = homogeneity\n        self.n_participants: int = n_participants\n        self.participant_ids: List[str] = [str(p) for p in range(self.n_participants)]\n        self.data_dir: str = data_directory\n        self.parts_dir: str = os.path.join(self.data_dir, \"split_data\")\n        if not os.path.exists(self.parts_dir):\n            os.mkdir(self.parts_dir)\n            LOG.info(\"created %s dir\", self.parts_dir)\n        self.train_file_path: str = os.path.join(self.data_dir, \"train.csv\")\n        self.test_file_path: str = os.path.join(self.data_dir, \"test.csv\")\n        self.train_df: pd.DataFrame = pd.DataFrame()\n        self.test_df: pd.DataFrame = pd.DataFrame()\n        self.labels: List[str] = []\n\n        # set the seed that will be used by numpy to make the results reproducible.\n        np.random.seed(42)\n\n    def read_data(self) -> None:\n        \"\"\"Find the train_set CSV file and load it into a dataframe\"\"\"\n        self.train_df = pd.read_csv(self.train_file_path, index_col=None)\n\n    @abstractmethod\n    def preprocess_data(self) -> None:\n        \"\"\"Abstract method to be implemented by the testcase data handling\n        subclass, to preprocess the data.\n\n        \"\"\"\n        raise NotImplementedError()\n\n    def create_testset(self) -> None:\n        \"\"\"Create testset by sampling and removing a TEST_RATIO percentage of\n        samples from self.train_df. Save the data locally.\n\n        \"\"\"\n\n        n_test_samples: int = int(len(self.train_df) * self.TEST_RATIO)\n        test_indexes: np.ndarray = np.random.choice(\n            self.train_df.index, n_test_samples, replace=False\n        )\n        self.test_df = self.train_df.loc[test_indexes, :]\n        self.train_df = self.train_df.drop(test_indexes)\n        self.test_df.to_csv(self.test_file_path)\n\n    def make_discrete_y(self) -> pd.Series:\n        \"\"\"Split a continuous Y variable into discrete bins, one per\n        participant.\n\n        Returns:\n\n            discrete_y: The discrete dependent variable.\n\n        \"\"\"\n\n        discrete_y: pd.Series = pd.cut(\n            self.train_df[\"Y\"],\n            bins=self.n_participants,\n            labels=range(self.n_participants),\n        )\n\n        self.labels = list(set(discrete_y))\n        return discrete_y\n\n    def make_iid_split(\n        self,\n        input_df: pd.DataFrame,\n        target_length: int,\n        assigned_samples: Optional[List[str]] = None,\n    ) -> np.ndarray:\n        \"\"\"Randomly select samples so that each participant has a similar\n        amount of samples.\n\n        Args:\n\n            input_df: DataFrame containing the samples to be selected.\n            target_length: Length of the full dataset considered for\n                IID split.\n            assigned_samples: List of sample IDs already assigned to\n                previous participants.\n\n        Returns:\n\n            The selected sample indexes.\n\n        \"\"\"\n\n        if assigned_samples is not None:\n            input_df = input_df.drop(assigned_samples)\n        samples_ids_per_participant: int = int(target_length / self.n_participants)\n        selected_sample_ids: np.ndarray = np.random.choice(\n            input_df.index, samples_ids_per_participant, replace=False\n        )\n        return selected_sample_ids\n\n    @staticmethod\n    def split_lists(\n        longer_list: List[str], shorter_list: List[str]\n    ) -> Dict[str, List[str]]:\n        \"\"\"Split the lists of labels and participant IDs.\n\n        We use longer and shorter list to make sure that the elements of the longer list\n        are distributed to the elements of the shorter.\n\n        For example:\n        - If there are more participants than labels, the samples of each label will be\n        distributed to different participants, and each participant will have samples\n        from only one label.\n        - If there are more labels than participants, each participant will have samples\n        from more than one label, but samples from a single label will belong to only one\n        participant.\n\n        Args:\n\n            longer_list: List of either labels or participant IDS,\n                whichever is longer.\n            shorter_list: List of either labels or participant IDS,\n                whichever is shorter.\n\n        Returns:\n\n            Dictionary whose keys are the elements of the shorted\n            list, and its values are a sample without replacement of\n            the elements of the longer list.\n\n        \"\"\"\n\n        ratio: int = len(longer_list) // len(shorter_list)\n        splits: List[List[str]] = [\n            longer_list[i : i + ratio] for i in range(0, len(longer_list), ratio)\n        ]\n        splits_by_shorter_element: Dict[str, List[str]] = {\n            item: splits[i] for i, item in enumerate(shorter_list)\n        }\n        return splits_by_shorter_element\n\n    def make_total_split(\n        self, discrete_y: pd.Series, participant_id: str, participant_ids: List[str]\n    ) -> np.ndarray:\n        \"\"\"Select labels for one participant.\n\n        If there are more labels than participants, it will select a\n        list of labels not assigned to any other participant. If there\n        are more participants than labels, it will select one label\n        only for this participant (the label may re-occur for other\n        participants).\n\n        Args:\n\n            discrete_y: The discrete dependent variable.\n            participant_id: The ID of the participant for which we are\n                currently selecting the samples for its dataset.\n            participant_ids: List of all participant IDs.\n\n        Returns:\n\n            List of selected samples for the current participant.\n\n        \"\"\"\n\n        labels_by_participant_id: Dict[str, List[str]]\n        selected_labels: List[str]\n        if len(self.labels) >= self.n_participants:\n            labels_by_participant_id = self.split_lists(\n                list(self.labels), participant_ids\n            )\n            selected_labels = labels_by_participant_id[participant_id]\n        else:\n            participant_ids_by_label = self.split_lists(participant_ids, self.labels)\n            selected_labels = [\n                label\n                for label, ids in participant_ids_by_label.items()\n                if participant_id in ids\n            ]\n        selected_samples: np.ndarray = np.array(\n            [i for i, label in discrete_y.items() if label in selected_labels]\n        )\n        return selected_samples\n\n    def make_intermediate_split(\n        self, assigned_samples: List[str], participant_id: str, discrete_y: pd.Series\n    ) -> np.ndarray:\n        \"\"\"Handles an intermediate split, 50% IID and 50% total_split.\n\n        Args:\n\n            assigned_samples: Samples that have already been assigned\n                to a participant.\n            participant_id: The ID of the participant that will have\n                samples assigned to.\n            discrete_y: The discrete dependent variable.\n\n        Raises:\n\n            AssertionError: If the selected samples are not\n                unique. Typically if there was replacement, or the\n                random seed had not been set.\n\n        Returns:\n\n            The IDs of the selected samples for this participant.\n\n        \"\"\"\n\n        remaining_samples_df: pd.DataFrame = self.train_df.drop(assigned_samples)\n        first_half_df: pd.DataFrame = remaining_samples_df.sample(frac=0.5)\n        second_half_df: pd.DataFrame = remaining_samples_df.drop(first_half_df.index)\n        target_length: int = len(self.train_df) // 2\n        iid_samples: np.ndarray = self.make_iid_split(first_half_df, target_length)\n        second_half_y: pd.Series = discrete_y.loc[second_half_df.index]\n        total_split_samples: np.ndarray = self.make_total_split(\n            second_half_y, participant_id, self.participant_ids\n        )\n        selected_samples: np.ndarray = np.concatenate(\n            (iid_samples, total_split_samples)\n        )\n        if len(set(selected_samples)) != len(selected_samples):\n            raise AssertionError\n        return selected_samples\n\n    def split_data(self) -> None:\n        \"\"\"Split the data.\n\n        Continuous variables (for regression) are made discrete only\n        for the purpose of splitting the data (not for analysis).\n\n        For each participant ID, it performs the data split according\n        to the level of homogeneity selected.\n\n        Saves the dataframe for each participant locally.\n\n        \"\"\"\n\n        discrete_y: pd.Series = self.make_discrete_y()\n        np.random.shuffle(self.labels)\n        np.random.shuffle(self.participant_ids)\n        assigned_samples: List[str] = []\n        selected_samples: np.ndarray\n        for participant_id in self.participant_ids:\n            if self.homogeneity == \"iid\":\n                selected_samples = self.make_iid_split(\n                    self.train_df, len(self.train_df), assigned_samples\n                )\n            elif self.homogeneity == \"total_split\":\n                selected_samples = self.make_total_split(\n                    discrete_y, participant_id, self.participant_ids\n                )\n            else:\n                selected_samples = self.make_intermediate_split(\n                    assigned_samples, participant_id, discrete_y\n                )\n            participant_df: pd.DataFrame = self.train_df.loc[selected_samples, :]\n            LOG.info(\n                \"participant %s df has shape %s\", participant_id, participant_df.shape\n            )\n            if len(participant_df) < self.MINIMUM_PARTICIPANT_N_SAMPLES:\n                LOG.info(\n                    \"participant %s has only %d samples.\",\n                    participant_id,\n                    len(participant_df),\n                )\n                LOG.info(\"consider decreasing the number of participants\")\n                # TODO: edge case: non-IID splits (especially 'total_split') with\n                #  too many participants may lead to an empty df. Pandas will save\n                #  the CSV anyway, but we may have problems reading the files later.\n                #  Solve this with: https://xainag.atlassian.net/browse/AP-154\n            output_filepath: str = os.path.join(\n                self.parts_dir, f\"data_part_{participant_id}.csv\"\n            )\n            participant_df.to_csv(output_filepath, index=False)\n            LOG.info(\"participant df saved to %s\", output_filepath)\n            assigned_samples.extend(participant_df.index)\n\n    def run(self) -> None:\n        \"\"\"One function to run them all.\"\"\"\n\n        self.read_data()\n        self.preprocess_data()\n        self.create_testset()\n        self.split_data()\n"
  },
  {
    "path": "bindings/python/examples/keras_house_prices/keras_house_prices/data_handlers/regression_data.py",
    "content": "\"\"\"Implementation of the RegressionData subclass, to handle the data of regression examples.\"\"\"\n\nimport argparse\nimport logging\n\nfrom keras_house_prices.data_handlers.data_handler import DataHandler\nimport numpy as np\nimport pandas as pd\nfrom sklearn.preprocessing import MinMaxScaler\n\nLOG = logging.getLogger(__name__)\n\n\nclass RegressionData(DataHandler):\n    \"\"\"Data processing logic that is specific to the house prices dataset.\"\"\"\n\n    def __init__(\n        self, data_directory: str, homogeneity: str, n_participants: int\n    ) -> None:\n        super().__init__(\n            data_directory, homogeneity=homogeneity, n_participants=n_participants\n        )\n\n    def fill_nan(self) -> None:\n        \"\"\"Filling missing data in the dataframe.\"\"\"\n\n        self.train_df[\"PoolQC\"] = self.train_df[\"PoolQC\"].fillna(\"None\")\n        self.train_df[\"MiscFeature\"] = self.train_df[\"MiscFeature\"].fillna(\"None\")\n        self.train_df[\"Alley\"] = self.train_df[\"Alley\"].fillna(\"None\")\n        self.train_df[\"Fence\"] = self.train_df[\"Fence\"].fillna(\"None\")\n        self.train_df[\"FireplaceQu\"] = self.train_df[\"FireplaceQu\"].fillna(\"None\")\n        self.train_df[\"LotFrontage\"] = self.train_df.groupby(\"Neighborhood\")[\n            \"LotFrontage\"\n        ].transform(lambda x: x.fillna(x.median()))\n        for col in (\"GarageType\", \"GarageFinish\", \"GarageQual\", \"GarageCond\"):\n            self.train_df[col] = self.train_df[col].fillna(\"None\")\n        for col in (\"GarageYrBlt\", \"GarageArea\", \"GarageCars\"):\n            self.train_df[col] = self.train_df[col].fillna(0)\n        for col in (\n            \"BsmtFinSF1\",\n            \"BsmtFinSF2\",\n            \"BsmtUnfSF\",\n            \"TotalBsmtSF\",\n            \"BsmtFullBath\",\n            \"BsmtHalfBath\",\n        ):\n            self.train_df[col] = self.train_df[col].fillna(0)\n        for col in (\n            \"BsmtQual\",\n            \"BsmtCond\",\n            \"BsmtExposure\",\n            \"BsmtFinType1\",\n            \"BsmtFinType2\",\n        ):\n            self.train_df[col] = self.train_df[col].fillna(\"None\")\n        self.train_df[\"MSZoning\"] = self.train_df[\"MSZoning\"].fillna(\n            self.train_df[\"MSZoning\"].mode()[0]\n        )\n\n        self.train_df[\"MasVnrType\"] = self.train_df[\"MasVnrType\"].fillna(\"None\")\n        self.train_df[\"MasVnrArea\"] = self.train_df[\"MasVnrArea\"].fillna(0)\n        self.train_df = self.train_df.drop([\"Utilities\"], axis=1)\n        self.train_df[\"Functional\"] = self.train_df[\"Functional\"].fillna(\"Typ\")\n        self.train_df[\"Electrical\"] = self.train_df[\"Electrical\"].fillna(\n            self.train_df[\"Electrical\"].mode()[0]\n        )\n        self.train_df[\"KitchenQual\"] = self.train_df[\"KitchenQual\"].fillna(\n            self.train_df[\"KitchenQual\"].mode()[0]\n        )\n        self.train_df[\"Exterior1st\"] = self.train_df[\"Exterior1st\"].fillna(\n            self.train_df[\"Exterior1st\"].mode()[0]\n        )\n        self.train_df[\"Exterior2nd\"] = self.train_df[\"Exterior2nd\"].fillna(\n            self.train_df[\"Exterior2nd\"].mode()[0]\n        )\n        self.train_df[\"SaleType\"] = self.train_df[\"SaleType\"].fillna(\n            self.train_df[\"SaleType\"].mode()[0]\n        )\n        self.train_df[\"MSSubClass\"] = self.train_df[\"MSSubClass\"].fillna(\"None\")\n\n        no_nulls_in_dataset = not self.train_df.isnull().values.any()\n        if no_nulls_in_dataset:\n            LOG.info(\"No missing values\")\n            LOG.info(\"data shape is %s\", self.train_df.shape)\n\n    def hot_encoding(self) -> None:\n        \"\"\"Hot encoding of the categorical features.\"\"\"\n\n        self.train_df: pd.DataFrame = pd.get_dummies(\n            self.train_df, dummy_na=True, drop_first=True\n        )\n        LOG.info(\"data shape is %s\", self.train_df.shape)\n\n    def scaling(self) -> None:\n        \"\"\"Scales the features in minmax way and the process in log(1+x).\"\"\"\n\n        self.train_df = self.train_df.rename(columns={\"SalePrice\": \"Y\"})\n        self.train_df[\"Y\"] = np.log1p(self.train_df[\"Y\"])\n        scaler = MinMaxScaler()\n        cols = self.train_df.drop(\"Y\", axis=1).columns\n        train = pd.DataFrame(\n            scaler.fit_transform(self.train_df.drop(\"Y\", axis=1)), columns=cols\n        )\n        self.train_df[cols] = train\n\n    def preprocess_data(self) -> None:\n        \"\"\"Call methods that execute the preprocessing.\"\"\"\n        self.train_df.drop(\"Id\", axis=1, inplace=True)\n        self.fill_nan()\n        self.hot_encoding()\n        self.scaling()\n\n\ndef main() -> None:\n    \"\"\"Initialise and run the regression data preparation.\"\"\"\n    logging.basicConfig(level=logging.DEBUG)\n\n    parser = argparse.ArgumentParser(description=\"Prepare data for regression\")\n    parser.add_argument(\n        \"--data-directory\",\n        type=str,\n        help=\"path to the directory that contains the raw data\",\n    )\n    parser.add_argument(\n        \"--number-of-participants\",\n        type=int,\n        help=\"number of participants into which the dataset will be split\",\n    )\n    args = parser.parse_args()\n\n    regression_data = RegressionData(\n        args.data_directory,\n        \"total_split\",\n        args.number_of_participants,\n    )\n    regression_data.run()\n"
  },
  {
    "path": "bindings/python/examples/keras_house_prices/keras_house_prices/participant.py",
    "content": "\"\"\"Tensorflow Keras regression test case\"\"\"\n\nimport argparse\nimport logging\nimport os\nimport random\nfrom typing import List, Optional, Tuple\n\nfrom keras_house_prices.regressor import Regressor\nimport numpy as np\nimport pandas as pd\nfrom tabulate import tabulate\n\nfrom xaynet_sdk import ParticipantABC, spawn_participant\n\nLOG = logging.getLogger(__name__)\n\n\nclass Participant(  # pylint: disable=too-few-public-methods,too-many-instance-attributes\n    ParticipantABC\n):\n    \"\"\"An example of a Keras implementation of a participant for federated\n    learning.\n\n    The attributes for the model and the datasets are only for\n    convenience, they might as well be loaded elsewhere.\n\n    Attributes:\n\n        regressor: The model to be trained.\n        trainset_x: A dataset for training.\n        trainset_y: Labels for training.\n        testset_x: A dataset for test.\n        testset_y: Labels for test.\n        number_samples: The number of samples in the training dataset.\n        performance_metrics: metrics collected after each round of training\n\n    \"\"\"\n\n    def __init__(self, dataset_dir: str) -> None:\n        \"\"\"Initialize a custom participant.\"\"\"\n        super().__init__()\n        self.load_random_dataset(dataset_dir)\n        self.regressor = Regressor(len(self.trainset_x.columns))\n        self.performance_metrics: List[Tuple[float, float]] = []\n\n    def load_random_dataset(self, dataset_dir: str) -> None:\n        \"\"\"Load a random dataset from the data directory\"\"\"\n        i = random.randrange(0, 10, 1)\n\n        LOG.info(\"Train on sample number %d\", i)\n        trainset_file_path = os.path.join(\n            dataset_dir, \"split_data\", f\"data_part_{i}.csv\"\n        )\n\n        trainset = pd.read_csv(trainset_file_path, index_col=None)\n        self.trainset_x = trainset.drop(\"Y\", axis=1)\n        self.trainset_y = trainset[\"Y\"]\n        self.number_of_samples = len(trainset)\n\n        testset_file_path = os.path.join(dataset_dir, \"test.csv\")\n        testset = pd.read_csv(testset_file_path, index_col=None)\n        testset_x = testset.drop(\"Y\", axis=1)\n        self.testset_x: pd.DataFrame = testset_x.drop(testset_x.columns[0], axis=1)\n        self.testset_y = testset[\"Y\"]\n\n    def train_round(self, training_input: Optional[np.ndarray]) -> np.ndarray:\n        \"\"\"Train a model in a federated learning round.\n\n        A model is given in terms of its weights and the model is\n        trained on the participant's dataset for a number of\n        epochs. The weights of the updated model are returned.\n\n        Args:\n\n            weights: The weights of the model to be trained.\n\n        Returns:\n\n            The updated model weights .\n        \"\"\"\n        if training_input is None:\n            # This is the first round: the coordinator doesn't have a\n            # global model yet, so we need to initialize the weights\n            self.regressor = Regressor(len(self.trainset_x.columns))\n            return self.regressor.get_weights()\n\n        weights = training_input\n        epochs = 10\n        self.regressor.set_weights(weights)\n        self.regressor.train_n_epochs(epochs, self.trainset_x, self.trainset_y)\n\n        loss: float\n        r_squared: float\n        loss, r_squared = self.regressor.evaluate_on_test(\n            self.testset_x, self.testset_y\n        )\n        LOG.info(\"loss = %f, R² = %f\", loss, r_squared)\n        self.performance_metrics.append((loss, r_squared))\n\n        return self.regressor.get_weights()\n\n    def deserialize_training_input(self, global_model: list) -> np.ndarray:\n        return np.array(global_model)\n\n    def serialize_training_result(self, training_result: np.ndarray) -> list:\n        return training_result.tolist()\n\n    def on_stop(self) -> None:\n        table = tabulate(self.performance_metrics, headers=[\"Loss\", \"R²\"])\n        print(table)\n\n\ndef main() -> None:\n    \"\"\"Entry point to start a participant.\"\"\"\n    parser = argparse.ArgumentParser(description=\"Prepare data for regression\")\n    parser.add_argument(\n        \"--data-directory\",\n        type=str,\n        help=\"path to the directory that contains the data\",\n    )\n    parser.add_argument(\n        \"--coordinator-url\",\n        type=str,\n        required=True,\n        help=\"URL of the coordinator\",\n    )\n    args = parser.parse_args()\n\n    # pylint: disable=invalid-name\n    logging.basicConfig(\n        format=\"%(asctime)s.%(msecs)03d %(levelname)8s %(message)s\",\n        level=logging.DEBUG,\n        datefmt=\"%b %d %H:%M:%S\",\n    )\n\n    participant = spawn_participant(\n        args.coordinator_url, Participant, args=(args.data_directory,)\n    )\n\n    try:\n        participant.join()\n    except KeyboardInterrupt:\n        participant.stop()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "bindings/python/examples/keras_house_prices/keras_house_prices/regressor.py",
    "content": "\"\"\"Wrapper for tensorflow regression neural network.\"\"\"\nfrom typing import List, Tuple\n\nimport numpy as np\nimport pandas as pd\nfrom sklearn.metrics import r2_score\nfrom tensorflow.keras import Sequential  # pylint: disable=import-error\nfrom tensorflow.keras.layers import Dense  # pylint: disable=import-error\n\n\nclass Regressor:\n    \"\"\"Neural network class for the Boston pricing house problem.\n\n    Attributes:\n        model: Keras Sequential model\n    \"\"\"\n\n    def __init__(self, dim: int):\n        self.model = Sequential()\n        self.model.add(Dense(144, input_dim=dim, activation=\"relu\"))\n        self.model.add(Dense(72, activation=\"relu\"))\n        self.model.add(Dense(18, activation=\"relu\"))\n        self.model.add(Dense(1, activation=\"linear\"))\n\n        self.model.compile(optimizer=\"adam\", loss=\"mean_squared_error\")\n\n    def train_n_epochs(\n        self, n_epochs: int, x_train: pd.DataFrame, y_train: pd.DataFrame\n    ) -> None:\n        \"\"\"Training function for the built in model.\n\n        Args:\n            n_epochs (int): Number of epochs to be trained.\n            x_train (~pd.dataframe): Features dataset for training.\n            y_train(~pd.dataframe): Labels for training.\n        \"\"\"\n\n        self.model.fit(x_train, y_train, epochs=n_epochs, verbose=0)\n\n    def evaluate_on_test(\n        self, x_test: pd.DataFrame, y_test: pd.DataFrame\n    ) -> Tuple[float, float]:\n        \"\"\"Evaluating on testset.\n\n        Args:\n             x_test (dataframe): Feature set for evaluation.\n             y_test (dataframe): Dependent variable for evaluation.\n\n        Returns:\n            test_loss: Value of the testing loss.\n            r_squared: Value of R-squared,\n                to be shown as 'accuracy' metric to the Coordinator\n        \"\"\"\n\n        y_pred: np.ndarray = self.model.predict(x_test)\n        r_squared: float = r2_score(y_test, y_pred)\n        test_loss: float = self.model.evaluate(x_test, y_test)\n        return test_loss, r_squared\n\n    def get_shapes(self) -> List[Tuple[int, ...]]:\n        return [weight.shape for weight in self.model.get_weights()]\n\n    def get_weights(self) -> np.ndarray:\n        return np.concatenate(self.model.get_weights(), axis=None)\n\n    def set_weights(self, weights: np.ndarray) -> None:\n        shapes = self.get_shapes()\n        # expand the flat weights\n        indices: np.ndarray = np.cumsum([np.prod(shape) for shape in shapes])\n        tensorflow_weights: List[np.ndarray] = np.split(\n            weights, indices_or_sections=indices\n        )\n        tensorflow_weights = [\n            np.reshape(weight, newshape=shape)\n            for weight, shape in zip(tensorflow_weights, shapes)\n        ]\n\n        # apply the weights to the tensorflow model\n        self.model.set_weights(tensorflow_weights)\n"
  },
  {
    "path": "bindings/python/examples/keras_house_prices/setup.py",
    "content": "# pylint: disable=invalid-name\nfrom setuptools import find_packages, setup\n\nsetup(\n    name=\"keras_house_prices\",\n    version=\"0.1\",\n    author=[\"Xayn Engineering\"],\n    author_email=\"engineering@xaynet.dev\",\n    license=\"Apache License Version 2.0\",\n    python_requires=\">=3.7.1, <=3.8\",\n    packages=find_packages(),\n    install_requires=[\n        \"pandas==1.4.3\",\n        \"scikit-learn==1.1.2\",\n        \"tensorflow==2.9.1\",\n        \"numpy>=1.19.2,<1.24.0\",\n        \"tabulate~=0.8.7\",\n    ],\n    entry_points={\n        \"console_scripts\": [\n            \"run-participant=keras_house_prices.participant:main\",\n            \"split-data=keras_house_prices.data_handlers.regression_data:main\",\n        ]\n    },\n)\n"
  },
  {
    "path": "bindings/python/examples/multiple_participants.py",
    "content": "\"\"\"Spawn multiple `ParticipantABC`s in a single process\"\"\"\n\nimport json\nimport logging\nimport time\nfrom typing import Optional\n\nimport xaynet_sdk\n\nLOG = logging.getLogger(__name__)\n\n\nclass Participant(xaynet_sdk.ParticipantABC):\n    def __init__(self, p_id: int, model: list) -> None:\n        self.p_id = p_id\n        self.model = model\n        super().__init__()\n\n    def deserialize_training_input(self, global_model: list) -> list:\n        return global_model\n\n    def train_round(self, training_input: Optional[list]) -> list:\n        LOG.info(\"participant %s: start training\", self.p_id)\n        time.sleep(5.0)\n        LOG.info(\"participant %s: training done\", self.p_id)\n        return self.model\n\n    def serialize_training_result(self, training_result: list) -> list:\n        return training_result\n\n    def participate_in_update_task(self) -> bool:\n        return True\n\n    def on_new_global_model(self, global_model: Optional[list]) -> None:\n        if global_model is not None:\n            with open(\"global_model.bin\", \"w\") as filehandle:\n                filehandle.write(json.dumps(global_model))\n\n\ndef main() -> None:\n    logging.basicConfig(\n        format=\"%(asctime)s.%(msecs)03d %(levelname)8s %(message)s\",\n        level=logging.DEBUG,\n        datefmt=\"%b %d %H:%M:%S\",\n    )\n\n    participant = xaynet_sdk.spawn_participant(\n        \"http://127.0.0.1:8081\",\n        Participant,\n        args=(\n            1,\n            [0.1, 0.2, 0.345, 0.3],\n        ),\n    )\n\n    participant_2 = xaynet_sdk.spawn_participant(\n        \"http://127.0.0.1:8081\",\n        Participant,\n        args=(\n            2,\n            [0.3, 0.4, 0.45, 0.1],\n        ),\n    )\n\n    participant_3 = xaynet_sdk.spawn_participant(\n        \"http://127.0.0.1:8081\",\n        Participant,\n        args=(\n            3,\n            [0.123, 0.1567, 0.123, 0.46],\n        ),\n    )\n\n    try:\n        participant.join()\n        participant_2.join()\n        participant_3.join()\n    except KeyboardInterrupt:\n        participant.stop()\n        participant_2.stop()\n        participant_3.stop()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "bindings/python/examples/participate_in_update.py",
    "content": "\"\"\"Only train a model when there is enough battery left\"\"\"\n\nimport json\nimport logging\nfrom random import randint\nimport time\nfrom typing import Optional\n\nimport xaynet_sdk\n\nLOG = logging.getLogger(__name__)\n\n\ndef get_battery_level():\n    return randint(1, 100)\n\n\nclass Participant(xaynet_sdk.ParticipantABC):\n    def __init__(self, model: list) -> None:\n        self.model = model\n        super().__init__()\n\n    def deserialize_training_input(self, global_model: list) -> list:\n        return global_model\n\n    def train_round(self, training_input: Optional[list]) -> list:\n        LOG.info(\"training\")\n        time.sleep(3.0)\n        LOG.info(\"training done\")\n        return self.model\n\n    def serialize_training_result(self, training_result: list) -> list:\n        return training_result\n\n    def participate_in_update_task(self) -> bool:\n        if get_battery_level() < 20:\n            LOG.info(\"low battery, skip training\")\n            return False\n        LOG.info(\"enough battery, participate in update task\")\n        return True\n\n    def on_new_global_model(self, global_model: Optional[list]) -> None:\n        if global_model is not None:\n            with open(\"global_model.bin\", \"w\") as filehandle:\n                filehandle.write(json.dumps(global_model))\n\n\ndef main() -> None:\n    logging.basicConfig(\n        format=\"%(asctime)s.%(msecs)03d %(levelname)8s %(message)s\",\n        level=logging.DEBUG,\n        datefmt=\"%b %d %H:%M:%S\",\n    )\n\n    participant = xaynet_sdk.spawn_participant(\n        \"http://127.0.0.1:8081\", Participant, args=([0.1, 0.2, 0.345, 0.3],)\n    )\n\n    try:\n        participant.join()\n    except KeyboardInterrupt:\n        participant.stop()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "bindings/python/examples/restore.py",
    "content": "\"\"\"Save and restore the state of an `AsyncParticipant`\"\"\"\n\nimport json\nimport logging\n\nimport xaynet_sdk\n\nLOG = logging.getLogger(__name__)\n\n\ndef main() -> None:\n    logging.basicConfig(\n        format=\"%(asctime)s.%(msecs)03d %(levelname)8s %(message)s\",\n        level=logging.DEBUG,\n        datefmt=\"%b %d %H:%M:%S\",\n    )\n\n    try:\n        with open(\"state.bin\", \"r\") as filehandle:\n            restored_state = json.loads(filehandle.read())\n    except IOError:\n        LOG.info(\"no saved state available, initialize new participant\")\n        restored_state = None\n\n    (participant, _) = xaynet_sdk.spawn_async_participant(\n        \"http://127.0.0.1:8081\", restored_state\n    )\n\n    state = participant.stop()\n    with open(\"state.bin\", \"w\") as filehandle:\n        filehandle.write(json.dumps(state))\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "bindings/python/migration_guide.md",
    "content": "# Migration from `v0.8.0` to `v.0.11.0`\n\nTo demonstrate the API changes from `v0.8.0` to `v.0.11.0`, we will use the keras example\nwhich is available in both versions. For reasons of clarity, some parts of the code have\nbeen removed.\n\n## [`v0.8.0`](https://github.com/xaynetwork/xaynet/blob/v0.8.0/python/sdk/xain_sdk/participant.py#L24)\n\n```bash\npip install xain-sdk\n```\n\n```python\nfrom xain_sdk import ParticipantABC, configure_logging, run_participant\n\nclass Participant(ParticipantABC):\n    def train_round(\n        self, training_input: Optional[np.ndarray]\n    ) -> Tuple[np.ndarray, int]:\n        if training_input is None:\n            self.regressor = Regressor(len(self.trainset_x.columns))\n            return (self.regressor.get_weights(), 0)\n\n        return (self.regressor.get_weights(), self.number_of_samples)\n\n    def deserialize_training_input(self, data: bytes) -> Optional[np.ndarray]:\n        if not data:\n            return None\n\n        reader = BytesIO(data)\n        return np.load(reader, allow_pickle=False)\n\n    def serialize_training_result(\n        self, training_result: Tuple[np.ndarray, int]\n    ) -> bytes:\n        (weights, number_of_samples) = training_result\n\n        writer = BytesIO()\n        writer.write(number_of_samples.to_bytes(4, byteorder=\"big\"))\n        np.save(writer, weights, allow_pickle=False)\n        return writer.getbuffer()[:]\n\ndef main() -> None:\n    participant = Participant(args.data_directory)\n\n    run_participant(\n        participant, args.coordinator_url, heartbeat_period=args.heartbeat_period\n    )\n```\n\n## [`v0.11.0`](https://github.com/xaynetwork/xaynet/blob/v0.11.0/bindings/python/xaynet_sdk/participant.py)\n\n```bash\npip install xaynet-sdk-python\n```\n\n```python\n# - renamed `run_participant` to `spawn_participant`\n# - removed `configure_logging`\nfrom xaynet_sdk import ParticipantABC, spawn_participant\n\nclass Participant(ParticipantABC):\n    # Returns:\n    #   - returns a `np.ndarray` instead of `Tuple[np.ndarray, int]`\n    #     The scalar has been moved to the `spawn_participant` function.\n    #     This change is only temporary. In a future version it will again\n    #     be possible to set the scalar in the `train_round` method.\n    def train_round(self, training_input: Optional[np.ndarray]) -> np.ndarray:\n        if training_input is None:\n            self.regressor = Regressor(len(self.trainset_x.columns))\n            return self.regressor.get_weights()\n\n        return self.regressor.get_weights()\n\n    # Args:\n    #   - renamed `data` to `global_model`\n    #   - provides a `list` instead of `Optional[bytes]`\n    #   - `deserialize_training_input` is not called if `global_model` is `None`\n    #     therefore the `None` case no longer needs to be handled.\n    #\n    # Returns:\n    #   - returns a `np.ndarray` instead of `Optional[np.ndarray]`\n    def deserialize_training_input(self, global_model: list) -> np.ndarray:\n        return np.array(global_model)\n\n    # Args:\n    #   - provides a `np.ndarray` instead of `Tuple[np.ndarray, int]`\n    #\n    # Returns:\n    #   - returns a `list` instead of `bytes`\n    def serialize_training_result(self, training_result: np.ndarray) -> list:\n        return training_result.tolist()\n\ndef main() -> None:\n    # - `spawn_participant` spawns the participant in a separate thread instead of the main thread.\n    #\n    # Args:\n    #   - removed `heartbeat_period`\n    #   - `Participant` is instantiated in the participant thread instead of the main thread.\n    #     This ensures that both the participant as well as the model of `Participant` live on\n    #     the same thread. If they don't live on the same thread, it can cause problems with some\n    #     of the ml frameworks.\n    participant = spawn_participant(\n        args.coordinator_url,\n        Participant,\n        args=(args.data_directory,)\n        scalar = 1 / number_of_samples\n    )\n\n    try:\n        participant.join()\n    except KeyboardInterrupt:\n        participant.stop()\n```\n"
  },
  {
    "path": "bindings/python/src/lib.rs",
    "content": "pub mod python_ffi;\n"
  },
  {
    "path": "bindings/python/src/python_ffi.rs",
    "content": "use pyo3::create_exception;\nuse pyo3::exceptions::PyException;\nuse pyo3::types::PyList;\nuse pyo3::{prelude::*, wrap_pyfunction};\nuse tracing::debug;\nuse tracing_subscriber::{EnvFilter, FmtSubscriber};\n\nuse xaynet_core::mask::IntoPrimitives;\nuse xaynet_core::mask::{DataType, FromPrimitives, Model};\nuse xaynet_sdk::settings::MaxMessageSize;\n\nuse crate::from_primitives;\nuse crate::into_primitives;\n\ncreate_exception!(xaynet_sdk, CryptoInit, PyException);\ncreate_exception!(xaynet_sdk, ParticipantInit, PyException);\ncreate_exception!(xaynet_sdk, ParticipantRestore, PyException);\ncreate_exception!(xaynet_sdk, UninitializedParticipant, PyException);\ncreate_exception!(xaynet_sdk, LocalModelLengthMisMatch, PyException);\ncreate_exception!(xaynet_sdk, LocalModelDataTypeMisMatch, PyException);\ncreate_exception!(xaynet_sdk, GlobalModelUnavailable, PyException);\ncreate_exception!(xaynet_sdk, GlobalModelDataTypeMisMatch, PyException);\n\n#[pymodule]\nfn xaynet_sdk(py: Python, m: &PyModule) -> PyResult<()> {\n    m.add_class::<Participant>()?;\n    m.add_function(wrap_pyfunction!(init_logging, m)?)?;\n\n    m.add(\"CryptoInit\", py.get_type::<CryptoInit>())?;\n    m.add(\"ParticipantInit\", py.get_type::<ParticipantInit>())?;\n    m.add(\"ParticipantRestore\", py.get_type::<ParticipantRestore>())?;\n    m.add(\n        \"UninitializedParticipant\",\n        py.get_type::<UninitializedParticipant>(),\n    )?;\n    m.add(\n        \"LocalModelLengthMisMatch\",\n        py.get_type::<LocalModelLengthMisMatch>(),\n    )?;\n    m.add(\n        \"LocalModelDataTypeMisMatch\",\n        py.get_type::<LocalModelDataTypeMisMatch>(),\n    )?;\n    m.add(\n        \"GlobalModelUnavailable\",\n        py.get_type::<GlobalModelUnavailable>(),\n    )?;\n    m.add(\n        \"GlobalModelDataTypeMisMatch\",\n        py.get_type::<GlobalModelDataTypeMisMatch>(),\n    )?;\n\n    Ok(())\n}\n\n#[pyclass]\n#[text_signature = \"(url, scalar, /)\"]\nstruct Participant {\n    inner: Option<xaynet_mobile::Participant>,\n}\n\n#[pymethods]\nimpl Participant {\n    #[new]\n    pub fn new(url: String, scalar: f64, state: Option<Vec<u8>>) -> PyResult<Self> {\n        sodiumoxide::init()\n            .map_err(|_| CryptoInit::new_err(\"failed to initialize crypto library\"))?;\n\n        let inner = if let Some(state) = state {\n            debug!(\"restore participant\");\n            xaynet_mobile::Participant::restore(&state, &url).map_err(|err| {\n                ParticipantRestore::new_err(format!(\"failed to restore participant: {}\", err))\n            })?\n        } else {\n            debug!(\"initialize participant\");\n            let mut settings = xaynet_mobile::Settings::new();\n            settings.set_url(url);\n            settings.set_keys(xaynet_core::crypto::SigningKeyPair::generate());\n            settings.set_scalar(scalar);\n            settings.set_max_message_size(MaxMessageSize::unlimited());\n\n            xaynet_mobile::Participant::new(settings).map_err(|err| {\n                ParticipantInit::new_err(format!(\"failed to initialize participant: {}\", err))\n            })?\n        };\n\n        Ok(Self { inner: Some(inner) })\n    }\n\n    #[text_signature = \"($self)\"]\n    pub fn tick(&mut self) -> PyResult<()> {\n        let inner = match self.inner {\n            Some(ref mut inner) => inner,\n            None => {\n                return Err(UninitializedParticipant::new_err(\n                    \"called 'tick' on an uninitialized participant. this is a bug.\",\n                ))\n            }\n        };\n\n        inner.tick();\n        Ok(())\n    }\n\n    #[text_signature = \"($self, local_model)\"]\n    pub fn set_model(&mut self, local_model: &PyList) -> PyResult<()> {\n        let inner = match self.inner {\n            Some(ref mut inner) => inner,\n            None => {\n                return Err(UninitializedParticipant::new_err(\n                    \"called 'set_model' on an uninitialized participant. this is a bug.\",\n                ))\n            }\n        };\n\n        let local_model_config = inner.local_model_config();\n\n        if local_model.len() != local_model_config.len {\n            return Err(LocalModelLengthMisMatch::new_err(format!(\n                \"the local model length is incompatible with the model length of the current model configuration {} != {}\",\n                local_model.len(),\n                local_model_config.len\n            )));\n        }\n\n        debug!(\n            \"convert local model to {:?} datatype\",\n            local_model_config.data_type\n        );\n\n        match local_model_config.data_type {\n            DataType::F32 => from_primitives!(inner, local_model, f32),\n            DataType::F64 => from_primitives!(inner, local_model, f64),\n            DataType::I32 => from_primitives!(inner, local_model, i32),\n            DataType::I64 => from_primitives!(inner, local_model, i64),\n        }\n    }\n\n    /// Check whether the participant internal state machine made progress while\n    /// executing the PET protocol. If so, the participant state likely changed.\n    #[text_signature = \"($self)\"]\n    pub fn made_progress(&self) -> PyResult<bool> {\n        let inner = match self.inner {\n            Some(ref inner) => inner,\n            None => {\n                return Err(UninitializedParticipant::new_err(\n                    \"called 'made_progress' on an uninitialized participant. this is a bug.\",\n                ))\n            }\n        };\n\n        Ok(inner.made_progress())\n    }\n\n    /// Check whether the participant internal state machine is waiting for the\n    /// participant to load its model into the store. If this method returns `true`, the\n    /// caller should make sure to call [`Participant::set_model()`] at some point.\n    #[text_signature = \"($self)\"]\n    pub fn should_set_model(&self) -> PyResult<bool> {\n        let inner = match self.inner {\n            Some(ref inner) => inner,\n            None => {\n                return Err(UninitializedParticipant::new_err(\n                    \"called 'should_set_model' on an uninitialized participant. this is a bug.\",\n                ))\n            }\n        };\n\n        Ok(inner.should_set_model())\n    }\n\n    #[text_signature = \"($self)\"]\n    pub fn task(&self) -> PyResult<u8> {\n        let inner = match self.inner {\n            Some(ref inner) => inner,\n            None => {\n                return Err(UninitializedParticipant::new_err(\n                    \"called 'task' on an uninitialized participant. this is a bug.\",\n                ))\n            }\n        };\n\n        // FIXME:\n        // Returning an enum is currently not supported: https://github.com/PyO3/pyo3/pull/1045\n        let task_as_u8 = match inner.task() {\n            xaynet_mobile::Task::None => 0,\n            xaynet_mobile::Task::Sum => 1,\n            xaynet_mobile::Task::Update => 2,\n        };\n\n        Ok(task_as_u8)\n    }\n\n    #[text_signature = \"($self)\"]\n    pub fn new_global_model(&self) -> PyResult<bool> {\n        let inner = match self.inner {\n            Some(ref inner) => inner,\n            None => {\n                return Err(UninitializedParticipant::new_err(\n                    \"called 'new_global_model' on an uninitialized participant. this is a bug.\",\n                ))\n            }\n        };\n\n        Ok(inner.new_global_model())\n    }\n\n    #[text_signature = \"($self)\"]\n    pub fn global_model(&mut self, py: Python) -> PyResult<Option<Py<PyList>>> {\n        let inner = match self.inner {\n            Some(ref mut inner) => inner,\n            None => {\n                return Err(UninitializedParticipant::new_err(\n                    \"called 'global_model' on an uninitialized participant. this is a bug.\",\n                ))\n            }\n        };\n\n        let global_model = inner\n            .global_model()\n            .map_err(|_| GlobalModelUnavailable::new_err(\"failed to fetch global model\"))?;\n\n        let global_model = match global_model {\n            Some(global_model) => global_model,\n            None => return Ok(None),\n        };\n\n        match inner.local_model_config().data_type {\n            DataType::F32 => into_primitives!(py, global_model, f32),\n            DataType::F64 => into_primitives!(py, global_model, f64),\n            DataType::I32 => into_primitives!(py, global_model, i32),\n            DataType::I64 => into_primitives!(py, global_model, i64),\n        }\n    }\n\n    #[text_signature = \"($self)\"]\n    pub fn save(&mut self) -> PyResult<Vec<u8>> {\n        let inner = match self.inner.take() {\n            Some(inner) => inner,\n            None => {\n                return Err(UninitializedParticipant::new_err(\n                    \"called 'save' on an uninitialized participant. this is a bug.\",\n                ))\n            }\n        };\n\n        Ok(inner.save())\n    }\n}\n\n#[macro_export]\nmacro_rules! into_primitives {\n    ($py:expr, $global_model:expr, $data_type:ty) => {\n        if let Ok(global_model) = $global_model\n            .into_primitives()\n            .collect::<Result<Vec<$data_type>, _>>()\n        {\n            let py_list = PyList::new($py, global_model.into_iter());\n            Ok(Some(py_list.into()))\n        } else {\n            Err(GlobalModelDataTypeMisMatch::new_err(\n                \"the global model data type is incompatible with the data type of the current model configuration\",\n            ))\n        }\n    };\n}\n\n#[macro_export]\nmacro_rules! from_primitives {\n    ($participant:expr, $local_model:expr, $data_type:ty) => {{\n            let model: Vec<$data_type> = $local_model.extract()\n                .map_err(|err| LocalModelDataTypeMisMatch::new_err(format!(\"{}\", err)))?;\n            let converted_model = Model::from_primitives(model.into_iter());\n            if let Ok(converted_model) = converted_model {\n                $participant.set_model(converted_model);\n                Ok(())\n            } else {\n                Err(LocalModelDataTypeMisMatch::new_err(\n                    \"the local model data type is incompatible with the data type of the current model configuration\"\n                ))\n            }}\n    };\n}\n\n#[pyfunction]\nfn init_logging() {\n    let env_filter = EnvFilter::try_from_env(\"XAYNET__CLIENT\");\n    if let Ok(filter) = env_filter {\n        let _fmt_subscriber = FmtSubscriber::builder()\n            .with_env_filter(filter)\n            .with_ansi(true)\n            .try_init();\n    }\n}\n"
  },
  {
    "path": "bindings/python/xaynet_sdk/__init__.py",
    "content": "import threading\nfrom typing import List, Optional, Tuple\n\nfrom .async_participant import *\nfrom .participant import *\n\n\ndef spawn_participant(\n    coordinator_url: str,\n    participant: ParticipantABC,\n    args: Tuple = (),\n    kwargs: dict = {},\n    state: Optional[List[int]] = None,\n    scalar: float = 1.0,\n):\n    \"\"\"\n    Spawns a `InternalParticipant` in a separate thread and returns a participant handle.\n    If a `state` is passed, this state is restored, otherwise a new `InternalParticipant`\n    is created.\n\n    Args:\n        coordinator_url: The url of the coordinator.\n        participant: A class that implements `ParticipantABC`.\n        args: The args that get passed to the constructor of the `participant` class.\n        kwargs: The kwargs that get passed to the constructor of the `participant` class.\n        state: A serialized participant state. Defaults to `None`.\n        scalar: The scalar used for masking. Defaults to `1.0`.\n\n    Note:\n        The `scalar` is used later when the models are aggregated in order to scale their weights.\n        It can be used when you want to weight the participants updates differently.\n\n        For example:\n        If not all participant updates should be weighted equally but proportionally to their\n        training samples, the scalar would be set to `scalar = 1 / number_of_samples`.\n\n    Returns:\n        The `InternalParticipant`.\n\n    Raises:\n        CryptoInit: If the initialization of the underling crypto library has failed.\n        ParticipantInit: If the participant cannot be initialized. This is most\n            likely caused by an invalid `coordinator_url`.\n        ParticipantRestore: If the participant cannot be restored due to invalid\n            serialized state. This exception can never be thrown if the `state` is `None`.\n        Exception: Any exception that can be thrown during the instantiation of `participant`.\n    \"\"\"\n    internal_participant = InternalParticipant(\n        coordinator_url, participant, args, kwargs, state, scalar\n    )\n    # spawns the internal participant in a thread.\n    # `start` calls the `run` method of `InternalParticipant`\n    # https://docs.python.org/3.8/library/threading.html#threading.Thread.start\n    # https://docs.python.org/3.8/library/threading.html#threading.Thread.run\n    internal_participant.start()\n    return internal_participant\n\n\ndef spawn_async_participant(\n    coordinator_url: str, state: Optional[List[int]] = None, scalar: float = 1.0\n) -> (AsyncParticipant, threading.Event):\n    \"\"\"\n    Spawns a `AsyncParticipant` in a separate thread and returns a participant handle\n    together with a global model notifier. If a `state` is passed, this state is restored,\n    otherwise a new participant is created.\n\n    Args:\n        coordinator_url: The url of the coordinator.\n        state: A serialized participant state. Defaults to `None`.\n        scalar: The scalar used for masking. Defaults to `1.0`.\n\n    Note:\n        The `scalar` is used later when the models are aggregated in order to scale their weights.\n        It can be used when you want to weight the participants updates differently.\n\n        For example:\n        If not all participant updates should be weighted equally but proportionally to their\n        training samples, the scalar would be set to `scalar = 1 / number_of_samples`.\n\n    Returns:\n        A tuple which consists of an `AsyncParticipant` and a global model notifier.\n\n    Raises:\n        CryptoInit: If the initialization of the underling crypto library has failed.\n        ParticipantInit: If the participant cannot be initialized. This is most\n            likely caused by an invalid `coordinator_url`.\n        ParticipantRestore: If the participant cannot be restored due to invalid\n            serialized state. This exception can never be thrown if the `state` is `None`.\n    \"\"\"\n    notifier = threading.Event()\n    async_participant = AsyncParticipant(coordinator_url, notifier, state, scalar)\n    async_participant.start()\n    return (async_participant, notifier)\n"
  },
  {
    "path": "bindings/python/xaynet_sdk/async_participant.py",
    "content": "import logging\nimport threading\nfrom typing import List, Optional\n\nfrom justbackoff import Backoff\n\nfrom xaynet_sdk import xaynet_sdk\n\n# rust participant logging\nxaynet_sdk.init_logging()\n# python participant logging\nLOG = logging.getLogger(\"participant\")\n\n\nclass AsyncParticipant(threading.Thread):\n    def __init__(\n        self,\n        coordinator_url: str,\n        notifier,\n        state,\n        scalar,\n    ):\n        # xaynet rust participant\n        self._xaynet_participant = xaynet_sdk.Participant(\n            coordinator_url, scalar, state\n        )\n\n        self._exit_event = threading.Event()\n        self._poll_period = Backoff(min_ms=100, max_ms=10000, factor=1.2, jitter=False)\n\n        # new global model notifier\n        self._notifier = notifier\n\n        # calls to an external lib are thread-safe https://stackoverflow.com/a/42023362\n        # however, if a user calls `stop` in the middle of the `_tick` call, the\n        # `save` method will be executed (which consumes the participant) and every following call\n        # will fail with a call on an uninitialized participant. Therefore we lock during `tick`.\n        self._tick_lock = threading.Lock()\n\n        super().__init__(daemon=True)\n\n    def run(self):\n        try:\n            self._run()\n        except Exception as err:  # pylint: disable=broad-except\n            LOG.error(\"unrecoverable error: %s shut down participant\", err)\n            self._exit_event.set()\n\n    def _notify(self):\n        if self._notifier.is_set() is False:\n            LOG.debug(\"notify that a new global model is available\")\n            self._notifier.set()\n\n    def _run(self):\n        while not self._exit_event.is_set():\n            self._tick()\n\n    def _tick(self):\n        with self._tick_lock:\n            self._xaynet_participant.tick()\n            new_global_model = self._xaynet_participant.new_global_model()\n            made_progress = self._xaynet_participant.made_progress()\n\n        if new_global_model:\n            self._notify()\n\n        if made_progress:\n            self._poll_period.reset()\n            self._exit_event.wait(timeout=self._poll_period.duration())\n        else:\n            self._exit_event.wait(timeout=self._poll_period.duration())\n\n    def get_global_model(self) -> Optional[list]:\n        \"\"\"\n        Fetches the current global model. This method can be called at any time. If no global\n        model exists (usually in the first round), the method returns `None`.\n\n        Returns:\n            The current global model in the form of a list or `None`. The data type of the\n            elements match the data type defined in the coordinator configuration.\n\n        Raises:\n            GlobalModelUnavailable: If the participant cannot connect to the coordinator to get\n                the global model.\n            GlobalModelDataTypeMisMatch: If the data type of the global model does not match\n                the data type defined in the coordinator configuration.\n        \"\"\"\n        LOG.debug(\"get global model\")\n        self._notifier.clear()\n        with self._tick_lock:\n            return self._xaynet_participant.global_model()\n\n    def set_local_model(self, local_model: list):\n        \"\"\"\n        Sets a local model. This method can be called at any time. Internally the\n        participant first caches the local model. As soon as the participant is selected as an\n        update participant, the currently cached local model is used. This means that the cache\n        is empty after this operation.\n\n        If a local model is already in the cache and `set_local_model` is called with a new local\n        model, the current cached local model will be replaced by the new one.\n        If the participant is an update participant and there is no local model in the cache,\n        the participant waits until a local model is set or until a new round has been started.\n\n        Args:\n            local_model: The local model in the form of a list. The data type of the\n                elements must match the data type defined in the coordinator configuration.\n\n        Raises:\n            LocalModelLengthMisMatch: If the length of the local model does not match the\n                length defined in the coordinator configuration.\n            LocalModelDataTypeMisMatch: If the data type of the local model does not match\n                the data type defined in the coordinator configuration.\n        \"\"\"\n        LOG.debug(\"set local model in model store\")\n        with self._tick_lock:\n            self._xaynet_participant.set_model(local_model)\n\n    def stop(self) -> List[int]:\n        \"\"\"\n        Stops the execution of the participant and returns its serialized state.\n        The serialized state can be passed to the `spawn_async_participant` function\n        to restore a participant.\n\n        After calling `stop`, the participant is consumed. Every further method\n        call on the handle of `AsyncParticipant` leads to an `UninitializedParticipant`\n        exception.\n\n        Note:\n            The serialized state contains unencrypted **private key(s)**. If used\n            in production, it is important that the serialized state is securely saved.\n\n        Returns:\n            The serialized state of the participant.\n        \"\"\"\n        LOG.debug(\"stop participant\")\n        self._exit_event.set()\n        self._notifier.clear()\n        with self._tick_lock:\n            return self._xaynet_participant.save()\n"
  },
  {
    "path": "bindings/python/xaynet_sdk/participant.py",
    "content": "from abc import ABC, abstractmethod\nimport logging\nimport threading\nfrom typing import List, Optional, TypeVar\n\nfrom justbackoff import Backoff\n\nfrom xaynet_sdk import xaynet_sdk\n\n# rust participant logging\nxaynet_sdk.init_logging()\n# python participant logging\nLOG = logging.getLogger(\"participant\")\n\nTrainingResult = TypeVar(\"TrainingResult\")\nTrainingInput = TypeVar(\"TrainingInput\")\n\n\nclass ParticipantABC(ABC):\n    @abstractmethod\n    def train_round(self, training_input: Optional[TrainingInput]) -> TrainingResult:\n        \"\"\"\n        Trains a model. `training_input` is the deserialized global model\n        (see `deserialize_training_input`). If no global model exists\n        (usually in the first round), `training_input` will be `None`.\n        In this case the weights of the model should be initialized and returned.\n\n        Args:\n            self: The participant.\n            training_input: The deserialized global model (weights of the global model) or None.\n\n        Returns:\n            The updated model weights (the local model).\n        \"\"\"\n        raise NotImplementedError()\n\n    @abstractmethod\n    def serialize_training_result(self, training_result: TrainingResult) -> list:\n        \"\"\"\n        Serializes the `training_result` into a `list`. The data type of the\n        elements must match the data type defined in the coordinator configuration.\n\n        Args:\n            self: The participant.\n            training_result: The `TrainingResult` of `train_round`.\n\n        Returns:\n            The `training_result` as a `list`.\n        \"\"\"\n        raise NotImplementedError()\n\n    @abstractmethod\n    def deserialize_training_input(self, global_model: list) -> TrainingInput:\n        \"\"\"\n        Deserializes the `global_model` from a `list` to the type of `TrainingInput`.\n        The data type of the elements matches the data type defined in the coordinator\n        configuration. If no global model exists (usually in the first round), the method will\n        not be called by the `InternalParticipant`.\n\n        Args:\n            self: The participant.\n            global_model: The global model.\n\n        Returns:\n            The `TrainingInput` for `train_round`.\n        \"\"\"\n        raise NotImplementedError()\n\n    def participate_in_update_task(self) -> bool:\n        \"\"\"\n        A callback used by the `InternalParticipant` to determine whether the\n        `train_round` method should be called. This callback is only called\n        if the participant is selected as an update participant. If `participate_in_update_task`\n        returns `False`, `train_round` will not be called by the `InternalParticipant`.\n\n        If the method is not overridden, it returns `True` by default.\n\n        Returns:\n            Whether the `train_round` method should be called when the participant\n            is an update participant.\n        \"\"\"\n        return True\n\n    def on_new_global_model(self, global_model: Optional[TrainingInput]) -> None:\n        \"\"\"\n        A callback that is called by the `InternalParticipant` once a new global model is\n        available. If no global model exists (usually in the first round), `global_model` will\n        be `None`. If a global model exists, `global_model` is already the deserialized\n        global model. (See `deserialize_training_input`)\n\n        If the method is not overridden, it does nothing by default.\n\n        Args:\n            self: The participant.\n            global_model: The deserialized global model or `None`.\n        \"\"\"\n\n    def on_stop(self) -> None:\n        \"\"\"\n        A callback that is called by the `InternalParticipant` before the `InternalParticipant`\n        thread is stopped.\n\n        This callback can be used, for example, to show performance values ​​that have been\n        collected in the participant over the course of the training rounds.\n\n        If the method is not overridden, it does nothing by default.\n\n        Args:\n            self: The participant.\n        \"\"\"\n\n\nclass InternalParticipant(threading.Thread):\n    def __init__(\n        self,\n        coordinator_url: str,\n        participant,\n        p_args,\n        p_kwargs,\n        state,\n        scalar,\n    ):\n        # xaynet rust participant\n        self._xaynet_participant = xaynet_sdk.Participant(\n            coordinator_url, scalar, state\n        )\n\n        # https://github.com/python/cpython/blob/3.9/Lib/multiprocessing/process.py#L80\n        # stores the Participant class with its args and kwargs\n        # the participant is created in the `run` method to ensure that the participant/ ml\n        # model is initialized on the participant thread otherwise the participant lives on the main\n        # thread which can created issues with some of the ml frameworks.\n        self._participant = participant\n        self._p_args = tuple(p_args)\n        self._p_kwargs = dict(p_kwargs)\n\n        self._exit_event = threading.Event()\n        self._poll_period = Backoff(min_ms=100, max_ms=10000, factor=1.2, jitter=False)\n\n        # global model cache\n        self._global_model = None\n        self._error_on_fetch_global_model = False\n\n        self._tick_lock = threading.Lock()\n\n        super().__init__(daemon=True)\n\n    def run(self):\n        self._participant = self._participant(*self._p_args, *self._p_kwargs)\n\n        try:\n            self._run()\n        except Exception as err:  # pylint: disable=broad-except\n            LOG.error(\"unrecoverable error: %s shut down participant\", err)\n            self._exit_event.set()\n\n    def _fetch_global_model(self):\n        LOG.debug(\"fetch global model\")\n        try:\n            global_model = self._xaynet_participant.global_model()\n        except (\n            xaynet_sdk.GlobalModelUnavailable,\n            xaynet_sdk.GlobalModelDataTypeMisMatch,\n        ) as err:\n            LOG.warning(\"failed to get global model: %s\", err)\n            self._error_on_fetch_global_model = True\n        else:\n            if global_model is not None:\n                self._global_model = self._participant.deserialize_training_input(\n                    global_model\n                )\n            else:\n                self._global_model = None\n            self._error_on_fetch_global_model = False\n\n    def _train(self):\n        LOG.debug(\"train model\")\n        data = self._participant.train_round(self._global_model)\n        local_model = self._participant.serialize_training_result(data)\n        try:\n            self._xaynet_participant.set_model(local_model)\n        except (\n            xaynet_sdk.LocalModelLengthMisMatch,\n            xaynet_sdk.LocalModelDataTypeMisMatch,\n        ) as err:\n            LOG.warning(\"failed to set local model: %s\", err)\n\n    def _run(self):\n        while not self._exit_event.is_set():\n            self._tick()\n\n    def _tick(self):\n        with self._tick_lock:\n            self._xaynet_participant.tick()\n\n            if (\n                self._xaynet_participant.new_global_model()\n                or self._error_on_fetch_global_model\n            ):\n                self._fetch_global_model()\n\n                if not self._error_on_fetch_global_model:\n                    self._participant.on_new_global_model(self._global_model)\n\n            if (\n                self._xaynet_participant.should_set_model()\n                and self._participant.participate_in_update_task()\n                and not self._error_on_fetch_global_model\n            ):\n                self._train()\n\n            made_progress = self._xaynet_participant.made_progress()\n\n        if made_progress:\n            self._poll_period.reset()\n            self._exit_event.wait(timeout=self._poll_period.duration())\n        else:\n            self._exit_event.wait(timeout=self._poll_period.duration())\n\n    def stop(self) -> List[int]:\n        \"\"\"\n        Stops the execution of the participant and returns its serialized state.\n        The serialized state can be passed to the `spawn_participant` function\n        to restore a participant.\n\n        After calling `stop`, the participant is consumed. Every further method\n        call on the handle of `InternalParticipant` leads to an `UninitializedParticipant`\n        exception.\n\n        Note:\n            The serialized state contains unencrypted **private key(s)**. If used\n            in production, it is important that the serialized state is securely saved.\n\n        Returns:\n            The serialized state of the participant.\n        \"\"\"\n        LOG.debug(\"stopping participant\")\n        self._exit_event.set()\n        with self._tick_lock:\n            state = self._xaynet_participant.save()\n            LOG.debug(\"participant stopped\")\n        self._participant.on_stop()\n        return state\n"
  },
  {
    "path": "configs/config.toml",
    "content": "[log]\nfilter = \"xaynet=debug,http=warn,info\"\n\n[api]\nbind_address = \"127.0.0.1:8081\"\ntls_certificate = \"/app/ssl/tls.pem\"\ntls_key = \"/app/ssl/tls.key\"\n# tls_client_auth = \"/app/ssl/trust_anchor.pem\"\n\n[pet.sum]\nprob = 0.5\ncount = { min = 1, max = 100 }\ntime = { min = 5, max = 3600 }\n\n[pet.update]\nprob = 0.9\ncount = { min = 3, max = 10000 }\ntime = { min = 10, max = 3600 }\n\n[pet.sum2]\ncount = { min = 1, max = 100 }\ntime = { min = 5, max = 3600 }\n\n[mask]\ngroup_type = \"Prime\"\ndata_type = \"F32\"\nbound_type = \"B0\"\nmodel_type = \"M3\"\n\n[model]\nlength = 4\n\n[metrics.influxdb]\nurl = \"http://127.0.0.1:8086\"\ndb = \"metrics\"\n\n[redis]\nurl = \"redis://127.0.0.1/\"\n\n[s3]\naccess_key = \"minio\"\nsecret_access_key = \"minio123\"\nregion = [\"minio\", \"http://localhost:9000\"]\n\n[restore]\nenable = true\n"
  },
  {
    "path": "configs/docker-dev.toml",
    "content": "[log]\nfilter = \"xaynet=debug,http=warn,info\"\n\n[api]\nbind_address = \"0.0.0.0:8081\"\ntls_certificate = \"/app/ssl/tls.pem\"\ntls_key = \"/app/ssl/tls.key\"\n# tls_client_auth = \"/app/ssl/trust_anchor.pem\"\n\n[pet.sum]\nprob = 0.01\ncount = { min = 1, max = 100 }\ntime = { min = 5, max = 3600 }\n\n[pet.update]\nprob = 0.1\ncount = { min = 3, max = 10000 }\ntime = { min = 10, max = 3600 }\n\n[pet.sum2]\ncount = { min = 1, max = 100 }\ntime = { min = 5, max = 3600 }\n\n[mask]\ngroup_type = \"Prime\"\ndata_type = \"F32\"\nbound_type = \"B0\"\nmodel_type = \"M3\"\n\n[model]\nlength = 4\n\n[metrics.influxdb]\nurl = \"http://influxdb:8086\"\ndb = \"metrics\"\n\n[redis]\nurl = \"redis://redis\"\n\n[s3]\naccess_key = \"minio\"\nsecret_access_key = \"minio123\"\nregion = [\"minio\", \"http://minio:9000\"]\n\n[restore]\nenable = true\n"
  },
  {
    "path": "docker/.dev.env",
    "content": "MINIO_ACCESS_KEY=minio\nMINIO_SECRET_KEY=minio123\n"
  },
  {
    "path": "docker/Dockerfile",
    "content": "FROM buildpack-deps:stable-curl AS builder\n\nRUN apt update\n\n# Install Rust\nENV RUSTUP_HOME=/usr/local/rustup \\\n    CARGO_HOME=/usr/local/cargo \\\n    PATH=/usr/local/cargo/bin:$PATH\nRUN curl https://sh.rustup.rs -sSf | sh -s -- -y --profile minimal\n\n# install build dependencies: libc, openssl\nRUN apt install -y build-essential libssl-dev pkg-config\n\nCOPY rust/ /rust/\nWORKDIR /rust/xaynet-server\n\n# https://github.com/linkerd/linkerd2-proxy/blob/main/Dockerfile#L31\n\n# Controls which profile the coordinator is compiled with.\n# If set to RELEASE_BUILD=1, the coordinator is compiled using the release profile.\n# Default is development profile.\nARG RELEASE_BUILD=0\n\n# Controls which optional features the coordinator is compiled with.\n# Syntax:\n# default features:     -\n# single feature:       COORDINATOR_FEATURES=tls\n# multiple features:    COORDINATOR_FEATURES=tls,metrics\n# all features:         COORDINATOR_FEATURES=full\nARG COORDINATOR_FEATURES\n\nRUN mkdir -p /out && \\\n  echo \"RELEASE_BUILD=$RELEASE_BUILD COORDINATOR_FEATURES=$COORDINATOR_FEATURES\" && \\\n  if [ \"$RELEASE_BUILD\" -eq \"0\" ]; \\\n  then \\\n    cargo build --features=\"$COORDINATOR_FEATURES\" && \\\n    mv /rust/target/debug/coordinator /out/coordinator; \\\n  else \\\n    cargo build --features=\"$COORDINATOR_FEATURES\" --release && \\\n    mv /rust/target/release/coordinator /out/coordinator; \\\n  fi\n\nFROM ubuntu:20.04\nRUN apt update && apt install -y --no-install-recommends libssl-dev\nCOPY --from=builder /out/coordinator /app/coordinator\n\nENTRYPOINT [\"/app/coordinator\", \"-c\", \"/app/config.toml\"]\n"
  },
  {
    "path": "docker/docker-compose.yml",
    "content": "version: \"3.8\"\nservices:\n  coordinator:\n    image: xaynetwork/xaynet:development\n    build:\n      context: ..\n      dockerfile: docker/Dockerfile\n    depends_on:\n      - minio\n      - redis\n      - influxdb\n    volumes:\n      - ${PWD}/configs/docker-dev.toml:/app/config.toml\n    networks:\n      - xaynet\n    ports:\n      - \"8081:8081\"\n    # temporary fix:\n    # The coordinator crashes if Redis is not ready or busy at startup\n    restart: unless-stopped\n\n  influxdb:\n    image: influxdb:1.8\n    hostname: influxdb\n    container_name: influxdb\n    environment:\n      INFLUXDB_DB: metrics\n      INFLUXDB_DATA_QUERY_LOG_ENABLED: 'false'\n      INFLUXDB_HTTP_LOG_ENABLED: 'false'\n    volumes:\n      - influxdb-data:/var/lib/influxdb\n    networks:\n      - xaynet\n    ports:\n      - \"8086:8086\"\n\n  minio:\n    image: minio/minio\n    hostname: minio\n    container_name: minio\n    env_file:\n      - .dev.env\n    command: server /data\n    volumes:\n      - minio-data:/data\n    networks:\n      - xaynet\n    ports:\n      - \"9000:9000\"\n\n  redis:\n    image: redis:6\n    hostname: redis\n    container_name: redis\n    entrypoint: /usr/local/bin/redis-server --appendonly yes --appendfsync everysec # using combination of RDB and AOF for persistence: https://redis.io/topics/persistence\n    volumes:\n      - redis-data:/data\n    networks:\n      - xaynet\n    ports:\n      - \"6379:6379\"\n\nvolumes:\n  minio-data:\n  redis-data:\n  influxdb-data:\n\nnetworks:\n  xaynet:\n"
  },
  {
    "path": "k8s/coordinator/base/deployment.yaml",
    "content": "apiVersion: apps/v1\nkind: Deployment\nmetadata:\n  name: coordinator-deployment\nspec:\n  selector:\n    matchLabels:\n      app: coordinator\n  replicas: 1\n  strategy:\n    type: Recreate\n  template:\n    metadata:\n      labels:\n        app: coordinator\n    spec:\n      containers:\n        - name: coordinator\n          image: coordinator\n          imagePullPolicy: Always\n          ports:\n            - containerPort: 8081\n              protocol: TCP\n          env:\n            - name: REDIS_AUTH\n              valueFrom:\n                secretKeyRef:\n                  name: redis-auth\n                  key: redis-password\n            - name: XAYNET__REDIS__URL\n              value: \"redis://:$(REDIS_AUTH)@redis-master\"\n            - name: XAYNET__S3__ACCESS_KEY\n              valueFrom:\n                secretKeyRef:\n                  name: minio-auth\n                  key: accesskey\n            - name: XAYNET__S3__SECRET_ACCESS_KEY\n              valueFrom:\n                secretKeyRef:\n                  name: minio-auth\n                  key: secretkey\n"
  },
  {
    "path": "k8s/coordinator/base/kustomization.yaml",
    "content": "apiVersion: kustomize.config.k8s.io/v1beta1\nkind: Kustomization\n\ncommonLabels:\n  app.kubernetes.io/component: backend\n  app.kubernetes.io/name: coordinator\n  app.kubernetes.io/part-of: xaynet\n\nresources:\n- deployment.yaml\n- service.yaml\n"
  },
  {
    "path": "k8s/coordinator/base/service.yaml",
    "content": "apiVersion: v1\nkind: Service\nmetadata:\n  name: coordinator-service\nspec:\n  type: ClusterIP\n  ports:\n    - port: 8081\n      targetPort: 8081\n      name: http-port\n  selector:\n    app: coordinator\n"
  },
  {
    "path": "k8s/coordinator/development/cert-volume-mount.yaml",
    "content": "apiVersion: apps/v1\nkind: Deployment\nmetadata:\n  name: coordinator-deployment\nspec:\n  template:\n    spec:\n      volumes:\n        - name: tls-certificate\n          secret:\n            secretName: dev-coordinator\n            items:\n              - key: tls.key\n                path: tls.key\n                mode: 0400\n              - key: tls.crt\n                path: tls.pem\n                mode: 0444\n      containers:\n        - name: coordinator\n          volumeMounts:\n            - name: tls-certificate \n              mountPath: \"/app/ssl\"\n              readOnly: true\n"
  },
  {
    "path": "k8s/coordinator/development/config-volume-mount.yaml",
    "content": "apiVersion: apps/v1\nkind: Deployment\nmetadata:\n  name: coordinator-deployment\nspec:\n  template:\n    spec:\n      volumes:\n      - name: config-volume\n        configMap:\n          name: config-toml\n          items:\n            - key: config.toml\n              path: config.toml\n      containers:\n        - name: coordinator\n          volumeMounts:\n            - name: config-volume\n              mountPath: /app/config.toml\n              subPath: config.toml\n"
  },
  {
    "path": "k8s/coordinator/development/config.toml",
    "content": "[log]\nfilter = \"xaynet=debug,http=warn,info\"\n\n[api]\nbind_address = \"0.0.0.0:8081\"\ntls_certificate = \"/app/ssl/tls.pem\"\ntls_key = \"/app/ssl/tls.key\"\n\n[pet.sum]\nprob = 0.5\ncount = { min = 1, max = 100 }\ntime = { min = 5, max = 3600 }\n\n[pet.update]\nprob = 0.9\ncount = { min = 3, max = 10000 }\ntime = { min = 10, max = 3600 }\n\n[pet.sum2]\ncount = { min = 1, max = 100 }\ntime = { min = 5, max = 3600 }\n\n[mask]\ngroup_type = \"Prime\"\ndata_type = \"F32\"\nbound_type = \"B0\"\nmodel_type = \"M3\"\n\n[model]\nlength = 4\n\n[metrics.influxdb]\nurl = \"http://influxdb:8086\"\ndb = \"metrics\"\n\n[redis]\n# The url is configured via the environment variable `XAYNET__REDIS__URL`.\n# `XAYNET__REDIS__URL` depends on the environment variable `REDIS_AUTH`,\n# which is defined as a Kubernetes secret and exposed to the coordinator pod.\n# See: k8s/coordinator/base/deployment.yaml\n\n[s3]\n# The access_key and secret_access_key are configured via the environment variables\n# `XAYNET__S3__ACCESS_KEY` and `XAYNET__S3__SECRET_ACCESS_KEY`.\n# See: k8s/coordinator/base/deployment.yaml\nregion = [\"minio\", \"http://minio:9000\"]\n\n[restore]\nenable = true\n"
  },
  {
    "path": "k8s/coordinator/development/history-limit.yaml",
    "content": "apiVersion: apps/v1\nkind: Deployment\nmetadata:\n  name: coordinator-deployment\nspec:\n  revisionHistoryLimit: 0\n"
  },
  {
    "path": "k8s/coordinator/development/ingress.yaml",
    "content": "apiVersion: networking.k8s.io/v1\r\nkind: Ingress\r\nmetadata:\r\n  name: coordinator-ingress\r\n  annotations:\r\n    kubernetes.io/ingress.class: \"nginx\"\r\n    cert-manager.io/cluster-issuer: \"letsencrypt-production\"\r\nspec:\r\n  tls:\r\n    - hosts:\r\n        - dev-coordinator.xaynet.dev\r\n      secretName: dev-coordinator\r\n  rules:\r\n    - host: dev-coordinator.xaynet.dev\r\n      http:\r\n        paths:\r\n          - path: /\r\n            pathType: Prefix\r\n            backend:\r\n              service:\r\n                name: coordinator-service\r\n                port:\r\n                  number: 8081\r\n"
  },
  {
    "path": "k8s/coordinator/development/kustomization.yaml",
    "content": "apiVersion: kustomize.config.k8s.io/v1beta1\nkind: Kustomization\n\nnamespace: xaynet\n\nimages:\n  - name: coordinator\n    newName: xaynetwork/xaynet\n    newTag: development\n\nconfigMapGenerator:\n  - name: config-toml\n    files:\n    - config.toml\n\nbases:\n  - ../base\n\npatchesStrategicMerge:\n  - history-limit.yaml\n  - config-volume-mount.yaml\n  - cert-volume-mount.yaml\nresources:\n  - ingress.yaml\n"
  },
  {
    "path": "rust/.gitignore",
    "content": "# https://github.com/github/gitignore/blob/master/Rust.gitignore\n# Generated by Cargo\n# will have compiled files and executables\n/target/\n\n# These are backup files generated by rustfmt\n**/*.rs.bk\n/benches/target/"
  },
  {
    "path": "rust/Cargo.toml",
    "content": "[workspace]\nmembers = [\n    \"xaynet\",\n    # \"xaynet-analytics\",\n    \"xaynet-core\",\n    \"xaynet-mobile\",\n    \"xaynet-server\",\n    \"xaynet-sdk\",\n\n    # internals\n    \"benches\",\n    \"examples\",\n]\n\n[workspace.metadata]\n# minimum supported rust version\nmsrv = \"1.51.0\"\n"
  },
  {
    "path": "rust/benches/Cargo.toml",
    "content": "[package]\nname = \"benches\"\nversion = \"0.0.0\"\nauthors = [\"Xayn Engineering <engineering@xaynet.dev>\"]\nedition = \"2018\"\ndescription = \"The Xayn Network project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant.\"\nreadme = \"../../README.md\"\nhomepage = \"https://xaynet.dev/\"\nrepository = \"https://github.com/xaynetwork/xaynet/\"\nlicense-file = \"../../LICENSE\"\nkeywords = [\"federated-learning\", \"fl\", \"ai\", \"machine-learning\"]\ncategories = [\"science\", \"cryptography\"]\npublish = false\n\n[dev-dependencies]\ncriterion = { version = \"0.3.6\", features = [\"html_reports\"] }\nnum = \"0.4.0\"\npaste = \"1.0.8\"\nxaynet-core = { path = \"../xaynet-core\", features = [\"testutils\"] }\n\n[[bench]]\nname = \"sum_message\"\npath = \"messages/sum.rs\"\nharness = false\n\n[[bench]]\nname = \"update_message\"\npath = \"messages/update.rs\"\nharness = false\n\n[[bench]]\nname = \"models_from_primitives\"\npath = \"models/from_primitives.rs\"\nharness = false\n\n[[bench]]\nname = \"models_to_primitives\"\npath = \"models/to_primitives.rs\"\nharness = false\n"
  },
  {
    "path": "rust/benches/messages/sum.rs",
    "content": "use std::time::Duration;\n\nuse criterion::{black_box, criterion_group, criterion_main, Criterion};\n\nuse xaynet_core::{\n    crypto::{ByteObject, SecretSigningKey},\n    message::Message,\n    testutils::messages as helpers,\n};\n\n// `Message::to_bytes` takes a secret key as argument. It is not\n// actually used, since the message we generate already contains a\n// (dummy) signature.\nfn participant_sk() -> SecretSigningKey {\n    SecretSigningKey::from_slice(vec![2; 64].as_slice()).unwrap()\n}\n\nfn to_bytes(crit: &mut Criterion) {\n    let (sum_message, _) = helpers::message(helpers::sum::payload);\n    let buf_len = sum_message.buffer_length();\n    let mut pre_allocated_buf = vec![0; buf_len];\n\n    // the benchmarks run under 20 ns. The results for such\n    // benchmarks can vary a bit more so we:\n    //   - eliminate outliers a bit more aggressively (confidence level)\n    //   - increase the noise threshold\n    //\n    // Note: criterion always reports p = 0.0 so lowering the\n    // significance level doesn't change anything\n    let mut crit = crit.benchmark_group(\"serialize sum message to bytes\");\n    crit.confidence_level(0.9).noise_threshold(0.05);\n\n    crit.bench_function(\"compute sum message buffer length\", |bench| {\n        bench.iter(|| black_box(&sum_message).buffer_length())\n    });\n\n    crit.bench_function(\"serialize sum message to bytes\", |bench| {\n        bench.iter(|| {\n            sum_message.to_bytes(\n                black_box(&mut pre_allocated_buf),\n                black_box(&participant_sk()),\n            )\n        })\n    });\n}\n\nfn from_bytes(crit: &mut Criterion) {\n    let sum_message = helpers::message(helpers::sum::payload).0;\n    let mut bytes = vec![0; sum_message.buffer_length()];\n    sum_message.to_bytes(&mut bytes, &participant_sk());\n\n    // This benchmark is also quite unstable so make it a bit more\n    // relaxed\n    let mut crit = crit.benchmark_group(\"deserialize sum message from bytes\");\n    crit.confidence_level(0.9).noise_threshold(0.05);\n\n    crit.bench_function(\"deserialize sum message from bytes\", |bench| {\n        bench.iter(|| Message::from_byte_slice(&black_box(bytes.as_slice())))\n    });\n}\n\ncriterion_group!(\n    name = bench_sum_message;\n    // By default criterion collection 100 sample and the\n    // measurement time is 5 seconds, but the results are\n    // quite unstable with this configuration. This\n    // config makes the benchmarks running longer but\n    // provide more reliable results\n    config = Criterion::default().sample_size(1000).measurement_time(Duration::new(10, 0));\n    targets =\n        to_bytes,\n        from_bytes,\n);\ncriterion_main!(bench_sum_message);\n"
  },
  {
    "path": "rust/benches/messages/update.rs",
    "content": "use criterion::{black_box, criterion_group, criterion_main, Criterion};\nuse paste::paste;\n\nuse xaynet_core::{\n    message::{FromBytes, ToBytes, Update},\n    testutils::multipart as helpers,\n};\n\nfn make_update(dict_len: usize, mask_len: usize, total_expected_len: usize) -> (Update, Vec<u8>) {\n    let update = helpers::update(dict_len, mask_len);\n    // just check that we made our calculation right\n    // message size = dict_len + mask_len + 64*2\n    assert_eq!(update.buffer_length(), total_expected_len);\n    let mut bytes = vec![0; update.buffer_length()];\n    update.to_bytes(&mut bytes);\n    (update, bytes)\n}\n\nmacro_rules! fn_from_bytes {\n    ($name: ident, $dict_len: expr, $mask_len: expr, $total_len: expr) => {\n        paste! {\n            #[allow(non_snake_case)]\n            fn [<from_bytes $name>](crit: &mut Criterion) {\n                let (_, bytes) = make_update($dict_len, $mask_len, $total_len);\n                let name = &stringify!($name)[1..];\n                let mut crit = crit.benchmark_group(format!(\"deserialize {} update from bytes\", name));\n\n                crit.bench_function(\n                    format!(\"deserialize {} update from bytes slice\", name).as_str(),\n                    |bench| {\n                        bench.iter(|| Update::from_byte_slice(&black_box(bytes.as_slice())))\n                    },\n                );\n\n                // it's less overhead to clone the iterator of bytes instead of re-creating it\n                // again in every benchmark iteration\n                let iter = bytes.into_iter();\n                crit.bench_function(\n                    format!(\"deserialize {} update from bytes stream\", name).as_str(),\n                    |bench| {\n                        bench.iter(|| Update::from_byte_stream(black_box(&mut iter.clone())))\n                    },\n                );\n            }\n        }\n    };\n}\n\n// Get an update that corresponds to:\n// - 1 sum participant (1 entry in the seed dict)\n// - a 42 bytes serialized masked model\nfn_from_bytes!(_tiny, 116, 42, 286);\n\n// Get an update that corresponds to:\n// - 1k sum participants (1k entries in the seed dict)\n// - a 6kB serialized masked model\nfn_from_bytes!(_100kB, 112_004, 6_018, 118_150);\n\n// Get an update that corresponds to:\n// - 10k sum participants (10k entries in the seed dict)\n// - a 60kB serialized masked model\nfn_from_bytes!(_1MB, 1_120_004, 60_018, 1_180_150);\n\n// Get an update that corresponds to:\n// - 10k sum participants (10k entries in the seed dict)\n// - a ~1MB serialized masked model\nfn_from_bytes!(_2MB, 1_120_004, 1_000_020, 2_120_152);\n\n// Get an update that corresponds to:\n// - 10k sum participants (10k entries in the seed dict)\n// - a ~9MB serialized masked model\nfn_from_bytes!(_10MB, 1_120_004, 9_000_018, 10_120_150);\n\ncriterion_group!(\n    name = bench_update_message;\n    config = Criterion::default();\n    targets =\n        from_bytes_tiny,\n        from_bytes_100kB,\n        from_bytes_1MB,\n        from_bytes_2MB,\n        from_bytes_10MB,\n);\ncriterion_main!(bench_update_message);\n"
  },
  {
    "path": "rust/benches/models/from_primitives.rs",
    "content": "use std::time::Duration;\n\nuse criterion::{black_box, criterion_group, criterion_main, Criterion};\nuse paste::paste;\n\nuse xaynet_core::mask::{FromPrimitives, Model};\n\nfn make_vector(bytes_size: usize) -> Vec<i32> {\n    // 1 i32 -> 4 bytes\n    assert_eq!(bytes_size % 4, 0);\n    let n_elements = bytes_size / 4;\n    vec![0_i32; n_elements]\n}\n\nmacro_rules! fn_from_primitives {\n    ($name: ident, $size: expr) => {\n        paste! {\n            #[allow(non_snake_case)]\n            fn [<from_primitives $name>](crit: &mut Criterion) {\n                let vector = make_vector($size);\n                let name = &stringify!($name)[1..];\n\n                let iter = vector.into_iter();\n                crit.bench_function(\n                    format!(\"convert {} model from primitive vector\", name).as_str(),\n                    |bench| {\n                        bench.iter(|| Model::from_primitives(black_box(iter.clone())))\n                    },\n                );\n            }\n        }\n    };\n}\n\n// 4 bytes\nfn_from_primitives!(_tiny, 4);\n\n// 100kB = 102_400 bytes\nfn_from_primitives!(_100kB, 102_400);\n\n// 1MB = 1_024_000 bytes\nfn_from_primitives!(_1MB, 1_024_000);\n\ncriterion_group!(\n    name = bench_model_from_primitives;\n    config = Criterion::default().sample_size(1000).measurement_time(Duration::new(10, 0));\n    targets =\n        from_primitives_tiny,\n        from_primitives_100kB,\n        from_primitives_1MB,\n);\ncriterion_main!(bench_model_from_primitives);\n"
  },
  {
    "path": "rust/benches/models/to_primitives.rs",
    "content": "use std::{iter, time::Duration};\n\nuse criterion::{black_box, criterion_group, criterion_main, Criterion};\nuse num::{bigint::BigInt, rational::Ratio};\nuse paste::paste;\n\nuse xaynet_core::mask::{IntoPrimitives, Model};\n\nfn make_model(bytes_size: usize) -> Model {\n    // 1 i32 -> 4 bytes\n    assert_eq!(bytes_size % 4, 0);\n    let n_elements = bytes_size / 4;\n    iter::repeat(Ratio::from(BigInt::from(0)))\n        .take(n_elements)\n        .collect()\n}\n\nmacro_rules! fn_to_primitives {\n    ($name: ident, $size: expr) => {\n        paste! {\n            #[allow(non_snake_case)]\n            fn [<to_primitives $name>](crit: &mut Criterion) {\n                let model = make_model($size);\n                let name = &stringify!($name)[1..];\n\n                crit.bench_function(\n                    format!(\"convert {} model to primitive vector\", name).as_str(),\n                    |bench| {\n                        bench.iter(|| black_box(&model).to_primitives().collect::<Result<Vec<i32>, _>>())\n                    }\n                );\n            }\n        }\n    };\n}\n\n// 4 bytes\nfn_to_primitives!(_tiny, 4);\n\n// 100kB = 102_400 bytes\nfn_to_primitives!(_100kB, 102_400);\n\n// 1MB = 1_024_000 bytes\nfn_to_primitives!(_1MB, 1_024_000);\n\ncriterion_group!(\n    name = bench_model_to_primitives;\n    config = Criterion::default().sample_size(1000).measurement_time(Duration::new(10, 0));\n    targets =\n        to_primitives_tiny,\n        to_primitives_100kB,\n        to_primitives_1MB,\n);\ncriterion_main!(bench_model_to_primitives);\n"
  },
  {
    "path": "rust/examples/Cargo.toml",
    "content": "[package]\nname = \"examples\"\nversion = \"0.0.0\"\nauthors = [\"Xayn Engineering <engineering@xaynet.dev>\"]\nedition = \"2018\"\ndescription = \"The Xayn Network project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant.\"\nreadme = \"../../README.md\"\nhomepage = \"https://xaynet.dev/\"\nrepository = \"https://github.com/xaynetwork/xaynet/\"\nlicense-file = \"../../LICENSE\"\nkeywords = [\"federated-learning\", \"fl\", \"ai\", \"machine-learning\"]\ncategories = [\"science\", \"cryptography\"]\npublish = false\n\n# https://github.com/http-rs/tide/issues/225\n# https://github.com/dependabot/dependabot-core/issues/1156\nautobins = false\n\n[dev-dependencies]\nasync-trait = \"0.1.57\"\nreqwest = { version = \"0.11.10\", default-features = false, features = [\"rustls-tls\"] }\nstructopt = \"0.3.26\"\ntokio = { version = \"1.20.1\", features = [\"sync\", \"time\", \"macros\", \"rt-multi-thread\", \"signal\"] }\ntracing = \"0.1.36\"\ntracing-futures = \"0.2.5\"\ntracing-subscriber = { version = \"0.3.15\", features = [\"env-filter\"] }\nxaynet-core = { path = \"../xaynet-core\" }\nxaynet-sdk = { path = \"../xaynet-sdk\", features = [\"reqwest-client\"] }\n\n[[example]]\nname = \"test-drive\"\npath = \"test-drive/main.rs\"\n"
  },
  {
    "path": "rust/examples/test-drive/main.rs",
    "content": "use std::{fs::File, io::Read, sync::Arc, time::Duration};\n\nuse structopt::StructOpt;\nuse tracing::error_span;\nuse tracing_futures::Instrument;\nuse tracing_subscriber::{EnvFilter, FmtSubscriber};\n\nuse xaynet_core::{\n    crypto::SigningKeyPair,\n    mask::{FromPrimitives, Model},\n};\nuse xaynet_sdk::{\n    client::{Client, ClientError},\n    settings::PetSettings,\n};\n\nmod participant;\nmod settings;\n\n#[tokio::main]\nasync fn main() -> Result<(), ClientError> {\n    let _fmt_subscriber = FmtSubscriber::builder()\n        .with_env_filter(EnvFilter::from_default_env())\n        .with_ansi(true)\n        .init();\n\n    let opt = settings::Opt::from_args();\n\n    // dummy local model for clients\n    let len = opt.len as usize;\n    let model = Arc::new(Model::from_primitives(vec![0; len].into_iter()).unwrap());\n\n    for id in 0..opt.nb_client {\n        spawn_participant(id as u32, &opt, model.clone())?;\n    }\n\n    tokio::signal::ctrl_c().await.unwrap();\n    Ok(())\n}\n\nfn generate_agent_config() -> PetSettings {\n    let keys = SigningKeyPair::generate();\n    PetSettings::new(keys)\n}\n\nfn build_http_client(settings: &settings::Opt) -> reqwest::Client {\n    let builder = reqwest::ClientBuilder::new();\n\n    let builder = if let Some(ref path) = settings.certificate {\n        let mut buf = Vec::new();\n        File::open(path).unwrap().read_to_end(&mut buf).unwrap();\n        let root_cert = reqwest::Certificate::from_pem(&buf).unwrap();\n        builder.use_rustls_tls().add_root_certificate(root_cert)\n    } else {\n        builder\n    };\n\n    let builder = if let Some(ref path) = settings.identity {\n        let mut buf = Vec::new();\n        File::open(path).unwrap().read_to_end(&mut buf).unwrap();\n        let identity = reqwest::Identity::from_pem(&buf).unwrap();\n        builder.use_rustls_tls().identity(identity)\n    } else {\n        builder\n    };\n\n    builder.build().unwrap()\n}\n\nfn spawn_participant(\n    id: u32,\n    settings: &settings::Opt,\n    model: Arc<Model>,\n) -> Result<(), ClientError> {\n    let config = generate_agent_config();\n    let http_client = build_http_client(settings);\n    let client = Client::new(http_client, &settings.url).unwrap();\n\n    let (participant, agent) = participant::Participant::new(config, client, model);\n    tokio::spawn(async move {\n        participant\n            .run()\n            .instrument(error_span!(\"participant\", id = id))\n            .await;\n    });\n    tokio::spawn(async move {\n        agent\n            .run(Duration::from_secs(1))\n            .instrument(error_span!(\"agent\", id = id))\n            .await;\n    });\n    Ok(())\n}\n"
  },
  {
    "path": "rust/examples/test-drive/participant.rs",
    "content": "use std::{sync::Arc, time::Duration};\n\nuse async_trait::async_trait;\nuse tokio::{sync::mpsc, time::sleep};\nuse tracing::{info, warn};\n\nuse xaynet_core::mask::Model;\nuse xaynet_sdk::{\n    client::Client,\n    settings::PetSettings,\n    ModelStore,\n    Notify,\n    StateMachine,\n    TransitionOutcome,\n    XaynetClient,\n};\n\nenum Event {\n    Update,\n    Sum,\n    NewRound,\n    Idle,\n}\n\npub struct Participant {\n    // FIXME: XaynetClient requires the client to be mutable. This may\n    // make it easier to implement clients, but as a result we can't\n    // wrap the client in an Arc, which would allow us to share the\n    // same client with all the participants. Maybe XaynetClient\n    // should have methods that take &self?\n    xaynet_client: Client<reqwest::Client>,\n    notifications: mpsc::Receiver<Event>,\n}\n\npub struct Agent(StateMachine);\n\nimpl Agent {\n    fn new<X, M, N>(settings: PetSettings, xaynet_client: X, model_store: M, notify: N) -> Self\n    where\n        X: XaynetClient + Send + 'static,\n        M: ModelStore + Send + 'static,\n        N: Notify + Send + 'static,\n    {\n        Agent(StateMachine::new(\n            settings,\n            xaynet_client,\n            model_store,\n            notify,\n        ))\n    }\n\n    pub async fn run(mut self, tick: Duration) {\n        loop {\n            self = match self.0.transition().await {\n                TransitionOutcome::Pending(state_machine) => {\n                    sleep(tick).await;\n                    Self(state_machine)\n                }\n                TransitionOutcome::Complete(state_machine) => Self(state_machine),\n            };\n        }\n    }\n}\n\nimpl Participant {\n    pub fn new(\n        settings: PetSettings,\n        xaynet_client: Client<reqwest::Client>,\n        model: Arc<Model>,\n    ) -> (Self, Agent) {\n        let (tx, rx) = mpsc::channel::<Event>(10);\n        let notifier = Notifier(tx);\n        let agent = Agent::new(settings, xaynet_client.clone(), LocalModel(model), notifier);\n        let participant = Self {\n            xaynet_client,\n            notifications: rx,\n        };\n        (participant, agent)\n    }\n\n    pub async fn run(mut self) {\n        use Event::*;\n        loop {\n            match self.notifications.recv().await {\n                Some(Sum) => {\n                    info!(\"taking part in the sum task\");\n                }\n                Some(Update) => {\n                    info!(\"taking part to the update task\");\n                }\n                Some(Idle) => {\n                    info!(\"waiting\");\n                }\n                Some(NewRound) => {\n                    info!(\"new round started, downloading latest global model\");\n                    if let Err(e) = self.xaynet_client.get_model().await {\n                        warn!(\"failed to download latest model: {}\", e);\n                    }\n                }\n                None => {\n                    warn!(\"notifications channel closed, terminating\");\n                    return;\n                }\n            }\n        }\n    }\n}\n\nstruct Notifier(mpsc::Sender<Event>);\n\nimpl Notify for Notifier {\n    fn new_round(&mut self) {\n        if let Err(e) = self.0.try_send(Event::NewRound) {\n            warn!(\"failed to notify participant: {}\", e);\n        }\n    }\n\n    fn sum(&mut self) {\n        if let Err(e) = self.0.try_send(Event::Sum) {\n            warn!(\"failed to notify participant: {}\", e);\n        }\n    }\n\n    fn update(&mut self) {\n        if let Err(e) = self.0.try_send(Event::Update) {\n            warn!(\"failed to notify participant: {}\", e);\n        }\n    }\n\n    fn idle(&mut self) {\n        if let Err(e) = self.0.try_send(Event::Idle) {\n            warn!(\"failed to notify participant: {}\", e);\n        }\n    }\n}\n\npub struct LocalModel(Arc<Model>);\n\n#[async_trait]\nimpl ModelStore for LocalModel {\n    type Model = Arc<Model>;\n    type Error = std::convert::Infallible;\n\n    async fn load_model(&mut self) -> Result<Option<Self::Model>, Self::Error> {\n        Ok(Some(self.0.clone()))\n    }\n}\n"
  },
  {
    "path": "rust/examples/test-drive/settings.rs",
    "content": "use std::path::PathBuf;\n\nuse structopt::StructOpt;\n\n#[derive(Debug, StructOpt)]\n#[structopt(name = \"Test Drive\")]\npub struct Opt {\n    #[structopt(\n        default_value = \"http://127.0.0.1:8081\",\n        short,\n        help = \"The URL of the coordinator\"\n    )]\n    pub url: String,\n\n    #[structopt(default_value = \"4\", short, help = \"The length of the model\")]\n    pub len: u32,\n\n    #[structopt(\n        default_value = \"1\",\n        short,\n        help = \"The time period at which to poll for service data, in seconds\"\n    )]\n    pub period: u64,\n\n    #[structopt(default_value = \"10\", short, help = \"The number of clients\")]\n    pub nb_client: u32,\n\n    #[structopt(\n        short,\n        long,\n        parse(from_os_str),\n        help = \"Trusted DER/PEM encoded TLS server certificate\"\n    )]\n    pub certificate: Option<PathBuf>,\n\n    #[structopt(\n        short,\n        long,\n        parse(from_os_str),\n        help = \"The PEM encoded TLS client identity\"\n    )]\n    pub identity: Option<PathBuf>,\n}\n"
  },
  {
    "path": "rust/rustfmt.toml",
    "content": "# requires nightly rustfmt until the options are stabilized\nformat_code_in_doc_comments = true\nimports_granularity = \"Crate\"\nimports_layout = \"HorizontalVertical\"\n"
  },
  {
    "path": "rust/xaynet/Cargo.toml",
    "content": "[package]\nname = \"xaynet\"\nversion = \"0.11.0\"\nauthors = [\"Xayn Engineering <engineering@xaynet.dev>\"]\nedition = \"2018\"\ndescription = \"The Xayn Network project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant.\"\nreadme = \"../../README.md\"\nhomepage = \"https://xaynet.dev/\"\nrepository = \"https://github.com/xaynetwork/xaynet/\"\nlicense-file = \"../../LICENSE\"\nkeywords = [\"federated-learning\", \"fl\", \"ai\", \"machine-learning\"]\ncategories = [\"science\", \"cryptography\"]\n\n[package.metadata.docs.rs]\nall-features = true\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n\n[badges]\ncodecov = { repository = \"xaynetwork/xaynet\", branch = \"master\", service = \"github\" }\nmaintenance = { status = \"actively-developed\" }\n\n[dependencies]\nxaynet-core = { path = \"../xaynet-core\", version = \"0.2.0\" }\n\n# feature: mobile\nxaynet-mobile = { path = \"../xaynet-mobile\", version = \"0.1.0\", optional = true }\n\n# feature: sdk\nxaynet-sdk = { path = \"../xaynet-sdk\", version = \"0.1.0\", optional = true }\n\n# feature: server\nxaynet-server = { path = \"../xaynet-server\", version = \"0.2.0\", optional = true }\n\n[features]\ndefault = []\nfull = [\"mobile\", \"sdk\", \"server\"]\nmobile = [\"xaynet-mobile\"]\nsdk = [\"xaynet-sdk\"]\nserver = [\"xaynet-server\"]\n"
  },
  {
    "path": "rust/xaynet/src/lib.rs",
    "content": "#![cfg_attr(docsrs, feature(doc_cfg))]\n#![cfg_attr(\n    doc,\n    forbid(rustdoc::broken_intra_doc_links, rustdoc::private_intra_doc_links)\n)]\n#![doc(\n    html_logo_url = \"https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/xaynet_banner.png\",\n    html_favicon_url = \"https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/favicon.png\",\n    issue_tracker_base_url = \"https://github.com/xaynetwork/xaynet/issues\"\n)]\n//! # Xaynet: Train on the Edge with Federated Learning\n//!\n//! Want a framework that supports federated learning on the edge, in\n//! desktop browsers, integrates well with mobile apps, is performant, and\n//! preserves privacy? Welcome to XayNet, written entirely in Rust!\n//!\n//! ## Making federated learning easy for developers\n//!\n//! Frameworks for machine learning - including those expressly for\n//! federated learning - exist already. These frameworks typically\n//! facilitate federated learning of cross-silo use cases - for example in\n//! collaborative learning across a limited number of hospitals or for\n//! instance across multiple banks working on a common use case without\n//! the need to share valuable and sensitive data.\n//!\n//! This repository focusses on masked cross-device federated learning to\n//! enable the orchestration of machine learning in millions of low-power\n//! edge devices, such as smartphones or even cars. By doing this, we hope\n//! to also increase the pace and scope of adoption of federated learning\n//! in practice and especially allow the protection of end user data. All\n//! data remains in private local premises, whereby only encrypted AI\n//! models get automatically and asynchronously aggregated. Thus, we\n//! provide a solution to the AI privacy dilemma and bridge the\n//! often-existing gap between privacy and convenience. Imagine, for\n//! example, a voice assistant to learn new words directly on device level\n//! and sharing this knowledge with all other instances, without recording\n//! and collecting your voice input centrally. Or, think about search\n//! engine that learns to personalise search results without collecting\n//! your often sensitive search queries centrally… There are thousands of\n//! such use cases that right today still trade privacy for\n//! convenience. We think this shouldn’t be the case and we want to\n//! provide an alternative to overcome this dilemma.\n//!\n//! Concretely, we provide developers with:\n//!\n//! - **App dev tools**: An SDK to integrate federated learning into\n//!   apps written in Dart or other languages of choice for mobile development,\n//!   as well as frameworks like Flutter.\n//! - **Privacy via cross-device federated learning**: Train your AI\n//!   models locally on edge devices such as mobile phones, browsers,\n//!   or even in cars. Federated learning automatically aggregates the\n//!   local models into a global model. Thus, all insights inherent in\n//!   the local models are captured, while the user data stays\n//!   private on end devices.\n//! - **Security Privacy via homomorphic encryption**: Aggregate\n//!   models with the highest security and trust. Xayn’s masking\n//!   protocol encrypts all models homomorphically. This enables you\n//!   to aggregate encrypted local models into a global one – without\n//!   having to decrypt local models at all. This protects private and\n//!   even the most sensitive data.\n//!\n//! ## The case for writing this framework in Rust\n//!\n//! Our framework for federated learning is not only a framework for\n//! machine learning as such. Rather, it supports the federation of\n//! machine learning that takes place on possibly heterogeneous devices\n//! and where use cases involve many such devices.\n//!\n//! The programming language in which this framework is written should\n//! therefore give us strong support for the following:\n//!\n//! - **Runs \"everywhere\"**: the language should not require its own\n//!   runtime and code should compile on a wide range of devices.\n//! - **Memory and concurrency safety**: code that compiles should be both\n//!   memory safe and free of data races.\n//! - **Secure communication**: state of the art cryptography should be\n//!   available in vetted implementations.\n//! - **Asynchronous communication**: abstractions for asynchronous\n//!   communication should exist that make federated learning scale.\n//! - **Fast and functional**: the language should offer functional\n//!   abstractions but also compile code into fast executables.\n//!\n//! Rust is one of the very few choices of modern programming languages\n//! that meets these requirements:\n//!\n//! - its concepts of Ownership and Borrowing make it both memory and\n//!   thread-safe (hence avoiding many common concurrency issues).\n//! - it has a strong and static type discipline and traits, which\n//!   describe shareable functionality of a type.\n//! - it is a modern systems programming language, with some functional\n//!   style features such as pattern matching, closures and iterators.\n//! - its idiomatic code compares favourably to idiomatic C in performance.\n//! - it compiles to WASM and can therefore be applied natively in browser\n//!   settings.\n//! - it is widely deployable and doesn't necessarily depend on a runtime,\n//!   unlike languages such as Java and their need for a virtual machine\n//!   to run its code. Foreign Function Interfaces support calls from\n//!   other languages/frameworks, including Dart, Python and Flutter.\n//! - it compiles into LLVM, and so it can draw from the abundant tool\n//!   suites for LLVM.\npub use xaynet_core as core;\n\n#[cfg(feature = \"mobile\")]\n#[cfg_attr(docsrs, doc(cfg(feature = \"mobile\")))]\npub use xaynet_mobile as mobile;\n\n#[cfg(feature = \"sdk\")]\n#[cfg_attr(docsrs, doc(cfg(feature = \"sdk\")))]\npub use xaynet_sdk as sdk;\n\n#[cfg(feature = \"server\")]\n#[cfg_attr(docsrs, doc(cfg(feature = \"server\")))]\npub use xaynet_server as server;\n"
  },
  {
    "path": "rust/xaynet-analytics/Cargo.toml",
    "content": "[package]\nname = \"xaynet-analytics\"\nversion = \"0.1.0\"\nauthors = [\"Xayn Engineering <engineering@xaynet.dev>\"]\nedition = \"2018\"\ndescription = \"The Xayn Network project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant.\"\nreadme = \"../../README.md\"\nhomepage = \"https://xaynet.dev/\"\nrepository = \"https://github.com/xaynetwork/xaynet/\"\nlicense-file = \"../../LICENSE\"\nkeywords = [\"federated-learning\", \"fl\", \"ai\", \"machine-learning\"]\ncategories = [\"science\", \"cryptography\"]\npublish = false\n\n[package.metadata.docs.rs]\nall-features = true\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n\n[dependencies]\nanyhow = \"1.0.59\"\nchrono = \"0.4.19\"\nisar-core = { git = \"https://github.com/isar/isar-core\", rev = \"59d9008be33343d1fd313c659e50e2835365a19d\" }\n"
  },
  {
    "path": "rust/xaynet-analytics/src/controller.rs",
    "content": "//! In this file the `AnalyticsController` is defined.\n\nuse anyhow::{anyhow, Error, Result};\nuse chrono::{DateTime, Datelike, Duration, NaiveDate, Utc};\n\nuse crate::{\n    data_combination::data_combiner::DataCombiner,\n    database::{\n        analytics_event::{\n            adapter::AnalyticsEventAdapter,\n            data_model::{AnalyticsEvent, AnalyticsEventType},\n        },\n        common::{CollectionNames, Repo, SchemaGenerator},\n        controller_data::{adapter::ControllerDataAdapter, data_model::ControllerData},\n        isar::IsarDb,\n        screen_route::{adapter::ScreenRouteAdapter, data_model::ScreenRoute},\n    },\n    sender::Sender,\n};\n\n/// The `AnalyticsController` is the core component of the library. It exposes public functions to the FFI wrapper, and it’s responsible for:\n/// - Instantiating the other necessary components (`DataCombiner`, `Sender` and `IsarDb`)\n/// - Receiving incoming data recorded by the mobile framework (via FFI of course) and saving them to the db via `IsarDb`.\n/// - Checking if the library needs to send data to the XayNet coordinator via `Sender`.\n/// - Holding some simple state (`self.is_charging`, `self.is_connected_to_wifi`) so that it knows whether it’s appropriate to send data to XayNet.\n///\n/// ## Arguments\n///\n/// * `db` - Singleton instance of `IsarDb`, used to operate with the database.\n/// * `is_charging` - Boolean flag representing whether the phone is currently charging or not.\n/// * `is_connected_to_wifi` - Boolean flag representing whether the phone is currently connected to the wifi or not.\n/// * `last_time_data_sent` - Timestamp representing when analytics data was last sent to the coordinator. If `None`, data was never sent before.\n/// * `combiner` - `DataCombiner` component responsible for calculating `DataPoints` based on `AnalyticsEvents` and `ScreenRoutes`.\n/// * `sender` - `Sender` component responsible for preparing the message to be sent to the coordinator for aggregation.\n/// * `send_frequency_hours` - `Duration` in hours representing periods within which we want to send data to the coordinator only once.\nstruct AnalyticsController {\n    db: IsarDb,\n    is_charging: bool,\n    is_connected_to_wifi: bool,\n    last_time_data_sent: Option<DateTime<Utc>>,\n    combiner: DataCombiner,\n    sender: Sender,\n    send_frequency_hours: Duration,\n}\n\n// TODO: remove allow dead code when AnalyticsController is integrated with FFI layer: https://xainag.atlassian.net/browse/XN-1415\n#[allow(dead_code)]\nimpl AnalyticsController {\n    const MAX_SEND_FREQUENCY_HOURS: u8 = 24;\n\n    pub fn init(\n        path: String,\n        is_charging: bool,\n        is_connected_to_wifi: bool,\n        input_send_frequency_hours: Option<u8>,\n    ) -> Result<Self, Error> {\n        let schemas = vec![\n            AnalyticsEventAdapter::get_schema(&CollectionNames::ANALYTICS_EVENTS)?,\n            ControllerDataAdapter::get_schema(&CollectionNames::CONTROLLER_DATA)?,\n            ScreenRouteAdapter::get_schema(&CollectionNames::SCREEN_ROUTES)?,\n        ];\n        let db = IsarDb::new(&path, schemas)?;\n        let last_time_data_sent = Self::get_last_time_data_sent(&db)?;\n        let send_frequency_hours = Self::validate_send_frequency(input_send_frequency_hours)?;\n\n        Ok(AnalyticsController {\n            db,\n            is_charging,\n            is_connected_to_wifi,\n            last_time_data_sent,\n            combiner: DataCombiner,\n            sender: Sender,\n            send_frequency_hours,\n        })\n    }\n\n    pub fn dispose(self) -> Result<(), Error> {\n        self.db.dispose()\n    }\n\n    pub fn save_analytics_event(\n        &self,\n        name: &str,\n        event_type: AnalyticsEventType,\n        timestamp: DateTime<Utc>,\n        option_screen_route_name: Option<&str>,\n    ) -> Result<(), Error> {\n        let option_screen_route = option_screen_route_name\n            .map(|screen_route_name| self.add_screen_route_if_new(screen_route_name, timestamp))\n            .transpose()?;\n        let event = AnalyticsEvent::new(name, event_type, timestamp, option_screen_route);\n        event.save(&self.db, &CollectionNames::ANALYTICS_EVENTS)?;\n        Ok(())\n    }\n\n    pub fn change_connectivity_status(&mut self) {\n        self.is_connected_to_wifi = !self.is_connected_to_wifi;\n    }\n\n    pub fn change_state_of_charge(&mut self) {\n        self.is_charging = !self.is_charging;\n    }\n\n    pub fn maybe_send_data(&mut self) -> Result<(), Error> {\n        if self.should_send_data() {\n            self.send_data()\n        } else {\n            Ok(())\n        }\n    }\n\n    #[cfg(test)]\n    fn db(&self) -> &IsarDb {\n        &self.db\n    }\n\n    /// Check whether `input_send_frequency_hours` is at most `MAX_SEND_FREQUENCY_HOURS`, otherwise return an `Error`.\n    /// If it's lower, return a `Duration`.\n    /// If it's `None`, assign `Self::MAX_SEND_FREQUENCY_HOURS` and turn it into a `Duration` as well.\n    fn validate_send_frequency(input_send_frequency_hours: Option<u8>) -> Result<Duration, Error> {\n        let send_frequency_hours =\n            input_send_frequency_hours.unwrap_or(Self::MAX_SEND_FREQUENCY_HOURS);\n        if send_frequency_hours > Self::MAX_SEND_FREQUENCY_HOURS {\n            Err(anyhow!(\n                \"input_send_frequency_hours must be between 0 and {}\",\n                Self::MAX_SEND_FREQUENCY_HOURS\n            ))\n        } else {\n            Ok(Duration::hours(send_frequency_hours as i64))\n        }\n    }\n\n    fn should_send_data(&self) -> bool {\n        let can_send_data = self.is_charging && self.is_connected_to_wifi;\n        can_send_data && !self.did_send_already_in_this_period()\n    }\n\n    /// Check whether the new incoming `screen_route_name` already exists in the `ScreenRoutes` saved to the db.\n    /// If it exists, return the existing `ScreenRoute` object from the db.\n    /// If it doesn't exist, create the new `ScreenRoute` object, save it to db, and return a clone of it.\n    fn add_screen_route_if_new(\n        &self,\n        screen_route_name: &str,\n        timestamp: DateTime<Utc>,\n    ) -> Result<ScreenRoute, Error> {\n        let existing_screen_routes =\n            ScreenRoute::get_all(&self.db, &CollectionNames::SCREEN_ROUTES)?;\n        if let Some(existing_screen_route) = existing_screen_routes\n            .into_iter()\n            .find(|existing_route| existing_route.name == screen_route_name)\n        {\n            Ok(existing_screen_route)\n        } else {\n            let screen_route = ScreenRoute::new(screen_route_name, timestamp);\n            screen_route\n                .clone()\n                .save(&self.db, &CollectionNames::SCREEN_ROUTES)?;\n            Ok(screen_route)\n        }\n    }\n\n    fn get_last_time_data_sent(db: &IsarDb) -> Result<Option<DateTime<Utc>>, Error> {\n        Ok(\n            ControllerData::get_all(db, &CollectionNames::CONTROLLER_DATA)?\n                .last()\n                .map(|data| data.time_data_sent),\n        )\n    }\n\n    /// This method implements a sliding 'time window' of `self.send_frequency_hours` duration, to check whether we have\n    /// already sent data in the current window, or not.\n    ///\n    /// An alternative implementation could be based on simply checking whether:\n    /// `last_time_data_sent > Utc::now() - self.send_frequency_hours`\n    ///\n    /// In the current implementation, it might be easier to then group the aggregated data on the coordinator side,\n    /// to then be displayed in the UI, especially if `self.send_frequency_hours == Duration::hours(24)`.\n    ///\n    /// The more dynamic approach however implies that if, for example, `self.send_frequency_hours == Duration::hours(6)`,\n    /// and the last time we sent the data was at 5AM, we would be able to send again at 7AM, while with the simpler solution\n    /// we wouldn't be able to send again until 11AM.\n    ///\n    /// The correct approach to be chosen should very much depend on the amount of data available for aggregation,\n    /// and it's possible that `MAX_SEND_FREQUENCY_HOURS` should be increased to more than 24.\n    /// In that case, this function below will need to be reworked, because it' coupled with `MAX_SEND_FREQUENCY_HOURS` being 24.\n    ///\n    /// Only once it's more clear how the aggregation will work on the coordinator side, there will be more information\n    /// to decide the approach here.\n    fn did_send_already_in_this_period(&self) -> bool {\n        self.last_time_data_sent\n            .map(|last_time_data_sent| {\n                let now = Utc::now();\n                let start_of_day: DateTime<Utc> = DateTime::from_utc(\n                    NaiveDate::from_ymd(now.year(), now.month(), now.day()).and_hms(0, 0, 0),\n                    Utc,\n                );\n                let mut end_of_current_period = start_of_day + self.send_frequency_hours;\n                while now > end_of_current_period {\n                    end_of_current_period = end_of_current_period + self.send_frequency_hours;\n                }\n                let start_of_current_period = end_of_current_period - self.send_frequency_hours;\n                last_time_data_sent > start_of_current_period\n            })\n            .unwrap_or(false)\n    }\n\n    /// Retrive all `AnalyticsEvents` and `ScreenRoutes` from the db and pass them to the `DataCombiner`.\n    /// The `DataCombiner` will init all `DataPoints` and pack them in a `Vec<DataPoint>`, which will be the input to the `Sender`.\n    /// After that, save the new time_data_sent inside `ControllerData`, and cache it in `self.last_time_data_sent`\n    fn send_data(&mut self) -> Result<(), Error> {\n        let events = AnalyticsEvent::get_all(&self.db, &CollectionNames::ANALYTICS_EVENTS)?;\n        let screen_routes = ScreenRoute::get_all(&self.db, &CollectionNames::SCREEN_ROUTES)?;\n        let time_data_sent = Utc::now();\n        self.sender\n            .send(self.combiner.init_data_points(&events, &screen_routes)?)\n            .and_then(|_| {\n                ControllerData::new(time_data_sent)\n                    .save(&self.db, &CollectionNames::CONTROLLER_DATA)\n            })\n            .map(|_| self.last_time_data_sent = Some(time_data_sent))\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    use std::{env, fs, path::PathBuf};\n\n    fn get_path(test_name: &str) -> PathBuf {\n        let temp_dir = env::temp_dir();\n        temp_dir.join(test_name)\n    }\n\n    fn get_controller(\n        test_name: &str,\n        input_send_data_frequency: Option<u8>,\n    ) -> AnalyticsController {\n        let path_buf = get_path(test_name);\n        let path = path_buf.to_str().unwrap().to_string();\n        if !path_buf.exists() {\n            fs::create_dir(path.clone()).unwrap();\n        }\n        AnalyticsController::init(path, true, true, input_send_data_frequency).unwrap()\n    }\n\n    fn remove_dir(test_name: &str) {\n        let path = get_path(test_name);\n        std::fs::remove_dir_all(path).unwrap();\n    }\n\n    fn cleanup(controller: AnalyticsController, test_name: &str) {\n        remove_dir(test_name);\n        controller.dispose().unwrap();\n    }\n\n    #[test]\n    fn test_dispose() {\n        let test_name = \"test_dispose\";\n        let controller = get_controller(test_name, None);\n        assert!(controller.dispose().is_ok());\n        remove_dir(test_name);\n    }\n\n    #[test]\n    fn test_save_analytics_event_no_screen_route() {\n        let test_name = \"test_save_analytics_event_no_screen_route\";\n        let controller = get_controller(test_name, None);\n        let name = \"test\";\n        let event_type = AnalyticsEventType::AppEvent;\n        let timestamp = DateTime::parse_from_rfc3339(\"2021-01-01T01:01:00+00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let existing_analytics_events =\n            AnalyticsEvent::get_all(controller.db(), CollectionNames::ANALYTICS_EVENTS).unwrap();\n        assert!(existing_analytics_events.is_empty());\n        assert!(controller\n            .save_analytics_event(name, event_type, timestamp, None)\n            .is_ok());\n\n        let analytics_event = AnalyticsEvent::new(name, event_type, timestamp, None);\n        let all_analytics_events =\n            AnalyticsEvent::get_all(controller.db(), CollectionNames::ANALYTICS_EVENTS).unwrap();\n        assert_eq!(all_analytics_events.len(), 1);\n        assert_eq!(all_analytics_events.first(), Some(&analytics_event));\n\n        cleanup(controller, test_name);\n    }\n\n    #[test]\n    fn test_save_analytics_event_with_screen_route() {\n        let test_name = \"test_save_analytics_event_with_screen_route\";\n        let controller = get_controller(test_name, None);\n        let name = \"test\";\n        let event_type = AnalyticsEventType::ScreenEnter;\n        let timestamp = DateTime::parse_from_rfc3339(\"2021-01-01T01:01:00+00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let screen_route_name = \"route\";\n        assert!(controller\n            .save_analytics_event(name, event_type, timestamp, Some(screen_route_name))\n            .is_ok());\n\n        let screen_route = ScreenRoute::new(screen_route_name, timestamp);\n        let analytics_event = AnalyticsEvent::new(name, event_type, timestamp, Some(screen_route));\n        let all_analytics_events =\n            AnalyticsEvent::get_all(controller.db(), CollectionNames::ANALYTICS_EVENTS).unwrap();\n        assert_eq!(all_analytics_events.len(), 1);\n        assert_eq!(all_analytics_events.first(), Some(&analytics_event));\n\n        cleanup(controller, test_name);\n    }\n\n    #[test]\n    fn test_change_connectivity_status() {\n        let test_name = \"test_change_connectivity_status\";\n        let mut controller = get_controller(test_name, None);\n        assert!(controller.is_connected_to_wifi);\n        controller.change_connectivity_status();\n        assert!(!controller.is_connected_to_wifi);\n\n        cleanup(controller, test_name);\n    }\n\n    #[test]\n    fn test_change_state_of_charge() {\n        let test_name = \"test_change_state_of_charge\";\n        let mut controller = get_controller(test_name, None);\n        assert!(controller.is_charging);\n        controller.change_state_of_charge();\n        assert!(!controller.is_charging);\n\n        cleanup(controller, test_name);\n    }\n\n    #[test]\n    fn test_validate_send_data_frequency_when_none() {\n        assert_eq!(\n            AnalyticsController::validate_send_frequency(None).unwrap(),\n            Duration::hours(AnalyticsController::MAX_SEND_FREQUENCY_HOURS as i64)\n        )\n    }\n\n    #[test]\n    fn test_validate_send_data_frequency_when_more_than_24() {\n        assert!(AnalyticsController::validate_send_frequency(Some(25)).is_err());\n    }\n\n    #[test]\n    fn test_validate_send_data_frequency_when_less_than_24() {\n        assert_eq!(\n            AnalyticsController::validate_send_frequency(Some(6)).unwrap(),\n            Duration::hours(6)\n        )\n    }\n\n    #[test]\n    fn test_validate_send_data_frequency_when_0() {\n        assert_eq!(\n            AnalyticsController::validate_send_frequency(Some(0)).unwrap(),\n            Duration::hours(0)\n        )\n    }\n\n    #[test]\n    fn test_add_screen_route_if_new_with_new_route() {\n        let test_name = \"test_add_screen_route_if_new_with_new_route\";\n        let controller = get_controller(test_name, None);\n        let screen_route_name = \"route\";\n        let timestamp = DateTime::parse_from_rfc3339(\"2021-01-01T01:01:00+00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let screen_route = ScreenRoute::new(screen_route_name, timestamp);\n        let existing_screen_routes =\n            ScreenRoute::get_all(controller.db(), CollectionNames::SCREEN_ROUTES).unwrap();\n        assert!(existing_screen_routes.is_empty());\n        assert_eq!(\n            controller\n                .add_screen_route_if_new(screen_route_name, timestamp)\n                .unwrap(),\n            screen_route\n        );\n\n        let retrieved_screen_routes =\n            ScreenRoute::get_all(controller.db(), CollectionNames::SCREEN_ROUTES).unwrap();\n        assert_eq!(retrieved_screen_routes.len(), 1);\n        assert_eq!(retrieved_screen_routes.first(), Some(&screen_route));\n\n        cleanup(controller, test_name);\n    }\n\n    #[test]\n    fn test_add_screen_route_if_new_without_new_route() {\n        let test_name = \"test_add_screen_route_if_new_without_new_route\";\n        let controller = get_controller(test_name, None);\n        let screen_route_name = \"route\";\n        let first_timestamp = DateTime::parse_from_rfc3339(\"2021-01-01T01:01:00+00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let first_screen_route = ScreenRoute::new(screen_route_name, first_timestamp);\n        assert!(controller\n            .add_screen_route_if_new(screen_route_name, first_timestamp)\n            .is_ok());\n\n        // if we call controller.add_screen_route_if_new() with the same screen_route_name, but a new_timestamp,\n        // we expect to get the first_screen_route back, with the first_timestamp\n        let new_timestamp = DateTime::parse_from_rfc3339(\"2021-02-02T02:02:00+00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        assert_eq!(\n            controller\n                .add_screen_route_if_new(screen_route_name, new_timestamp)\n                .unwrap(),\n            first_screen_route\n        );\n\n        let retrieved_screen_routes =\n            ScreenRoute::get_all(controller.db(), CollectionNames::SCREEN_ROUTES).unwrap();\n        assert_eq!(retrieved_screen_routes.len(), 1);\n        assert_eq!(retrieved_screen_routes.first(), Some(&first_screen_route));\n\n        cleanup(controller, test_name);\n    }\n\n    #[test]\n    fn test_get_last_time_data_sent() {\n        let test_name = \"test_get_last_time_data_sent_is_none\";\n        let controller = get_controller(test_name, None);\n\n        let last_time_data_sent = AnalyticsController::get_last_time_data_sent(controller.db());\n        assert!(last_time_data_sent.is_ok());\n        assert!(last_time_data_sent.unwrap().is_none());\n\n        let timestamp = DateTime::parse_from_rfc3339(\"2021-03-03T03:03:00+00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let controller_data = ControllerData::new(timestamp);\n        let existing_controller_data =\n            ControllerData::get_all(controller.db(), CollectionNames::CONTROLLER_DATA).unwrap();\n        assert!(existing_controller_data.is_empty());\n        assert!(controller_data\n            .save(controller.db(), CollectionNames::CONTROLLER_DATA)\n            .is_ok());\n\n        let last_time_data_sent = AnalyticsController::get_last_time_data_sent(controller.db());\n        assert!(last_time_data_sent.is_ok());\n        assert_eq!(last_time_data_sent.unwrap(), Some(timestamp));\n\n        cleanup(controller, test_name);\n    }\n\n    #[test]\n    fn test_did_send_already_in_this_period_never_sent_before() {\n        let test_name = \"test_did_send_already_in_this_period_never_sent_before\";\n        let controller = get_controller(test_name, Some(24));\n        assert!(!controller.did_send_already_in_this_period());\n\n        cleanup(controller, test_name);\n    }\n\n    #[test]\n    fn test_did_send_already_in_this_period_inside_24h() {\n        let test_name = \"test_did_send_already_in_this_period_inside_24h\";\n        let initial_controller = get_controller(test_name, Some(24));\n\n        let timestamp = Utc::now();\n        let controller_data = ControllerData::new(timestamp);\n        assert!(controller_data\n            .save(initial_controller.db(), CollectionNames::CONTROLLER_DATA)\n            .is_ok());\n\n        // init controller again, to read self.last_time_data_sent from db\n        assert!(initial_controller.dispose().is_ok());\n        let controller = get_controller(test_name, Some(24));\n        assert!(controller.did_send_already_in_this_period());\n\n        cleanup(controller, test_name);\n    }\n\n    #[test]\n    fn test_did_send_already_in_this_period_outside_24h() {\n        let test_name = \"test_did_send_already_in_this_period_outside_24h\";\n        let initial_controller = get_controller(test_name, Some(24));\n\n        let timestamp = Utc::now() - Duration::hours(25);\n        let controller_data = ControllerData::new(timestamp);\n        assert!(controller_data\n            .save(initial_controller.db(), CollectionNames::CONTROLLER_DATA)\n            .is_ok());\n\n        // init controller again, to read self.last_time_data_sent from db\n        assert!(initial_controller.dispose().is_ok());\n        let controller = get_controller(test_name, Some(24));\n        assert!(!controller.did_send_already_in_this_period());\n\n        cleanup(controller, test_name);\n    }\n\n    #[test]\n    fn test_did_send_already_in_this_period_inside_12h() {\n        let test_name = \"test_did_send_already_in_this_period_inside_12h\";\n        let initial_controller = get_controller(test_name, Some(12));\n\n        let timestamp = Utc::now();\n        let controller_data = ControllerData::new(timestamp);\n        assert!(controller_data\n            .save(initial_controller.db(), CollectionNames::CONTROLLER_DATA)\n            .is_ok());\n\n        // init controller again, to read self.last_time_data_sent from db\n        assert!(initial_controller.dispose().is_ok());\n        let controller = get_controller(test_name, Some(12));\n        assert!(controller.did_send_already_in_this_period());\n\n        cleanup(controller, test_name);\n    }\n\n    #[test]\n    fn test_did_send_already_in_this_period_outside_12h() {\n        let test_name = \"test_did_send_already_in_this_period_outside_12h\";\n        let initial_controller = get_controller(test_name, Some(12));\n\n        let timestamp = Utc::now() - Duration::hours(13);\n        let controller_data = ControllerData::new(timestamp);\n        assert!(controller_data\n            .save(initial_controller.db(), CollectionNames::CONTROLLER_DATA)\n            .is_ok());\n\n        // init controller again, to read self.last_time_data_sent from db\n        assert!(initial_controller.dispose().is_ok());\n        let controller = get_controller(test_name, Some(12));\n        assert!(!controller.did_send_already_in_this_period());\n\n        cleanup(controller, test_name);\n    }\n\n    #[test]\n    fn test_did_send_already_in_this_period_inside_6h() {\n        let test_name = \"test_did_send_already_in_this_period_inside_6h\";\n        let initial_controller = get_controller(test_name, Some(6));\n\n        let timestamp = Utc::now();\n        let controller_data = ControllerData::new(timestamp);\n        assert!(controller_data\n            .save(initial_controller.db(), CollectionNames::CONTROLLER_DATA)\n            .is_ok());\n\n        // init controller again, to read self.last_time_data_sent from db\n        assert!(initial_controller.dispose().is_ok());\n        let controller = get_controller(test_name, Some(6));\n        assert!(controller.did_send_already_in_this_period());\n\n        cleanup(controller, test_name);\n    }\n\n    #[test]\n    fn test_did_send_already_in_this_period_outside_6h() {\n        let test_name = \"test_did_send_already_in_this_period_outside_6h\";\n        let initial_controller = get_controller(test_name, Some(6));\n\n        let timestamp = Utc::now() - Duration::hours(7);\n        let controller_data = ControllerData::new(timestamp);\n        assert!(controller_data\n            .save(initial_controller.db(), CollectionNames::CONTROLLER_DATA)\n            .is_ok());\n\n        // init controller again, to read self.last_time_data_sent from db\n        assert!(initial_controller.dispose().is_ok());\n        let controller = get_controller(test_name, Some(6));\n        assert!(!controller.did_send_already_in_this_period());\n\n        cleanup(controller, test_name);\n    }\n\n    #[test]\n    fn test_did_send_already_in_this_period_outside_twice_6h() {\n        let test_name = \"test_did_send_already_in_this_period_outside_twice_6h\";\n        let initial_controller = get_controller(test_name, Some(6));\n\n        let timestamp = Utc::now() - Duration::hours(13);\n        let controller_data = ControllerData::new(timestamp);\n        assert!(controller_data\n            .save(initial_controller.db(), CollectionNames::CONTROLLER_DATA)\n            .is_ok());\n\n        // init controller again, to read self.last_time_data_sent from db\n        assert!(initial_controller.dispose().is_ok());\n        let controller = get_controller(test_name, Some(6));\n        assert!(!controller.did_send_already_in_this_period());\n\n        cleanup(controller, test_name);\n    }\n\n    #[test]\n    fn test_did_send_already_in_this_period_outside_thrice_6h() {\n        let test_name = \"test_did_send_already_in_this_period_outside_thrice_6h\";\n        let initial_controller = get_controller(test_name, Some(6));\n\n        let timestamp = Utc::now() - Duration::hours(19);\n        let controller_data = ControllerData::new(timestamp);\n        assert!(controller_data\n            .save(initial_controller.db(), CollectionNames::CONTROLLER_DATA)\n            .is_ok());\n\n        // init controller again, to read self.last_time_data_sent from db\n        assert!(initial_controller.dispose().is_ok());\n        let controller = get_controller(test_name, Some(6));\n        assert!(!controller.did_send_already_in_this_period());\n\n        cleanup(controller, test_name);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-analytics/src/data_combination/data_combiner.rs",
    "content": "//! Declaration and implementation of `DataCombiner`.\n\nuse anyhow::{Error, Result};\nuse chrono::{DateTime, Datelike, Duration, NaiveDate, Utc};\nuse std::iter::empty;\n\nuse crate::{\n    data_combination::data_points::data_point::{\n        CalcScreenActiveTime,\n        CalcScreenEnterCount,\n        CalcWasActiveEachPastPeriod,\n        CalcWasActivePastNDays,\n        DataPoint,\n        DataPointMetadata,\n        Period,\n        PeriodUnit,\n    },\n    database::{\n        analytics_event::data_model::AnalyticsEvent,\n        screen_route::data_model::ScreenRoute,\n    },\n};\n\n/// `DataCombiner` is responsible for instantiating the `DataPoint` variants. When it’s time to send the data to XayNet,\n/// the `AnalyticsEvents` and `ScreenRoutes` are retrieved from the db (by the `AnalyticsController`) and passed to the `DataCombiner`,\n/// which then instantiates the various `DataPoint` variants and packs them in a `Vec`, which will be utilised by the `Sender`.\n///\n/// Possible improvements include:\n/// - Move the `DataPointMetadatas` to a sort of config, and pass them to the `DataCombiner`.\n/// - Turn `DataCombiner` into a trait on each `DataPoint`.\n/// See: https://xainag.atlassian.net/browse/XN-1651\npub struct DataCombiner;\n\nimpl<'screen> DataCombiner {\n    pub fn init_data_points(\n        &self,\n        events: &[AnalyticsEvent],\n        screen_routes: &[ScreenRoute],\n    ) -> Result<Vec<DataPoint>, Error> {\n        let end_period = Utc::now();\n\n        let one_day_period_metadata =\n            DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period);\n        let was_active_each_period_metadatas = vec![\n            DataPointMetadata::new(Period::new(PeriodUnit::Days, 7), end_period),\n            DataPointMetadata::new(Period::new(PeriodUnit::Weeks, 6), end_period),\n            DataPointMetadata::new(Period::new(PeriodUnit::Months, 3), end_period),\n        ];\n        let was_active_past_days_metadatas = vec![\n            one_day_period_metadata,\n            DataPointMetadata::new(Period::new(PeriodUnit::Days, 7), end_period),\n            DataPointMetadata::new(Period::new(PeriodUnit::Days, 28), end_period),\n        ];\n\n        let data_points = empty::<DataPoint>()\n            .chain(Self::init_screen_active_time_vars(\n                one_day_period_metadata,\n                events,\n                screen_routes,\n            ))\n            .chain(Self::init_screen_enter_count_vars(\n                one_day_period_metadata,\n                events,\n                screen_routes,\n            ))\n            .chain(Self::init_was_active_each_past_period_vars(\n                was_active_each_period_metadatas,\n                events,\n            ))\n            .chain(Self::init_was_active_past_n_days_vars(\n                was_active_past_days_metadatas,\n                events,\n            ))\n            .collect();\n        Ok(data_points)\n    }\n\n    fn init_screen_active_time_vars(\n        metadata: DataPointMetadata,\n        events: &[AnalyticsEvent],\n        screen_routes: &[ScreenRoute],\n    ) -> Vec<DataPoint> {\n        let mut screen_active_time_vars: Vec<DataPoint> = screen_routes\n            .iter()\n            .map(|route| {\n                let events_this_route = Self::get_events_single_route(route, events);\n                CalcScreenActiveTime::new(\n                    metadata,\n                    Self::filter_events_in_this_period(metadata, events_this_route.as_slice()),\n                )\n            })\n            .map(DataPoint::ScreenActiveTime)\n            .collect();\n        screen_active_time_vars.push(DataPoint::ScreenActiveTime(CalcScreenActiveTime::new(\n            metadata,\n            Self::filter_events_in_this_period(metadata, events),\n        )));\n        screen_active_time_vars\n    }\n\n    fn init_screen_enter_count_vars(\n        metadata: DataPointMetadata,\n        events: &[AnalyticsEvent],\n        screen_routes: &[ScreenRoute],\n    ) -> Vec<DataPoint> {\n        screen_routes\n            .iter()\n            .map(|route| {\n                let events_this_route = Self::get_events_single_route(&route, events);\n                CalcScreenEnterCount::new(\n                    metadata,\n                    Self::filter_events_in_this_period(metadata, events_this_route.as_slice()),\n                )\n            })\n            .map(DataPoint::ScreenEnterCount)\n            .collect()\n    }\n\n    fn init_was_active_each_past_period_vars(\n        metadatas: Vec<DataPointMetadata>,\n        events: &[AnalyticsEvent],\n    ) -> Vec<DataPoint> {\n        metadatas\n            .iter()\n            .map(|metadata| {\n                let period_thresholds = (0..metadata.period.n)\n                    .map(|i| Self::get_start_of_period(*metadata, Some(i)))\n                    .collect();\n                CalcWasActiveEachPastPeriod::new(\n                    *metadata,\n                    Self::filter_events_in_this_period(*metadata, events),\n                    period_thresholds,\n                )\n            })\n            .map(DataPoint::WasActiveEachPastPeriod)\n            .collect()\n    }\n\n    fn init_was_active_past_n_days_vars(\n        metadatas: Vec<DataPointMetadata>,\n        events: &[AnalyticsEvent],\n    ) -> Vec<DataPoint> {\n        metadatas\n            .iter()\n            .map(|metadata| {\n                CalcWasActivePastNDays::new(\n                    *metadata,\n                    Self::filter_events_in_this_period(*metadata, events),\n                )\n            })\n            .map(DataPoint::WasActivePastNDays)\n            .collect()\n    }\n\n    // TODO: return an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517\n    fn filter_events_in_this_period(\n        metadata: DataPointMetadata,\n        events: &[AnalyticsEvent],\n    ) -> Vec<AnalyticsEvent> {\n        let start_of_period = Self::get_start_of_period(metadata, None);\n        Self::filter_events_before_end_of_period(metadata.end, events)\n            .iter()\n            .filter(|event| event.timestamp > start_of_period)\n            .cloned()\n            .collect()\n    }\n\n    fn get_start_of_period(\n        metadata: DataPointMetadata,\n        n_periods_override: Option<u32>,\n    ) -> DateTime<Utc> {\n        let n_periods = if let Some(n_periods) = n_periods_override {\n            n_periods\n        } else {\n            metadata.period.n\n        };\n        let avg_days_per_month = 365.0 / 12.0;\n        let midnight_end_of_period = get_midnight(metadata.end);\n        match metadata.period.unit {\n            PeriodUnit::Days => midnight_end_of_period - Duration::days(n_periods as i64),\n            PeriodUnit::Weeks => midnight_end_of_period - Duration::weeks(n_periods as i64),\n            PeriodUnit::Months => {\n                midnight_end_of_period\n                    - Duration::days((n_periods as f64 * avg_days_per_month) as i64)\n            }\n        }\n    }\n\n    // TODO: return an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517\n    fn filter_events_before_end_of_period(\n        end_of_period: DateTime<Utc>,\n        events: &[AnalyticsEvent],\n    ) -> Vec<AnalyticsEvent> {\n        let midnight_end_of_period = get_midnight(end_of_period);\n        events\n            .iter()\n            .filter(|event| event.timestamp < midnight_end_of_period)\n            .cloned()\n            .collect()\n    }\n\n    // TODO: return an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517\n    fn get_events_single_route(\n        route: &ScreenRoute,\n        all_events: &[AnalyticsEvent],\n    ) -> Vec<AnalyticsEvent> {\n        all_events\n            .iter()\n            .filter(|event| event.screen_route.as_ref() == Some(route))\n            .cloned()\n            .collect()\n    }\n}\n\nfn get_midnight(timestamp: DateTime<Utc>) -> DateTime<Utc> {\n    DateTime::from_utc(\n        NaiveDate::from_ymd(timestamp.year(), timestamp.month(), timestamp.day()).and_hms(0, 0, 0),\n        Utc,\n    )\n}\n\n#[cfg(test)]\nmod tests {\n    use chrono::{DateTime, Duration, Utc};\n\n    use crate::{\n        data_combination::{\n            data_combiner::{get_midnight, DataCombiner},\n            data_points::data_point::{\n                CalcScreenActiveTime,\n                CalcScreenEnterCount,\n                CalcWasActiveEachPastPeriod,\n                CalcWasActivePastNDays,\n                DataPoint,\n                DataPointMetadata,\n                Period,\n                PeriodUnit,\n            },\n        },\n        database::{\n            analytics_event::data_model::{AnalyticsEvent, AnalyticsEventType},\n            screen_route::data_model::ScreenRoute,\n        },\n    };\n\n    #[test]\n    fn test_init_screen_active_time_vars() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-01-01T01:01:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period);\n        let screen_route = ScreenRoute::new(\"home_screen\", end_period + Duration::days(1));\n        let first_event = AnalyticsEvent::new(\n            \"test1\",\n            AnalyticsEventType::ScreenEnter,\n            end_period - Duration::hours(12),\n            Some(screen_route.clone()),\n        );\n        let all_events = vec![\n            first_event.clone(),\n            AnalyticsEvent::new(\n                \"test1\",\n                AnalyticsEventType::AppEvent,\n                end_period - Duration::hours(13),\n                None,\n            ),\n        ];\n        let expected_output = vec![\n            DataPoint::ScreenActiveTime(CalcScreenActiveTime::new(metadata, vec![first_event])),\n            DataPoint::ScreenActiveTime(CalcScreenActiveTime::new(metadata, all_events.clone())),\n        ];\n        let actual_output =\n            DataCombiner::init_screen_active_time_vars(metadata, &all_events, &[screen_route]);\n        assert_eq!(actual_output, expected_output);\n    }\n\n    #[test]\n    fn test_init_screen_enter_count_vars() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-02-02T02:02:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period);\n        let screen_route = ScreenRoute::new(\"home_screen\", end_period + Duration::days(1));\n        let events = vec![AnalyticsEvent::new(\n            \"test1\",\n            AnalyticsEventType::ScreenEnter,\n            end_period - Duration::hours(12),\n            Some(screen_route.clone()),\n        )];\n        let expected_output = vec![DataPoint::ScreenEnterCount(CalcScreenEnterCount::new(\n            metadata,\n            events.clone(),\n        ))];\n        let actual_output =\n            DataCombiner::init_screen_enter_count_vars(metadata, &events, &[screen_route]);\n        assert_eq!(actual_output, expected_output);\n    }\n\n    #[test]\n    fn test_init_was_active_each_past_period_vars() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-03-03T03:03:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period);\n        let events = vec![AnalyticsEvent::new(\n            \"test1\",\n            AnalyticsEventType::AppEvent,\n            end_period - Duration::hours(12),\n            None,\n        )];\n        let period_thresholds = vec![get_midnight(end_period)];\n        let expected_output = vec![DataPoint::WasActiveEachPastPeriod(\n            CalcWasActiveEachPastPeriod::new(metadata, events.clone(), period_thresholds),\n        )];\n        let actual_output =\n            DataCombiner::init_was_active_each_past_period_vars(vec![metadata], &events);\n        assert_eq!(actual_output, expected_output);\n    }\n\n    #[test]\n    fn test_init_was_active_past_n_days_vars() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-04-04T04:04:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period);\n        let events = vec![AnalyticsEvent::new(\n            \"test1\",\n            AnalyticsEventType::AppEvent,\n            end_period - Duration::hours(12),\n            None,\n        )];\n        let expected_output = vec![DataPoint::WasActivePastNDays(CalcWasActivePastNDays::new(\n            metadata,\n            events.clone(),\n        ))];\n        let actual_output = DataCombiner::init_was_active_past_n_days_vars(vec![metadata], &events);\n        assert_eq!(actual_output, expected_output);\n    }\n\n    #[test]\n    fn test_filter_events_in_this_period() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-05-05T05:05:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 3), end_period);\n        let event_before = AnalyticsEvent::new(\n            \"test1\",\n            AnalyticsEventType::AppEvent,\n            end_period - Duration::days(5),\n            None,\n        );\n        let event_during = AnalyticsEvent::new(\n            \"test2\",\n            AnalyticsEventType::AppEvent,\n            end_period - Duration::days(1),\n            None,\n        );\n        let event_after = AnalyticsEvent::new(\n            \"test3\",\n            AnalyticsEventType::AppEvent,\n            end_period + Duration::days(2),\n            None,\n        );\n        let events = vec![event_before, event_during.clone(), event_after];\n        let expected_output = vec![event_during];\n        let actual_output = DataCombiner::filter_events_in_this_period(metadata, &events);\n        assert_eq!(actual_output, expected_output);\n    }\n\n    #[test]\n    fn test_get_start_of_period_one_day() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-01-01T00:00:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period);\n        let expected_output = end_period - Duration::days(1);\n        let actual_output = DataCombiner::get_start_of_period(metadata, None);\n        assert_eq!(actual_output, expected_output);\n    }\n\n    #[test]\n    fn test_get_start_of_period_one_day_with_override() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-02-02T00:00:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 2), end_period);\n        let expected_output = end_period - Duration::days(1);\n        let actual_output = DataCombiner::get_start_of_period(metadata, Some(1));\n        assert_eq!(actual_output, expected_output);\n    }\n\n    #[test]\n    fn test_get_start_of_period_one_week() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-03-03T00:00:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Weeks, 1), end_period);\n        let expected_output = end_period - Duration::weeks(1);\n        let actual_output = DataCombiner::get_start_of_period(metadata, None);\n        assert_eq!(actual_output, expected_output);\n    }\n\n    #[test]\n    fn test_get_start_of_period_one_month() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-04-04T00:00:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Months, 1), end_period);\n        let expected_output = end_period - Duration::days(30);\n        let actual_output = DataCombiner::get_start_of_period(metadata, None);\n        assert_eq!(actual_output, expected_output);\n    }\n\n    #[test]\n    fn text_filter_events_before_end_of_period() {\n        let end_of_period = Utc::now();\n        let event_before = AnalyticsEvent::new(\n            \"test1\",\n            AnalyticsEventType::AppEvent,\n            end_of_period - Duration::days(1),\n            None,\n        );\n        let event_after = AnalyticsEvent::new(\n            \"test2\",\n            AnalyticsEventType::AppEvent,\n            end_of_period + Duration::days(1),\n            None,\n        );\n        let events = vec![event_before.clone(), event_after];\n        let expected_output = vec![event_before];\n        let actual_output =\n            DataCombiner::filter_events_before_end_of_period(end_of_period, &events);\n        assert_eq!(actual_output, expected_output);\n    }\n\n    #[test]\n    fn test_get_events_single_route() {\n        let timestamp = Utc::now();\n        let home_route = ScreenRoute::new(\"home_screen\", timestamp + Duration::days(1));\n        let other_route = ScreenRoute::new(\"other_screen\", timestamp + Duration::days(2));\n        let home_route_event = AnalyticsEvent::new(\n            \"test1\",\n            AnalyticsEventType::AppEvent,\n            timestamp,\n            Some(home_route.clone()),\n        );\n        let other_route_event = AnalyticsEvent::new(\n            \"test2\",\n            AnalyticsEventType::ScreenEnter,\n            timestamp,\n            Some(other_route),\n        );\n        let all_events = [home_route_event.clone(), other_route_event];\n        let expected_output = vec![home_route_event];\n        let actual_output = DataCombiner::get_events_single_route(&home_route, &all_events);\n        assert_eq!(actual_output, expected_output);\n    }\n\n    #[test]\n    fn test_get_midnight() {\n        let timestamp = DateTime::parse_from_rfc3339(\"2021-01-01T21:21:21-02:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let expected_output = DateTime::parse_from_rfc3339(\"2021-01-01T00:00:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let actual_output = get_midnight(timestamp);\n        assert_eq!(actual_output, expected_output);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-analytics/src/data_combination/data_points/data_point.rs",
    "content": "//! File containing various structs used to define `DataPoints`.\n\nuse chrono::{DateTime, Utc};\n\nuse crate::database::analytics_event::data_model::AnalyticsEvent;\n\n#[derive(Debug, PartialEq, Eq, Clone, Copy)]\npub enum PeriodUnit {\n    Days,\n    Weeks,\n    Months,\n}\n\n/// Period combines information about the unit of this period, and the number of periods.\n/// For example a `Period` of three weeks can be represented with `Period::new(unit: PeriodUnit::Weeks, n: 3)`\n#[derive(Debug, PartialEq, Eq, Clone, Copy)]\npub struct Period {\n    pub unit: PeriodUnit,\n    pub n: u32,\n}\n\nimpl Period {\n    pub fn new(unit: PeriodUnit, n: u32) -> Self {\n        Self { unit, n }\n    }\n}\n\n/// `DataPointMetadata` contains information about `Period` and when the period ends. It is used to\n/// define which `AnalyticsEvents` fall inside a `Period` and must therefore be included in the calculation\n/// of a specific `DataPoint`.\n#[derive(Debug, PartialEq, Eq, Clone, Copy)]\npub struct DataPointMetadata {\n    pub period: Period,\n    pub end: DateTime<Utc>,\n}\n\nimpl DataPointMetadata {\n    pub fn new(period: Period, end: DateTime<Utc>) -> Self {\n        Self { period, end }\n    }\n}\n\npub trait CalculateDataPoints {\n    fn metadata(&self) -> DataPointMetadata;\n\n    fn calculate(&self) -> Vec<u32>;\n}\n\n/// `DataPoint` is an enum whose variants represent data points that will need to be aggregated and shown to the user.\n/// They are the actual analytics information that is valuable to the user. Each `DataPoint` refers to a specific `Period`.\n/// ## Variants:\n/// * `ScreenActiveTime`: How much time was spent on a specific screen.\n/// * `ScreenEnterCount`: How many times the user entered a specific screen.\n/// * `WasActiveEachPastPeriod`: Whether the user was active or not in each specified period (in general, not by screen).\n/// * `WasActivePastNDays`: Whether the user was active or not in the past N days (in general, not by screen).\n///\n/// There are still more variants to be implemented: https://xainag.atlassian.net/browse/XN-1687\n#[derive(Debug, PartialEq, Eq)]\npub enum DataPoint {\n    ScreenActiveTime(CalcScreenActiveTime),\n    ScreenEnterCount(CalcScreenEnterCount),\n    WasActiveEachPastPeriod(CalcWasActiveEachPastPeriod),\n    WasActivePastNDays(CalcWasActivePastNDays),\n}\n\n#[allow(dead_code)]\n// TODO: will be called when preparing the data to be sent to the coordinator\nimpl DataPoint {\n    fn calculate(&self) -> Vec<u32> {\n        match self {\n            DataPoint::ScreenActiveTime(data) => data.calculate(),\n            DataPoint::ScreenEnterCount(data) => data.calculate(),\n            DataPoint::WasActiveEachPastPeriod(data) => data.calculate(),\n            DataPoint::WasActivePastNDays(data) => data.calculate(),\n        }\n    }\n}\n\n#[derive(Debug, PartialEq, Eq)]\n// TODO: accept an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517\npub struct CalcScreenActiveTime {\n    pub metadata: DataPointMetadata,\n    pub events: Vec<AnalyticsEvent>,\n}\n\n#[derive(Debug, PartialEq, Eq)]\n// TODO: accept an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517\npub struct CalcScreenEnterCount {\n    pub metadata: DataPointMetadata,\n    pub events: Vec<AnalyticsEvent>,\n}\n\n#[derive(Debug, PartialEq, Eq)]\n// TODO: accept an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517\npub struct CalcWasActiveEachPastPeriod {\n    pub metadata: DataPointMetadata,\n    pub events: Vec<AnalyticsEvent>,\n    pub period_thresholds: Vec<DateTime<Utc>>,\n}\n\n#[derive(Debug, PartialEq, Eq)]\n// TODO: accept an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517\npub struct CalcWasActivePastNDays {\n    pub metadata: DataPointMetadata,\n    pub events: Vec<AnalyticsEvent>,\n}\n"
  },
  {
    "path": "rust/xaynet-analytics/src/data_combination/data_points/mod.rs",
    "content": "pub mod data_point;\npub mod screen_active_time;\npub mod screen_enter_count;\npub mod was_active_each_past_period;\npub mod was_active_past_n_days;\n"
  },
  {
    "path": "rust/xaynet-analytics/src/data_combination/data_points/screen_active_time.rs",
    "content": "use chrono::Duration;\n\nuse crate::{\n    data_combination::data_points::data_point::{\n        CalcScreenActiveTime,\n        CalculateDataPoints,\n        DataPointMetadata,\n    },\n    database::analytics_event::data_model::{AnalyticsEvent, AnalyticsEventType},\n};\n\nimpl CalcScreenActiveTime {\n    pub fn new(metadata: DataPointMetadata, events: Vec<AnalyticsEvent>) -> Self {\n        Self { metadata, events }\n    }\n\n    // TODO: return an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517\n    fn get_screen_and_app_events(&self) -> Vec<AnalyticsEvent> {\n        self.events\n            .iter()\n            .filter(|event| {\n                matches!(\n                    event.event_type,\n                    AnalyticsEventType::ScreenEnter | AnalyticsEventType::AppEvent\n                )\n            })\n            .cloned()\n            .collect()\n    }\n}\n\nimpl CalculateDataPoints for CalcScreenActiveTime {\n    fn metadata(&self) -> DataPointMetadata {\n        self.metadata\n    }\n\n    fn calculate(&self) -> Vec<u32> {\n        let screen_and_app_events = self.get_screen_and_app_events();\n        let value = if screen_and_app_events.is_empty() {\n            0\n        } else {\n            screen_and_app_events\n                .iter()\n                .scan(\n                    screen_and_app_events.first().unwrap().timestamp,\n                    |last_timestamp, event| {\n                        let duration = if event.event_type == AnalyticsEventType::ScreenEnter {\n                            last_timestamp.signed_duration_since(event.timestamp)\n                        } else {\n                            Duration::zero()\n                        };\n                        *last_timestamp = event.timestamp;\n                        Some(duration)\n                    },\n                )\n                .map(|duration| duration.num_milliseconds() as u32)\n                .sum()\n        };\n        vec![value]\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use chrono::{DateTime, Duration, Utc};\n\n    use super::*;\n    use crate::{\n        data_combination::data_points::data_point::{Period, PeriodUnit},\n        database::screen_route::data_model::ScreenRoute,\n    };\n\n    #[test]\n    fn test_get_screen_and_app_events() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-01-01T01:01:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period);\n        let screen_route = ScreenRoute::new(\"home_screen\", end_period + Duration::days(1));\n        let screen_enter_event = AnalyticsEvent::new(\n            \"test1\",\n            AnalyticsEventType::ScreenEnter,\n            end_period - Duration::hours(10),\n            Some(screen_route),\n        );\n        let app_event = AnalyticsEvent::new(\n            \"test1\",\n            AnalyticsEventType::AppEvent,\n            end_period - Duration::hours(12),\n            None,\n        );\n        let events = vec![\n            screen_enter_event.clone(),\n            AnalyticsEvent::new(\n                \"test1\",\n                AnalyticsEventType::AppError,\n                end_period - Duration::hours(11),\n                None,\n            ),\n            app_event.clone(),\n            AnalyticsEvent::new(\n                \"test1\",\n                AnalyticsEventType::UserAction,\n                end_period - Duration::hours(13),\n                None,\n            ),\n        ];\n        let screen_active_time = CalcScreenActiveTime::new(metadata, events);\n        let expected_output = vec![screen_enter_event, app_event];\n        let actual_output = screen_active_time.get_screen_and_app_events();\n        assert_eq!(actual_output, expected_output);\n    }\n\n    #[test]\n    fn test_calculate_when_no_events() {\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), Utc::now());\n        let screen_active_time = CalcScreenActiveTime::new(metadata, Vec::new());\n        assert_eq!(screen_active_time.calculate(), vec![0]);\n    }\n\n    #[test]\n    fn test_calculate_when_one_screen_enter_event() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-03-03T03:03:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period);\n        let screen_route = ScreenRoute::new(\"home_screen\", end_period + Duration::days(1));\n        let events = vec![AnalyticsEvent::new(\n            \"test1\",\n            AnalyticsEventType::ScreenEnter,\n            end_period - Duration::hours(12),\n            Some(screen_route),\n        )];\n        let screen_active_time = CalcScreenActiveTime::new(metadata, events);\n        assert_eq!(screen_active_time.calculate(), vec![0]);\n    }\n\n    #[test]\n    fn test_calculate_when_two_screen_enter_events() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-03-03T03:03:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period);\n        let screen_route = ScreenRoute::new(\"home_screen\", end_period + Duration::days(1));\n        let events = vec![\n            AnalyticsEvent::new(\n                \"test1\",\n                AnalyticsEventType::ScreenEnter,\n                end_period - Duration::hours(12),\n                Some(screen_route.clone()),\n            ),\n            AnalyticsEvent::new(\n                \"test2\",\n                AnalyticsEventType::ScreenEnter,\n                end_period - Duration::hours(15),\n                Some(screen_route),\n            ),\n        ];\n        let time_between_events =\n            events.first().unwrap().timestamp - events.last().unwrap().timestamp;\n        let screen_active_time = CalcScreenActiveTime::new(metadata, events);\n        assert_eq!(\n            screen_active_time.calculate(),\n            vec![time_between_events.num_milliseconds() as u32]\n        );\n    }\n\n    #[test]\n    fn test_calculate_when_mixed_type_events() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-04-04T04:04:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period);\n        let screen_route = ScreenRoute::new(\"home_screen\", end_period + Duration::days(1));\n        let first = AnalyticsEvent::new(\n            \"test1\",\n            AnalyticsEventType::ScreenEnter,\n            end_period - Duration::hours(12),\n            Some(screen_route.clone()),\n        );\n        let second = AnalyticsEvent::new(\n            \"test1\",\n            AnalyticsEventType::AppEvent,\n            end_period - Duration::hours(13),\n            None,\n        );\n        let third = AnalyticsEvent::new(\n            \"test2\",\n            AnalyticsEventType::ScreenEnter,\n            end_period - Duration::hours(14),\n            Some(screen_route.clone()),\n        );\n        let fourth = AnalyticsEvent::new(\n            \"test2\",\n            AnalyticsEventType::ScreenEnter,\n            end_period - Duration::hours(14),\n            Some(screen_route),\n        );\n        let events = vec![first.clone(), second.clone(), third.clone(), fourth.clone()];\n        let time_between_events =\n            first.timestamp - second.timestamp + (third.timestamp - fourth.timestamp);\n        let screen_active_time = CalcScreenActiveTime::new(metadata, events);\n        assert_eq!(\n            screen_active_time.calculate(),\n            vec![time_between_events.num_milliseconds() as u32]\n        );\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-analytics/src/data_combination/data_points/screen_enter_count.rs",
    "content": "use crate::{\n    data_combination::data_points::data_point::{\n        CalcScreenEnterCount,\n        CalculateDataPoints,\n        DataPointMetadata,\n    },\n    database::analytics_event::data_model::{AnalyticsEvent, AnalyticsEventType},\n};\n\nimpl CalcScreenEnterCount {\n    pub fn new(metadata: DataPointMetadata, events: Vec<AnalyticsEvent>) -> Self {\n        Self { metadata, events }\n    }\n}\n\nimpl CalculateDataPoints for CalcScreenEnterCount {\n    fn metadata(&self) -> DataPointMetadata {\n        self.metadata\n    }\n\n    fn calculate(&self) -> Vec<u32> {\n        let value = self\n            .events\n            .iter()\n            .filter(|event| event.event_type == AnalyticsEventType::ScreenEnter)\n            .count() as u32;\n        vec![value]\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use chrono::{DateTime, Duration, Utc};\n\n    use super::*;\n    use crate::{\n        data_combination::data_points::data_point::{Period, PeriodUnit},\n        database::screen_route::data_model::ScreenRoute,\n    };\n\n    #[test]\n    fn test_calculate_when_no_events() {\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), Utc::now());\n        let screen_enter_count = CalcScreenEnterCount::new(metadata, Vec::new());\n        assert_eq!(screen_enter_count.calculate(), vec![0]);\n    }\n\n    #[test]\n    fn test_calculate_when_one_event() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-01-01T01:01:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period);\n        let screen_route = ScreenRoute::new(\"home_screen\", end_period + Duration::days(1));\n        let events = vec![AnalyticsEvent::new(\n            \"test1\",\n            AnalyticsEventType::ScreenEnter,\n            end_period - Duration::hours(12),\n            Some(screen_route),\n        )];\n        let screen_enter_count = CalcScreenEnterCount::new(metadata, events);\n        assert_eq!(screen_enter_count.calculate(), vec![1]);\n    }\n\n    #[test]\n    fn test_calculate_when_two_events() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-02-02T02:02:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period);\n        let screen_route = ScreenRoute::new(\"home_screen\", end_period + Duration::days(1));\n        let events = vec![\n            AnalyticsEvent::new(\n                \"test1\",\n                AnalyticsEventType::ScreenEnter,\n                end_period - Duration::hours(9),\n                Some(screen_route.clone()),\n            ),\n            AnalyticsEvent::new(\n                \"test2\",\n                AnalyticsEventType::ScreenEnter,\n                end_period - Duration::hours(18),\n                Some(screen_route),\n            ),\n        ];\n        let screen_enter_count = CalcScreenEnterCount::new(metadata, events);\n        assert_eq!(screen_enter_count.calculate(), vec![2]);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-analytics/src/data_combination/data_points/was_active_each_past_period.rs",
    "content": "use chrono::{DateTime, Utc};\nuse std::collections::BTreeMap;\n\nuse crate::{\n    data_combination::data_points::data_point::{\n        CalcWasActiveEachPastPeriod,\n        CalculateDataPoints,\n        DataPointMetadata,\n    },\n    database::analytics_event::data_model::AnalyticsEvent,\n};\n\nimpl CalcWasActiveEachPastPeriod {\n    pub fn new(\n        metadata: DataPointMetadata,\n        events: Vec<AnalyticsEvent>,\n        period_thresholds: Vec<DateTime<Utc>>,\n    ) -> Self {\n        Self {\n            metadata,\n            events,\n            period_thresholds,\n        }\n    }\n\n    // TODO: this could possibly later be optimised by sorting events by timestamp and caching the last timestamp added to the HashMap\n    fn group_timestamps_by_period_threshold(&self) -> BTreeMap<DateTime<Utc>, Vec<DateTime<Utc>>> {\n        let mut timestamps_by_period_threshold = BTreeMap::new();\n        for these_thresholds in self.period_thresholds.windows(2) {\n            // safe unwrap: `windows` guarantees that there are at least two elements.\n            // If `period_thresholds` contains less than two elements, this code block is not executed\n            let newer_threshold = these_thresholds.first().unwrap();\n            let older_threshold = these_thresholds.last().unwrap();\n            let timestamps: Vec<DateTime<Utc>> = self\n                .events\n                .iter()\n                .filter(|event| {\n                    event.timestamp < *newer_threshold && event.timestamp > *older_threshold\n                })\n                .map(|event| event.timestamp)\n                .collect();\n            timestamps_by_period_threshold.insert(*newer_threshold, timestamps);\n        }\n        timestamps_by_period_threshold\n    }\n}\n\nimpl CalculateDataPoints for CalcWasActiveEachPastPeriod {\n    fn metadata(&self) -> DataPointMetadata {\n        self.metadata\n    }\n\n    fn calculate(&self) -> Vec<u32> {\n        let timestamps_by_period_threshold = self.group_timestamps_by_period_threshold();\n        // since we are travelling 'back in time' we need to reverse the order of the values of the BTreeMap\n        timestamps_by_period_threshold\n            .values()\n            .rev()\n            .map(|timestamps| !timestamps.is_empty() as u32)\n            .collect::<Vec<u32>>()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use chrono::{DateTime, Duration, Utc};\n\n    use super::*;\n    use crate::{\n        data_combination::data_points::data_point::{Period, PeriodUnit},\n        database::analytics_event::data_model::AnalyticsEventType,\n    };\n\n    #[test]\n    fn test_calculate_no_events_in_a_period() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-02-02T00:00:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period);\n        let period_thresholds = vec![end_period, end_period - Duration::days(1)];\n        let was_active_each_past_period =\n            CalcWasActiveEachPastPeriod::new(metadata, Vec::new(), period_thresholds);\n        assert_eq!(was_active_each_past_period.calculate(), vec![0]);\n    }\n\n    #[test]\n    fn test_calculate_one_event_in_a_period() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-03-03T00:00:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), end_period);\n        let events = vec![AnalyticsEvent::new(\n            \"test1\",\n            AnalyticsEventType::UserAction,\n            end_period - Duration::hours(12),\n            None,\n        )];\n        let period_thresholds = vec![end_period, end_period - Duration::days(1)];\n        let was_active_each_past_period =\n            CalcWasActiveEachPastPeriod::new(metadata, events, period_thresholds);\n        assert_eq!(was_active_each_past_period.calculate(), vec![1]);\n    }\n\n    #[test]\n    fn test_calculate_no_events_in_two_periods() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-04-04T00:00:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 2), end_period);\n        let period_thresholds = vec![\n            end_period,\n            end_period - Duration::days(1),\n            end_period - Duration::days(2),\n        ];\n        let was_active_each_past_period =\n            CalcWasActiveEachPastPeriod::new(metadata, Vec::new(), period_thresholds);\n        assert_eq!(was_active_each_past_period.calculate(), vec![0, 0]);\n    }\n\n    #[test]\n    fn test_calculate_one_event_in_one_period_zero_in_another() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-05-05T00:00:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 2), end_period);\n        let events = vec![AnalyticsEvent::new(\n            \"test1\",\n            AnalyticsEventType::UserAction,\n            end_period - Duration::hours(12),\n            None,\n        )];\n        let period_thresholds = vec![\n            end_period,\n            end_period - Duration::days(1),\n            end_period - Duration::days(2),\n        ];\n        let was_active_each_past_period =\n            CalcWasActiveEachPastPeriod::new(metadata, events, period_thresholds);\n        assert_eq!(was_active_each_past_period.calculate(), vec![1, 0]);\n    }\n\n    #[test]\n    fn test_calculate_two_events_in_one_period_zero_in_another() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-06-06T00:00:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 2), end_period);\n        let events = vec![\n            AnalyticsEvent::new(\n                \"test1\",\n                AnalyticsEventType::UserAction,\n                end_period - Duration::hours(12),\n                None,\n            ),\n            AnalyticsEvent::new(\n                \"test2\",\n                AnalyticsEventType::AppError,\n                end_period - Duration::hours(15),\n                None,\n            ),\n        ];\n        let period_thresholds = vec![\n            end_period,\n            end_period - Duration::days(1),\n            end_period - Duration::days(2),\n        ];\n        let was_active_each_past_period =\n            CalcWasActiveEachPastPeriod::new(metadata, events, period_thresholds);\n        assert_eq!(was_active_each_past_period.calculate(), vec![1, 0]);\n    }\n\n    #[test]\n    fn test_calculate_two_periods_with_one_event_each() {\n        let end_period = DateTime::parse_from_rfc3339(\"2021-07-07T00:00:00-00:00\")\n            .unwrap()\n            .with_timezone(&Utc);\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 2), end_period);\n        let events = vec![\n            AnalyticsEvent::new(\n                \"test1\",\n                AnalyticsEventType::UserAction,\n                end_period - Duration::hours(12),\n                None,\n            ),\n            AnalyticsEvent::new(\n                \"test2\",\n                AnalyticsEventType::AppError,\n                end_period - Duration::hours(36),\n                None,\n            ),\n        ];\n        let period_thresholds = vec![\n            end_period,\n            end_period - Duration::days(1),\n            end_period - Duration::days(2),\n        ];\n        let was_active_each_past_period =\n            CalcWasActiveEachPastPeriod::new(metadata, events, period_thresholds);\n        assert_eq!(was_active_each_past_period.calculate(), vec![1, 1]);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-analytics/src/data_combination/data_points/was_active_past_n_days.rs",
    "content": "use crate::{\n    data_combination::data_points::data_point::{\n        CalcWasActivePastNDays,\n        CalculateDataPoints,\n        DataPointMetadata,\n    },\n    database::analytics_event::data_model::AnalyticsEvent,\n};\n\nimpl CalcWasActivePastNDays {\n    pub fn new(metadata: DataPointMetadata, events: Vec<AnalyticsEvent>) -> Self {\n        Self { metadata, events }\n    }\n}\n\nimpl CalculateDataPoints for CalcWasActivePastNDays {\n    fn metadata(&self) -> DataPointMetadata {\n        self.metadata\n    }\n\n    fn calculate(&self) -> Vec<u32> {\n        vec![!self.events.is_empty() as u32]\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use chrono::{Duration, Utc};\n\n    use super::*;\n    use crate::{\n        data_combination::data_points::data_point::{Period, PeriodUnit},\n        database::analytics_event::data_model::AnalyticsEventType,\n    };\n\n    #[test]\n    fn test_calculate_without_events() {\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), Utc::now());\n        let was_active_past_n_days = CalcWasActivePastNDays::new(metadata, Vec::new());\n        assert_eq!(was_active_past_n_days.calculate(), vec![0]);\n    }\n\n    #[test]\n    fn test_calculate_with_events() {\n        let metadata = DataPointMetadata::new(Period::new(PeriodUnit::Days, 1), Utc::now());\n        let events = vec![AnalyticsEvent::new(\n            \"test1\",\n            AnalyticsEventType::AppEvent,\n            metadata.end - Duration::hours(12),\n            None,\n        )];\n        let was_active_past_n_days = CalcWasActivePastNDays::new(metadata, events);\n        assert_eq!(was_active_past_n_days.calculate(), vec![1]);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-analytics/src/data_combination/mod.rs",
    "content": "pub mod data_combiner;\npub mod data_points;\n"
  },
  {
    "path": "rust/xaynet-analytics/src/database/analytics_event/adapter.rs",
    "content": "//! This file contains struct and impls for `AnalyticsEventAdapter` and `AnalyticsEventRelationalAdapter`,\n//! as well as the implementation of `IsarAdapter` for `AnalyticsEventAdapter`.\n\nuse anyhow::{anyhow, Error, Result};\nuse isar_core::object::{\n    data_type::DataType,\n    isar_object::{IsarObject, Property},\n    object_builder::ObjectBuilder,\n};\nuse std::{convert::TryFrom, vec::IntoIter};\n\nuse crate::database::{\n    common::{FieldProperty, IsarAdapter, RelationalField, Repo, SchemaGenerator},\n    isar::IsarDb,\n    screen_route::data_model::ScreenRoute,\n};\n\n/// `AnalyticsEventAdapter` allows to convert an `IsarObject` from the db to an `AnalyticsEvent`. It is an intermediate\n/// representation.\n#[derive(Debug, PartialEq, Eq, Clone)]\npub struct AnalyticsEventAdapter {\n    pub name: String,\n    pub event_type: i32,\n    pub timestamp: String,\n    pub screen_route_field: Option<String>,\n}\n\nimpl AnalyticsEventAdapter {\n    pub fn new<N: Into<String>>(\n        name: N,\n        event_type: i32,\n        timestamp: String,\n        screen_route_field: Option<RelationalField>,\n    ) -> Self {\n        Self {\n            name: name.into(),\n            event_type,\n            timestamp,\n            screen_route_field: screen_route_field.map(|field| field.into()),\n        }\n    }\n}\n\nimpl<'event> IsarAdapter<'event> for AnalyticsEventAdapter {\n    fn get_oid(&self) -> String {\n        format!(\"{}-{}\", self.name, self.timestamp)\n    }\n\n    fn into_field_properties() -> IntoIter<FieldProperty> {\n        vec![\n            FieldProperty::new(\"oid\", DataType::String, true),\n            FieldProperty::new(\"name\", DataType::String, false),\n            FieldProperty::new(\"event_type\", DataType::Int, false),\n            FieldProperty::new(\"timestamp\", DataType::String, false),\n            FieldProperty::new(\"screen_route_field\", DataType::String, false),\n        ]\n        .into_iter()\n    }\n\n    fn write_with_object_builder(&self, object_builder: &mut ObjectBuilder) {\n        object_builder.write_string(Some(&self.get_oid()));\n        object_builder.write_string(Some(&self.name));\n        object_builder.write_int(self.event_type);\n        object_builder.write_string(Some(&self.timestamp));\n        object_builder.write_string(self.screen_route_field.as_deref());\n    }\n\n    fn read(\n        isar_object: &'event IsarObject,\n        isar_properties: &'event [(String, Property)],\n    ) -> Result<AnalyticsEventAdapter, Error> {\n        let name_property = Self::find_property_by_name(\"name\", isar_properties)?;\n        let event_type_property = Self::find_property_by_name(\"event_type\", isar_properties)?;\n        let timestamp_property = Self::find_property_by_name(\"timestamp\", isar_properties)?;\n        let screen_route_field_property =\n            Self::find_property_by_name(\"screen_route_field\", isar_properties)?;\n\n        let name_field = isar_object\n            .read_string(name_property)\n            .ok_or_else(|| anyhow!(\"unable to read name\"))?;\n        let event_type_field = isar_object.read_int(event_type_property);\n        let timestamp_field = isar_object\n            .read_string(timestamp_property)\n            .ok_or_else(|| anyhow!(\"unable to read timestamp\"))?\n            .to_string();\n        let screen_route_field = isar_object\n            .read_string(screen_route_field_property)\n            .map(RelationalField::try_from)\n            .transpose()?;\n\n        Ok(AnalyticsEventAdapter::new(\n            name_field,\n            event_type_field,\n            timestamp_field,\n            screen_route_field,\n        ))\n    }\n}\n\nimpl<'event> SchemaGenerator<'event, AnalyticsEventAdapter> for AnalyticsEventAdapter {}\n\n/// `AnalyticsEventRelationalAdapter` is needed as an intermediate step when saving/retrieving events\n/// from the db because `AnalyticsEvent` contains an `Option<ScreenRoute>`, which, if `Some`, needs to be retrieved\n/// from a different collection in Isar.\npub struct AnalyticsEventRelationalAdapter {\n    pub name: String,\n    pub event_type: i32,\n    pub timestamp: String,\n    pub screen_route: Option<ScreenRoute>,\n}\n\nimpl AnalyticsEventRelationalAdapter {\n    pub fn new(adapter: AnalyticsEventAdapter, db: &IsarDb) -> Result<Self, Error> {\n        let screen_route = adapter\n            .screen_route_field\n            .map(|screen_route_field| {\n                let relational_field = RelationalField::try_from(screen_route_field.as_str())?;\n                ScreenRoute::get(\n                    &relational_field.value,\n                    db,\n                    &relational_field.collection_name,\n                )\n            })\n            .transpose()?;\n\n        Ok(Self {\n            name: adapter.name,\n            event_type: adapter.event_type,\n            timestamp: adapter.timestamp,\n            screen_route,\n        })\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-analytics/src/database/analytics_event/data_model.rs",
    "content": "//! In this file `AnalyticsEvent` and `AnalyticsEventType` are declared, together with some conversion methods to/from adapters.\n\nuse anyhow::{anyhow, Result};\nuse chrono::{DateTime, Utc};\nuse std::convert::{From, Into, TryFrom, TryInto};\n\nuse crate::database::{\n    analytics_event::adapter::{AnalyticsEventAdapter, AnalyticsEventRelationalAdapter},\n    common::RelationalField,\n    screen_route::data_model::ScreenRoute,\n};\n\n/// The type of `AnalyticsEvent` recorded on the framework side.\n/// ## Variants:\n/// * `AppEvent`: It refes to Flutter's `AppLifeCyclesEvents` (of the equivalent in other frameworks):\n///   https://flutter.dev/docs/get-started/flutter-for/android-devs#how-do-i-listen-to-android-activity-lifecycle-events\n/// * `AppError`: A known error logged by the developers\n/// * `ScreenEnter`: Registers when the user enters a specific screen\n/// * `UserAction`: A custom event logged by the developer (eg: clicked on a specific button)\n#[derive(Debug, PartialEq, Eq, Clone, Copy)]\npub enum AnalyticsEventType {\n    AppEvent = 0,\n    AppError = 1,\n    ScreenEnter = 2,\n    UserAction = 3,\n}\n\nimpl TryFrom<i32> for AnalyticsEventType {\n    type Error = anyhow::Error;\n\n    fn try_from(v: i32) -> Result<Self, Self::Error> {\n        match v {\n            x if x == AnalyticsEventType::AppEvent as i32 => Ok(AnalyticsEventType::AppEvent),\n            x if x == AnalyticsEventType::AppError as i32 => Ok(AnalyticsEventType::AppError),\n            x if x == AnalyticsEventType::ScreenEnter as i32 => Ok(AnalyticsEventType::ScreenEnter),\n            x if x == AnalyticsEventType::UserAction as i32 => Ok(AnalyticsEventType::UserAction),\n            _ => Err(anyhow!(\n                \"i32 value {:?} is not mapped to an AnalyticsEventType variant\",\n                v\n            )),\n        }\n    }\n}\n\n/// The core data model of the library. It represents an event recorded on the mobile framework side.\n/// It can be logged manually by the developers, or automatically detected by Flutter/the mobile framework side.\n/// ## Fields:\n/// * `name`: The name of the event.\n/// * `event_type`: The type of event.\n/// * `timestamp`: When the event was created.\n/// * `screen_route`: Optional field representing the screen on which the event was recorded.\n#[derive(Debug, PartialEq, Eq, Clone)]\npub struct AnalyticsEvent {\n    pub name: String,\n    pub event_type: AnalyticsEventType,\n    pub timestamp: DateTime<Utc>,\n    pub screen_route: Option<ScreenRoute>,\n}\n\nimpl AnalyticsEvent {\n    pub fn new<N: Into<String>>(\n        name: N,\n        event_type: AnalyticsEventType,\n        timestamp: DateTime<Utc>,\n        screen_route: Option<ScreenRoute>,\n    ) -> Self {\n        Self {\n            name: name.into(),\n            event_type,\n            timestamp,\n            screen_route,\n        }\n    }\n}\n\nimpl TryFrom<AnalyticsEventRelationalAdapter> for AnalyticsEvent {\n    type Error = anyhow::Error;\n\n    fn try_from(adapter: AnalyticsEventRelationalAdapter) -> Result<Self, Self::Error> {\n        let event = AnalyticsEvent::new(\n            adapter.name,\n            adapter\n                .event_type\n                .try_into()\n                .map_err(|_| anyhow!(\"unable to convert event_type into enum\"))?,\n            DateTime::parse_from_rfc3339(&adapter.timestamp)?.with_timezone(&Utc),\n            adapter.screen_route,\n        );\n        Ok(event)\n    }\n}\n\nimpl From<AnalyticsEvent> for AnalyticsEventAdapter {\n    fn from(ae: AnalyticsEvent) -> Self {\n        AnalyticsEventAdapter::new(\n            ae.name,\n            ae.event_type as i32,\n            ae.timestamp.to_rfc3339(),\n            ae.screen_route.map(RelationalField::from),\n        )\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::database::common::CollectionNames;\n\n    #[test]\n    fn test_analytics_event_type_try_from_valid_i32() {\n        assert_eq!(\n            AnalyticsEventType::try_from(0).unwrap(),\n            AnalyticsEventType::AppEvent\n        );\n        assert_eq!(\n            AnalyticsEventType::try_from(1).unwrap(),\n            AnalyticsEventType::AppError\n        );\n        assert_eq!(\n            AnalyticsEventType::try_from(2).unwrap(),\n            AnalyticsEventType::ScreenEnter\n        );\n        assert_eq!(\n            AnalyticsEventType::try_from(3).unwrap(),\n            AnalyticsEventType::UserAction\n        );\n    }\n\n    #[test]\n    fn test_analytics_event_type_invalid_i32() {\n        assert!(AnalyticsEventType::try_from(42).is_err());\n    }\n\n    #[test]\n    fn test_analytics_event_try_from_relational_adapter_without_screen_route() {\n        let timestamp = \"2021-01-01T01:01:00+00:00\";\n        let relational_adapter = AnalyticsEventRelationalAdapter {\n            name: \"test\".to_string(),\n            event_type: 0,\n            timestamp: timestamp.to_string(),\n            screen_route: None,\n        };\n        let analytics_event = AnalyticsEvent::new(\n            \"test\",\n            AnalyticsEventType::AppEvent,\n            DateTime::parse_from_rfc3339(timestamp)\n                .unwrap()\n                .with_timezone(&Utc),\n            None,\n        );\n        assert_eq!(\n            AnalyticsEvent::try_from(relational_adapter).unwrap(),\n            analytics_event\n        );\n    }\n\n    #[test]\n    fn test_analytics_event_try_from_relational_adapter_with_screen_route() {\n        let timestamp_str = \"2021-01-01T01:01:00+00:00\";\n        let timestamp_parsed = DateTime::parse_from_rfc3339(timestamp_str)\n            .unwrap()\n            .with_timezone(&Utc);\n        let screen_route = ScreenRoute::new(\"route\", timestamp_parsed);\n        let relational_adapter = AnalyticsEventRelationalAdapter {\n            name: \"test\".to_string(),\n            event_type: 2,\n            timestamp: timestamp_str.to_string(),\n            screen_route: Some(screen_route.clone()),\n        };\n        let analytics_event = AnalyticsEvent::new(\n            \"test\",\n            AnalyticsEventType::ScreenEnter,\n            timestamp_parsed,\n            Some(screen_route),\n        );\n        assert_eq!(\n            AnalyticsEvent::try_from(relational_adapter).unwrap(),\n            analytics_event\n        );\n    }\n\n    #[test]\n    fn test_analytics_event_try_into_adapter_without_screen_route() {\n        let timestamp_str = \"2021-01-01T01:01:00+00:00\";\n        let timestamp_parsed = DateTime::parse_from_rfc3339(timestamp_str)\n            .unwrap()\n            .with_timezone(&Utc);\n        let analytics_event =\n            AnalyticsEvent::new(\"test\", AnalyticsEventType::AppError, timestamp_parsed, None);\n\n        let actual_analytics_event_adapter: AnalyticsEventAdapter =\n            analytics_event.try_into().unwrap();\n        let expected_analytics_event_adapter =\n            AnalyticsEventAdapter::new(\"test\", 1, timestamp_str.to_string(), None);\n        assert_eq!(\n            actual_analytics_event_adapter,\n            expected_analytics_event_adapter\n        );\n    }\n\n    #[test]\n    fn test_analytics_event_try_into_adapter_with_screen_route() {\n        let timestamp_str = \"2021-01-01T01:01:00+00:00\";\n        let timestamp_parsed = DateTime::parse_from_rfc3339(timestamp_str)\n            .unwrap()\n            .with_timezone(&Utc);\n        let screen_route = ScreenRoute::new(\"route\", timestamp_parsed);\n        let relationa_field = RelationalField {\n            value: \"route\".to_string(),\n            collection_name: CollectionNames::SCREEN_ROUTES.to_string(),\n        };\n        let analytics_event = AnalyticsEvent::new(\n            \"test\",\n            AnalyticsEventType::UserAction,\n            timestamp_parsed,\n            Some(screen_route),\n        );\n\n        let actual_analytics_event_adapter: AnalyticsEventAdapter =\n            analytics_event.try_into().unwrap();\n        let expected_analytics_event_adapter =\n            AnalyticsEventAdapter::new(\"test\", 3, timestamp_str.to_string(), Some(relationa_field));\n        assert_eq!(\n            actual_analytics_event_adapter,\n            expected_analytics_event_adapter\n        );\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-analytics/src/database/analytics_event/mod.rs",
    "content": "pub mod adapter;\npub mod data_model;\npub mod repo;\n"
  },
  {
    "path": "rust/xaynet-analytics/src/database/analytics_event/repo.rs",
    "content": "//! Implementations of the methods needed to save and get `AnalyticsEvents` to/from Isar.\n\nuse anyhow::{anyhow, Error, Result};\nuse std::convert::{Into, TryFrom};\n\nuse crate::database::{\n    analytics_event::{\n        adapter::{AnalyticsEventAdapter, AnalyticsEventRelationalAdapter},\n        data_model::AnalyticsEvent,\n    },\n    common::{IsarAdapter, Repo},\n    isar::IsarDb,\n};\n\n/// Inside `get()` and `get_all()` there is an intermediate conversion from `Adapter` to `RelationalAdapter`,\n/// and then to data model (`AnalyticsEvent`), which is different than other data models where\n/// they can be converted directly from Adapter to data model.\nimpl<'db> Repo<'db, AnalyticsEvent> for AnalyticsEvent {\n    fn save(self, db: &'db IsarDb, collection_name: &str) -> Result<(), Error> {\n        let mut object_builder = db.get_object_builder(collection_name)?;\n        let event_adapter: AnalyticsEventAdapter = self.into();\n        event_adapter.write_with_object_builder(&mut object_builder);\n        db.put(collection_name, object_builder.finish().as_bytes())\n    }\n\n    // TODO: return an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517\n    fn get_all(db: &'db IsarDb, collection_name: &str) -> Result<Vec<Self>, Error> {\n        let isar_properties = db.get_collection_properties(collection_name)?;\n        db.get_all_isar_objects(collection_name)?\n            .into_iter()\n            .map(|(_, isar_object)| AnalyticsEventAdapter::read(&isar_object, isar_properties))\n            .map(|adapter| AnalyticsEventRelationalAdapter::new(adapter?, &db))\n            .map(|relational_adapter| AnalyticsEvent::try_from(relational_adapter?))\n            .collect()\n    }\n\n    fn get(oid: &str, db: &'db IsarDb, collection_name: &str) -> Result<Self, Error> {\n        let isar_properties = db.get_collection_properties(collection_name)?;\n        let object_id = db.get_object_id_from_str(collection_name, oid)?;\n        let mut transaction = db.get_read_transaction()?;\n        let isar_object =\n            db.get_isar_object_by_id(&object_id, collection_name, &mut transaction)?;\n        if let Some(isar_object) = isar_object {\n            let adapter = AnalyticsEventAdapter::read(&isar_object, isar_properties)?;\n            let relational_adapter = AnalyticsEventRelationalAdapter::new(adapter, &db)?;\n            AnalyticsEvent::try_from(relational_adapter)\n        } else {\n            Err(anyhow!(\"unable to get {:?} object\", object_id))\n        }\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-analytics/src/database/common.rs",
    "content": "//! This file contains traits and structs that are common to other components involved with the database.\n//! It could be split up in smaller files, especially if more traits receive a default implementation.\n//! See: https://xainag.atlassian.net/browse/XN-1692\n\nuse anyhow::{anyhow, Error, Result};\nuse isar_core::{\n    index::IndexType,\n    object::{\n        data_type::DataType,\n        isar_object::{IsarObject, Property},\n        object_builder::ObjectBuilder,\n    },\n    schema::collection_schema::{\n        CollectionSchema,\n        IndexPropertySchema,\n        IndexSchema,\n        PropertySchema,\n    },\n};\nuse std::{convert::TryFrom, vec::IntoIter};\n\nuse crate::database::isar::IsarDb;\n\n/// `IsarAdapter` trait needs to be implemented for each data model adapters.\n/// This is needed to be able to tell Isar how to write/read objects to/from a collection.\n///\n/// The implementations of these methods could actually be automated by a macro, since they are always the same.\n/// See: https://xainag.atlassian.net/browse/XN-1689\npub trait IsarAdapter<'object>: Sized {\n    fn get_oid(&self) -> String;\n\n    fn into_field_properties() -> IntoIter<FieldProperty>;\n\n    fn write_with_object_builder(&self, object_builder: &mut ObjectBuilder);\n\n    fn read(\n        isar_object: &'object IsarObject,\n        isar_properties: &'object [(String, Property)],\n    ) -> Result<Self, Error>;\n\n    fn find_property_by_name(\n        name: &str,\n        isar_properties: &[(String, Property)],\n    ) -> Result<Property, Error> {\n        isar_properties\n            .iter()\n            .find(|(isar_property_name, _)| isar_property_name == name)\n            .map(|(_, property)| *property)\n            .ok_or_else(|| anyhow!(\"failed to retrieve property {:?}\", name))\n    }\n}\n\n/// This trait is implemented directly for each data model to have a high level API for `AnalyticsController` to\n/// save/get objects from the db.\n///\n/// Consider using default implementations here, to reduce boiler plate code in repo.rs files.\n/// See: https://xainag.atlassian.net/browse/XN-1688\npub trait Repo<'db, M>\nwhere\n    M: Sized,\n{\n    fn save(self, db: &'db IsarDb, collection_name: &str) -> Result<(), Error>;\n\n    fn get_all(db: &'db IsarDb, collection_name: &str) -> Result<Vec<M>, Error>;\n\n    fn get(object_id: &str, db: &'db IsarDb, collection_name: &str) -> Result<M, Error>;\n}\n\n/// `FieldProperty` is a simple struct that holds data used to register properties and indexes for Isar schemas.\npub struct FieldProperty {\n    pub name: String,\n    pub data_type: DataType,\n    pub is_oid: bool,\n    pub index_type: IndexType,\n    pub is_case_sensitive: bool,\n    pub is_unique: bool,\n}\n\nimpl FieldProperty {\n    pub fn new<N: Into<String>>(name: N, data_type: DataType, is_oid: bool) -> Self {\n        Self {\n            name: name.into(),\n            data_type,\n            is_oid,\n            index_type: IndexType::Value,\n            is_case_sensitive: data_type == DataType::String,\n            is_unique: true,\n        }\n    }\n}\n\n/// `SchemaGenerator` is needed to register the `PropertySchema` and `IndexSchema` for each `FieldProperty`.\n/// `PropertySchema` and `IndexSchema` are imported from Isar, while `FieldProperty` is an internal struct to\n/// make it convenient to iterate through each property (see the fold below).\n///\n/// When `Ok` it returns a `CollectionSchema` that is needed by Isar to manage a collection.\npub trait SchemaGenerator<'object, A>\nwhere\n    A: IsarAdapter<'object>,\n{\n    fn get_schema(name: &str) -> Result<CollectionSchema, Error> {\n        let (properties, indexes) = A::into_field_properties().fold(\n            (Vec::new(), Vec::new()),\n            |(mut properties, mut indexes), prop| {\n                let property_schema = PropertySchema::new(&prop.name, prop.data_type, prop.is_oid);\n                let is_index_case_sensitive =\n                    Some(true).filter(|_| prop.data_type == DataType::String);\n                let index_property_schema = vec![IndexPropertySchema::new(\n                    &prop.name,\n                    prop.index_type,\n                    is_index_case_sensitive,\n                )];\n                let index_schema = IndexSchema::new(index_property_schema, prop.is_unique);\n                properties.push(property_schema);\n                indexes.push(index_schema);\n                (properties, indexes)\n            },\n        );\n        Ok(CollectionSchema::new(name, properties, indexes))\n    }\n}\n\n/// `RelationalField` is the struct that allows to save data model instances inside other data models.\n///\n/// ## Arguments\n/// * `value` - is a `String` representing an id with which the data model can be identified\n/// * `collection_name` - is the name of the collection where the object is saved\n#[derive(Debug, PartialEq, Eq, Clone)]\npub struct RelationalField {\n    pub value: String,\n    pub collection_name: String,\n}\n\n// NOTE: when split_once gets to stable, it would be a much better solution for this\n// https://doc.rust-lang.org/std/string/struct.String.html#method.split_once\nimpl TryFrom<&str> for RelationalField {\n    type Error = anyhow::Error;\n\n    fn try_from(data: &str) -> Result<Self, Error> {\n        let data_split: Vec<&str> = data.split('=').collect();\n        if data_split.len() != 2 {\n            return Err(anyhow!(\n                \"data {:?} is not a str made of two elements separated by '='\",\n                data\n            ));\n        }\n\n        Ok(Self {\n            value: data_split[0].to_string(),\n            collection_name: data_split[1].to_string(),\n        })\n    }\n}\n\nimpl From<RelationalField> for String {\n    fn from(rf: RelationalField) -> String {\n        [rf.value, rf.collection_name].join(\"=\")\n    }\n}\n\n/// Stores the name of each collection. Whenever you need to make an operation on an `IsarCollection`,\n/// these `str`s are needed.\npub struct CollectionNames;\n\nimpl CollectionNames {\n    pub const ANALYTICS_EVENTS: &'static str = \"analytics_events\";\n    pub const CONTROLLER_DATA: &'static str = \"controller_data\";\n    pub const SCREEN_ROUTES: &'static str = \"screen_routes\";\n}\n"
  },
  {
    "path": "rust/xaynet-analytics/src/database/controller_data/adapter.rs",
    "content": "//! This file contains struct and impl for `ControllerDataAdapter` the implementation of `IsarAdapter`\n//! for `ControllerDataAdapter`.\n\nuse anyhow::{anyhow, Error, Result};\nuse isar_core::object::{\n    data_type::DataType,\n    isar_object::{IsarObject, Property},\n    object_builder::ObjectBuilder,\n};\nuse std::vec::IntoIter;\n\nuse crate::database::common::{FieldProperty, IsarAdapter, SchemaGenerator};\n\n/// Allows to convert an IsarObject from the db to a `ControllerData`.\n#[derive(Debug, PartialEq, Eq, Clone)]\npub struct ControllerDataAdapter {\n    pub time_data_sent: String,\n}\n\nimpl ControllerDataAdapter {\n    pub fn new<T: Into<String>>(time_data_sent: T) -> Self {\n        Self {\n            time_data_sent: time_data_sent.into(),\n        }\n    }\n}\n\nimpl<'ctrl> IsarAdapter<'ctrl> for ControllerDataAdapter {\n    fn get_oid(&self) -> String {\n        self.time_data_sent.clone()\n    }\n\n    fn into_field_properties() -> IntoIter<FieldProperty> {\n        vec![\n            FieldProperty::new(\"oid\", DataType::String, true),\n            FieldProperty::new(\"time_data_sent\", DataType::String, false),\n        ]\n        .into_iter()\n    }\n\n    fn write_with_object_builder(&self, object_builder: &mut ObjectBuilder) {\n        object_builder.write_string(Some(&self.get_oid()));\n        object_builder.write_string(Some(&self.time_data_sent));\n    }\n\n    fn read(\n        isar_object: &'ctrl IsarObject,\n        isar_properties: &'ctrl [(String, Property)],\n    ) -> Result<ControllerDataAdapter, Error> {\n        let time_data_sent_property =\n            Self::find_property_by_name(\"time_data_sent\", isar_properties)?;\n\n        let time_data_sent_data = isar_object\n            .read_string(time_data_sent_property)\n            .ok_or_else(|| anyhow!(\"unable to read time_data_sent\"))?;\n\n        Ok(ControllerDataAdapter::new(time_data_sent_data))\n    }\n}\n\nimpl<'ctrl> SchemaGenerator<'ctrl, ControllerDataAdapter> for ControllerDataAdapter {}\n"
  },
  {
    "path": "rust/xaynet-analytics/src/database/controller_data/data_model.rs",
    "content": "//! In this file `ControllerData` is declared, together with some conversion methods to/from adapters.\n\nuse anyhow::Result;\nuse chrono::{DateTime, Utc};\nuse std::convert::TryFrom;\n\nuse crate::database::controller_data::adapter::ControllerDataAdapter;\n\n/// Holds some metadata useful for the `AnalyticsController`. For now it only contains `time_data_sent`,\n/// which is the time when analytics data was last sent to the coordinator for aggregation.\n#[derive(Debug, PartialEq, Eq, Clone)]\npub struct ControllerData {\n    pub time_data_sent: DateTime<Utc>,\n}\n\nimpl ControllerData {\n    pub fn new(time_data_sent: DateTime<Utc>) -> Self {\n        Self { time_data_sent }\n    }\n}\n\nimpl TryFrom<ControllerDataAdapter> for ControllerData {\n    type Error = anyhow::Error;\n\n    fn try_from(adapter: ControllerDataAdapter) -> Result<Self, Self::Error> {\n        Ok(ControllerData::new(\n            DateTime::parse_from_rfc3339(&adapter.time_data_sent)?.with_timezone(&Utc),\n        ))\n    }\n}\n\nimpl From<ControllerData> for ControllerDataAdapter {\n    fn from(cd: ControllerData) -> Self {\n        ControllerDataAdapter::new(cd.time_data_sent.to_rfc3339())\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn test_controller_data_try_from_adapter() {\n        let timestamp_str = \"2021-01-01T01:01:00+00:00\";\n        let timestamp_parsed = DateTime::parse_from_rfc3339(timestamp_str)\n            .unwrap()\n            .with_timezone(&Utc);\n        let controller_data = ControllerData::new(timestamp_parsed);\n        let adapter = ControllerDataAdapter::new(timestamp_str);\n        assert_eq!(ControllerData::try_from(adapter).unwrap(), controller_data);\n    }\n\n    #[test]\n    fn test_adapter_into_controller_data() {\n        let timestamp_str = \"2021-01-01T01:01:00+00:00\";\n        let timestamp_parsed = DateTime::parse_from_rfc3339(timestamp_str)\n            .unwrap()\n            .with_timezone(&Utc);\n        let controller_data = ControllerData::new(timestamp_parsed);\n        let actual_adapter: ControllerDataAdapter = controller_data.into();\n        let expected_adapter = ControllerDataAdapter::new(timestamp_str);\n        assert_eq!(actual_adapter, expected_adapter);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-analytics/src/database/controller_data/mod.rs",
    "content": "pub mod adapter;\npub mod data_model;\npub mod repo;\n"
  },
  {
    "path": "rust/xaynet-analytics/src/database/controller_data/repo.rs",
    "content": "//! Implementations of the methods needed to save and get `ControllerData` to/from Isar.\n\nuse anyhow::{anyhow, Error, Result};\nuse std::convert::{Into, TryFrom};\n\nuse crate::database::{\n    common::{IsarAdapter, Repo},\n    controller_data::{adapter::ControllerDataAdapter, data_model::ControllerData},\n    isar::IsarDb,\n};\n\nimpl<'db> Repo<'db, ControllerData> for ControllerData {\n    fn save(self, db: &'db IsarDb, collection_name: &str) -> Result<(), Error> {\n        let mut object_builder = db.get_object_builder(collection_name)?;\n        let data_adapter: ControllerDataAdapter = self.into();\n        data_adapter.write_with_object_builder(&mut object_builder);\n        db.put(collection_name, object_builder.finish().as_bytes())\n    }\n\n    // TODO: return an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517\n    fn get_all(db: &'db IsarDb, collection_name: &str) -> Result<Vec<Self>, Error> {\n        let isar_properties = db.get_collection_properties(collection_name)?;\n        db.get_all_isar_objects(collection_name)?\n            .into_iter()\n            .map(|(_, isar_object)| ControllerDataAdapter::read(&isar_object, isar_properties))\n            .map(|data_adapter| ControllerData::try_from(data_adapter?))\n            .collect()\n    }\n\n    fn get(oid: &str, db: &'db IsarDb, collection_name: &str) -> Result<Self, Error> {\n        let isar_properties = db.get_collection_properties(collection_name)?;\n        let object_id = db.get_object_id_from_str(collection_name, oid)?;\n        let mut transaction = db.get_read_transaction()?;\n        let isar_object =\n            db.get_isar_object_by_id(&object_id, collection_name, &mut transaction)?;\n        if let Some(isar_object) = isar_object {\n            let data_adapter = ControllerDataAdapter::read(&isar_object, isar_properties)?;\n            ControllerData::try_from(data_adapter)\n        } else {\n            Err(anyhow!(\"unable to get {:?} object\", object_id))\n        }\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-analytics/src/database/isar.rs",
    "content": "//! `IsarDb` is an internal abstraction on top of Isar that wraps `IsarInstance`, the main singleton from Isar.\n\nuse anyhow::{anyhow, Error, Result};\nuse isar_core::{\n    collection::IsarCollection,\n    instance::IsarInstance,\n    object::{\n        isar_object::{IsarObject, Property},\n        object_builder::ObjectBuilder,\n        object_id::ObjectId,\n    },\n    schema::{collection_schema::CollectionSchema, Schema},\n    txn::IsarTxn,\n};\nuse std::sync::Arc;\n\n/// `IsarDb` is the internal singleton wrapping the `IsarInstance`, which is the singleton coming from Isar.\n/// `IsarDb` exposes public methods for the `AnalyticsController` to save/get models via the `Repo` impls and the adapters.\npub struct IsarDb {\n    instance: Arc<IsarInstance>,\n}\n\nimpl IsarDb {\n    const MAX_SIZE: usize = 10000000;\n\n    /// `IsarInstance` is the singleton from Isar that coordinates the whole database.\n    ///\n    /// `Vec<CollectionSchema>` is required by Isar to register each data model `IsarCollection`.\n    /// A `IsarCollection` organises data for a single data model (eg: `AnalyticsEvents`).\n    pub fn new(path: &str, collection_schemas: Vec<CollectionSchema>) -> Result<IsarDb, Error> {\n        IsarInstance::open(\n            path,\n            IsarDb::MAX_SIZE,\n            IsarDb::get_schema(collection_schemas)?,\n        )\n        .map_err(|error| anyhow!(\"failed to create IsarInstance: {:?}\", error))\n        .map(|instance| IsarDb { instance })\n    }\n\n    pub fn get_all_isar_objects(\n        &self,\n        collection_name: &str,\n    ) -> Result<Vec<(ObjectId, IsarObject)>, Error> {\n        self.get_collection(collection_name)?\n            .new_query_builder()\n            .build()\n            .find_all_vec(&mut self.begin_txn(false)?)\n            .map_err(|error| {\n                anyhow!(\n                    \"failed to find all objects from collection {}: {:?}\",\n                    collection_name,\n                    error,\n                )\n            })\n    }\n\n    /// Transactions are needed to write and read from Isar.\n    /// This method is public because it's called inside `Repo::read()`, before passing it to `get_isar_object_by_id()`,\n    /// so that the transaction is in scope when called, and the lifetimes are valid.\n    pub fn get_read_transaction(&self) -> Result<IsarTxn, Error> {\n        self.begin_txn(false)\n    }\n\n    pub fn get_isar_object_by_id<'txn>(\n        &self,\n        object_id: &ObjectId,\n        collection_name: &str,\n        transaction: &'txn mut IsarTxn,\n    ) -> Result<Option<IsarObject<'txn>>, Error> {\n        self.get_collection(collection_name)?\n            .get(transaction, object_id)\n            .map_err(|error| anyhow!(\"unable to get {:?} object ({:?})\", object_id, error))\n    }\n\n    pub fn put(&self, collection_name: &str, object: &[u8]) -> Result<(), Error> {\n        let mut transaction = self.begin_txn(true)?;\n        self.get_collection(collection_name)?\n            .put(&mut transaction, IsarObject::new(object))\n            .and_then(|_| transaction.commit())\n            .map_err(|error| {\n                anyhow!(\n                    \"failed to add object {:?} to collection: {} | {:?}\",\n                    object,\n                    collection_name,\n                    error,\n                )\n            })\n    }\n\n    pub fn get_object_builder(&self, collection_name: &str) -> Result<ObjectBuilder, Error> {\n        Ok(self\n            .get_collection(collection_name)?\n            .new_object_builder(None))\n    }\n\n    /// When `Ok`, this method returns a valid `ObjectId` that can be used to retrieve a single object from a collection.\n    pub fn get_object_id_from_str(\n        &self,\n        collection_name: &str,\n        oid: &str,\n    ) -> Result<ObjectId, Error> {\n        self.get_collection(collection_name)?\n            .new_string_oid(oid)\n            .map_err(|error| anyhow!(\"could not get the object id from {:?}: {:?}\", oid, error))\n    }\n\n    /// Returns the properties from a collection that were registered via the `CollectionSchema`, and are needed to\n    /// read/write objects to/from the collection.\n    pub fn get_collection_properties(\n        &self,\n        collection_name: &str,\n    ) -> Result<&[(String, Property)], Error> {\n        Ok(self.get_collection(collection_name)?.get_properties())\n    }\n\n    pub fn dispose(self) -> Result<(), Error> {\n        match self.instance.close() {\n            Some(_) => Err(anyhow!(\"could not close the IsarInstance\")),\n            None => Ok(()),\n        }\n    }\n\n    /// The `Schema` is needed to open the `IsarInstance` and is automatically produced by Isar\n    /// based on the `Vec<CollectionSchema>` provided when calling `IsarDb::new()`.\n    fn get_schema(collection_schemas: Vec<CollectionSchema>) -> Result<Schema, Error> {\n        Schema::new(collection_schemas).map_err(|error| {\n            anyhow!(\n                \"failed to add collection schemas to instance schema: {:?}\",\n                error\n            )\n        })\n    }\n\n    fn get_collection(&self, collection_name: &str) -> Result<&IsarCollection, Error> {\n        self.instance\n            .get_collection_by_name(collection_name)\n            .ok_or_else(|| anyhow!(\"wrong collection name: {}\", collection_name))\n    }\n\n    /// Transactions are needed to read/write objects from Isar. Write transactions should stay private.\n    fn begin_txn(&self, is_write: bool) -> Result<IsarTxn, Error> {\n        self.instance\n            .begin_txn(is_write)\n            .map_err(|error| anyhow!(\"failed to begin transaction: {:?}\", error))\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-analytics/src/database/mod.rs",
    "content": "pub mod analytics_event;\npub mod common;\npub mod controller_data;\npub mod isar;\npub mod screen_route;\n"
  },
  {
    "path": "rust/xaynet-analytics/src/database/screen_route/adapter.rs",
    "content": "//! This file contains struct and impl for `ScreenRouteAdapter` the implementation of `IsarAdapter`\n//! for `ScreenRouteAdapter`.\n\nuse anyhow::{anyhow, Error, Result};\nuse isar_core::object::{\n    data_type::DataType,\n    isar_object::{IsarObject, Property},\n    object_builder::ObjectBuilder,\n};\nuse std::vec::IntoIter;\n\nuse crate::database::common::{FieldProperty, IsarAdapter, SchemaGenerator};\n\n/// Allows to convert an `IsarObject` from the db to a `ScreenRoute`.\n#[derive(Debug, PartialEq, Eq, Clone)]\npub struct ScreenRouteAdapter {\n    pub name: String,\n    pub created_at: String,\n}\n\nimpl ScreenRouteAdapter {\n    pub fn new<S: Into<String>>(name: S, created_at: S) -> Self {\n        Self {\n            name: name.into(),\n            created_at: created_at.into(),\n        }\n    }\n}\n\nimpl<'screen> IsarAdapter<'screen> for ScreenRouteAdapter {\n    fn get_oid(&self) -> String {\n        self.name.clone()\n    }\n\n    fn into_field_properties() -> IntoIter<FieldProperty> {\n        vec![\n            FieldProperty::new(\"oid\", DataType::String, true),\n            FieldProperty::new(\"name\", DataType::String, false),\n            FieldProperty::new(\"created_at\", DataType::String, false),\n        ]\n        .into_iter()\n    }\n\n    fn write_with_object_builder(&self, object_builder: &mut ObjectBuilder) {\n        object_builder.write_string(Some(&self.get_oid()));\n        object_builder.write_string(Some(&self.name));\n        object_builder.write_string(Some(&self.created_at));\n    }\n\n    fn read(\n        isar_object: &'screen IsarObject,\n        isar_properties: &'screen [(String, Property)],\n    ) -> Result<ScreenRouteAdapter, Error> {\n        let name_property = Self::find_property_by_name(\"name\", isar_properties)?;\n        let created_at_property = Self::find_property_by_name(\"created_at\", isar_properties)?;\n\n        let name_data = isar_object\n            .read_string(name_property)\n            .ok_or_else(|| anyhow!(\"unable to read name\"))?;\n        let created_at_data = isar_object\n            .read_string(created_at_property)\n            .ok_or_else(|| anyhow!(\"unable to read created_at\"))?;\n\n        Ok(ScreenRouteAdapter::new(\n            name_data.to_string(),\n            created_at_data.to_string(),\n        ))\n    }\n}\n\nimpl<'screen> SchemaGenerator<'screen, ScreenRouteAdapter> for ScreenRouteAdapter {}\n"
  },
  {
    "path": "rust/xaynet-analytics/src/database/screen_route/data_model.rs",
    "content": "//! In this file `ScreenRoute` is declared, together with some conversion methods to/from adapters.\n\nuse anyhow::Result;\nuse chrono::{DateTime, Utc};\nuse std::convert::{Into, TryFrom};\n\nuse crate::database::{\n    common::{CollectionNames, RelationalField},\n    screen_route::adapter::ScreenRouteAdapter,\n};\n\n/// A `ScreenRoute` is the internal representation of a screen in the app.\n#[derive(Debug, PartialEq, Eq, Clone)]\npub struct ScreenRoute {\n    pub name: String,\n    pub created_at: DateTime<Utc>,\n}\n\nimpl ScreenRoute {\n    pub fn new<N: Into<String>>(name: N, created_at: DateTime<Utc>) -> Self {\n        Self {\n            name: name.into(),\n            created_at,\n        }\n    }\n}\n\nimpl TryFrom<ScreenRouteAdapter> for ScreenRoute {\n    type Error = anyhow::Error;\n\n    fn try_from(adapter: ScreenRouteAdapter) -> Result<Self, Self::Error> {\n        Ok(ScreenRoute::new(\n            adapter.name,\n            DateTime::parse_from_rfc3339(&adapter.created_at)?.with_timezone(&Utc),\n        ))\n    }\n}\n\nimpl From<ScreenRoute> for ScreenRouteAdapter {\n    fn from(sr: ScreenRoute) -> Self {\n        ScreenRouteAdapter::new(sr.name, sr.created_at.to_rfc3339())\n    }\n}\n\nimpl From<ScreenRoute> for RelationalField {\n    fn from(screen_route: ScreenRoute) -> Self {\n        Self {\n            value: screen_route.name,\n            collection_name: CollectionNames::SCREEN_ROUTES.to_string(),\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn test_screen_route_try_from_adapter() {\n        let timestamp_str = \"2021-01-01T01:01:00+00:00\";\n        let timestamp_parsed = DateTime::parse_from_rfc3339(timestamp_str)\n            .unwrap()\n            .with_timezone(&Utc);\n        let screen_route = ScreenRoute::new(\"route\", timestamp_parsed);\n        let adapter = ScreenRouteAdapter::new(\"route\", timestamp_str);\n        assert_eq!(ScreenRoute::try_from(adapter).unwrap(), screen_route);\n    }\n\n    #[test]\n    fn test_adapter_into_screen_route() {\n        let timestamp_str = \"2021-01-01T01:01:00+00:00\";\n        let timestamp_parsed = DateTime::parse_from_rfc3339(timestamp_str)\n            .unwrap()\n            .with_timezone(&Utc);\n        let screen_route = ScreenRoute::new(\"route\", timestamp_parsed);\n        let actual_adapter: ScreenRouteAdapter = screen_route.into();\n        let expected_adapter = ScreenRouteAdapter::new(\"route\", timestamp_str);\n        assert_eq!(actual_adapter, expected_adapter);\n    }\n\n    #[test]\n    fn test_screen_route_from_relational_field() {\n        let timestamp_str = \"2021-01-01T01:01:00+00:00\";\n        let timestamp_parsed = DateTime::parse_from_rfc3339(timestamp_str)\n            .unwrap()\n            .with_timezone(&Utc);\n        let screen_route = ScreenRoute::new(\"route\", timestamp_parsed);\n        let relational_field = RelationalField {\n            value: \"route\".to_string(),\n            collection_name: CollectionNames::SCREEN_ROUTES.to_string(),\n        };\n        assert_eq!(RelationalField::from(screen_route), relational_field);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-analytics/src/database/screen_route/mod.rs",
    "content": "pub mod adapter;\npub mod data_model;\npub mod repo;\n"
  },
  {
    "path": "rust/xaynet-analytics/src/database/screen_route/repo.rs",
    "content": "//! Implementations of the methods needed to save and get ScreenRoute to/from Isar.\n\nuse anyhow::{anyhow, Error, Result};\nuse std::convert::{Into, TryFrom};\n\nuse crate::database::{\n    common::{IsarAdapter, Repo},\n    isar::IsarDb,\n    screen_route::{adapter::ScreenRouteAdapter, data_model::ScreenRoute},\n};\n\nimpl<'db> Repo<'db, ScreenRoute> for ScreenRoute {\n    fn save(self, db: &'db IsarDb, collection_name: &str) -> Result<(), Error> {\n        let mut object_builder = db.get_object_builder(collection_name)?;\n        let route_adapter: ScreenRouteAdapter = self.into();\n        route_adapter.write_with_object_builder(&mut object_builder);\n        db.put(collection_name, object_builder.finish().as_bytes())\n    }\n\n    // TODO: return an iterator instead of Vec: https://xainag.atlassian.net/browse/XN-1517\n    fn get_all(db: &'db IsarDb, collection_name: &str) -> Result<Vec<Self>, Error> {\n        let isar_properties = db.get_collection_properties(collection_name)?;\n        db.get_all_isar_objects(collection_name)?\n            .into_iter()\n            .map(|(_, isar_object)| ScreenRouteAdapter::read(&isar_object, isar_properties))\n            .map(|screen_route_adapter| ScreenRoute::try_from(screen_route_adapter?))\n            .collect()\n    }\n\n    fn get(oid: &str, db: &'db IsarDb, collection_name: &str) -> Result<Self, Error> {\n        let isar_properties = db.get_collection_properties(collection_name)?;\n        let object_id = db.get_object_id_from_str(collection_name, oid)?;\n        let mut transaction = db.get_read_transaction()?;\n        let isar_object =\n            db.get_isar_object_by_id(&object_id, collection_name, &mut transaction)?;\n        if let Some(isar_object) = isar_object {\n            let screen_route_adapter = ScreenRouteAdapter::read(&isar_object, isar_properties)?;\n            ScreenRoute::try_from(screen_route_adapter)\n        } else {\n            Err(anyhow!(\"unable to get {:?} object\", object_id))\n        }\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-analytics/src/lib.rs",
    "content": "#![cfg_attr(doc, forbid(broken_intra_doc_links, private_intra_doc_links))]\n#![doc(\n    html_logo_url = \"https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/xaynet_banner.png\",\n    html_favicon_url = \"https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/favicon.png\",\n    issue_tracker_base_url = \"https://github.com/xaynetwork/xaynet/issues\"\n)]\n//! This crate containes the Rust component of Federated Analytics,\n//! a framework that allows mobile applications to collect and aggregate\n//! analytics data via the _Privacy-Enhancing Technology_ (PET) protocol.\n\n#[cfg(not(tarpaulin))]\npub mod controller;\n#[cfg(not(tarpaulin))]\npub mod data_combination;\n#[cfg(not(tarpaulin))]\npub mod database;\n#[cfg(not(tarpaulin))]\npub mod sender;\n"
  },
  {
    "path": "rust/xaynet-analytics/src/sender.rs",
    "content": "//! In this file `Sender` is just stubbed and will need to be implemented.\n\nuse anyhow::{Error, Result};\n\nuse crate::data_combination::data_points::data_point::DataPoint;\n\n/// `Sender` receives a `Vec<DataPoint>` from the `DataCombiner`.\n///\n/// It will need to call the exposed `calculate()` method on each `DataPoint` variant and compose the messages\n/// that will then need to reach the XayNet coordinator.\n///\n/// These messages should contain not only the actual data that is the output of calling `calculate()` on the variant,\n/// but also some extra data so that the coordinator knows how to aggregate each `DataPoint` variant.\n/// This is in line with the research done on the “global spec” idea.\npub struct Sender;\n\nimpl Sender {\n    pub fn send(&self, _data_points: Vec<DataPoint>) -> Result<(), Error> {\n        // TODO: https://xainag.atlassian.net/browse/XN-1647\n        todo!()\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/Cargo.toml",
    "content": "[package]\nname = \"xaynet-core\"\nversion = \"0.2.0\"\nauthors = [\"Xayn Engineering <engineering@xaynet.dev>\"]\nedition = \"2018\"\ndescription = \"The Xayn Network project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant.\"\nreadme = \"../../README.md\"\nhomepage = \"https://xaynet.dev/\"\nrepository = \"https://github.com/xaynetwork/xaynet/\"\nlicense-file = \"../../LICENSE\"\nkeywords = [\"federated-learning\", \"fl\", \"ai\", \"machine-learning\"]\ncategories = [\"science\", \"cryptography\"]\n\n[package.metadata.docs.rs]\nall-features = true\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n\n[dependencies]\nanyhow = \"1.0.62\"\nbitflags = \"1.3.2\"\nderive_more = { version = \"0.99.17\", default-features = false, features = [\n    \"as_ref\",\n    \"as_mut\",\n    \"display\",\n    \"from\",\n    \"index\",\n    \"index_mut\",\n    \"into\",\n] }\nnum = { version = \"0.4.0\", features = [\"serde\"] }\nrand = \"0.8.5\"\nrand_chacha = \"0.3.1\"\nserde = { version = \"1.0.144\", features = [\"derive\"] }\nsodiumoxide = \"0.2.7\"\nthiserror = \"1.0.32\"\n\n[features]\ntestutils = []\n\n[dev-dependencies]\npaste = \"1.0.8\"\n"
  },
  {
    "path": "rust/xaynet-core/src/common.rs",
    "content": "use serde::{Deserialize, Serialize};\nuse sodiumoxide::{self, crypto::box_};\n\nuse crate::{crypto::ByteObject, mask::MaskConfigPair, CoordinatorPublicKey};\n\n/// The round parameters.\n#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]\npub struct RoundParameters {\n    /// The public key of the coordinator used for encryption.\n    pub pk: CoordinatorPublicKey,\n    /// Fraction of participants to be selected for the sum task.\n    pub sum: f64,\n    /// Fraction of participants to be selected for the update task.\n    pub update: f64,\n    /// The random round seed.\n    pub seed: RoundSeed,\n    /// The masking configuration\n    pub mask_config: MaskConfigPair,\n    /// The length of the model.\n    pub model_length: usize,\n}\n\n#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]\n/// A seed for a round.\npub struct RoundSeed(box_::Seed);\n\nimpl ByteObject for RoundSeed {\n    const LENGTH: usize = box_::SEEDBYTES;\n\n    /// Creates a round seed from a slice of bytes.\n    ///\n    /// # Errors\n    /// Fails if the length of the input is invalid.\n    fn from_slice(bytes: &[u8]) -> Option<Self> {\n        box_::Seed::from_slice(bytes).map(Self)\n    }\n\n    /// Creates a round seed initialized to zero.\n    fn zeroed() -> Self {\n        Self(box_::Seed([0_u8; Self::LENGTH]))\n    }\n\n    /// Gets the round seed as a slice.\n    fn as_slice(&self) -> &[u8] {\n        self.0.as_ref()\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/crypto/encrypt.rs",
    "content": "//! Wrappers around some of the [sodiumoxide] encryption primitives.\n//!\n//! See the [crypto module] documentation since this is a private module anyways.\n//!\n//! [sodiumoxide]: https://docs.rs/sodiumoxide/\n//! [crypto module]: crate::crypto\n\nuse derive_more::{AsMut, AsRef, From};\nuse serde::{Deserialize, Serialize};\nuse sodiumoxide::crypto::{box_, sealedbox};\n\nuse super::ByteObject;\n\n/// Number of additional bytes in a ciphertext compared to the corresponding plaintext.\npub const SEALBYTES: usize = sealedbox::SEALBYTES;\n\n#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]\n/// A `C25519` key pair for asymmetric authenticated encryption.\npub struct EncryptKeyPair {\n    /// The `C25519` public key.\n    pub public: PublicEncryptKey,\n    /// The `C25519` secret key.\n    pub secret: SecretEncryptKey,\n}\n\nimpl EncryptKeyPair {\n    /// Generates a new random `C25519` key pair for encryption.\n    pub fn generate() -> Self {\n        let (pk, sk) = box_::gen_keypair();\n        Self {\n            public: PublicEncryptKey(pk),\n            secret: SecretEncryptKey(sk),\n        }\n    }\n\n    /// Deterministically derives a new `C25519` key pair for encryption from a seed.\n    pub fn derive_from_seed(seed: &EncryptKeySeed) -> Self {\n        let (pk, sk) = seed.derive_encrypt_key_pair();\n        Self {\n            public: pk,\n            secret: sk,\n        }\n    }\n}\n\n#[derive(\n    AsRef,\n    AsMut,\n    From,\n    Serialize,\n    Deserialize,\n    Hash,\n    Eq,\n    Ord,\n    PartialEq,\n    Copy,\n    Clone,\n    PartialOrd,\n    Debug,\n)]\n/// A `C25519` public key for asymmetric authenticated encryption.\npub struct PublicEncryptKey(box_::PublicKey);\n\nimpl ByteObject for PublicEncryptKey {\n    const LENGTH: usize = box_::PUBLICKEYBYTES;\n\n    fn zeroed() -> Self {\n        Self(box_::PublicKey([0_u8; box_::PUBLICKEYBYTES]))\n    }\n\n    fn as_slice(&self) -> &[u8] {\n        self.0.as_ref()\n    }\n\n    fn from_slice(bytes: &[u8]) -> Option<Self> {\n        box_::PublicKey::from_slice(bytes).map(Self)\n    }\n}\n\nimpl PublicEncryptKey {\n    /// Encrypts a message `m` with this public key.\n    ///\n    /// The resulting ciphertext length is [`SEALBYTES`]` + m.len()`.\n    ///\n    /// The function creates a new ephemeral key pair for the message and attaches the ephemeral\n    /// public key to the ciphertext. The ephemeral secret key is zeroed out and is not accessible\n    /// after this function returns.\n    pub fn encrypt(&self, m: &[u8]) -> Vec<u8> {\n        sealedbox::seal(m, self.as_ref())\n    }\n}\n\n#[derive(thiserror::Error, Debug)]\n#[error(\"decryption of a message failed\")]\n/// An error related to the decryption of a message.\npub struct DecryptionError;\n\n#[derive(AsRef, AsMut, From, Serialize, Deserialize, Eq, PartialEq, Clone, Debug)]\n/// A `C25519` secret key for asymmetric authenticated encryption.\n///\n/// When this goes out of scope, its contents will be zeroed out.\npub struct SecretEncryptKey(box_::SecretKey);\n\nimpl SecretEncryptKey {\n    /// Decrypts the ciphertext `c` using this secret key and the associated public key, and returns\n    /// the decrypted message.\n    ///\n    /// # Errors\n    /// Returns `Err(DecryptionError)` if decryption fails.\n    pub fn decrypt(&self, c: &[u8], pk: &PublicEncryptKey) -> Result<Vec<u8>, DecryptionError> {\n        sealedbox::open(c, pk.as_ref(), self.as_ref()).map_err(|_| DecryptionError)\n    }\n\n    /// Computes the corresponding public key for this secret key.\n    pub fn public_key(&self) -> PublicEncryptKey {\n        PublicEncryptKey(self.0.public_key())\n    }\n}\n\nimpl ByteObject for SecretEncryptKey {\n    const LENGTH: usize = box_::SECRETKEYBYTES;\n\n    fn zeroed() -> Self {\n        Self(box_::SecretKey([0_u8; box_::SECRETKEYBYTES]))\n    }\n\n    fn as_slice(&self) -> &[u8] {\n        self.0.as_ref()\n    }\n\n    fn from_slice(bytes: &[u8]) -> Option<Self> {\n        box_::SecretKey::from_slice(bytes).map(Self)\n    }\n}\n\n#[derive(AsRef, AsMut, From, Serialize, Deserialize, Eq, PartialEq, Clone)]\n/// A seed that can be used for `C25519` encryption key pair generation.\n///\n/// When this goes out of scope, its contents will be zeroed out.\npub struct EncryptKeySeed(box_::Seed);\n\nimpl EncryptKeySeed {\n    /// Deterministically derives a new key pair from this seed.\n    pub fn derive_encrypt_key_pair(&self) -> (PublicEncryptKey, SecretEncryptKey) {\n        let (pk, sk) = box_::keypair_from_seed(self.as_ref());\n        (PublicEncryptKey(pk), SecretEncryptKey(sk))\n    }\n}\n\nimpl ByteObject for EncryptKeySeed {\n    const LENGTH: usize = box_::SEEDBYTES;\n\n    fn from_slice(bytes: &[u8]) -> Option<Self> {\n        box_::Seed::from_slice(bytes).map(Self)\n    }\n\n    fn zeroed() -> Self {\n        Self(box_::Seed([0; box_::SEEDBYTES]))\n    }\n\n    fn as_slice(&self) -> &[u8] {\n        self.0.as_ref()\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/crypto/hash.rs",
    "content": "//! Wrappers around some of the [sodiumoxide] hashing primitives.\n//!\n//! See the [crypto module] documentation since this is a private module anyways.\n//!\n//! [sodiumoxide]: https://docs.rs/sodiumoxide/\n//! [crypto module]: crate::crypto\n\nuse derive_more::{AsMut, AsRef, From};\nuse serde::{Deserialize, Serialize};\nuse sodiumoxide::crypto::hash::sha256;\n\nuse super::ByteObject;\n\n#[derive(\n    AsRef,\n    AsMut,\n    From,\n    Serialize,\n    Deserialize,\n    Hash,\n    Eq,\n    Ord,\n    PartialEq,\n    Copy,\n    Clone,\n    PartialOrd,\n    Debug,\n)]\n/// A digest of the `SHA256` hash function.\npub struct Sha256(sha256::Digest);\n\nimpl ByteObject for Sha256 {\n    const LENGTH: usize = sha256::DIGESTBYTES;\n\n    fn zeroed() -> Self {\n        Self(sha256::Digest([0_u8; sha256::DIGESTBYTES]))\n    }\n\n    fn as_slice(&self) -> &[u8] {\n        self.0.as_ref()\n    }\n\n    fn from_slice(bytes: &[u8]) -> Option<Self> {\n        sha256::Digest::from_slice(bytes).map(Self)\n    }\n}\n\nimpl Sha256 {\n    /// Computes the digest of the message `m`.\n    pub fn hash(m: &[u8]) -> Self {\n        Self(sha256::hash(m))\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/crypto/mod.rs",
    "content": "//! Wrappers around some of the [sodiumoxide] crypto primitives.\n//!\n//! The wrappers provide methods defined on structs instead of the sodiumoxide functions. This is\n//! done for the `C25519` encryption and `Ed25519` signature key pairs and their corresponding seeds\n//! as well as the `SHA256` hash function. Additionally, some methods for slicing and signature\n//! eligibility are available.\n//!\n//! # Examples\n//! ## Encryption of messages\n//! ```\n//! # use xaynet_core::crypto::EncryptKeyPair;\n//! let keys = EncryptKeyPair::generate();\n//! let message = b\"Hello world!\".to_vec();\n//! let cipher = keys.public.encrypt(&message);\n//! assert_eq!(message, keys.secret.decrypt(&cipher, &keys.public).unwrap());\n//! ```\n//!\n//! ## Signing of messages\n//! ```\n//! # use xaynet_core::crypto::SigningKeyPair;\n//! let keys = SigningKeyPair::generate();\n//! let message = b\"Hello world!\".to_vec();\n//! let signature = keys.secret.sign_detached(&message);\n//! assert!(keys.public.verify_detached(&signature, &message));\n//! ```\n//!\n//! [sodiumoxide]: https://docs.rs/sodiumoxide/\n\npub(crate) mod encrypt;\npub(crate) mod hash;\npub(crate) mod prng;\npub(crate) mod sign;\n\nuse sodiumoxide::randombytes::randombytes;\n\npub use self::{\n    encrypt::{EncryptKeyPair, EncryptKeySeed, PublicEncryptKey, SecretEncryptKey, SEALBYTES},\n    hash::Sha256,\n    prng::generate_integer,\n    sign::{PublicSigningKey, SecretSigningKey, Signature, SigningKeyPair, SigningKeySeed},\n};\n\n/// An interface for slicing into cryptographic byte objects.\npub trait ByteObject: Sized {\n    /// Length in bytes of this object\n    const LENGTH: usize;\n\n    /// Creates a new object with all the bytes initialized to `0`.\n    fn zeroed() -> Self;\n\n    /// Gets the object byte representation.\n    fn as_slice(&self) -> &[u8];\n\n    /// Creates an object from the given buffer.\n    ///\n    /// # Errors\n    /// Returns `None` if the length of the byte-slice isn't equal to the length of the object.\n    fn from_slice(bytes: &[u8]) -> Option<Self>;\n\n    /// Creates an object from the given buffer.\n    ///\n    /// # Panics\n    /// Panics if the length of the byte-slice isn't equal to the length of the object.\n    fn from_slice_unchecked(bytes: &[u8]) -> Self {\n        Self::from_slice(bytes).unwrap()\n    }\n\n    /// Generates an object with random bytes\n    fn generate() -> Self {\n        // safe unwrap: length of slice is guaranteed by constants\n        Self::from_slice_unchecked(randombytes(Self::LENGTH).as_slice())\n    }\n\n    /// A helper for instantiating an object filled with the given value\n    fn fill_with(value: u8) -> Self {\n        Self::from_slice_unchecked(&vec![value; Self::LENGTH])\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/crypto/prng.rs",
    "content": "//! PRNG utilities for the crypto primitives.\n//!\n//! See the [crypto module] documentation since this is a private module anyways.\n//!\n//! [sodiumoxide]: https://docs.rs/sodiumoxide/\n//! [crypto module]: crate::crypto\n\nuse num::{bigint::BigUint, traits::identities::Zero};\nuse rand::RngCore;\nuse rand_chacha::ChaCha20Rng;\n\n/// Generates a secure pseudo-random integer.\n///\n/// Draws from a uniform distribution over the integers between zero (included) and\n/// `max_int` (excluded). Employs the `ChaCha20` stream cipher as a PRNG.\npub fn generate_integer(prng: &mut ChaCha20Rng, max_int: &BigUint) -> BigUint {\n    if max_int.is_zero() {\n        return BigUint::zero();\n    }\n    let mut bytes = max_int.to_bytes_le();\n    let mut rand_int = max_int.clone();\n    while &rand_int >= max_int {\n        prng.fill_bytes(&mut bytes);\n        rand_int = BigUint::from_bytes_le(&bytes);\n    }\n    rand_int\n}\n\n#[cfg(test)]\nmod tests {\n    use num::traits::{pow::Pow, Num};\n    use rand::SeedableRng;\n\n    use super::*;\n\n    #[test]\n    fn test_generate_integer() {\n        let mut prng = ChaCha20Rng::from_seed([0_u8; 32]);\n        let max_int = BigUint::from(u128::max_value()).pow(2_usize);\n        assert_eq!(\n            generate_integer(&mut prng, &max_int),\n            BigUint::from_str_radix(\n                \"90034050956742099321159087842304570510687605373623064829879336909608119744630\",\n                10\n            )\n            .unwrap()\n        );\n        assert_eq!(\n            generate_integer(&mut prng, &max_int),\n            BigUint::from_str_radix(\n                \"60790020689334235010238064028215988394112077193561636249125918224917556969946\",\n                10\n            )\n            .unwrap()\n        );\n        assert_eq!(\n            generate_integer(&mut prng, &max_int),\n            BigUint::from_str_radix(\n                \"107415344426328791036720294006773438815099086866510488084511304829720271980447\",\n                10\n            )\n            .unwrap()\n        );\n        assert_eq!(\n            generate_integer(&mut prng, &max_int),\n            BigUint::from_str_radix(\n                \"50343610553303623842889112417183549658912134525854625844144939347139411162921\",\n                10\n            )\n            .unwrap()\n        );\n        assert_eq!(\n            generate_integer(&mut prng, &max_int),\n            BigUint::from_str_radix(\n                \"42382469383990928111449714288937630103705168010724718767641573929365517895981\",\n                10\n            )\n            .unwrap()\n        );\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/crypto/sign.rs",
    "content": "//! Wrappers around some of the [sodiumoxide] signing primitives.\n//!\n//! See the [crypto module] documentation since this is a private module anyways.\n//!\n//! [sodiumoxide]: https://docs.rs/sodiumoxide/\n//! [crypto module]: crate::crypto\n\nuse std::convert::TryInto;\n\nuse derive_more::{AsMut, AsRef, From};\nuse num::{\n    bigint::{BigUint, ToBigInt},\n    rational::Ratio,\n};\nuse serde::{Deserialize, Serialize};\nuse sodiumoxide::crypto::{hash::sha256, sign};\n\nuse super::ByteObject;\n\n#[derive(Debug, Clone, Serialize, Deserialize)]\n/// A `Ed25519` key pair for signatures.\npub struct SigningKeyPair {\n    /// The `Ed25519` public key.\n    pub public: PublicSigningKey,\n    /// The `Ed25519` secret key.\n    pub secret: SecretSigningKey,\n}\n\nimpl SigningKeyPair {\n    /// Generates a new random `Ed25519` key pair for signing.\n    pub fn generate() -> Self {\n        let (pk, sk) = sign::gen_keypair();\n        Self {\n            public: PublicSigningKey(pk),\n            secret: SecretSigningKey(sk),\n        }\n    }\n\n    pub fn derive_from_seed(seed: &SigningKeySeed) -> Self {\n        let (pk, sk) = seed.derive_signing_key_pair();\n        Self {\n            public: pk,\n            secret: sk,\n        }\n    }\n}\n\n#[derive(\n    AsRef,\n    AsMut,\n    From,\n    Serialize,\n    Deserialize,\n    Hash,\n    Eq,\n    Ord,\n    PartialEq,\n    Copy,\n    Clone,\n    PartialOrd,\n    Debug,\n)]\n/// An `Ed25519` public key for signatures.\npub struct PublicSigningKey(sign::PublicKey);\n\nimpl PublicSigningKey {\n    /// Verifies the signature `s` against the message `m` and this public key.\n    ///\n    /// Returns `true` if the signature is valid and `false` otherwise.\n    pub fn verify_detached(&self, s: &Signature, m: &[u8]) -> bool {\n        sign::verify_detached(s.as_ref(), m, self.as_ref())\n    }\n}\n\nimpl ByteObject for PublicSigningKey {\n    const LENGTH: usize = sign::PUBLICKEYBYTES;\n\n    fn zeroed() -> Self {\n        Self(sign::PublicKey([0_u8; sign::PUBLICKEYBYTES]))\n    }\n\n    fn as_slice(&self) -> &[u8] {\n        self.0.as_ref()\n    }\n\n    fn from_slice(bytes: &[u8]) -> Option<Self> {\n        sign::PublicKey::from_slice(bytes).map(Self)\n    }\n}\n\n#[derive(AsRef, AsMut, From, Serialize, Deserialize, Eq, PartialEq, Clone, Debug)]\n/// An `Ed25519` secret key for signatures.\n///\n/// When this goes out of scope, its contents will be zeroed out.\npub struct SecretSigningKey(sign::SecretKey);\n\nimpl SecretSigningKey {\n    /// Signs a message `m` with this secret key.\n    pub fn sign_detached(&self, m: &[u8]) -> Signature {\n        sign::sign_detached(m, self.as_ref()).into()\n    }\n\n    /// Computes the corresponding public key for this secret key.\n    pub fn public_key(&self) -> PublicSigningKey {\n        PublicSigningKey(self.0.public_key())\n    }\n}\n\nimpl ByteObject for SecretSigningKey {\n    const LENGTH: usize = sign::SECRETKEYBYTES;\n\n    fn zeroed() -> Self {\n        Self(sign::SecretKey([0_u8; Self::LENGTH]))\n    }\n\n    fn as_slice(&self) -> &[u8] {\n        self.0.as_ref()\n    }\n\n    fn from_slice(bytes: &[u8]) -> Option<Self> {\n        sign::SecretKey::from_slice(bytes).map(Self)\n    }\n}\n\n#[derive(AsRef, AsMut, From, Eq, PartialEq, Copy, Clone, Debug)]\n/// An `Ed25519` signature detached from its message.\npub struct Signature(sign::Signature);\n\nmod manually_derive_serde_for_signature {\n    //! TODO:\n    //! remove this if sodiumoxide decides to reintroduce serialization of signatures\n    //! <https://github.com/sodiumoxide/sodiumoxide/pull/434>\n\n    use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer};\n\n    use crate::crypto::{sign::Signature, ByteObject};\n\n    impl Serialize for Signature {\n        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>\n        where\n            S: Serializer,\n        {\n            self.as_slice().serialize(serializer)\n        }\n    }\n\n    impl<'de> Deserialize<'de> for Signature {\n        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>\n        where\n            D: Deserializer<'de>,\n        {\n            let bytes = <&[u8] as Deserialize>::deserialize(deserializer)?;\n            Self::from_slice(bytes).ok_or_else(|| {\n                D::Error::custom(format!(\n                    \"invalid length {}, expected {}\",\n                    bytes.len(),\n                    Self::LENGTH,\n                ))\n            })\n        }\n    }\n}\n\nimpl ByteObject for Signature {\n    const LENGTH: usize = sign::SIGNATUREBYTES;\n\n    fn zeroed() -> Self {\n        Self(sign::Signature::new([0_u8; Self::LENGTH]))\n    }\n\n    fn as_slice(&self) -> &[u8] {\n        self.0.as_ref()\n    }\n\n    fn from_slice(bytes: &[u8]) -> Option<Self> {\n        bytes.try_into().ok().map(Self)\n    }\n}\n\nimpl Signature {\n    /// Computes the floating point representation of the hashed signature and ensures that it is\n    /// below the given threshold:\n    /// ```no_rust\n    /// int(hash(signature)) / (2**hashbits - 1) <= threshold.\n    /// ```\n    pub fn is_eligible(&self, threshold: f64) -> bool {\n        if threshold < 0_f64 {\n            return false;\n        } else if threshold > 1_f64 {\n            return true;\n        }\n        // safe unwraps: `to_bigint` never fails for `BigUint`s\n        let numer = BigUint::from_bytes_le(sha256::hash(self.as_slice()).as_ref())\n            .to_bigint()\n            .unwrap();\n        let denom = BigUint::from_bytes_le([u8::MAX; sha256::DIGESTBYTES].as_ref())\n            .to_bigint()\n            .unwrap();\n        // safe unwrap: `threshold` is guaranteed to be finite\n        Ratio::new(numer, denom) <= Ratio::from_float(threshold).unwrap()\n    }\n}\n\n#[derive(AsRef, AsMut, From, Serialize, Deserialize, Eq, PartialEq, Clone)]\n/// A seed that can be used for `Ed25519` signing key pair generation.\n///\n/// When this goes out of scope, its contents will be zeroed out.\npub struct SigningKeySeed(sign::Seed);\n\nimpl SigningKeySeed {\n    /// Deterministically derives a new signing key pair from this seed.\n    pub fn derive_signing_key_pair(&self) -> (PublicSigningKey, SecretSigningKey) {\n        let (pk, sk) = sign::keypair_from_seed(&self.0);\n        (PublicSigningKey(pk), SecretSigningKey(sk))\n    }\n}\n\nimpl ByteObject for SigningKeySeed {\n    const LENGTH: usize = sign::SEEDBYTES;\n\n    fn from_slice(bytes: &[u8]) -> Option<Self> {\n        sign::Seed::from_slice(bytes).map(Self)\n    }\n\n    fn zeroed() -> Self {\n        Self(sign::Seed([0; sign::PUBLICKEYBYTES]))\n    }\n\n    fn as_slice(&self) -> &[u8] {\n        self.0.as_ref()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn test_signature_is_eligible() {\n        // eligible signature\n        let sig = Signature::from_slice_unchecked(&[\n            172, 29, 85, 219, 118, 44, 107, 32, 219, 253, 25, 242, 53, 45, 111, 62, 102, 130, 24,\n            8, 222, 199, 34, 120, 166, 163, 223, 229, 100, 50, 252, 244, 250, 88, 196, 151, 136,\n            48, 39, 198, 166, 86, 29, 151, 13, 81, 69, 198, 40, 148, 134, 126, 7, 202, 1, 56, 174,\n            43, 89, 28, 242, 194, 4, 0,\n        ]);\n        assert!(sig.is_eligible(0.5_f64));\n\n        // ineligible signature\n        let sig = Signature::from_slice_unchecked(&[\n            119, 2, 197, 174, 52, 165, 229, 22, 218, 210, 240, 188, 220, 232, 149, 129, 211, 13,\n            61, 217, 186, 79, 102, 15, 109, 237, 83, 193, 12, 117, 210, 66, 99, 230, 30, 131, 63,\n            108, 28, 222, 48, 92, 153, 71, 159, 220, 115, 181, 183, 155, 146, 182, 205, 89, 140,\n            234, 100, 40, 199, 248, 23, 147, 172, 0,\n        ]);\n        assert!(!sig.is_eligible(0.5_f64));\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/lib.rs",
    "content": "#![cfg_attr(docsrs, feature(doc_cfg))]\n#![cfg_attr(\n    doc,\n    forbid(rustdoc::broken_intra_doc_links, rustdoc::private_intra_doc_links)\n)]\n#![doc(\n    html_logo_url = \"https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/xaynet_banner.png\",\n    html_favicon_url = \"https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/favicon.png\",\n    issue_tracker_base_url = \"https://github.com/xaynetwork/xaynet/issues\"\n)]\n//! `xaynet_core` provides basic building blocks for implementing the\n//! _Privacy-Enhancing Technology_ (PET), a privacy preserving\n//! protocol for federated machine learning. Download the [whitepaper]\n//! for an introduction.\n//!\n//! [whitepaper]: https://uploads-ssl.webflow.com/5f0c5c0bb18a279f0a62919e/5f157004da6585f299fa542b_XayNet%20Whitepaper%202.1.pdf\n\npub mod common;\npub mod crypto;\npub mod mask;\npub mod message;\n#[cfg(any(feature = \"testutils\", test))]\n#[cfg_attr(docsrs, doc(cfg(feature = \"testutils\")))]\npub mod testutils;\n\nuse std::collections::HashMap;\n\nuse thiserror::Error;\n\nuse self::crypto::{\n    encrypt::{PublicEncryptKey, SecretEncryptKey},\n    sign::{PublicSigningKey, SecretSigningKey, Signature},\n};\n\n#[derive(Error, Debug)]\n#[error(\"initialization failed: insufficient system entropy to generate secrets\")]\n/// An error related to insufficient system entropy for secrets at program startup.\npub struct InitError;\n\n/// A public encryption key that identifies a coordinator.\npub type CoordinatorPublicKey = PublicEncryptKey;\n\n/// A secret encryption key that belongs to the public key of a\n/// coordinator.\npub type CoordinatorSecretKey = SecretEncryptKey;\n\n/// A public signature key that identifies a participant.\npub type ParticipantPublicKey = PublicSigningKey;\n\n/// A secret signature key that belongs to the public key of a\n/// participant.\npub type ParticipantSecretKey = SecretSigningKey;\n\n/// A public signature key that identifies a sum participant.\npub type SumParticipantPublicKey = ParticipantPublicKey;\n\n/// A secret signature key that belongs to the public key of a sum\n/// participant.\npub type SumParticipantSecretKey = ParticipantSecretKey;\n\n/// A public encryption key generated by a sum participant. It is used\n/// by the update participants to encrypt their masking seed for each\n/// sum participant.\npub type SumParticipantEphemeralPublicKey = PublicEncryptKey;\n\n/// The secret counterpart of [`SumParticipantEphemeralPublicKey`]\npub type SumParticipantEphemeralSecretKey = SecretEncryptKey;\n\n/// A public signature key that identifies an update participant.\npub type UpdateParticipantPublicKey = ParticipantPublicKey;\n\n/// A secret signature key that belongs to the public key of an update\n/// participant.\npub type UpdateParticipantSecretKey = ParticipantSecretKey;\n\n/// A signature to prove a participant's eligibility for a task.\npub type ParticipantTaskSignature = Signature;\n\n/// A dictionary created during the sum phase of the protocol. It maps the public key of every sum\n/// participant to the ephemeral public key generated by that sum participant.\npub type SumDict = HashMap<SumParticipantPublicKey, SumParticipantEphemeralPublicKey>;\n\n/// Local seed dictionaries are sent by update participants. They contain the participant's masking\n/// seed, encrypted with the ephemeral public key of each sum participant.\npub type LocalSeedDict = HashMap<SumParticipantPublicKey, mask::seed::EncryptedMaskSeed>;\n\n/// A dictionary created during the update phase of the protocol. The global seed dictionary is\n/// built from the local seed dictionaries sent by the update participants. It maps each sum\n/// participant to the encrypted masking seeds of all the update participants.\npub type SeedDict = HashMap<SumParticipantPublicKey, UpdateSeedDict>;\n\n/// Values of [`SeedDict`]. Sent to sum participants.\npub type UpdateSeedDict = HashMap<UpdateParticipantPublicKey, mask::seed::EncryptedMaskSeed>;\n"
  },
  {
    "path": "rust/xaynet-core/src/mask/config/mod.rs",
    "content": "//! Masking configuration parameters.\n//!\n//! See the [mask module] documentation since this is a private module anyways.\n//!\n//! [mask module]: crate::mask\n\npub(crate) mod serialization;\n\nuse std::convert::TryFrom;\n\nuse num::{\n    bigint::{BigInt, BigUint},\n    rational::Ratio,\n    traits::{pow::Pow, Num},\n};\nuse serde::{Deserialize, Serialize};\nuse thiserror::Error;\n\n// target dependent maximum bytes per mask object element\n#[cfg(target_pointer_width = \"16\")]\nconst MAX_BPN: u64 = u16::MAX as u64;\n#[cfg(target_pointer_width = \"32\")]\nconst MAX_BPN: u64 = u32::MAX as u64;\n\n#[derive(Debug, Error)]\n/// Errors related to invalid masking configurations.\npub enum InvalidMaskConfigError {\n    #[error(\"invalid group type\")]\n    GroupType,\n    #[error(\"invalid data type\")]\n    DataType,\n    #[error(\"invalid bound type\")]\n    BoundType,\n    #[error(\"invalid model type\")]\n    ModelType,\n}\n\n#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]\n#[repr(u8)]\n/// The order of the finite group.\npub enum GroupType {\n    /// A finite group of exact integer order.\n    Integer = 0,\n    /// A finite group of prime order.\n    Prime = 1,\n    /// A finite group of power-of-two order.\n    Power2 = 2,\n}\n\nimpl TryFrom<u8> for GroupType {\n    type Error = InvalidMaskConfigError;\n\n    fn try_from(byte: u8) -> Result<Self, Self::Error> {\n        match byte {\n            0 => Ok(GroupType::Integer),\n            1 => Ok(GroupType::Prime),\n            2 => Ok(GroupType::Power2),\n            _ => Err(InvalidMaskConfigError::GroupType),\n        }\n    }\n}\n\n#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]\n#[repr(u8)]\n/// The original primitive data type of the numerical values to be masked.\npub enum DataType {\n    /// Numbers of type f32.\n    F32 = 0,\n    /// Numbers of type f64.\n    F64 = 1,\n    /// Numbers of type i32.\n    I32 = 2,\n    /// Numbers of type i64.\n    I64 = 3,\n}\n\nimpl TryFrom<u8> for DataType {\n    type Error = InvalidMaskConfigError;\n\n    fn try_from(byte: u8) -> Result<Self, Self::Error> {\n        match byte {\n            0 => Ok(DataType::F32),\n            1 => Ok(DataType::F64),\n            2 => Ok(DataType::I32),\n            3 => Ok(DataType::I64),\n            _ => Err(InvalidMaskConfigError::DataType),\n        }\n    }\n}\n\n#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]\n#[repr(u8)]\n/// The bounds of the numerical values.\n///\n/// For a value `v` to be absolutely bounded by another value `b`, it has to hold that\n/// `-b <= v <= b` or equivalently `|v| <= b`.\npub enum BoundType {\n    /// Numerical values absolutely bounded by 1.\n    B0 = 0,\n    /// Numerical values absolutely bounded by 100.\n    B2 = 2,\n    /// Numerical values absolutely bounded by 10_000.\n    B4 = 4,\n    /// Numerical values absolutely bounded by 1_000_000.\n    B6 = 6,\n    /// Numerical values absolutely bounded by their original primitive data type's maximum absolute\n    /// value.\n    Bmax = 255,\n}\n\nimpl TryFrom<u8> for BoundType {\n    type Error = InvalidMaskConfigError;\n\n    fn try_from(byte: u8) -> Result<Self, Self::Error> {\n        match byte {\n            0 => Ok(BoundType::B0),\n            2 => Ok(BoundType::B2),\n            4 => Ok(BoundType::B4),\n            6 => Ok(BoundType::B6),\n            255 => Ok(BoundType::Bmax),\n            _ => Err(InvalidMaskConfigError::ModelType),\n        }\n    }\n}\n\n#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]\n#[repr(u8)]\n/// The maximum number of models to be aggregated.\npub enum ModelType {\n    /// At most 1_000 models to be aggregated.\n    M3 = 3,\n    /// At most 1_000_000 models to be aggregated.\n    M6 = 6,\n    /// At most 1_000_000_000 models to be aggregated.\n    M9 = 9,\n    /// At most 1_000_000_000_000 models to be aggregated.\n    M12 = 12,\n}\n\nimpl ModelType {\n    /// Gets the maximum number of models that can be aggregated for this model type.\n    pub fn max_nb_models(&self) -> usize {\n        10_usize.pow(*self as u8 as u32)\n    }\n}\n\nimpl TryFrom<u8> for ModelType {\n    type Error = InvalidMaskConfigError;\n\n    fn try_from(byte: u8) -> Result<Self, Self::Error> {\n        match byte {\n            3 => Ok(ModelType::M3),\n            6 => Ok(ModelType::M6),\n            9 => Ok(ModelType::M9),\n            12 => Ok(ModelType::M12),\n            _ => Err(InvalidMaskConfigError::ModelType),\n        }\n    }\n}\n\n#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]\n/// A masking configuration.\n///\n/// This configuration is applied for masking, aggregation and unmasking of models.\npub struct MaskConfig {\n    /// The order of the finite group.\n    pub group_type: GroupType,\n    /// The original primitive data type of the numerical values to be masked.\n    pub data_type: DataType,\n    /// The bounds of the numerical values.\n    pub bound_type: BoundType,\n    /// The maximum number of models to be aggregated.\n    pub model_type: ModelType,\n}\n\nimpl MaskConfig {\n    /// Returns the number of bytes needed for an element of a mask object.\n    ///\n    /// # Panics\n    /// Panics if the bytes per number can't be represented as usize.\n    pub(crate) fn bytes_per_number(&self) -> usize {\n        let max_number = self.order() - BigUint::from(1_u8);\n        let bpn = (max_number.bits() + 7) / 8;\n\n        // the largest bpn from the masking configuration catalogue is currently 173, hence this is\n        // almost impossible on 32 bits targets and smaller targets are currently not of interest\n        #[cfg(any(target_pointer_width = \"16\", target_pointer_width = \"32\"))]\n        if bpn > MAX_BPN {\n            panic!(\"the employed masking config is not supported on the target\")\n        }\n\n        bpn as usize\n    }\n\n    /// Gets the additional shift value for masking/unmasking.\n    pub fn add_shift(&self) -> Ratio<BigInt> {\n        use BoundType::{Bmax, B0, B2, B4, B6};\n        use DataType::{F32, F64, I32, I64};\n\n        match self.bound_type {\n            B0 => Ratio::from_integer(BigInt::from(1)),\n            B2 => Ratio::from_integer(BigInt::from(100)),\n            B4 => Ratio::from_integer(BigInt::from(10_000)),\n            B6 => Ratio::from_integer(BigInt::from(1_000_000)),\n            Bmax => match self.data_type {\n                // safe unwraps: all numbers are finite\n                F32 => Ratio::from_float(f32::MAX).unwrap(),\n                F64 => Ratio::from_float(f64::MAX).unwrap(),\n                I32 => Ratio::from_integer(-BigInt::from(i32::MIN)),\n                I64 => Ratio::from_integer(-BigInt::from(i64::MIN)),\n            },\n        }\n    }\n\n    /// Gets the exponential shift value for masking/unmasking.\n    pub fn exp_shift(&self) -> BigInt {\n        use BoundType::{Bmax, B0, B2, B4, B6};\n        use DataType::{F32, F64, I32, I64};\n\n        match self.data_type {\n            F32 => match self.bound_type {\n                B0 | B2 | B4 | B6 => BigInt::from(10).pow(10_u8),\n                Bmax => BigInt::from(10).pow(45_u8),\n            },\n            F64 => match self.bound_type {\n                B0 | B2 | B4 | B6 => BigInt::from(10).pow(20_u8),\n                Bmax => BigInt::from(10).pow(324_u16),\n            },\n            I32 | I64 => BigInt::from(10).pow(10_u8),\n        }\n    }\n\n    /// Gets the finite group order value for masking/unmasking.\n    pub fn order(&self) -> BigUint {\n        use BoundType::{Bmax, B0, B2, B4, B6};\n        use DataType::{F32, F64, I32, I64};\n        use GroupType::{Integer, Power2, Prime};\n        use ModelType::{M12, M3, M6, M9};\n\n        let order_str = match self.group_type {\n            Integer => match self.data_type {\n                F32 => match self.bound_type {\n                    B0 => match self.model_type {\n                        M3 => \"20_000_000_000_001\",\n                        M6 => \"20_000_000_000_000_001\",\n                        M9 => \"20_000_000_000_000_000_001\",\n                        M12 => \"20_000_000_000_000_000_000_001\",\n                    }\n                    B2 => match self.model_type {\n                        M3 => \"2_000_000_000_000_001\",\n                        M6 => \"2_000_000_000_000_000_001\",\n                        M9 => \"2_000_000_000_000_000_000_001\",\n                        M12 => \"2_000_000_000_000_000_000_000_001\",\n                    }\n                    B4 => match self.model_type {\n                        M3 => \"200_000_000_000_000_001\",\n                        M6 => \"200_000_000_000_000_000_001\",\n                        M9 => \"200_000_000_000_000_000_000_001\",\n                        M12 => \"200_000_000_000_000_000_000_000_001\",\n                    }\n                    B6 => match self.model_type {\n                        M3 => \"20_000_000_000_000_000_001\",\n                        M6 => \"20_000_000_000_000_000_000_001\",\n                        M9 => \"20_000_000_000_000_000_000_000_001\",\n                        M12 => \"20_000_000_000_000_000_000_000_000_001\",\n                    }\n                    Bmax => match self.model_type {\n                        M3 => \"680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_001\",\n                        M6 => \"680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_001\",\n                        M9 => \"680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_001\",\n                        M12 => \"680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_001\",\n                    }\n                }\n                F64 => match self.bound_type {\n                    B0 => match self.model_type {\n                        M3 => \"200_000_000_000_000_000_000_001\",\n                        M6 => \"200_000_000_000_000_000_000_000_001\",\n                        M9 => \"200_000_000_000_000_000_000_000_000_001\",\n                        M12 => \"200_000_000_000_000_000_000_000_000_000_001\",\n                    }\n                    B2 => match self.model_type {\n                        M3 => \"20_000_000_000_000_000_000_000_001\",\n                        M6 => \"20_000_000_000_000_000_000_000_000_001\",\n                        M9 => \"20_000_000_000_000_000_000_000_000_000_001\",\n                        M12 => \"20_000_000_000_000_000_000_000_000_000_000_001\",\n                    }\n                    B4 => match self.model_type {\n                        M3 => \"2_000_000_000_000_000_000_000_000_001\",\n                        M6 => \"2_000_000_000_000_000_000_000_000_000_001\",\n                        M9 => \"2_000_000_000_000_000_000_000_000_000_000_001\",\n                        M12 => \"2_000_000_000_000_000_000_000_000_000_000_000_001\",\n                    }\n                    B6 => match self.model_type {\n                        M3 => \"200_000_000_000_000_000_000_000_000_001\",\n                        M6 => \"200_000_000_000_000_000_000_000_000_000_001\",\n                        M9 => \"200_000_000_000_000_000_000_000_000_000_000_001\",\n                        M12 => \"200_000_000_000_000_000_000_000_000_000_000_000_001\",\n                    }\n                    Bmax => match self.model_type {\n                        M3 => \"359_538_626_972_463_100_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_001\",\n                        M6 => \"359_538_626_972_463_100_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_001\",\n                        M9 => \"359_538_626_972_463_100_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_001\",\n                        M12 => \"359_538_626_972_463_100_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_001\",\n                    }\n                }\n                I32 => match self.bound_type {\n                    B0 => match self.model_type {\n                        M3 => \"20_000_000_000_001\",\n                        M6 => \"20_000_000_000_000_001\",\n                        M9 => \"20_000_000_000_000_000_001\",\n                        M12 => \"20_000_000_000_000_000_000_001\",\n                    }\n                    B2 => match self.model_type {\n                        M3 => \"2_000_000_000_000_001\",\n                        M6 => \"2_000_000_000_000_000_001\",\n                        M9 => \"2_000_000_000_000_000_000_001\",\n                        M12 => \"2_000_000_000_000_000_000_000_001\",\n                    }\n                    B4 => match self.model_type {\n                        M3 => \"200_000_000_000_000_001\",\n                        M6 => \"200_000_000_000_000_000_001\",\n                        M9 => \"200_000_000_000_000_000_000_001\",\n                        M12 => \"200_000_000_000_000_000_000_000_001\",\n                    }\n                    B6 => match self.model_type {\n                        M3 => \"20_000_000_000_000_000_001\",\n                        M6 => \"20_000_000_000_000_000_000_001\",\n                        M9 => \"20_000_000_000_000_000_000_000_001\",\n                        M12 => \"20_000_000_000_000_000_000_000_000_001\",\n                    }\n                    Bmax => match self.model_type {\n                        M3 => \"42_949_672_950_000_000_000_001\",\n                        M6 => \"42_949_672_950_000_000_000_000_001\",\n                        M9 => \"42_949_672_950_000_000_000_000_000_001\",\n                        M12 => \"42_949_672_950_000_000_000_000_000_000_001\",\n                    }\n                }\n                I64 => match self.bound_type {\n                    B0 => match self.model_type {\n                        M3 => \"20_000_000_000_001\",\n                        M6 => \"20_000_000_000_000_001\",\n                        M9 => \"20_000_000_000_000_000_001\",\n                        M12 => \"20_000_000_000_000_000_000_001\",\n                    }\n                    B2 => match self.model_type {\n                        M3 => \"2_000_000_000_000_001\",\n                        M6 => \"2_000_000_000_000_000_001\",\n                        M9 => \"2_000_000_000_000_000_000_001\",\n                        M12 => \"2_000_000_000_000_000_000_000_001\",\n                    }\n                    B4 => match self.model_type {\n                        M3 => \"200_000_000_000_000_001\",\n                        M6 => \"200_000_000_000_000_000_001\",\n                        M9 => \"200_000_000_000_000_000_000_001\",\n                        M12 => \"200_000_000_000_000_000_000_000_001\",\n                    }\n                    B6 => match self.model_type {\n                        M3 => \"20_000_000_000_000_000_001\",\n                        M6 => \"20_000_000_000_000_000_000_001\",\n                        M9 => \"20_000_000_000_000_000_000_000_001\",\n                        M12 => \"20_000_000_000_000_000_000_000_000_001\",\n                    }\n                    Bmax => match self.model_type {\n                        M3 => \"184_467_440_737_095_516_150_000_000_000_001\",\n                        M6 => \"184_467_440_737_095_516_150_000_000_000_000_001\",\n                        M9 => \"184_467_440_737_095_516_150_000_000_000_000_000_001\",\n                        M12 => \"184_467_440_737_095_516_150_000_000_000_000_000_000_001\",\n                    }\n                }\n            }\n            Prime => match self.data_type {\n                F32 => match self.bound_type {\n                    B0 => match self.model_type {\n                        M3 => \"20_000_000_000_021\",\n                        M6 => \"20_000_000_000_000_003\",\n                        M9 => \"20_000_000_000_000_000_011\",\n                        M12 => \"20_000_000_000_000_000_000_003\",\n                    }\n                    B2 => match self.model_type {\n                        M3 => \"2_000_000_000_000_021\",\n                        M6 => \"2_000_000_000_000_000_057\",\n                        M9 => \"2_000_000_000_000_000_000_069\",\n                        M12 => \"2_000_000_000_000_000_000_000_003\",\n                    }\n                    B4 => match self.model_type {\n                        M3 => \"200_000_000_000_000_003\",\n                        M6 => \"200_000_000_000_000_000_089\",\n                        M9 => \"200_000_000_000_000_000_000_069\",\n                        M12 => \"200_000_000_000_000_000_000_000_027\",\n                    }\n                    B6 => match self.model_type {\n                        M3 => \"20_000_000_000_000_000_011\",\n                        M6 => \"20_000_000_000_000_000_000_003\",\n                        M9 => \"20_000_000_000_000_000_000_000_009\",\n                        M12 => \"20_000_000_000_000_000_000_000_000_131\",\n                    }\n                    Bmax => match self.model_type {\n                        M3 => \"680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_281\",\n                        M6 => \"680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_323\",\n                        M9 => \"680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_191\",\n                        M12 => \"680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_083\",\n                    }\n                }\n                F64 => match self.bound_type {\n                    B0 => match self.model_type {\n                        M3 => \"200_000_000_000_000_000_000_069\",\n                        M6 => \"200_000_000_000_000_000_000_000_027\",\n                        M9 => \"200_000_000_000_000_000_000_000_000_017\",\n                        M12 => \"200_000_000_000_000_000_000_000_000_000_159\",\n                    }\n                    B2 => match self.model_type {\n                        M3 => \"20_000_000_000_000_000_000_000_009\",\n                        M6 => \"20_000_000_000_000_000_000_000_000_131\",\n                        M9 => \"20_000_000_000_000_000_000_000_000_000_047\",\n                        M12 => \"20_000_000_000_000_000_000_000_000_000_000_203\",\n                    }\n                    B4 => match self.model_type {\n                        M3 => \"2_000_000_000_000_000_000_000_000_039\",\n                        M6 => \"2_000_000_000_000_000_000_000_000_000_071\",\n                        M9 => \"2_000_000_000_000_000_000_000_000_000_000_017\",\n                        M12 => \"2_000_000_000_000_000_000_000_000_000_000_000_041\",\n                    }\n                    B6 => match self.model_type {\n                        M3 => \"200_000_000_000_000_000_000_000_000_017\",\n                        M6 => \"200_000_000_000_000_000_000_000_000_000_159\",\n                        M9 => \"200_000_000_000_000_000_000_000_000_000_000_003\",\n                        M12 => \"200_000_000_000_000_000_000_000_000_000_000_000_023\",\n                    }\n                    Bmax => match self.model_type {\n                        M3 => \"359_538_626_972_463_140_000_000_000_000_000_000_000_593_874_019_667_231_666_067_439_096_529_924_969_333_439_983_391_110_599_943_465_644_007_133_099_721_551_828_263_813_044_710_323_667_390_405_279_670_626_898_022_875_314_671_948_577_301_533_414_396_469_719_048_504_306_012_596_386_638_859_340_084_030_210_314_832_025_518_258_115_226_051_894_034_477_843_584_650_149_420_090_374_373_134_876_775_786_923_748_346_298_936_467_612_015_276_401_624_887_654_050_299_443_392_510_555_689_981_501_608_709_494_004_423_956_258_647_440_955_320_257_123_787_935_493_476_104_132_776_728_548_437_783_283_112_428_445_450_269_488_453_346_610_914_359_272_368_862_786_051_728_965_455_746_393_095_846_720_860_347_644_662_201_994_241_194_193_316_457_656_284_847_050_135_299_403_149_697_261_199_957_835_824_000_531_233_031_619_352_921_347_101_423_914_861_961_738_035_659_301\",\n                        M6 => \"359_538_626_972_463_139_999_999_999_999_999_999_999_903_622_106_309_601_840_402_558_296_261_360_055_843_460_163_714_984_640_183_652_353_129_826_112_739_444_431_322_400_938_984_152_600_575_421_591_212_739_537_896_016_542_591_595_727_264_024_538_428_559_469_178_136_611_680_881_710_150_818_089_794_351_154_869_285_409_959_876_691_068_635_451_827_253_162_844_058_791_343_487_286_852_635_234_799_336_668_682_655_217_329_655_102_622_197_942_194_212_857_658_834_043_465_713_831_143_523_811_067_060_369_640_438_677_832_007_091_511_212_788_398_470_391_285_320_720_769_417_737_628_120_102_221_909_739_846_753_580_817_462_645_602_854_496_103_866_327_474_145_187_363_329_320_852_679_912_679_009_543_036_760_757_409_720_574_191_338_832_841_104_183_169_976_025_577_743_061_881_721_861_634_977_765_641_182_996_194_573_448_626_763_720_938_201_976_656_541_039_724_303\",\n                        M9 => \"359_538_626_972_463_139_999_999_999_999_999_999_999_904_930_781_891_526_077_660_862_016_966_437_766_478_934_820_885_791_914_528_679_207_262_530_042_483_798_832_910_003_057_874_958_310_694_484_517_139_841_166_977_272_287_522_418_122_134_527_125_053_808_273_636_647_181_903_383_717_418_169_782_215_585_647_900_802_728_035_567_327_931_187_710_919_458_230_957_036_511_507_150_288_137_858_111_024_099_126_399_746_768_695_036_546_643_813_753_385_062_385_762_652_380_150_346_615_796_407_577_297_605_069_883_839_431_646_689_072_072_214_687_584_099_356_273_959_025_519_093_953_786_032_481_175_596_842_406_101_871_239_892_163_505_527_137_519_569_046_747_947_203_065_300_865_116_331_411_924_515_285_552_096_042_635_874_474_960_733_445_241_451_746_509_870_642_272_026_256_695_499_704_624_475_309_137_281_644_358_183_373_160_068_523_639_023_207_643_484_888_657_559_597\",\n                        M12 => \"359_538_626_972_463_139_999_999_999_999_999_999_999_904_931_540_467_867_407_238_817_633_447_114_203_759_664_620_787_471_913_925_990_313_859_370_016_783_101_785_327_523_046_787_247_090_978_931_042_236_128_228_564_142_680_745_383_377_953_776_024_143_512_065_781_667_978_525_748_300_241_659_425_164_472_387_573_470_260_831_720_974_578_793_447_369_507_661_739_490_218_806_790_001_765_109_117_055_431_552_295_585_457_639_803_896_262_637_528_011_897_242_316_426_079_400_392_728_240_523_639_775_219_294_589_603_009_325_941_759_217_573_340_626_063_716_838_671_315_192_395_974_939_441_284_468_885_927_433_422_082_497_928_190_254_190_935_717_337_452_741_850_223_510_814_859_331_413_287_559_285_438_144_477_756_395_583_878_761_313_295_130_567_342_888_620_541_025_745_968_373_350_261_259_032_809_052_052_475_301_496_416_128_372_300_050_762_773_363_722_300_553_930_211_649\",\n                    }\n                }\n                I32 => match self.bound_type {\n                    B0 => match self.model_type {\n                        M3 => \"20_000_000_000_021\",\n                        M6 => \"20_000_000_000_000_003\",\n                        M9 => \"20_000_000_000_000_000_011\",\n                        M12 => \"20_000_000_000_000_000_000_003\",\n                    }\n                    B2 => match self.model_type {\n                        M3 => \"2_000_000_000_000_021\",\n                        M6 => \"2_000_000_000_000_000_057\",\n                        M9 => \"2_000_000_000_000_000_000_069\",\n                        M12 => \"2_000_000_000_000_000_000_000_003\",\n                    }\n                    B4 => match self.model_type {\n                        M3 => \"200_000_000_000_000_003\",\n                        M6 => \"200_000_000_000_000_000_089\",\n                        M9 => \"200_000_000_000_000_000_000_069\",\n                        M12 => \"200_000_000_000_000_000_000_000_027\",\n                    }\n                    B6 => match self.model_type {\n                        M3 => \"20_000_000_000_000_000_011\",\n                        M6 => \"20_000_000_000_000_000_000_003\",\n                        M9 => \"20_000_000_000_000_000_000_000_009\",\n                        M12 => \"20_000_000_000_000_000_000_000_000_131\",\n                    }\n                    Bmax => match self.model_type {\n                        M3 => \"42_949_672_950_000_000_000_029\",\n                        M6 => \"42_949_672_950_000_000_000_000_049\",\n                        M9 => \"42_949_672_950_000_000_000_000_000_043\",\n                        M12 => \"42_949_672_950_000_000_000_000_000_000_109\",\n                    }\n                }\n                I64 => match self.bound_type {\n                    B0 => match self.model_type {\n                        M3 => \"20_000_000_000_021\",\n                        M6 => \"20_000_000_000_000_003\",\n                        M9 => \"20_000_000_000_000_000_011\",\n                        M12 => \"20_000_000_000_000_000_000_003\",\n                    }\n                    B2 => match self.model_type {\n                        M3 => \"2_000_000_000_000_021\",\n                        M6 => \"2_000_000_000_000_000_057\",\n                        M9 => \"2_000_000_000_000_000_000_069\",\n                        M12 => \"2_000_000_000_000_000_000_000_003\",\n                    }\n                    B4 => match self.model_type {\n                        M3 => \"200_000_000_000_000_003\",\n                        M6 => \"200_000_000_000_000_000_089\",\n                        M9 => \"200_000_000_000_000_000_000_069\",\n                        M12 => \"200_000_000_000_000_000_000_000_027\",\n                    }\n                    B6 => match self.model_type {\n                        M3 => \"20_000_000_000_000_000_011\",\n                        M6 => \"20_000_000_000_000_000_000_003\",\n                        M9 => \"20_000_000_000_000_000_000_000_009\",\n                        M12 => \"20_000_000_000_000_000_000_000_000_131\",\n                    }\n                    Bmax => match self.model_type {\n                        M3 => \"184_467_440_737_095_516_150_000_000_000_073\",\n                        M6 => \"184_467_440_737_095_516_150_000_000_000_000_013\",\n                        M9 => \"184_467_440_737_095_516_150_000_000_000_000_000_167\",\n                        M12 => \"184_467_440_737_095_516_150_000_000_000_000_000_000_089\",\n                    }\n                }\n            },\n            Power2 => match self.data_type {\n                F32 => match self.bound_type {\n                    B0 => match self.model_type {\n                        M3 => \"35_184_372_088_832\",\n                        M6 => \"36_028_797_018_963_968\",\n                        M9 => \"36_893_488_147_419_103_232\",\n                        M12 => \"37_778_931_862_957_161_709_568\",\n                    }\n                    B2 => match self.model_type {\n                        M3 => \"2_251_799_813_685_248\",\n                        M6 => \"2_305_843_009_213_693_952\",\n                        M9 => \"2_361_183_241_434_822_606_848\",\n                        M12 => \"2_417_851_639_229_258_349_412_352\",\n                    }\n                    B4 => match self.model_type {\n                        M3 => \"288_230_376_151_711_744\",\n                        M6 => \"295_147_905_179_352_825_856\",\n                        M9 => \"302_231_454_903_657_293_676_544\",\n                        M12 => \"309_485_009_821_345_068_724_781_056\",\n                    }\n                    B6 => match self.model_type {\n                        M3 => \"36_893_488_147_419_103_232\",\n                        M6 => \"37_778_931_862_957_161_709_568\",\n                        M9 => \"38_685_626_227_668_133_590_597_632\",\n                        M12 => \"39_614_081_257_132_168_796_771_975_168\",\n                    }\n                    Bmax => match self.model_type {\n                        M3 => \"994_646_472_819_573_284_310_764_496_293_641_680_200_912_301_594_695_434_880_927_953_786_318_994_025_066_751_066_112\",\n                        M6 => \"1_018_517_988_167_243_043_134_222_844_204_689_080_525_734_196_832_968_125_318_070_224_677_190_649_881_668_353_091_698_688\",\n                        M9 => \"1_042_962_419_883_256_876_169_444_192_465_601_618_458_351_817_556_959_360_325_703_910_069_443_225_478_828_393_565_899_456_512\",\n                        M12 => \"1_067_993_517_960_455_041_197_510_853_084_776_057_301_352_261_178_326_384_973_520_803_911_109_862_890_320_275_011_481_043_468_288\",\n                    }\n                }\n                F64 => match self.bound_type {\n                    B0 => match self.model_type {\n                        M3 => \"302_231_454_903_657_293_676_544\",\n                        M6 => \"309_485_009_821_345_068_724_781_056\",\n                        M9 => \"316_912_650_057_057_350_374_175_801_344\",\n                        M12 => \"324_518_553_658_426_726_783_156_020_576_256\",\n                    }\n                    B2 => match self.model_type {\n                        M3 => \"38_685_626_227_668_133_590_597_632\",\n                        M6 => \"39_614_081_257_132_168_796_771_975_168\",\n                        M9 => \"20_282_409_603_651_670_423_947_251_286_016\",\n                        M12 => \"20_769_187_434_139_310_514_121_985_316_880_384\",\n                    }\n                    B4 => match self.model_type {\n                        M3 => \"2_475_880_078_570_760_549_798_248_448\",\n                        M6 => \"2_535_301_200_456_458_802_993_406_410_752\",\n                        M9 => \"2_596_148_429_267_413_814_265_248_164_610_048\",\n                        M12 => \"2_658_455_991_569_831_745_807_614_120_560_689_152\",\n                    }\n                    B6 => match self.model_type {\n                        M3 => \"316_912_650_057_057_350_374_175_801_344\",\n                        M6 => \"324_518_553_658_426_726_783_156_020_576_256\",\n                        M9 => \"332_306_998_946_228_968_225_951_765_070_086_144\",\n                        M12 => \"340_282_366_920_938_463_463_374_607_431_768_211_456\",\n                    }\n                    Bmax => match self.model_type {\n                        M3 => \"596_143_540_225_991_923_146_302_416_688_458_341_289_203_474_674_553_062_792_993_127_033_853_365_765_018_588_197_722_567_551_977_295_508_215_323_031_793_155_057_153_946_025_631_943_349_443_566_464_703_583_960_364_782_216_884_718_655_637_955_371_883_889_285_523_680_681_542_682_622_992_485_998_454_422_254_346_205_188_269_982_058_330_848_165_814_218_528_432_304_958_458_516_472_675_321_199_923_576_436_128_746_194_040_030_388_187_813_654_706_961_312_852_788_047_760_914_640_519_973_439_182_188_222_756_017_424_664_821_230_981_616_162_111_762_973_371_192_278_908_910_941_031_147_045_555_738_506_834_254_728_517_124_812_756_790_583_181_174_762_115_337_827_697_771_072_593_076_558_961_853_936_203_969_690_859_453_400_618_497_370_766_001_868_317_217_344_149_071_638_768_630_396_860_838_478_405_181_466_899_321_747_678_290_733_613_480_879_657_473_540_096\",\n                        M6 => \"610_450_985_191_415_729_301_813_674_688_981_341_480_144_358_066_742_336_300_024_962_082_665_846_543_379_034_314_467_909_173_224_750_600_412_490_784_556_190_778_525_640_730_247_109_989_830_212_059_856_469_975_413_536_990_089_951_903_373_266_300_809_102_628_376_249_017_899_707_005_944_305_662_417_328_388_450_514_112_788_461_627_730_788_521_793_759_773_114_680_277_461_520_868_019_528_908_721_742_270_595_836_102_696_991_117_504_321_182_419_928_384_361_254_960_907_176_591_892_452_801_722_560_740_102_161_842_856_776_940_525_174_950_002_445_284_732_100_893_602_724_803_615_894_574_649_076_230_998_276_842_001_535_808_262_953_557_177_522_956_406_105_935_562_517_578_335_310_396_376_938_430_672_864_963_440_080_282_233_341_307_664_385_913_156_830_560_408_649_358_099_077_526_385_498_601_886_905_822_104_905_469_622_569_711_220_204_420_769_252_905_058_304\",\n                        M9 => \"625_101_808_836_009_706_805_057_202_881_516_893_675_667_822_660_344_152_371_225_561_172_649_826_860_420_131_138_015_138_993_382_144_614_822_390_563_385_539_357_210_256_107_773_040_629_586_137_149_293_025_254_823_461_877_852_110_749_054_224_692_028_521_091_457_278_994_329_299_974_086_968_998_315_344_269_773_326_451_495_384_706_796_327_446_316_810_007_669_432_604_120_597_368_851_997_602_531_064_085_090_136_169_161_718_904_324_424_890_798_006_665_585_925_079_968_948_830_097_871_668_963_902_197_864_613_727_085_339_587_097_779_148_802_503_971_565_671_315_049_190_198_902_676_044_440_654_060_542_235_486_209_572_667_661_264_442_549_783_507_359_852_478_016_018_000_215_357_845_889_984_953_009_013_722_562_642_209_006_941_499_048_331_175_072_594_493_858_456_942_693_455_387_018_750_568_332_191_561_835_423_200_893_511_384_289_489_326_867_714_974_779_703_296\",\n                        M12 => \"640_104_252_248_073_939_768_378_575_750_673_299_123_883_850_404_192_412_028_134_974_640_793_422_705_070_214_285_327_502_329_223_316_085_578_127_936_906_792_301_783_302_254_359_593_604_696_204_440_876_057_860_939_224_962_920_561_407_031_526_084_637_205_597_652_253_690_193_203_173_465_056_254_274_912_532_247_886_286_331_273_939_759_439_305_028_413_447_853_498_986_619_491_705_704_445_544_991_809_623_132_299_437_221_600_158_028_211_088_177_158_825_559_987_281_888_203_602_020_220_589_019_035_850_613_364_456_535_387_737_188_125_848_373_764_066_883_247_426_610_370_763_676_340_269_507_229_757_995_249_137_878_602_411_685_134_789_170_978_311_536_488_937_488_402_432_220_526_434_191_344_591_881_230_051_904_145_622_023_108_095_025_491_123_274_336_761_711_059_909_318_098_316_307_200_581_972_164_159_319_473_357_714_955_657_512_437_070_712_540_134_174_416_175_104\",\n                    }\n                }\n                I32 => match self.bound_type {\n                    B0 => match self.model_type {\n                        M3 => \"35_184_372_088_832\",\n                        M6 => \"36_028_797_018_963_968\",\n                        M9 => \"36_893_488_147_419_103_232\",\n                        M12 => \"37_778_931_862_957_161_709_568\",\n                    }\n                    B2 => match self.model_type {\n                        M3 => \"2_251_799_813_685_248\",\n                        M6 => \"2_305_843_009_213_693_952\",\n                        M9 => \"2_361_183_241_434_822_606_848\",\n                        M12 => \"2_417_851_639_229_258_349_412_352\",\n                    }\n                    B4 => match self.model_type {\n                        M3 => \"288_230_376_151_711_744\",\n                        M6 => \"295_147_905_179_352_825_856\",\n                        M9 => \"302_231_454_903_657_293_676_544\",\n                        M12 => \"309_485_009_821_345_068_724_781_056\",\n                    }\n                    B6 => match self.model_type {\n                        M3 => \"36_893_488_147_419_103_232\",\n                        M6 => \"37_778_931_862_957_161_709_568\",\n                        M9 => \"38_685_626_227_668_133_590_597_632\",\n                        M12 => \"39_614_081_257_132_168_796_771_975_168\",\n                    }\n                    Bmax => match self.model_type {\n                        M3 => \"75_557_863_725_914_323_419_136\",\n                        M6 => \"77_371_252_455_336_267_181_195_264\",\n                        M9 => \"79_228_162_514_264_337_593_543_950_336\",\n                        M12 => \"81_129_638_414_606_681_695_789_005_144_064\",\n                    }\n                }\n                I64 => match self.bound_type {\n                    B0 => match self.model_type {\n                        M3 => \"35_184_372_088_832\",\n                        M6 => \"36_028_797_018_963_968\",\n                        M9 => \"36_893_488_147_419_103_232\",\n                        M12 => \"37_778_931_862_957_161_709_568\",\n                    }\n                    B2 => match self.model_type {\n                        M3 => \"2_251_799_813_685_248\",\n                        M6 => \"2_305_843_009_213_693_952\",\n                        M9 => \"2_361_183_241_434_822_606_848\",\n                        M12 => \"2_417_851_639_229_258_349_412_352\",\n                    }\n                    B4 => match self.model_type {\n                        M3 => \"288_230_376_151_711_744\",\n                        M6 => \"295_147_905_179_352_825_856\",\n                        M9 => \"302_231_454_903_657_293_676_544\",\n                        M12 => \"309_485_009_821_345_068_724_781_056\",\n                    }\n                    B6 => match self.model_type {\n                        M3 => \"36_893_488_147_419_103_232\",\n                        M6 => \"37_778_931_862_957_161_709_568\",\n                        M9 => \"38_685_626_227_668_133_590_597_632\",\n                        M12 => \"39_614_081_257_132_168_796_771_975_168\",\n                    }\n                    Bmax => match self.model_type {\n                        M3 => \"324_518_553_658_426_726_783_156_020_576_256\",\n                        M6 => \"332_306_998_946_228_968_225_951_765_070_086_144\",\n                        M9 => \"340_282_366_920_938_463_463_374_607_431_768_211_456\",\n                        M12 => \"348_449_143_727_040_986_586_495_598_010_130_648_530_944\",\n                    }\n                }\n            }\n        };\n        // safe unwrap: string and radix are valid\n        BigUint::from_str_radix(order_str, 10).unwrap()\n    }\n}\n\n#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]\n/// Convenience struct for a pair of masking configurations.\n///\n/// One configuration is intended for (un)masking a vector of values, the other\n/// for a unit value.\npub struct MaskConfigPair {\n    pub vect: MaskConfig,\n    pub unit: MaskConfig,\n}\n\nimpl From<MaskConfig> for MaskConfigPair {\n    /// Creates two copies of the given masking configuration as a pair.\n    fn from(config: MaskConfig) -> Self {\n        Self {\n            vect: config,\n            unit: config,\n        }\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/mask/config/serialization.rs",
    "content": "//! Serialization of masking configurations.\n//!\n//! See the [mask module] documentation since this is a private module anyways.\n//!\n//! [mask module]: crate::mask\n\nuse std::convert::TryInto;\n\nuse anyhow::{anyhow, Context};\n\nuse crate::{\n    mask::config::MaskConfig,\n    message::{\n        traits::{FromBytes, ToBytes},\n        DecodeError,\n    },\n};\n\nconst GROUP_TYPE_FIELD: usize = 0;\nconst DATA_TYPE_FIELD: usize = 1;\nconst BOUND_TYPE_FIELD: usize = 2;\nconst MODEL_TYPE_FIELD: usize = 3;\npub(crate) const MASK_CONFIG_BUFFER_LEN: usize = 4;\n\n/// A buffer for serialized masking configurations.\npub struct MaskConfigBuffer<T> {\n    inner: T,\n}\n\nimpl<T: AsRef<[u8]>> MaskConfigBuffer<T> {\n    /// Creates a new buffer from `bytes`.\n    ///\n    /// # Errors\n    /// Fails if the `bytes` don't conform to the required buffer length for masking configurations.\n    pub fn new(bytes: T) -> Result<Self, DecodeError> {\n        let buffer = Self { inner: bytes };\n        buffer\n            .check_buffer_length()\n            .context(\"not a valid MaskConfigBuffer\")?;\n        Ok(buffer)\n    }\n\n    /// Creates a new buffer from `bytes`.\n    pub fn new_unchecked(bytes: T) -> Self {\n        Self { inner: bytes }\n    }\n\n    /// Checks if this buffer conforms to the required buffer length for masking configurations.\n    ///\n    /// # Errors\n    /// Fails if the buffer is too small.\n    pub fn check_buffer_length(&self) -> Result<(), DecodeError> {\n        let len = self.inner.as_ref().len();\n        if len < MASK_CONFIG_BUFFER_LEN {\n            return Err(anyhow!(\n                \"invalid buffer length: {} < {}\",\n                len,\n                MASK_CONFIG_BUFFER_LEN\n            ));\n        }\n        Ok(())\n    }\n\n    /// Gets the serialized group type of the masking configuration.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn group_type(&self) -> u8 {\n        self.inner.as_ref()[GROUP_TYPE_FIELD]\n    }\n\n    /// Gets the serialized data type of the masking configuration.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn data_type(&self) -> u8 {\n        self.inner.as_ref()[DATA_TYPE_FIELD]\n    }\n\n    /// Gets the serialized bound type of the masking configuration.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn bound_type(&self) -> u8 {\n        self.inner.as_ref()[BOUND_TYPE_FIELD]\n    }\n\n    /// Gets the serialized model type of the masking configuration.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn model_type(&self) -> u8 {\n        self.inner.as_ref()[MODEL_TYPE_FIELD]\n    }\n}\n\nimpl<T: AsMut<[u8]>> MaskConfigBuffer<T> {\n    /// Sets the serialized group type of the masking configuration.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn set_group_type(&mut self, value: u8) {\n        self.inner.as_mut()[GROUP_TYPE_FIELD] = value;\n    }\n\n    /// Sets the serialized data type of the masking configuration.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn set_data_type(&mut self, value: u8) {\n        self.inner.as_mut()[DATA_TYPE_FIELD] = value;\n    }\n\n    /// Sets the serialized bound type of the masking configuration.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn set_bound_type(&mut self, value: u8) {\n        self.inner.as_mut()[BOUND_TYPE_FIELD] = value;\n    }\n\n    /// Sets the serialized model type of the masking configuration.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn set_model_type(&mut self, value: u8) {\n        self.inner.as_mut()[MODEL_TYPE_FIELD] = value;\n    }\n}\n\nimpl ToBytes for MaskConfig {\n    fn buffer_length(&self) -> usize {\n        MASK_CONFIG_BUFFER_LEN\n    }\n\n    fn to_bytes<T: AsMut<[u8]>>(&self, buffer: &mut T) {\n        let mut writer = MaskConfigBuffer::new_unchecked(buffer.as_mut());\n        writer.set_group_type(self.group_type as u8);\n        writer.set_data_type(self.data_type as u8);\n        writer.set_bound_type(self.bound_type as u8);\n        writer.set_model_type(self.model_type as u8);\n    }\n}\n\nimpl FromBytes for MaskConfig {\n    fn from_byte_slice<T: AsRef<[u8]>>(buffer: &T) -> Result<Self, DecodeError> {\n        let reader = MaskConfigBuffer::new(buffer.as_ref())?;\n        Ok(Self {\n            group_type: reader\n                .group_type()\n                .try_into()\n                .context(\"invalid masking config\")?,\n            data_type: reader\n                .data_type()\n                .try_into()\n                .context(\"invalid masking config\")?,\n            bound_type: reader\n                .bound_type()\n                .try_into()\n                .context(\"invalid masking config\")?,\n            model_type: reader\n                .model_type()\n                .try_into()\n                .context(\"invalid masking config\")?,\n        })\n    }\n\n    fn from_byte_stream<I: Iterator<Item = u8> + ExactSizeIterator>(\n        iter: &mut I,\n    ) -> Result<Self, DecodeError> {\n        let buf: Vec<u8> = iter.take(MASK_CONFIG_BUFFER_LEN).collect();\n        Self::from_byte_slice(&buf)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::mask::config::{BoundType, DataType, GroupType, MaskConfig, ModelType};\n\n    #[test]\n    fn serialize() {\n        let config = MaskConfig {\n            group_type: GroupType::Prime,\n            data_type: DataType::F64,\n            bound_type: BoundType::Bmax,\n            model_type: ModelType::M9,\n        };\n\n        let mut buf = vec![0xff; 4];\n        config.to_bytes(&mut buf);\n        assert_eq!(buf, vec![1, 1, 255, 9]);\n    }\n\n    #[test]\n    fn deserialize() {\n        let bytes = vec![1, 1, 255, 9];\n        let config = MaskConfig::from_byte_slice(&bytes).unwrap();\n        assert_eq!(\n            config,\n            MaskConfig {\n                group_type: GroupType::Prime,\n                data_type: DataType::F64,\n                bound_type: BoundType::Bmax,\n                model_type: ModelType::M9,\n            }\n        );\n    }\n\n    #[test]\n    fn stream_deserialize() {\n        let mut bytes = vec![1, 1, 255, 9].into_iter();\n        let config = MaskConfig::from_byte_stream(&mut bytes).unwrap();\n        assert_eq!(\n            config,\n            MaskConfig {\n                group_type: GroupType::Prime,\n                data_type: DataType::F64,\n                bound_type: BoundType::Bmax,\n                model_type: ModelType::M9,\n            }\n        );\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/mask/masking.rs",
    "content": "//! Masking, aggregation and unmasking of models.\n//!\n//! See the [mask module] documentation since this is a private module anyways.\n//!\n//! [mask module]: crate::mask\n\nuse std::iter::{self, Iterator};\n\nuse num::{\n    bigint::{BigInt, BigUint, ToBigInt},\n    clamp,\n    rational::Ratio,\n    traits::clamp_max,\n};\nuse rand::SeedableRng;\nuse rand_chacha::ChaCha20Rng;\nuse thiserror::Error;\n\nuse crate::{\n    crypto::{prng::generate_integer, ByteObject},\n    mask::{\n        config::MaskConfigPair,\n        model::Model,\n        object::{MaskObject, MaskUnit, MaskVect},\n        scalar::Scalar,\n        seed::MaskSeed,\n    },\n};\n\n#[derive(Debug, Error, Eq, PartialEq)]\n/// Errors related to the unmasking of models.\npub enum UnmaskingError {\n    #[error(\"there is no model to unmask\")]\n    NoModel,\n\n    #[error(\"too many models were aggregated for the current unmasking configuration\")]\n    TooManyModels,\n\n    #[error(\"too many scalars were aggregated for the current unmasking configuration\")]\n    TooManyScalars,\n\n    #[error(\"the masked model is incompatible with the mask used for unmasking\")]\n    MaskManyMismatch,\n\n    #[error(\"the masked scalar is incompatible with the mask used for unmasking\")]\n    MaskOneMismatch,\n\n    #[error(\"the mask is invalid\")]\n    InvalidMask,\n}\n\n#[derive(Debug, Error)]\n/// Errors related to the aggregation of masks and models.\npub enum AggregationError {\n    // TODO rename Model -> Vector; or use MaskMany/One terminology\n    #[error(\"the object to aggregate is invalid\")]\n    InvalidObject,\n\n    #[error(\"too many models were aggregated for the current unmasking configuration\")]\n    TooManyModels,\n\n    #[error(\"too many scalars were aggregated for the current unmasking configuration\")]\n    TooManyScalars,\n\n    #[error(\"the model to aggregate is incompatible with the current aggregated scalar\")]\n    ModelMismatch,\n\n    #[error(\"the scalar to aggregate is incompatible with the current aggregated scalar\")]\n    ScalarMismatch,\n}\n\n#[derive(Debug, Clone)]\n/// An aggregator for masks and masked models.\npub struct Aggregation {\n    nb_models: usize,\n    object: MaskObject,\n    object_size: usize,\n}\n\nimpl From<MaskObject> for Aggregation {\n    fn from(object: MaskObject) -> Self {\n        Self {\n            nb_models: 1,\n            object_size: object.vect.data.len(),\n            object,\n        }\n    }\n}\n\nimpl From<Aggregation> for MaskObject {\n    fn from(aggr: Aggregation) -> Self {\n        aggr.object\n    }\n}\n\n#[allow(clippy::len_without_is_empty)]\nimpl Aggregation {\n    /// Creates a new, empty aggregator for masks or masked models.\n    pub fn new(config: MaskConfigPair, object_size: usize) -> Self {\n        Self {\n            nb_models: 0,\n            object: MaskObject::empty(config, object_size),\n            object_size,\n        }\n    }\n\n    /// Gets the length of the aggregated mask object.\n    pub fn len(&self) -> usize {\n        self.object_size\n    }\n\n    /// Gets the masking configurations of the aggregator.\n    pub fn config(&self) -> MaskConfigPair {\n        MaskConfigPair {\n            vect: self.object.vect.config,\n            unit: self.object.unit.config,\n        }\n    }\n\n    /// Validates if unmasking of the aggregated masked model with the given `mask` may be\n    /// safely performed.\n    ///\n    /// This should be checked before calling [`unmask()`], since unmasking may return garbage\n    /// values otherwise.\n    ///\n    /// # Errors\n    /// Fails in one of the following cases:\n    /// - The aggregator has not yet aggregated any models.\n    /// - The number of aggregated masked models is larger than the chosen masking configuration\n    ///   allows.\n    /// - The masking configuration of the aggregator and of the `mask` don't coincide.\n    /// - The length of the aggregated masked model and the `mask` don't coincide.\n    /// - The `mask` itself is invalid.\n    ///\n    /// Even though it does not produce any meaningful values, it is safe and technically possible\n    /// due to the [`MaskObject`] type to validate, that:\n    /// - a mask may unmask another mask\n    /// - a masked model may unmask a mask\n    /// - a masked model may unmask another masked model\n    ///\n    /// [`unmask()`]: Aggregation::unmask\n    pub fn validate_unmasking(&self, mask: &MaskObject) -> Result<(), UnmaskingError> {\n        // We cannot perform unmasking without at least one real model\n        if self.nb_models == 0 {\n            return Err(UnmaskingError::NoModel);\n        }\n\n        if self.nb_models > self.object.vect.config.model_type.max_nb_models() {\n            return Err(UnmaskingError::TooManyModels);\n        }\n\n        if self.nb_models > self.object.unit.config.model_type.max_nb_models() {\n            return Err(UnmaskingError::TooManyScalars);\n        }\n\n        if self.object.vect.config != mask.vect.config || self.object_size != mask.vect.data.len() {\n            return Err(UnmaskingError::MaskManyMismatch);\n        }\n\n        if self.object.unit.config != mask.unit.config {\n            return Err(UnmaskingError::MaskOneMismatch);\n        }\n\n        if !mask.is_valid() {\n            return Err(UnmaskingError::InvalidMask);\n        }\n\n        Ok(())\n    }\n\n    /// Unmasks the aggregated masked model with the given `mask`.\n    ///\n    /// It should be checked that [`validate_unmasking()`] succeeds before calling this, since\n    /// unmasking may return garbage values otherwise. The unmasking is performed in opposite order\n    /// as described for [`mask()`].\n    ///\n    /// # Panics\n    /// This may only panic if [`validate_unmasking()`] fails.\n    ///\n    /// Even though it does not produce any meaningful values, it is safe and technically possible\n    /// due to the [`MaskObject`] type to unmask:\n    /// - a mask with another mask\n    /// - a mask with a masked model\n    /// - a masked model with another masked model\n    ///\n    /// if [`validate_unmasking()`] returns `true`.\n    ///\n    /// [`validate_unmasking()`]: Aggregation::validate_unmasking\n    /// [`mask()`]: Masker::mask\n    pub fn unmask(self, mask_obj: MaskObject) -> Model {\n        let MaskObject { vect, unit } = self.object;\n        let (masked_n, config_n) = (vect.data, vect.config);\n        let (masked_1, config_1) = (unit.data, unit.config);\n        let mask_n = mask_obj.vect.data;\n        let mask_1 = mask_obj.unit.data;\n\n        // unmask scalar sum\n        let scaled_add_shift_1 = config_1.add_shift() * BigInt::from(self.nb_models);\n        let exp_shift_1 = config_1.exp_shift();\n        let order_1 = config_1.order();\n        let n = (masked_1 + &order_1 - mask_1) % &order_1;\n        let ratio = Ratio::<BigInt>::from(n.to_bigint().unwrap());\n        let scalar_sum = ratio / &exp_shift_1 - &scaled_add_shift_1;\n\n        // unmask global model\n        let scaled_add_shift_n = config_n.add_shift() * BigInt::from(self.nb_models);\n        let exp_shift_n = config_n.exp_shift();\n        let order_n = config_n.order();\n        masked_n\n            .into_iter()\n            .zip(mask_n)\n            .map(|(masked, mask)| {\n                // PANIC_SAFE: The substraction panics if it\n                // underflows, which can only happen if:\n                //\n                //     mask > order_n\n                //\n                // If the mask is valid, we are guaranteed that this\n                // cannot happen. Thus this method may panic only if\n                // given an invalid mask.\n                let n = (masked + &order_n - mask) % &order_n;\n\n                // UNWRAP_SAFE: to_bigint never fails for BigUint\n                let ratio = Ratio::<BigInt>::from(n.to_bigint().unwrap());\n                let unmasked = ratio / &exp_shift_n - &scaled_add_shift_n;\n\n                // scaling correction\n                unmasked / &scalar_sum\n            })\n            .collect()\n    }\n\n    /// Validates if aggregation of the aggregated mask object with the given `object` may be safely\n    /// performed.\n    ///\n    /// This should be checked before calling [`aggregate()`], since aggregation may return garbage\n    /// values otherwise.\n    ///\n    /// # Errors\n    /// Fails in one of the following cases:\n    /// - The masking configuration of the aggregator and of the `object` don't coincide.\n    /// - The length of the aggregated masks or masked model and the `object` don't coincide. If the\n    ///   aggregator is empty, then an `object` of any length may be aggregated.\n    /// - The new number of aggregated masks or masked models would exceed the number that the\n    ///   chosen masking configuration allows.\n    /// - The `object` itself is invalid.\n    ///\n    /// Even though it does not produce any meaningful values, it is safe and technically possible\n    /// due to the [`MaskObject`] type to validate, that a mask may be aggregated with a masked\n    /// model.\n    ///\n    /// [`aggregate()`]: Aggregation::aggregate\n    pub fn validate_aggregation(&self, object: &MaskObject) -> Result<(), AggregationError> {\n        if self.object.vect.config != object.vect.config {\n            return Err(AggregationError::ModelMismatch);\n        }\n\n        if self.object.unit.config != object.unit.config {\n            return Err(AggregationError::ScalarMismatch);\n        }\n\n        if self.object_size != object.vect.data.len() {\n            return Err(AggregationError::ModelMismatch);\n        }\n\n        if self.nb_models >= self.object.vect.config.model_type.max_nb_models() {\n            return Err(AggregationError::TooManyModels);\n        }\n\n        if self.nb_models >= self.object.unit.config.model_type.max_nb_models() {\n            return Err(AggregationError::TooManyScalars);\n        }\n\n        if !object.is_valid() {\n            return Err(AggregationError::InvalidObject);\n        }\n\n        Ok(())\n    }\n\n    /// Aggregates the aggregated mask object with the given `object`.\n    ///\n    /// It should be checked that [`validate_aggregation()`] succeeds before calling this, since\n    /// aggregation may return garbage values otherwise.\n    ///\n    /// # Errors\n    /// Even though it does not produce any meaningful values, it is safe and technically possible\n    /// due to the [`MaskObject`] type to aggregate a mask with a masked model if\n    /// [`validate_aggregation()`] returns `true`.\n    ///\n    /// [`validate_aggregation()`]: Aggregation::validate_aggregation\n    pub fn aggregate(&mut self, object: MaskObject) {\n        if self.nb_models == 0 {\n            self.object = object;\n            self.nb_models = 1;\n            return;\n        }\n\n        let order_n = self.object.vect.config.order();\n        for (i, j) in self\n            .object\n            .vect\n            .data\n            .iter_mut()\n            .zip(object.vect.data.into_iter())\n        {\n            *i = (&*i + j) % &order_n\n        }\n\n        let order_1 = self.object.unit.config.order();\n        let a = &mut self.object.unit.data;\n        let b = object.unit.data;\n        *a = (&*a + b) % &order_1;\n\n        self.nb_models += 1;\n    }\n}\n\n/// A masker for models.\npub struct Masker {\n    config: MaskConfigPair,\n    seed: MaskSeed,\n}\n\nimpl Masker {\n    /// Creates a new masker with the given masking `config`uration with a randomly generated seed.\n    pub fn new(config: MaskConfigPair) -> Self {\n        Self {\n            config,\n            seed: MaskSeed::generate(),\n        }\n    }\n\n    /// Creates a new masker with the given masking `config`uration and `seed`.\n    pub fn with_seed(config: MaskConfigPair, seed: MaskSeed) -> Self {\n        Self { config, seed }\n    }\n}\n\nimpl Masker {\n    /// Masks the given `model` wrt the masking configuration. Enforces bounds on the scalar and\n    /// weights.\n    ///\n    /// The masking proceeds in the following steps:\n    /// - Clamp the scalar and the weights according to the masking configuration.\n    /// - Scale the weights by the scalar.\n    /// - Shift the weights into the non-negative reals.\n    /// - Shift the weights into the non-negative integers.\n    /// - Shift the weights into the finite group.\n    /// - Mask the weights with random elements from the finite group.\n    ///\n    /// The `scalar` is also masked, following a similar process.\n    ///\n    /// The random elements are derived from a seeded PRNG. Unmasking as performed in [`unmask()`]\n    /// proceeds in reverse order.\n    ///\n    /// [`unmask()`]: Aggregation::unmask\n    pub fn mask(self, scalar: Scalar, model: &Model) -> (MaskSeed, MaskObject) {\n        let (random_int, mut random_ints) = self.random_ints();\n        let Self { config, seed } = self;\n        let MaskConfigPair {\n            vect: config_n,\n            unit: config_1,\n        } = config;\n\n        // clamp the scalar\n        let add_shift_1 = config_1.add_shift();\n        let scalar_ratio = scalar.into();\n        let scalar_clamped = clamp_max(&scalar_ratio, &add_shift_1);\n\n        let exp_shift_n = config_n.exp_shift();\n        let add_shift_n = config_n.add_shift();\n        let order_n = config_n.order();\n        let higher_bound = &add_shift_n;\n        let lower_bound = -&add_shift_n;\n\n        // mask the (scaled) weights\n        let masked_weights = model\n            .iter()\n            .zip(&mut random_ints)\n            .map(|(weight, rand_int)| {\n                let scaled = scalar_clamped * weight;\n                let scaled_clamped = clamp(&scaled, &lower_bound, higher_bound);\n                // PANIC_SAFE: shifted weight is guaranteed to be non-negative\n                let shifted = ((scaled_clamped + &add_shift_n) * &exp_shift_n)\n                    .to_integer()\n                    .to_biguint()\n                    .unwrap();\n                (shifted + rand_int) % &order_n\n            })\n            .collect();\n        let masked_model = MaskVect::new_unchecked(config_n, masked_weights);\n\n        // mask the scalar\n        // PANIC_SAFE: shifted scalar is guaranteed to be non-negative\n        let shifted = ((scalar_clamped + &add_shift_1) * config_1.exp_shift())\n            .to_integer()\n            .to_biguint()\n            .unwrap();\n        let masked = (shifted + random_int) % config_1.order();\n        let masked_scalar = MaskUnit::new_unchecked(config_1, masked);\n\n        (seed, MaskObject::new_unchecked(masked_model, masked_scalar))\n    }\n\n    /// Randomly generates integers wrt the masking configurations.\n    ///\n    /// The first is generated wrt the scalar configuration, while the rest are\n    /// wrt the vector configuration and returned as an iterator.\n    fn random_ints(&self) -> (BigUint, impl Iterator<Item = BigUint>) {\n        let order_n = self.config.vect.order();\n        let order_1 = self.config.unit.order();\n        let mut prng = ChaCha20Rng::from_seed(self.seed.as_array());\n        let int = generate_integer(&mut prng, &order_1);\n        let ints = iter::from_fn(move || Some(generate_integer(&mut prng, &order_n)));\n        (int, ints)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use std::iter;\n\n    use num::traits::Signed;\n    use rand::{\n        distributions::{Distribution, Uniform},\n        SeedableRng,\n    };\n    use rand_chacha::ChaCha20Rng;\n\n    use super::*;\n    use crate::mask::{\n        config::{\n            BoundType::{Bmax, B0, B2, B4, B6},\n            DataType::{F32, F64, I32, I64},\n            GroupType::{Integer, Power2, Prime},\n            MaskConfig,\n            ModelType::M3,\n        },\n        model::FromPrimitives,\n        scalar::FromPrimitive,\n    };\n\n    /// Generate tests for masking and unmasking of a single model:\n    /// - generate random weights from a uniform distribution with a seeded PRNG\n    /// - create a model from the weights and mask it\n    /// - check that all masked weights belong to the chosen finite group\n    /// - unmask the masked model\n    /// - check that all unmasked weights are equal to the original weights (up to a tolerance\n    ///   determined by the masking configuration)\n    ///\n    /// The arguments to the macro are:\n    /// - a suffix for the test name\n    /// - the group type of the model (variants of `GroupType`)\n    /// - the data type of the model (either primitives or variants of `DataType`)\n    /// - an absolute bound for the weights (optional, choices: 1, 100, 10_000, 1_000_000)\n    /// - the number of weights\n    macro_rules! test_masking {\n        ($suffix:ident, $group:ty, $data:ty, $bound:expr, $len:expr $(,)?) => {\n            paste::item! {\n                #[test]\n                fn [<test_masking_ $suffix>]() {\n                    // Step 1: Build the masking config\n                    let config = MaskConfig {\n                        group_type: $group,\n                        data_type: paste::expr! { [<$data:upper>] },\n                        bound_type: match $bound {\n                            1 => B0,\n                            100 => B2,\n                            10_000 => B4,\n                            1_000_000 => B6,\n                            _ => Bmax,\n                        },\n                        model_type: M3,\n                    };\n                    let vect_len = $len as usize;\n\n                    // Step 2: Generate a random model\n                    let bound = if $bound == 0 {\n                        paste::expr! { [<$data:lower>]::MAX / (2.1 as [<$data:lower>]) }\n                    } else {\n                        paste::expr! { $bound as [<$data:lower>] }\n                    };\n                    let mut prng = ChaCha20Rng::from_seed(MaskSeed::generate().as_array());\n                    let random_weights = Uniform::new_inclusive(-bound, bound)\n                        .sample_iter(&mut prng)\n                        .take(vect_len);\n                    let model = Model::from_primitives(random_weights).unwrap();\n                    assert_eq!(model.len(), vect_len);\n\n                    // Step 3 (actual test):\n                    // a. mask the model\n                    // b. derive the mask corresponding to the seed used\n                    // c. unmask the model and check it against the original one.\n                    let (mask_seed, masked_model) =\n                        Masker::new(config.into()).mask(Scalar::unit(), &model);\n                    assert_eq!(masked_model.vect.data.len(), vect_len);\n                    assert!(masked_model.is_valid());\n\n                    let mask = mask_seed.derive_mask(vect_len, config.into());\n                    let aggregation = Aggregation::from(masked_model);\n                    let unmasked_model = aggregation.unmask(mask);\n\n                    let tolerance = Ratio::from_integer(config.exp_shift()).recip();\n                    assert!(\n                        model.iter()\n                            .zip(unmasked_model.iter())\n                            .all(|(weight, unmasked_weight)| {\n                                (weight - unmasked_weight).abs() <= tolerance\n                            })\n                    );\n                }\n            }\n        };\n        ($suffix:ident, $group:ty, $data:ty, $len:expr $(,)?) => {\n            test_masking!($suffix, $group, $data, 0, $len);\n        };\n    }\n\n    test_masking!(int_f32_b0, Integer, f32, 1, 10);\n    test_masking!(int_f32_b2, Integer, f32, 100, 10);\n    test_masking!(int_f32_b4, Integer, f32, 10_000, 10);\n    test_masking!(int_f32_b6, Integer, f32, 1_000_000, 10);\n    test_masking!(int_f32_bmax, Integer, f32, 10);\n\n    test_masking!(prime_f32_b0, Prime, f32, 1, 10);\n    test_masking!(prime_f32_b2, Prime, f32, 100, 10);\n    test_masking!(prime_f32_b4, Prime, f32, 10_000, 10);\n    test_masking!(prime_f32_b6, Prime, f32, 1_000_000, 10);\n    test_masking!(prime_f32_bmax, Prime, f32, 10);\n\n    test_masking!(pow_f32_b0, Power2, f32, 1, 10);\n    test_masking!(pow_f32_b2, Power2, f32, 100, 10);\n    test_masking!(pow_f32_b4, Power2, f32, 10_000, 10);\n    test_masking!(pow_f32_b6, Power2, f32, 1_000_000, 10);\n    test_masking!(pow_f32_bmax, Power2, f32, 10);\n\n    test_masking!(int_f64_b0, Integer, f64, 1, 10);\n    test_masking!(int_f64_b2, Integer, f64, 100, 10);\n    test_masking!(int_f64_b4, Integer, f64, 10_000, 10);\n    test_masking!(int_f64_b6, Integer, f64, 1_000_000, 10);\n    test_masking!(int_f64_bmax, Integer, f64, 10);\n\n    test_masking!(prime_f64_b0, Prime, f64, 1, 10);\n    test_masking!(prime_f64_b2, Prime, f64, 100, 10);\n    test_masking!(prime_f64_b4, Prime, f64, 10_000, 10);\n    test_masking!(prime_f64_b6, Prime, f64, 1_000_000, 10);\n    test_masking!(prime_f64_bmax, Prime, f64, 10);\n\n    test_masking!(pow_f64_b0, Power2, f64, 1, 10);\n    test_masking!(pow_f64_b2, Power2, f64, 100, 10);\n    test_masking!(pow_f64_b4, Power2, f64, 10_000, 10);\n    test_masking!(pow_f64_b6, Power2, f64, 1_000_000, 10);\n    test_masking!(pow_f64_bmax, Power2, f64, 10);\n\n    test_masking!(int_i32_b0, Integer, i32, 1, 10);\n    test_masking!(int_i32_b2, Integer, i32, 100, 10);\n    test_masking!(int_i32_b4, Integer, i32, 10_000, 10);\n    test_masking!(int_i32_b6, Integer, i32, 1_000_000, 10);\n    test_masking!(int_i32_bmax, Integer, i32, 10);\n\n    test_masking!(prime_i32_b0, Prime, i32, 1, 10);\n    test_masking!(prime_i32_b2, Prime, i32, 100, 10);\n    test_masking!(prime_i32_b4, Prime, i32, 10_000, 10);\n    test_masking!(prime_i32_b6, Prime, i32, 1_000_000, 10);\n    test_masking!(prime_i32_bmax, Prime, i32, 10);\n\n    test_masking!(pow_i32_b0, Power2, i32, 1, 10);\n    test_masking!(pow_i32_b2, Power2, i32, 100, 10);\n    test_masking!(pow_i32_b4, Power2, i32, 10_000, 10);\n    test_masking!(pow_i32_b6, Power2, i32, 1_000_000, 10);\n    test_masking!(pow_i32_bmax, Power2, i32, 10);\n\n    test_masking!(int_i64_b0, Integer, i64, 1, 10);\n    test_masking!(int_i64_b2, Integer, i64, 100, 10);\n    test_masking!(int_i64_b4, Integer, i64, 10_000, 10);\n    test_masking!(int_i64_b6, Integer, i64, 1_000_000, 10);\n    test_masking!(int_i64_bmax, Integer, i64, 10);\n\n    test_masking!(prime_i64_b0, Prime, i64, 1, 10);\n    test_masking!(prime_i64_b2, Prime, i64, 100, 10);\n    test_masking!(prime_i64_b4, Prime, i64, 10_000, 10);\n    test_masking!(prime_i64_b6, Prime, i64, 1_000_000, 10);\n    test_masking!(prime_i64_bmax, Prime, i64, 10);\n\n    test_masking!(pow_i64_b0, Power2, i64, 1, 10);\n    test_masking!(pow_i64_b2, Power2, i64, 100, 10);\n    test_masking!(pow_i64_b4, Power2, i64, 10_000, 10);\n    test_masking!(pow_i64_b6, Power2, i64, 1_000_000, 10);\n    test_masking!(pow_i64_bmax, Power2, i64, 10);\n\n    /// Generate tests for masking and unmasking of a single model:\n    /// - generate random scalar from a uniform distribution with a seeded PRNG\n    /// - scale a model of unit weights and mask it\n    /// - check that all masked weights belong to the chosen finite group\n    /// - unmask the masked model\n    /// - check that all unmasked weights are equal to the original weights (up to a tolerance\n    ///   determined by the masking configuration)\n    ///\n    /// The arguments to the macro are:\n    /// - a suffix for the test name\n    /// - the group type of the model and scalar (variants of `GroupType`)\n    /// - the data type of the model and scalar (either float primitives or float variants of\n    ///   `DataType`)\n    /// - an absolute bound for the scalar (optional, choices: 1, 100, 10_000, 1_000_000)\n    /// - the number of weights\n    macro_rules! test_masking_scalar {\n        ($suffix:ident, $group:ty, $data:ty, $bound:expr, $len:expr $(,)?) => {\n            paste::item! {\n                #[test]\n                fn [<test_masking_scalar_ $suffix>]() {\n                    // Step 1: Build the masking config\n                    let config = MaskConfig {\n                        group_type: $group,\n                        data_type: paste::expr! { [<$data:upper>] },\n                        bound_type: match $bound {\n                            1 => B0,\n                            100 => B2,\n                            10_000 => B4,\n                            1_000_000 => B6,\n                            _ => Bmax,\n                        },\n                        model_type: M3,\n                    };\n                    let vect_len = $len as usize;\n\n                    // Step 2: Generate a random scalar from (0, bound]\n                    // take vector [1, ..., 1] as the model to scale\n                    let bound = if $bound == 0 {\n                        paste::expr! { [<$data:lower>]::MAX / (2.1 as [<$data:lower>]) }\n                    } else {\n                        paste::expr! { $bound as [<$data:lower>] }\n                    };\n                    let eps = [<$data:lower>]::EPSILON;\n                    let mut prng = ChaCha20Rng::from_seed(MaskSeed::generate().as_array());\n                    let random_weight = Uniform::new_inclusive(eps, bound).sample(&mut prng);\n                    let scalar = Scalar::from_primitive(random_weight).unwrap();\n                    let model = Model::from_primitives(iter::repeat(1).take(vect_len)).unwrap();\n                    assert_eq!(model.len(), vect_len);\n\n                    // Step 3 (actual test):\n                    // a. mask the model\n                    // b. derive the mask corresponding to the seed used\n                    // c. unmask the model and check it against the expected [1, ..., 1]\n                    let (mask_seed, masked_model) =\n                        Masker::new(config.into()).mask(scalar, &model);\n                    assert_eq!(masked_model.vect.data.len(), vect_len);\n                    assert!(masked_model.is_valid());\n\n                    let mask = mask_seed.derive_mask(vect_len, config.into());\n                    let unmasked_model = Aggregation::from(masked_model).unmask(mask);\n\n                    let tolerance = Ratio::from_integer(config.exp_shift()).recip();\n                    let expected_weight = Ratio::from_integer(BigInt::from(1));\n                    assert!(\n                        unmasked_model\n                            .iter()\n                            .all(|unmasked_weight| {\n                                (unmasked_weight - &expected_weight).abs() <= tolerance\n                            })\n                    );\n                }\n            }\n        };\n        ($suffix:ident, $group:ty, $data:ty, $len:expr $(,)?) => {\n            test_masking_scalar!($suffix, $group, $data, 0, $len);\n        };\n    }\n\n    test_masking_scalar!(int_f32_b0, Integer, f32, 1, 10);\n    test_masking_scalar!(int_f32_b2, Integer, f32, 100, 10);\n    test_masking_scalar!(int_f32_b4, Integer, f32, 10_000, 10);\n    test_masking_scalar!(int_f32_b6, Integer, f32, 1_000_000, 10);\n    test_masking_scalar!(int_f32_bmax, Integer, f32, 10);\n\n    test_masking_scalar!(prime_f32_b0, Prime, f32, 1, 10);\n    test_masking_scalar!(prime_f32_b2, Prime, f32, 100, 10);\n    test_masking_scalar!(prime_f32_b4, Prime, f32, 10_000, 10);\n    test_masking_scalar!(prime_f32_b6, Prime, f32, 1_000_000, 10);\n    test_masking_scalar!(prime_f32_bmax, Prime, f32, 10);\n\n    test_masking_scalar!(pow_f32_b0, Power2, f32, 1, 10);\n    test_masking_scalar!(pow_f32_b2, Power2, f32, 100, 10);\n    test_masking_scalar!(pow_f32_b4, Power2, f32, 10_000, 10);\n    test_masking_scalar!(pow_f32_b6, Power2, f32, 1_000_000, 10);\n    test_masking_scalar!(pow_f32_bmax, Power2, f32, 10);\n\n    test_masking_scalar!(int_f64_b0, Integer, f64, 1, 10);\n    test_masking_scalar!(int_f64_b2, Integer, f64, 100, 10);\n    test_masking_scalar!(int_f64_b4, Integer, f64, 10_000, 10);\n    test_masking_scalar!(int_f64_b6, Integer, f64, 1_000_000, 10);\n    test_masking_scalar!(int_f64_bmax, Integer, f64, 10);\n\n    test_masking_scalar!(prime_f64_b0, Prime, f64, 1, 10);\n    test_masking_scalar!(prime_f64_b2, Prime, f64, 100, 10);\n    test_masking_scalar!(prime_f64_b4, Prime, f64, 10_000, 10);\n    test_masking_scalar!(prime_f64_b6, Prime, f64, 1_000_000, 10);\n    test_masking_scalar!(prime_f64_bmax, Prime, f64, 10);\n\n    test_masking_scalar!(pow_f64_b0, Power2, f64, 1, 10);\n    test_masking_scalar!(pow_f64_b2, Power2, f64, 100, 10);\n    test_masking_scalar!(pow_f64_b4, Power2, f64, 10_000, 10);\n    test_masking_scalar!(pow_f64_b6, Power2, f64, 1_000_000, 10);\n    test_masking_scalar!(pow_f64_bmax, Power2, f64, 10);\n\n    /// Generate tests for aggregation of multiple masked models:\n    /// - generate random integers from a uniform distribution with a seeded PRNG\n    /// - create a masked model from the integers and aggregate it to the aggregated masked models\n    /// - check that all integers belong to the chosen finite group\n    ///\n    /// The arguments to the macro are:\n    /// - a suffix for the test name\n    /// - the group type of the model (variants of `GroupType`)\n    /// - the data type of the model (variants of `DataType`)\n    /// - the bound type of the model (variants of `BoundType`)\n    /// - the number of integers per masked model\n    /// - the number of masked models\n    macro_rules! test_aggregation {\n        ($suffix:ident, $group:ty, $data:ty, $bound:expr, $len:expr, $count:expr $(,)?) => {\n            paste::item! {\n                #[test]\n                fn [<test_aggregation_ $suffix>]() {\n                    // Step 1: Build the masking config\n                    let config = MaskConfig {\n                        group_type: $group,\n                        data_type: $data,\n                        bound_type: $bound,\n                        model_type: M3,\n                    };\n                    let vect_len = $len as usize;\n\n                    // Step 2: generate random masked models\n                    let mut prng = ChaCha20Rng::from_seed(MaskSeed::generate().as_array());\n                    let mut masked_models = iter::repeat_with(move || {\n                        let order = config.order();\n                        let integer = generate_integer(&mut prng, &order);\n                        let integers = iter::repeat_with(|| generate_integer(&mut prng, &order))\n                            .take(vect_len)\n                            .collect::<Vec<_>>();\n                        MaskObject::new(config.into(), integers, integer).unwrap()\n                    });\n\n                    // Step 3 (actual test):\n                    // a. aggregate the masked models\n                    // b. check the aggregated masked model\n                    let mut aggregated_masked_model = Aggregation::new(config.into(), vect_len);\n                    for nb in 1..$count as usize + 1 {\n                        let masked_model = masked_models.next().unwrap();\n                        assert!(\n                            aggregated_masked_model.validate_aggregation(&masked_model).is_ok()\n                        );\n                        aggregated_masked_model.aggregate(masked_model);\n\n                        assert_eq!(aggregated_masked_model.nb_models, nb);\n                        assert_eq!(aggregated_masked_model.object.vect.data.len(), vect_len);\n                        assert_eq!(aggregated_masked_model.object.vect.config, config);\n                        assert_eq!(aggregated_masked_model.object.unit.config, config);\n                        assert!(aggregated_masked_model.object.is_valid());\n                    }\n                }\n            }\n        };\n    }\n\n    test_aggregation!(int_f32_b0, Integer, F32, B0, 10, 5);\n    test_aggregation!(int_f32_b2, Integer, F32, B2, 10, 5);\n    test_aggregation!(int_f32_b4, Integer, F32, B4, 10, 5);\n    test_aggregation!(int_f32_b6, Integer, F32, B6, 10, 5);\n    test_aggregation!(int_f32_bmax, Integer, F32, Bmax, 10, 5);\n\n    test_aggregation!(prime_f32_b0, Prime, F32, B0, 10, 5);\n    test_aggregation!(prime_f32_b2, Prime, F32, B2, 10, 5);\n    test_aggregation!(prime_f32_b4, Prime, F32, B4, 10, 5);\n    test_aggregation!(prime_f32_b6, Prime, F32, B6, 10, 5);\n    test_aggregation!(prime_f32_bmax, Prime, F32, Bmax, 10, 5);\n\n    test_aggregation!(pow_f32_b0, Power2, F32, B0, 10, 5);\n    test_aggregation!(pow_f32_b2, Power2, F32, B2, 10, 5);\n    test_aggregation!(pow_f32_b4, Power2, F32, B4, 10, 5);\n    test_aggregation!(pow_f32_b6, Power2, F32, B6, 10, 5);\n    test_aggregation!(pow_f32_bmax, Power2, F32, Bmax, 10, 5);\n\n    test_aggregation!(int_f64_b0, Integer, F64, B0, 10, 5);\n    test_aggregation!(int_f64_b2, Integer, F64, B2, 10, 5);\n    test_aggregation!(int_f64_b4, Integer, F64, B4, 10, 5);\n    test_aggregation!(int_f64_b6, Integer, F64, B6, 10, 5);\n    test_aggregation!(int_f64_bmax, Integer, F64, Bmax, 10, 5);\n\n    test_aggregation!(prime_f64_b0, Prime, F64, B0, 10, 5);\n    test_aggregation!(prime_f64_b2, Prime, F64, B2, 10, 5);\n    test_aggregation!(prime_f64_b4, Prime, F64, B4, 10, 5);\n    test_aggregation!(prime_f64_b6, Prime, F64, B6, 10, 5);\n    test_aggregation!(prime_f64_bmax, Prime, F64, Bmax, 10, 5);\n\n    test_aggregation!(pow_f64_b0, Power2, F64, B0, 10, 5);\n    test_aggregation!(pow_f64_b2, Power2, F64, B2, 10, 5);\n    test_aggregation!(pow_f64_b4, Power2, F64, B4, 10, 5);\n    test_aggregation!(pow_f64_b6, Power2, F64, B6, 10, 5);\n    test_aggregation!(pow_f64_bmax, Power2, F64, Bmax, 10, 5);\n\n    test_aggregation!(int_i32_b0, Integer, I32, B0, 10, 5);\n    test_aggregation!(int_i32_b2, Integer, I32, B2, 10, 5);\n    test_aggregation!(int_i32_b4, Integer, I32, B4, 10, 5);\n    test_aggregation!(int_i32_b6, Integer, I32, B6, 10, 5);\n    test_aggregation!(int_i32_bmax, Integer, I32, Bmax, 10, 5);\n\n    test_aggregation!(prime_i32_b0, Prime, I32, B0, 10, 5);\n    test_aggregation!(prime_i32_b2, Prime, I32, B2, 10, 5);\n    test_aggregation!(prime_i32_b4, Prime, I32, B4, 10, 5);\n    test_aggregation!(prime_i32_b6, Prime, I32, B6, 10, 5);\n    test_aggregation!(prime_i32_bmax, Prime, I32, Bmax, 10, 5);\n\n    test_aggregation!(pow_i32_b0, Power2, I32, B0, 10, 5);\n    test_aggregation!(pow_i32_b2, Power2, I32, B2, 10, 5);\n    test_aggregation!(pow_i32_b4, Power2, I32, B4, 10, 5);\n    test_aggregation!(pow_i32_b6, Power2, I32, B6, 10, 5);\n    test_aggregation!(pow_i32_bmax, Power2, I32, Bmax, 10, 5);\n\n    test_aggregation!(int_i64_b0, Integer, I64, B0, 10, 5);\n    test_aggregation!(int_i64_b2, Integer, I64, B2, 10, 5);\n    test_aggregation!(int_i64_b4, Integer, I64, B4, 10, 5);\n    test_aggregation!(int_i64_b6, Integer, I64, B6, 10, 5);\n    test_aggregation!(int_i64_bmax, Integer, I64, Bmax, 10, 5);\n\n    test_aggregation!(prime_i64_b0, Prime, I64, B0, 10, 5);\n    test_aggregation!(prime_i64_b2, Prime, I64, B2, 10, 5);\n    test_aggregation!(prime_i64_b4, Prime, I64, B4, 10, 5);\n    test_aggregation!(prime_i64_b6, Prime, I64, B6, 10, 5);\n    test_aggregation!(prime_i64_bmax, Prime, I64, Bmax, 10, 5);\n\n    test_aggregation!(pow_i64_b0, Power2, I64, B0, 10, 5);\n    test_aggregation!(pow_i64_b2, Power2, I64, B2, 10, 5);\n    test_aggregation!(pow_i64_b4, Power2, I64, B4, 10, 5);\n    test_aggregation!(pow_i64_b6, Power2, I64, B6, 10, 5);\n    test_aggregation!(pow_i64_bmax, Power2, I64, Bmax, 10, 5);\n\n    /// Generate tests for masking, aggregation and unmasking of multiple models:\n    /// - generate random weights from a uniform distribution with a seeded PRNG\n    /// - create a model from the weights, mask and aggregate it to the aggregated masked models\n    /// - derive a mask from the mask seed and aggregate it to the aggregated masks\n    /// - unmask the aggregated masked model\n    /// - check that all aggregated unmasked weights are equal to the averaged original weights (up\n    ///   to a tolerance determined by the masking configuration)\n    ///\n    /// The arguments to the macro are:\n    /// - a suffix for the test name\n    /// - the group type of the model (variants of `GroupType`)\n    /// - the data type of the model (either primitives or variants of `DataType`)\n    /// - an absolute bound for the weights (optional, choices: 1, 100, 10_000, 1_000_000)\n    /// - the number of weights per model\n    /// - the number of models\n    macro_rules! test_masking_and_aggregation {\n        ($suffix:ident, $group:ty, $data:ty, $bound:expr, $len:expr, $count:expr $(,)?) => {\n            paste::item! {\n                #[test]\n                fn [<test_masking_and_aggregation_ $suffix>]() {\n                    // Step 1: Build the masking config\n                    let config = MaskConfig {\n                        group_type: $group,\n                        data_type: paste::expr! { [<$data:upper>] },\n                        bound_type: match $bound {\n                            1 => B0,\n                            100 => B2,\n                            10_000 => B4,\n                            1_000_000 => B6,\n                            _ => Bmax,\n                        },\n                        model_type: M3,\n                    };\n                    let vect_len = $len as usize;\n                    let model_count = $count as usize;\n\n                    // Step 2: Generate random models\n                    let bound = if $bound == 0 {\n                        paste::expr! { [<$data:lower>]::MAX / (2.1 as [<$data:lower>]) }\n                    } else {\n                        paste::expr! { $bound as [<$data:lower>] }\n                    };\n                    let mut prng = ChaCha20Rng::from_seed(MaskSeed::generate().as_array());\n                    let mut models = iter::repeat_with(move || {\n                        Model::from_primitives(\n                            Uniform::new_inclusive(-bound, bound)\n                                .sample_iter(&mut prng)\n                                .take(vect_len)\n                        )\n                        .unwrap()\n                    });\n\n                    // Step 3 (actual test):\n                    // a. average the model weights for later checks\n                    // b. mask the model\n                    // c. derive the mask corresponding to the seed used\n                    // d. aggregate the masked model resp. mask\n                    // e. repeat a-d, then unmask the model and check it against the averaged one\n                    let mut averaged_model = Model::from_primitives(\n                        iter::repeat(paste::expr! { 0 as [<$data:lower>] }).take(vect_len)\n                    )\n                    .unwrap();\n                    let mut aggregated_masked_model = Aggregation::new(config.into(), vect_len);\n                    let mut aggregated_mask = Aggregation::new(config.into(), vect_len);\n                    let scalar = Scalar::new(1, model_count);\n                    let scalar_ratio = &scalar.to_ratio();\n                    for _ in 0..model_count {\n                        let model = models.next().unwrap();\n                        averaged_model\n                            .iter_mut()\n                            .zip(model.iter())\n                            .for_each(|(averaged_weight, weight)| {\n                                *averaged_weight += scalar_ratio * weight;\n                            });\n\n                        let (mask_seed, masked_model) =\n                            Masker::new(config.into()).mask(scalar.clone(), &model);\n                        let mask = mask_seed.derive_mask(vect_len, config.into());\n\n                        assert!(\n                            aggregated_masked_model.validate_aggregation(&masked_model).is_ok()\n                        );\n                        aggregated_masked_model.aggregate(masked_model);\n                        assert!(aggregated_mask.validate_aggregation(&mask).is_ok());\n                        aggregated_mask.aggregate(mask);\n                    }\n\n                    let mask = aggregated_mask.into();\n                    assert!(aggregated_masked_model.validate_unmasking(&mask).is_ok());\n                    let unmasked_model = aggregated_masked_model.unmask(mask);\n                    let tolerance = Ratio::from_integer(BigInt::from(model_count))\n                        / Ratio::from_integer(config.exp_shift());\n                    assert!(\n                        averaged_model.iter()\n                            .zip(unmasked_model.iter())\n                            .all(|(averaged_weight, unmasked_weight)| {\n                                (averaged_weight - unmasked_weight).abs() <= tolerance\n                            })\n                    );\n                }\n            }\n        };\n        ($suffix:ident, $group:ty, $data:ty, $len:expr, $count:expr $(,)?) => {\n            test_masking_and_aggregation!($suffix, $group, $data, 0, $len, $count);\n        };\n    }\n\n    test_masking_and_aggregation!(int_f32_b0, Integer, f32, 1, 10, 5);\n    test_masking_and_aggregation!(int_f32_b2, Integer, f32, 100, 10, 5);\n    test_masking_and_aggregation!(int_f32_b4, Integer, f32, 10_000, 10, 5);\n    test_masking_and_aggregation!(int_f32_b6, Integer, f32, 1_000_000, 10, 5);\n    test_masking_and_aggregation!(int_f32_bmax, Integer, f32, 10, 5);\n\n    test_masking_and_aggregation!(prime_f32_b0, Prime, f32, 1, 10, 5);\n    test_masking_and_aggregation!(prime_f32_b2, Prime, f32, 100, 10, 5);\n    test_masking_and_aggregation!(prime_f32_b4, Prime, f32, 10_000, 10, 5);\n    test_masking_and_aggregation!(prime_f32_b6, Prime, f32, 1_000_000, 10, 5);\n    test_masking_and_aggregation!(prime_f32_bmax, Prime, f32, 10, 5);\n\n    test_masking_and_aggregation!(pow_f32_b0, Power2, f32, 1, 10, 5);\n    test_masking_and_aggregation!(pow_f32_b2, Power2, f32, 100, 10, 5);\n    test_masking_and_aggregation!(pow_f32_b4, Power2, f32, 10_000, 10, 5);\n    test_masking_and_aggregation!(pow_f32_b6, Power2, f32, 1_000_000, 10, 5);\n    test_masking_and_aggregation!(pow_f32_bmax, Power2, f32, 10, 5);\n\n    test_masking_and_aggregation!(int_f64_b0, Integer, f64, 1, 10, 5);\n    test_masking_and_aggregation!(int_f64_b2, Integer, f64, 100, 10, 5);\n    test_masking_and_aggregation!(int_f64_b4, Integer, f64, 10_000, 10, 5);\n    test_masking_and_aggregation!(int_f64_b6, Integer, f64, 1_000_000, 10, 5);\n    test_masking_and_aggregation!(int_f64_bmax, Integer, f64, 10, 5);\n\n    test_masking_and_aggregation!(prime_f64_b0, Prime, f64, 1, 10, 5);\n    test_masking_and_aggregation!(prime_f64_b2, Prime, f64, 100, 10, 5);\n    test_masking_and_aggregation!(prime_f64_b4, Prime, f64, 10_000, 10, 5);\n    test_masking_and_aggregation!(prime_f64_b6, Prime, f64, 1_000_000, 10, 5);\n    test_masking_and_aggregation!(prime_f64_bmax, Prime, f64, 10, 5);\n\n    test_masking_and_aggregation!(pow_f64_b0, Power2, f64, 1, 10, 5);\n    test_masking_and_aggregation!(pow_f64_b2, Power2, f64, 100, 10, 5);\n    test_masking_and_aggregation!(pow_f64_b4, Power2, f64, 10_000, 10, 5);\n    test_masking_and_aggregation!(pow_f64_b6, Power2, f64, 1_000_000, 10, 5);\n    test_masking_and_aggregation!(pow_f64_bmax, Power2, f64, 10, 5);\n\n    test_masking_and_aggregation!(int_i32_b0, Integer, i32, 1, 10, 5);\n    test_masking_and_aggregation!(int_i32_b2, Integer, i32, 100, 10, 5);\n    test_masking_and_aggregation!(int_i32_b4, Integer, i32, 10_000, 10, 5);\n    test_masking_and_aggregation!(int_i32_b6, Integer, i32, 1_000_000, 10, 5);\n    test_masking_and_aggregation!(int_i32_bmax, Integer, i32, 10, 5);\n\n    test_masking_and_aggregation!(prime_i32_b0, Prime, i32, 1, 10, 5);\n    test_masking_and_aggregation!(prime_i32_b2, Prime, i32, 100, 10, 5);\n    test_masking_and_aggregation!(prime_i32_b4, Prime, i32, 10_000, 10, 5);\n    test_masking_and_aggregation!(prime_i32_b6, Prime, i32, 1_000_000, 10, 5);\n    test_masking_and_aggregation!(prime_i32_bmax, Prime, i32, 10, 5);\n\n    test_masking_and_aggregation!(pow_i32_b0, Power2, i32, 1, 10, 5);\n    test_masking_and_aggregation!(pow_i32_b2, Power2, i32, 100, 10, 5);\n    test_masking_and_aggregation!(pow_i32_b4, Power2, i32, 10_000, 10, 5);\n    test_masking_and_aggregation!(pow_i32_b6, Power2, i32, 1_000_000, 10, 5);\n    test_masking_and_aggregation!(pow_i32_bmax, Power2, i32, 10, 5);\n\n    test_masking_and_aggregation!(int_i64_b0, Integer, i64, 1, 10, 5);\n    test_masking_and_aggregation!(int_i64_b2, Integer, i64, 100, 10, 5);\n    test_masking_and_aggregation!(int_i64_b4, Integer, i64, 10_000, 10, 5);\n    test_masking_and_aggregation!(int_i64_b6, Integer, i64, 1_000_000, 10, 5);\n    test_masking_and_aggregation!(int_i64_bmax, Integer, i64, 10, 5);\n\n    test_masking_and_aggregation!(prime_i64_b0, Prime, i64, 1, 10, 5);\n    test_masking_and_aggregation!(prime_i64_b2, Prime, i64, 100, 10, 5);\n    test_masking_and_aggregation!(prime_i64_b4, Prime, i64, 10_000, 10, 5);\n    test_masking_and_aggregation!(prime_i64_b6, Prime, i64, 1_000_000, 10, 5);\n    test_masking_and_aggregation!(prime_i64_bmax, Prime, i64, 10, 5);\n\n    test_masking_and_aggregation!(pow_i64_b0, Power2, i64, 1, 10, 5);\n    test_masking_and_aggregation!(pow_i64_b2, Power2, i64, 100, 10, 5);\n    test_masking_and_aggregation!(pow_i64_b4, Power2, i64, 10_000, 10, 5);\n    test_masking_and_aggregation!(pow_i64_b6, Power2, i64, 1_000_000, 10, 5);\n    test_masking_and_aggregation!(pow_i64_bmax, Power2, i64, 10, 5);\n\n    /// Generate tests for masking, aggregation and unmasking of multiple models:\n    /// - generate random scalars from a uniform distribution with a seeded PRNG\n    /// - scale a model of unit weights, mask and aggregate it to the aggregated masked models\n    /// - derive a mask from the mask seed and aggregate it to the aggregated masks\n    /// - unmask the aggregated masked model\n    /// - check that all aggregated unmasked weights are equal to the original unit weights (up\n    ///   to a tolerance determined by the masking configuration)\n    ///\n    /// The arguments to the macro are:\n    /// - a suffix for the test name\n    /// - the group type of the model and scalar (variants of `GroupType`)\n    /// - the data type of the model and scalar (either float primitives or float variants of\n    ///   `DataType`)\n    /// - an absolute bound for the scalar (optional, choices: 1, 100, 10_000, 1_000_000)\n    /// - the number of weights per model\n    /// - the number of models\n    macro_rules! test_masking_and_aggregation_scalar {\n        ($suffix:ident, $group:ty, $data:ty, $bound:expr, $len:expr, $count:expr $(,)?) => {\n            paste::item! {\n                #[test]\n                fn [<test_masking_and_aggregation_scalar $suffix>]() {\n                    // Step 1: Build the masking config\n                    let config = MaskConfig {\n                        group_type: $group,\n                        data_type: paste::expr! { [<$data:upper>] },\n                        bound_type: match $bound {\n                            1 => B0,\n                            100 => B2,\n                            10_000 => B4,\n                            1_000_000 => B6,\n                            _ => Bmax,\n                        },\n                        model_type: M3,\n                    };\n                    let vect_len = $len as usize;\n                    let model_count = $count as usize;\n\n                    // Step 2: Generate random scalars\n                    // take vectors [1, ..., 1] as models to scale\n                    let bound = if $bound == 0 {\n                        paste::expr! { [<$data:lower>]::MAX / (2 as [<$data:lower>]) }\n                    } else {\n                        paste::expr! { $bound as [<$data:lower>] }\n                    };\n                    let eps = [<$data:lower>]::EPSILON;\n                    let mut prng = ChaCha20Rng::from_seed(MaskSeed::generate().as_array());\n                    let mut scalars = iter::repeat_with(move || {\n                        let random_weight = Uniform::new_inclusive(eps, bound).sample(&mut prng);\n                        Scalar::from_primitive(random_weight).unwrap()\n                    });\n                    let mut models =\n                        iter::repeat(Model::from_primitives(iter::repeat(1).take(vect_len)).unwrap());\n\n                    // Step 3 (actual test):\n                    // a. mask the model\n                    // b. derive the mask corresponding to the seed used\n                    // c. aggregate the masked model resp. mask\n                    // d. repeat a-c, unmask the model and check it against the expected [1, ..., 1]\n                    let mut aggregated_masked_model = Aggregation::new(config.into(), vect_len);\n                    let mut aggregated_mask = Aggregation::new(config.into(), vect_len);\n                    for _ in 0..model_count {\n                        let model = models.next().unwrap();\n                        let scalar = scalars.next().unwrap();\n\n                        let (mask_seed, masked_model) =\n                            Masker::new(config.into()).mask(scalar, &model);\n                        let mask = mask_seed.derive_mask(vect_len, config.into());\n\n                        assert!(\n                            aggregated_masked_model.validate_aggregation(&masked_model).is_ok()\n                        );\n                        aggregated_masked_model.aggregate(masked_model);\n                        assert!(aggregated_mask.validate_aggregation(&mask).is_ok());\n                        aggregated_mask.aggregate(mask);\n                    }\n\n                    let mask = aggregated_mask.into();\n                    assert!(aggregated_masked_model.validate_unmasking(&mask).is_ok());\n                    let unmasked_model = aggregated_masked_model.unmask(mask);\n                    let tolerance = Ratio::from_integer(BigInt::from(model_count))\n                        / Ratio::from_integer(config.exp_shift());\n                    let expected_weight = Ratio::from_integer(BigInt::from(1));\n                    assert!(\n                        unmasked_model\n                            .iter()\n                            .all(|unmasked_weight| {\n                                (unmasked_weight - &expected_weight).abs() <= tolerance\n                            })\n                    );\n                }\n            }\n        };\n        ($suffix:ident, $group:ty, $data:ty, $len:expr, $count:expr $(,)?) => {\n            test_masking_and_aggregation_scalar!($suffix, $group, $data, 0, $len, $count);\n        };\n    }\n\n    test_masking_and_aggregation_scalar!(int_f32_b0, Integer, f32, 1, 10, 5);\n    test_masking_and_aggregation_scalar!(int_f32_b2, Integer, f32, 100, 10, 5);\n    test_masking_and_aggregation_scalar!(int_f32_b4, Integer, f32, 10_000, 10, 5);\n    test_masking_and_aggregation_scalar!(int_f32_b6, Integer, f32, 1_000_000, 10, 5);\n    test_masking_and_aggregation_scalar!(int_f32_bmax, Integer, f32, 10, 2);\n\n    test_masking_and_aggregation_scalar!(prime_f32_b0, Prime, f32, 1, 10, 5);\n    test_masking_and_aggregation_scalar!(prime_f32_b2, Prime, f32, 100, 10, 5);\n    test_masking_and_aggregation_scalar!(prime_f32_b4, Prime, f32, 10_000, 10, 5);\n    test_masking_and_aggregation_scalar!(prime_f32_b6, Prime, f32, 1_000_000, 10, 5);\n    test_masking_and_aggregation_scalar!(prime_f32_bmax, Prime, f32, 10, 2);\n\n    test_masking_and_aggregation_scalar!(pow_f32_b0, Power2, f32, 1, 10, 5);\n    test_masking_and_aggregation_scalar!(pow_f32_b2, Power2, f32, 100, 10, 5);\n    test_masking_and_aggregation_scalar!(pow_f32_b4, Power2, f32, 10_000, 10, 5);\n    test_masking_and_aggregation_scalar!(pow_f32_b6, Power2, f32, 1_000_000, 10, 5);\n    test_masking_and_aggregation_scalar!(pow_f32_bmax, Power2, f32, 10, 2);\n\n    test_masking_and_aggregation_scalar!(int_f64_b0, Integer, f64, 1, 10, 2);\n    test_masking_and_aggregation_scalar!(int_f64_b2, Integer, f64, 100, 10, 2);\n    test_masking_and_aggregation_scalar!(int_f64_b4, Integer, f64, 10_000, 10, 2);\n    test_masking_and_aggregation_scalar!(int_f64_b6, Integer, f64, 1_000_000, 10, 2);\n    test_masking_and_aggregation_scalar!(int_f64_bmax, Integer, f64, 10, 2);\n\n    test_masking_and_aggregation_scalar!(prime_f64_b0, Prime, f64, 1, 10, 2);\n    test_masking_and_aggregation_scalar!(prime_f64_b2, Prime, f64, 100, 10, 2);\n    test_masking_and_aggregation_scalar!(prime_f64_b4, Prime, f64, 10_000, 10, 2);\n    test_masking_and_aggregation_scalar!(prime_f64_b6, Prime, f64, 1_000_000, 10, 2);\n    test_masking_and_aggregation_scalar!(prime_f64_bmax, Prime, f64, 10, 2);\n\n    test_masking_and_aggregation_scalar!(pow_f64_b0, Power2, f64, 1, 10, 2);\n    test_masking_and_aggregation_scalar!(pow_f64_b2, Power2, f64, 100, 10, 2);\n    test_masking_and_aggregation_scalar!(pow_f64_b4, Power2, f64, 10_000, 10, 2);\n    test_masking_and_aggregation_scalar!(pow_f64_b6, Power2, f64, 1_000_000, 10, 2);\n    test_masking_and_aggregation_scalar!(pow_f64_bmax, Power2, f64, 10, 2);\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/mask/mod.rs",
    "content": "//! Masking, aggregation and unmasking of models.\n//!\n//! # Models\n//! A [`Model`] is a collection of weights/parameters which are represented as finite numerical\n//! values (i.e. rational numbers) of arbitrary precision. As such, a model in itself is not bound\n//! to any particular primitive data type, but it can be created from those and converted back into\n//! them.\n//!\n//! Currently, the primitive data types [`f32`], [`f64`], [`i32`] and [`i64`] are supported and\n//! this might be extended in the future.\n//!\n//! ```\n//! # use xaynet_core::mask::{FromPrimitives, IntoPrimitives, Model};\n//! let weights = vec![0_f32; 10];\n//! let model = Model::from_primitives_bounded(weights.into_iter());\n//! assert_eq!(\n//!     model.into_primitives_unchecked().collect::<Vec<f32>>(),\n//!     vec![0_f32; 10],\n//! );\n//! ```\n//!\n//! # Masking configurations\n//! The masking, aggregation and unmasking of models requires certain information about the models\n//! to guarantee that no information is lost during the process, which is configured via the\n//! [`MaskConfig`]. Each masking configuration consists of the group type, data type, bound type and\n//! model type. Usually, a masking configuration is decided on and configured depending on the\n//! specific machine learning use case as part of the setup for the XayNet federated learning\n//! platform.\n//!\n//! Currently, those choices are catalogued for certain fixed variants for each type, but we aim\n//! to generalize this in the future to more flexible masking configurations to allow for a more\n//! fine-grained tradeoff between representability and performance.\n//!\n//! ## Group type\n//! The [`GroupType`] describes the order of the finite group in which the masked model weights are\n//! embedded. The smaller the gap between the maximum possible embedded weights and the group order\n//! is, the less theoretically possible information flow about the masks may be observed. Specific\n//! group orders provide potentially higher performance on the other hand, which always makes this\n//! a tradeoff between security and performance. The group type variants are:\n//! - Integer: no gap but potentially slowest performance.\n//! - Prime: usually small gap with higher performance.\n//! - Power2: usually higher gap with potentially highest performance.\n//!\n//! ## Data type\n//! The [`DataType`] describes the original primitive data type of the model weights. This in\n//! combination with the bound type influences the preserved decimal places of the model weights\n//! during the masking, aggregation and unmasking process, which are:\n//! - F32: 10 decimal places for bounded model weights and 45 decimal places for unbounded.\n//! - F64: 20 decimal places for bounded model weights and 324 decimal places for unbounded.\n//! - I32 and I64: 10 decimal places (required for scaled aggregation).\n//!\n//! Currently the primitive data types [`f32`], [`f64`], [`i32`] and [`i64`] are supported via the\n//! data type variants.\n//!\n//! ## Bound type\n//! The [`BoundType`] describes the absolute bounds on all model weights. The smaller the bounds of\n//! the model weights, the less bytes are required to represent the masked model weights. These\n//! bounds are enforced on the model weights before masking them to prevent information loss during\n//! the masking, aggregation and unmasking process. The bound type variants are:\n//! - B0: all model weights are absolutely bounded by 1.\n//! - B2: all model weights are absolutely bounded by 100.\n//! - B4: all model weights are absolutely bounded by 10,000.\n//! - B6: all model weights are absolutely bounded by 1,000,000.\n//! - Bmax: all model weights are absolutely bounded by their primitive data type's absolute\n//!   maximum value.\n//!\n//! ## Model type\n//! The [`ModelType`] describes the maximum number of masked models that can be aggregated without\n//! information loss. The smaller the number of masked models, the less bytes are required to\n//! represent masked model weights. The model type variants are:\n//! - M3: at most 1,000 masked models may be aggregated.\n//! - M6: at most 1,000,000 masked models may be aggregated.\n//! - M9: at most 1,000,000,000 masked models may be aggregated.\n//! - M12: at most 1,000,000,000,000 masked models may be aggregated.\n//!\n//! # Masking, aggregation and unmasking\n//! Local models should be masked (i.e. encrypted) before they are communicated somewhere else to\n//! protect the possibly sensitive information learned from local data. The masking should allow\n//! for masked models to be aggregated while they are still masked (i.e. homomorphic encryption).\n//! Then the aggregated masked model can safely be unmasked without jeopardizing the secrecy of\n//! personal information if the model is generalized enough.\n//!\n//! ## Masking\n//! A [`Model`] can be masked with a [`Masker`], which requires a [`MaskConfig`]. During the\n//! masking, the model weights are scaled, then embedded as elements of the chosen finite group and\n//! finally masked by randomly generated elements from that very same finite group. The scalar\n//! provides the necessary means to perform different aggregation strategies, for example federated\n//! averaging. The masked model is returned as a [`MaskObject`] and the mask used to mask the model\n//! can be generated via the additionally returned [`MaskSeed`].\n//!\n//! ```\n//! # use xaynet_core::mask::{BoundType, DataType, FromPrimitives, GroupType, MaskConfig, Masker, Model, ModelType, Scalar};\n//! // create local models and a fitting masking configuration\n//! let number_weights = 10;\n//! let scalar = Scalar::new(1, 2_u8);\n//! let local_model_1 = Model::from_primitives_bounded(vec![0_f32; number_weights].into_iter());\n//! let local_model_2 = Model::from_primitives_bounded(vec![1_f32; number_weights].into_iter());\n//! let config = MaskConfig {\n//!     group_type: GroupType::Prime,\n//!     data_type: DataType::F32,\n//!     bound_type: BoundType::B0,\n//!     model_type: ModelType::M3,\n//! };\n//!\n//! // mask the local models\n//! let (local_mask_seed_1, masked_local_model_1) = Masker::new(config.into()).mask(scalar.clone(), &local_model_1);\n//! let (local_mask_seed_2, masked_local_model_2) = Masker::new(config.into()).mask(scalar, &local_model_2);\n//!\n//! // derive the masks of the local masked models\n//! let local_mask_1 = local_mask_seed_1.derive_mask(number_weights, config.into());\n//! let local_mask_2 = local_mask_seed_2.derive_mask(number_weights, config.into());\n//! ```\n//!\n//! ## Aggregation\n//! Masked models can be aggregated via an [`Aggregation`]. Masks themselves can be aggregated via\n//! an [`Aggregation`] as well. An aggregated masked model can only be unmasked by the aggregation\n//! of masks for each model. Aggregation should always be validated beforehand so that it may be\n//! safely performed wrt the chosen masking configuration without possible loss of information.\n//!\n//! ```\n//! # use xaynet_core::mask::{Aggregation, BoundType, DataType, FromPrimitives, GroupType, MaskConfig, Masker, MaskObject, Model, ModelType, Scalar};\n//! # let number_weights = 10;\n//! # let scalar = Scalar::new(1, 2_u8);\n//! # let local_model_1 = Model::from_primitives_bounded(vec![0_f32; number_weights].into_iter());\n//! # let local_model_2 = Model::from_primitives_bounded(vec![1_f32; number_weights].into_iter());\n//! # let config = MaskConfig { group_type: GroupType::Prime, data_type: DataType::F32, bound_type: BoundType::B0, model_type: ModelType::M3};\n//! # let (local_mask_seed_1, masked_local_model_1) = Masker::new(config.into()).mask(scalar.clone(), &local_model_1);\n//! # let (local_mask_seed_2, masked_local_model_2) = Masker::new(config.into()).mask(scalar, &local_model_2);\n//! # let local_model_mask_1 = local_mask_seed_1.derive_mask(number_weights, config.into());\n//! # let local_model_mask_2 = local_mask_seed_2.derive_mask(number_weights, config.into());\n//! // aggregate the local model masks (similarly for local scalar masks)\n//! let mut mask_aggregator = Aggregation::new(config.into(), number_weights);\n//! if let Ok(_) = mask_aggregator.validate_aggregation(&local_model_mask_1) {\n//!     mask_aggregator.aggregate(local_model_mask_1);\n//! };\n//! if let Ok(_) = mask_aggregator.validate_aggregation(&local_model_mask_2) {\n//!     mask_aggregator.aggregate(local_model_mask_2);\n//! };\n//! let global_mask: MaskObject = mask_aggregator.into();\n//!\n//! // aggregate the local masked models\n//! let mut model_aggregator = Aggregation::new(config.into(), number_weights);\n//! if let Ok(_) = model_aggregator.validate_aggregation(&masked_local_model_1) {\n//!     model_aggregator.aggregate(masked_local_model_1);\n//! };\n//! if let Ok(_) = model_aggregator.validate_aggregation(&masked_local_model_2) {\n//!     model_aggregator.aggregate(masked_local_model_2);\n//! };\n//! ```\n//!\n//! ## Unmasking\n//! A masked model can be unmasked by the corresponding mask via an [`Aggregation`]. Unmasking\n//! should always be validated beforehand so that it may be safely performed wrt the chosen mask\n//! configuration without possible loss of information.\n//!\n//! ```no_run\n//! # use xaynet_core::mask::{Aggregation, BoundType, DataType, FromPrimitives, GroupType, MaskConfig, Masker, MaskObject, Model, ModelType, Scalar};\n//! # let number_weights = 10;\n//! # let scalar = Scalar::new(1, 2_u8);\n//! # let local_model_1 = Model::from_primitives_bounded(vec![0_f32; number_weights].into_iter());\n//! # let local_model_2 = Model::from_primitives_bounded(vec![1_f32; number_weights].into_iter());\n//! # let config = MaskConfig { group_type: GroupType::Prime, data_type: DataType::F32, bound_type: BoundType::B0, model_type: ModelType::M3};\n//! # let (local_mask_seed_1, masked_local_model_1) = Masker::new(config.into()).mask(scalar.clone(), &local_model_1);\n//! # let (local_mask_seed_2, masked_local_model_2) = Masker::new(config.into()).mask(scalar, &local_model_2);\n//! # let local_model_mask_1 = local_mask_seed_1.derive_mask(number_weights, config.into());\n//! # let local_model_mask_2 = local_mask_seed_2.derive_mask(number_weights, config.into());\n//! # let mut mask_aggregator = Aggregation::new(config.into(), number_weights);\n//! # if let Ok(_) = mask_aggregator.validate_aggregation(&local_model_mask_1) { mask_aggregator.aggregate(local_model_mask_1); };\n//! # if let Ok(_) = mask_aggregator.validate_aggregation(&local_model_mask_2) { mask_aggregator.aggregate(local_model_mask_2); };\n//! # let global_mask: MaskObject = mask_aggregator.into();\n//! # let mut model_aggregator = Aggregation::new(config.into(), number_weights);\n//! # if let Ok(_) = model_aggregator.validate_aggregation(&masked_local_model_1) { model_aggregator.aggregate(masked_local_model_1); };\n//! # if let Ok(_) = model_aggregator.validate_aggregation(&masked_local_model_2) { model_aggregator.aggregate(masked_local_model_2); };\n//! // unmask the aggregated masked model with the aggregated mask\n//! if let Ok(_) = model_aggregator.validate_unmasking(&global_mask) {\n//!     let global_model = model_aggregator.unmask(global_mask);\n//!     assert_eq!(\n//!         global_model,\n//!         Model::from_primitives_bounded(vec![0.5_f32; number_weights].into_iter()),\n//!     );\n//! };\n//! ```\n\npub(crate) mod config;\npub(crate) mod masking;\npub(crate) mod model;\npub(crate) mod object;\npub(crate) mod scalar;\npub(crate) mod seed;\n\npub use self::{\n    config::{\n        serialization::MaskConfigBuffer,\n        BoundType,\n        DataType,\n        GroupType,\n        InvalidMaskConfigError,\n        MaskConfig,\n        MaskConfigPair,\n        ModelType,\n    },\n    masking::{Aggregation, AggregationError, Masker, UnmaskingError},\n    model::{FromPrimitives, IntoPrimitives, Model, ModelCastError, PrimitiveCastError},\n    object::{\n        serialization::vect::MaskVectBuffer,\n        InvalidMaskObjectError,\n        MaskObject,\n        MaskUnit,\n        MaskVect,\n    },\n    scalar::{FromPrimitive, IntoPrimitive, Scalar, ScalarCastError},\n    seed::{EncryptedMaskSeed, MaskSeed},\n};\n"
  },
  {
    "path": "rust/xaynet-core/src/mask/model.rs",
    "content": "//! Model representation and conversion.\n//!\n//! See the [mask module] documentation since this is a private module anyways.\n//!\n//! [mask module]: crate::mask\n\nuse std::{\n    fmt::Debug,\n    iter::{FromIterator, IntoIterator},\n    slice::{Iter, IterMut},\n};\n\nuse derive_more::{Display, From, Index, IndexMut, Into};\nuse num::{\n    bigint::BigInt,\n    clamp,\n    rational::Ratio,\n    traits::{float::FloatCore, identities::Zero, ToPrimitive},\n};\nuse serde::{Deserialize, Serialize};\nuse thiserror::Error;\n\n#[derive(Debug, Clone, PartialEq, Hash, From, Index, IndexMut, Into, Serialize, Deserialize)]\n/// A numerical representation of a machine learning model.\npub struct Model(Vec<Ratio<BigInt>>);\n\nimpl std::convert::AsRef<Model> for Model {\n    fn as_ref(&self) -> &Model {\n        self\n    }\n}\n\n#[allow(clippy::len_without_is_empty)]\nimpl Model {\n    /// Gets the number of weights/parameters of this model.\n    pub fn len(&self) -> usize {\n        self.0.len()\n    }\n\n    /// Creates an iterator that yields references to the weights/parameters of this model.\n    pub fn iter(&self) -> Iter<Ratio<BigInt>> {\n        self.0.iter()\n    }\n\n    /// Creates an iterator that yields mutable references to the weights/parameters of this model.\n    pub fn iter_mut(&mut self) -> IterMut<Ratio<BigInt>> {\n        self.0.iter_mut()\n    }\n}\n\nimpl FromIterator<Ratio<BigInt>> for Model {\n    fn from_iter<I: IntoIterator<Item = Ratio<BigInt>>>(iter: I) -> Self {\n        let data: Vec<Ratio<BigInt>> = iter.into_iter().collect();\n        Model(data)\n    }\n}\n\nimpl IntoIterator for Model {\n    type Item = Ratio<BigInt>;\n    type IntoIter = std::vec::IntoIter<Self::Item>;\n\n    fn into_iter(self) -> Self::IntoIter {\n        self.0.into_iter()\n    }\n}\n\n#[derive(Debug, Display)]\n/// A primitive data type as a target for model conversion.\npub(crate) enum PrimitiveType {\n    F32,\n    F64,\n    I32,\n    I64,\n}\n\n#[derive(Error, Debug)]\n#[error(\"Could not convert weight {weight} to primitive type {target}\")]\n/// Errors related to model conversion into primitives.\npub struct ModelCastError {\n    weight: Ratio<BigInt>,\n    target: PrimitiveType,\n}\n\n#[derive(Clone, Error, Debug)]\n#[error(\"Could not convert primitive type {0:?} to weight\")]\n/// Errors related to weight conversion from primitives.\npub struct PrimitiveCastError<P: Debug>(pub(crate) P);\n\n/// An interface to convert a collection of numerical values into an iterator of primitive values.\n///\n/// This trait is used to convert a [`Model`], which has its own internal representation of the\n/// weights, into primitive types ([`f32`], [`f64`], [`i32`], [`i64`]). The opposite trait is\n/// [`FromPrimitives`].\npub trait IntoPrimitives<P: 'static>: Sized {\n    /// Creates an iterator from numerical values that yields converted primitive values.\n    ///\n    /// # Errors\n    /// Yields an error for each numerical value that can't be converted into a primitive value.\n    fn into_primitives(self) -> Box<dyn Iterator<Item = Result<P, ModelCastError>>>;\n\n    /// Creates an iterator from numerical values that yields converted primitive values.\n    ///\n    /// # Errors\n    /// Yields an error for each numerical value that can't be converted into a primitive value.\n    fn to_primitives(&self) -> Box<dyn Iterator<Item = Result<P, ModelCastError>>>;\n\n    /// Consume this model and into an iterator that yields `P` values.\n    ///\n    /// # Panics\n    /// Panics if a numerical value can't be converted into a primitive value.\n    fn into_primitives_unchecked(self) -> Box<dyn Iterator<Item = P>> {\n        Box::new(\n            self.into_primitives()\n                .map(|res| res.expect(\"conversion to primitive type failed\")),\n        )\n    }\n}\n\n/// An interface to convert a collection of primitive values into an iterator of numerical values.\n///\n/// This trait is used to convert primitive types ([`f32`], [`f64`], [`i32`], [`i64`]) into a\n/// [`Model`], which has its own internal representation of the weights. The opposite trait is\n/// [`IntoPrimitives`].\npub trait FromPrimitives<P: Debug>: Sized {\n    /// Creates an iterator from primitive values that yields converted numerical values.\n    ///\n    /// # Errors\n    /// Yields an error for the first encountered primitive value that can't be converted into a\n    /// numerical value due to not being finite.\n    fn from_primitives<I: Iterator<Item = P>>(iter: I) -> Result<Self, PrimitiveCastError<P>>;\n\n    /// Creates an iterator from primitive values that yields converted numerical values.\n    ///\n    /// If a primitive value cannot be directly converted into a numerical value due to not being\n    /// finite, it is clamped.\n    fn from_primitives_bounded<I: Iterator<Item = P>>(iter: I) -> Self;\n}\n\nimpl IntoPrimitives<i32> for Model {\n    fn into_primitives(self) -> Box<dyn Iterator<Item = Result<i32, ModelCastError>>> {\n        Box::new(self.0.into_iter().map(|i| {\n            i.to_integer().to_i32().ok_or(ModelCastError {\n                weight: i,\n                target: PrimitiveType::I32,\n            })\n        }))\n    }\n\n    fn to_primitives(&self) -> Box<dyn Iterator<Item = Result<i32, ModelCastError>>> {\n        let vec = self.0.clone();\n        Box::new(vec.into_iter().map(|i| {\n            i.to_integer().to_i32().ok_or(ModelCastError {\n                weight: i,\n                target: PrimitiveType::I32,\n            })\n        }))\n    }\n}\n\nimpl FromPrimitives<i32> for Model {\n    fn from_primitives<I: Iterator<Item = i32>>(iter: I) -> Result<Self, PrimitiveCastError<i32>> {\n        Ok(iter.map(|p| Ratio::from_integer(BigInt::from(p))).collect())\n    }\n\n    fn from_primitives_bounded<I: Iterator<Item = i32>>(iter: I) -> Self {\n        Self::from_primitives(iter).unwrap()\n    }\n}\n\nimpl IntoPrimitives<i64> for Model {\n    fn into_primitives(self) -> Box<dyn Iterator<Item = Result<i64, ModelCastError>>> {\n        Box::new(self.0.into_iter().map(|i| {\n            i.to_integer().to_i64().ok_or(ModelCastError {\n                weight: i,\n                target: PrimitiveType::I64,\n            })\n        }))\n    }\n\n    fn to_primitives(&self) -> Box<dyn Iterator<Item = Result<i64, ModelCastError>>> {\n        let vec = self.0.clone();\n        Box::new(vec.into_iter().map(|i| {\n            i.to_integer().to_i64().ok_or(ModelCastError {\n                weight: i,\n                target: PrimitiveType::I64,\n            })\n        }))\n    }\n}\n\nimpl FromPrimitives<i64> for Model {\n    fn from_primitives<I: Iterator<Item = i64>>(iter: I) -> Result<Self, PrimitiveCastError<i64>> {\n        Ok(iter.map(|p| Ratio::from_integer(BigInt::from(p))).collect())\n    }\n\n    fn from_primitives_bounded<I: Iterator<Item = i64>>(iter: I) -> Self {\n        Self::from_primitives(iter).unwrap()\n    }\n}\n\nimpl IntoPrimitives<f32> for Model {\n    fn into_primitives(self) -> Box<dyn Iterator<Item = Result<f32, ModelCastError>>> {\n        let iter = self.0.into_iter().map(|r| {\n            ratio_to_float::<f32>(&r).ok_or(ModelCastError {\n                weight: r,\n                target: PrimitiveType::F32,\n            })\n        });\n        Box::new(iter)\n    }\n\n    fn to_primitives(&self) -> Box<dyn Iterator<Item = Result<f32, ModelCastError>>> {\n        let vec = self.0.clone();\n        let iter = vec.into_iter().map(|r| {\n            ratio_to_float::<f32>(&r).ok_or(ModelCastError {\n                weight: r,\n                target: PrimitiveType::F32,\n            })\n        });\n        Box::new(iter)\n    }\n}\n\nimpl FromPrimitives<f32> for Model {\n    fn from_primitives<I: Iterator<Item = f32>>(iter: I) -> Result<Self, PrimitiveCastError<f32>> {\n        iter.map(|f| Ratio::from_float(f).ok_or(PrimitiveCastError(f)))\n            .collect()\n    }\n\n    fn from_primitives_bounded<I: Iterator<Item = f32>>(iter: I) -> Self {\n        iter.map(float_to_ratio_bounded::<f32>).collect()\n    }\n}\n\nimpl IntoPrimitives<f64> for Model {\n    fn into_primitives(self) -> Box<dyn Iterator<Item = Result<f64, ModelCastError>>> {\n        let iter = self.0.into_iter().map(|r| {\n            ratio_to_float::<f64>(&r).ok_or(ModelCastError {\n                weight: r,\n                target: PrimitiveType::F64,\n            })\n        });\n        Box::new(iter)\n    }\n\n    fn to_primitives(&self) -> Box<dyn Iterator<Item = Result<f64, ModelCastError>>> {\n        let vec = self.0.clone();\n        let iter = vec.into_iter().map(|r| {\n            ratio_to_float::<f64>(&r).ok_or(ModelCastError {\n                weight: r,\n                target: PrimitiveType::F64,\n            })\n        });\n        Box::new(iter)\n    }\n}\n\nimpl FromPrimitives<f64> for Model {\n    fn from_primitives<I: Iterator<Item = f64>>(iter: I) -> Result<Self, PrimitiveCastError<f64>> {\n        iter.map(|f| Ratio::from_float(f).ok_or(PrimitiveCastError(f)))\n            .collect()\n    }\n\n    fn from_primitives_bounded<I: Iterator<Item = f64>>(iter: I) -> Self {\n        iter.map(float_to_ratio_bounded::<f64>).collect()\n    }\n}\n\n/// Converts a numerical value into a primitive floating point value.\n///\n/// # Errors\n/// Fails if the numerical value is not representable in the primitive data type.\npub(crate) fn ratio_to_float<F: FloatCore>(ratio: &Ratio<BigInt>) -> Option<F> {\n    let min_value = Ratio::from_float(F::min_value()).unwrap();\n    let max_value = Ratio::from_float(F::max_value()).unwrap();\n    if ratio < &min_value || ratio > &max_value {\n        return None;\n    }\n\n    let mut numer = ratio.numer().clone();\n    let mut denom = ratio.denom().clone();\n    // safe loop: terminates after at most bit-length of ratio iterations\n    loop {\n        if let (Some(n), Some(d)) = (F::from(numer.clone()), F::from(denom.clone())) {\n            if n == F::zero() || d == F::zero() {\n                break Some(F::zero());\n            } else {\n                let float = n / d;\n                if float.is_finite() {\n                    break Some(float);\n                }\n            }\n        } else {\n            numer >>= 1_usize;\n            denom >>= 1_usize;\n        }\n    }\n}\n\n/// Converts the primitive floating point value into a numerical value.\n///\n/// Maps positive/negative infinity to max/min of the primitive data type and NaN to zero.\npub(crate) fn float_to_ratio_bounded<F: FloatCore>(f: F) -> Ratio<BigInt> {\n    if f.is_nan() {\n        Ratio::<BigInt>::zero()\n    } else {\n        let finite_f = clamp(f, F::min_value(), F::max_value());\n        // safe unwrap: clamped weight is guaranteed to be finite\n        Ratio::<BigInt>::from_float(finite_f).unwrap()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use std::iter;\n\n    type R = Ratio<BigInt>;\n\n    #[test]\n    fn test_model_f32() {\n        let expected_primitives = vec![-1_f32, 0_f32, 1_f32];\n        let expected_model = Model::from(vec![\n            R::from_float(-1_f32).unwrap(),\n            R::zero(),\n            R::from_float(1_f32).unwrap(),\n        ]);\n\n        let actual_model = Model::from_primitives(expected_primitives.iter().cloned()).unwrap();\n        assert_eq!(actual_model, expected_model);\n\n        let actual_model = Model::from_primitives_bounded(expected_primitives.iter().cloned());\n        assert_eq!(actual_model, expected_model);\n\n        let actual_primitives: Vec<f32> = expected_model.into_primitives_unchecked().collect();\n        assert_eq!(actual_primitives, expected_primitives);\n    }\n\n    #[test]\n    fn test_model_f64() {\n        let expected_primitives = vec![-1_f64, 0_f64, 1_f64];\n        let expected_model = Model::from(vec![\n            R::from_float(-1_f64).unwrap(),\n            R::zero(),\n            R::from_float(1_f64).unwrap(),\n        ]);\n\n        let actual_model = Model::from_primitives(expected_primitives.iter().cloned()).unwrap();\n        assert_eq!(actual_model, expected_model);\n\n        let actual_model = Model::from_primitives_bounded(expected_primitives.iter().cloned());\n        assert_eq!(actual_model, expected_model);\n\n        let actual_primitives: Vec<f64> = expected_model.into_primitives_unchecked().collect();\n        assert_eq!(actual_primitives, expected_primitives);\n    }\n\n    #[test]\n    fn test_model_f32_from_weird_primitives() {\n        // +infinity\n        assert!(Model::from_primitives(iter::once(f32::INFINITY)).is_err());\n        assert_eq!(\n            Model::from_primitives_bounded(iter::once(f32::INFINITY)),\n            vec![R::from_float(f32::MAX).unwrap()].into()\n        );\n\n        // -infinity\n        assert!(Model::from_primitives(iter::once(f32::NEG_INFINITY)).is_err());\n        assert_eq!(\n            Model::from_primitives_bounded(iter::once(f32::NEG_INFINITY)),\n            vec![R::from_float(f32::MIN).unwrap()].into()\n        );\n\n        // NaN\n        assert!(Model::from_primitives(iter::once(f32::NAN)).is_err());\n        assert_eq!(\n            Model::from_primitives_bounded(iter::once(f32::NAN)),\n            vec![R::zero()].into()\n        );\n    }\n\n    #[test]\n    fn test_model_f64_from_weird_primitives() {\n        // +infinity\n        assert!(Model::from_primitives(iter::once(f64::INFINITY)).is_err());\n        assert_eq!(\n            Model::from_primitives_bounded(iter::once(f64::INFINITY)),\n            vec![R::from_float(f64::MAX).unwrap()].into()\n        );\n\n        // -infinity\n        assert!(Model::from_primitives(iter::once(f64::NEG_INFINITY)).is_err());\n        assert_eq!(\n            Model::from_primitives_bounded(iter::once(f64::NEG_INFINITY)),\n            vec![R::from_float(f64::MIN).unwrap()].into()\n        );\n\n        // NaN\n        assert!(Model::from_primitives(iter::once(f64::NAN)).is_err());\n        assert_eq!(\n            Model::from_primitives_bounded(iter::once(f64::NAN)),\n            vec![R::zero()].into()\n        );\n    }\n\n    #[test]\n    fn test_model_i32() {\n        let expected_primitives = vec![-1_i32, 0_i32, 1_i32];\n        let expected_model = Model::from(vec![\n            R::from_integer(BigInt::from(-1_i32)),\n            R::zero(),\n            R::from_integer(BigInt::from(1_i32)),\n        ]);\n\n        let actual_model = Model::from_primitives(expected_primitives.iter().cloned()).unwrap();\n        assert_eq!(actual_model, expected_model);\n\n        let actual_model = Model::from_primitives_bounded(expected_primitives.iter().cloned());\n        assert_eq!(actual_model, expected_model);\n\n        let actual_primitives: Vec<i32> = expected_model.into_primitives_unchecked().collect();\n        assert_eq!(actual_primitives, expected_primitives);\n    }\n\n    #[test]\n    fn test_model_i64() {\n        let expected_primitives = vec![-1_i64, 0_i64, 1_i64];\n        let expected_model = Model::from(vec![\n            R::from_integer(BigInt::from(-1_i64)),\n            R::zero(),\n            R::from_integer(BigInt::from(1_i64)),\n        ]);\n\n        let actual_model = Model::from_primitives(expected_primitives.iter().cloned()).unwrap();\n        assert_eq!(actual_model, expected_model);\n\n        let actual_model = Model::from_primitives_bounded(expected_primitives.iter().cloned());\n        assert_eq!(actual_model, expected_model);\n\n        let actual_primitives: Vec<i64> = expected_model.into_primitives_unchecked().collect();\n        assert_eq!(actual_primitives, expected_primitives);\n    }\n\n    #[test]\n    #[allow(clippy::float_cmp)]\n    fn test_ratio_to_float() {\n        let ratio = R::from_float(0_f32).unwrap();\n        assert_eq!(ratio_to_float::<f32>(&ratio).unwrap(), 0_f32);\n        let ratio = R::from_float(0_f64).unwrap();\n        assert_eq!(ratio_to_float::<f64>(&ratio).unwrap(), 0_f64);\n\n        let ratio = R::from_float(0.1_f32).unwrap();\n        assert_eq!(ratio_to_float::<f32>(&ratio).unwrap(), 0.1_f32);\n        let ratio = R::from_float(0.1_f64).unwrap();\n        assert_eq!(ratio_to_float::<f64>(&ratio).unwrap(), 0.1_f64);\n\n        let f32_max = R::from_float(f32::max_value()).unwrap();\n        let ratio = &f32_max * BigInt::from(10_usize) / (f32_max * BigInt::from(100_usize));\n        assert_eq!(ratio_to_float::<f32>(&ratio).unwrap(), 0.1_f32);\n\n        let f64_max = R::from_float(f64::max_value()).unwrap();\n        let ratio = &f64_max * BigInt::from(10_usize) / (f64_max * BigInt::from(100_usize));\n        assert_eq!(ratio_to_float::<f64>(&ratio).unwrap(), 0.1_f64);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/mask/object/mod.rs",
    "content": "//! Masked objects.\n//!\n//! See the [mask module] documentation since this is a private module anyways.\n//!\n//! [mask module]: crate::mask\n\npub mod serialization;\n\nuse std::iter::Iterator;\n\nuse num::bigint::BigUint;\nuse serde::{Deserialize, Serialize};\nuse thiserror::Error;\n\nuse crate::mask::config::{MaskConfig, MaskConfigPair};\n\n#[derive(Error, Debug)]\n#[error(\"the mask object is invalid: data is incompatible with the masking configuration\")]\n/// Errors related to invalid mask objects.\npub struct InvalidMaskObjectError;\n\n#[derive(Debug, Hash, PartialEq, Eq, Clone, Serialize, Deserialize)]\n/// A *mask vector* which represents a masked model or its corresponding mask.\npub struct MaskVect {\n    pub data: Vec<BigUint>,\n    pub config: MaskConfig,\n}\n\nimpl MaskVect {\n    /// Creates a new mask vector from the given data and masking configuration.\n    pub fn new_unchecked(config: MaskConfig, data: Vec<BigUint>) -> Self {\n        Self { data, config }\n    }\n\n    /// Creates a new mask vector from the given data and masking configuration.\n    ///\n    /// # Errors\n    /// Fails if the elements of the mask object don't conform to the given masking configuration.\n    pub fn new(config: MaskConfig, data: Vec<BigUint>) -> Result<Self, InvalidMaskObjectError> {\n        let obj = Self::new_unchecked(config, data);\n        if obj.is_valid() {\n            Ok(obj)\n        } else {\n            Err(InvalidMaskObjectError)\n        }\n    }\n\n    /// Creates a new empty mask vector of given size and masking configuration.\n    pub fn empty(config: MaskConfig, size: usize) -> Self {\n        Self {\n            data: Vec::with_capacity(size),\n            config,\n        }\n    }\n\n    /// Checks if the elements of this mask vector conform to the masking configuration.\n    pub fn is_valid(&self) -> bool {\n        let order = self.config.order();\n        self.data.iter().all(|i| i < &order)\n    }\n}\n\n#[derive(Debug, Hash, PartialEq, Eq, Clone, Serialize, Deserialize)]\n/// A *mask unit* which represents a masked scalar or its corresponding mask.\npub struct MaskUnit {\n    pub data: BigUint,\n    pub config: MaskConfig,\n}\n\nimpl From<&MaskUnit> for MaskVect {\n    fn from(mask_unit: &MaskUnit) -> Self {\n        Self::new_unchecked(mask_unit.config, vec![mask_unit.data.clone()])\n    }\n}\n\nimpl From<MaskUnit> for MaskVect {\n    fn from(mask_unit: MaskUnit) -> Self {\n        Self::new_unchecked(mask_unit.config, vec![mask_unit.data])\n    }\n}\n\nimpl MaskUnit {\n    /// Creates a new mask unit from the given mask and masking configuration.\n    pub fn new_unchecked(config: MaskConfig, data: BigUint) -> Self {\n        Self { data, config }\n    }\n\n    /// Creates a new mask unit from the given mask and masking configuration.\n    ///\n    /// # Errors\n    /// Fails if the mask unit doesn't conform to the given masking configuration.\n    pub fn new(config: MaskConfig, data: BigUint) -> Result<Self, InvalidMaskObjectError> {\n        let obj = Self::new_unchecked(config, data);\n        if obj.is_valid() {\n            Ok(obj)\n        } else {\n            Err(InvalidMaskObjectError)\n        }\n    }\n\n    /// Creates a new mask unit of given masking configuration with default value `1`.\n    pub fn default(config: MaskConfig) -> Self {\n        Self {\n            data: BigUint::from(1_u8),\n            config,\n        }\n    }\n\n    /// Checks if the data value conforms to the masking configuration.\n    pub fn is_valid(&self) -> bool {\n        self.data < self.config.order()\n    }\n}\n\n#[derive(Debug, Hash, PartialEq, Eq, Clone, Serialize, Deserialize)]\n/// A mask object consisting of a vector part and unit part.\npub struct MaskObject {\n    pub vect: MaskVect,\n    pub unit: MaskUnit,\n}\n\nimpl MaskObject {\n    /// Creates a new mask object from the given vector and unit.\n    pub fn new_unchecked(vect: MaskVect, unit: MaskUnit) -> Self {\n        Self { vect, unit }\n    }\n\n    /// Creates a new mask object from the given vector, unit and masking configurations.\n    pub fn new(\n        config: MaskConfigPair,\n        data_vect: Vec<BigUint>,\n        data_unit: BigUint,\n    ) -> Result<Self, InvalidMaskObjectError> {\n        let vect = MaskVect::new(config.vect, data_vect)?;\n        let unit = MaskUnit::new(config.unit, data_unit)?;\n        Ok(Self { vect, unit })\n    }\n\n    /// Creates a new empty mask object of given size and masking configurations.\n    pub fn empty(config: MaskConfigPair, size: usize) -> Self {\n        Self {\n            vect: MaskVect::empty(config.vect, size),\n            unit: MaskUnit::default(config.unit),\n        }\n    }\n\n    /// Checks if this mask object conforms to the masking configurations.\n    pub fn is_valid(&self) -> bool {\n        self.vect.is_valid() && self.unit.is_valid()\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/mask/object/serialization/mod.rs",
    "content": "//! Serialization of masked objects.\n//!\n//! See the [mask module] documentation since this is a private module anyways.\n//!\n//! [mask module]: crate::mask\n\npub(crate) mod unit;\npub(crate) mod vect;\n\nuse anyhow::Context;\n\nuse crate::{\n    mask::object::{\n        serialization::{unit::MaskUnitBuffer, vect::MaskVectBuffer},\n        MaskObject,\n        MaskUnit,\n        MaskVect,\n    },\n    message::{\n        traits::{FromBytes, ToBytes},\n        DecodeError,\n    },\n};\n\n// target dependent maximum number of mask object elements\n#[cfg(target_pointer_width = \"16\")]\nconst MAX_NB: u32 = u16::MAX as u32;\n\n/// A buffer for serialized mask objects.\npub struct MaskObjectBuffer<T> {\n    inner: T,\n}\n\nimpl<T: AsRef<[u8]>> MaskObjectBuffer<T> {\n    /// Creates a new buffer from `bytes`.\n    ///\n    /// # Errors\n    /// Fails if the `bytes` don't conform to the required buffer length for mask objects.\n    pub fn new(bytes: T) -> Result<Self, DecodeError> {\n        let buffer = Self { inner: bytes };\n        buffer\n            .check_buffer_length()\n            .context(\"not a valid mask object\")?;\n        Ok(buffer)\n    }\n\n    /// Creates a new buffer from `bytes`.\n    pub fn new_unchecked(bytes: T) -> Self {\n        Self { inner: bytes }\n    }\n\n    /// Checks if this buffer conforms to the required buffer length for mask objects.\n    ///\n    /// # Errors\n    /// Fails if the buffer is too small.\n    pub fn check_buffer_length(&self) -> Result<(), DecodeError> {\n        let inner = self.inner.as_ref();\n        // check length of vect field\n        MaskVectBuffer::new(inner).context(\"invalid vector field\")?;\n        // check length of unit field\n        MaskUnitBuffer::new(&inner[self.unit_offset()..]).context(\"invalid unit field\")?;\n        Ok(())\n    }\n\n    /// Gets the vector part.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn vect(&self) -> &[u8] {\n        let len = self.unit_offset();\n        &self.inner.as_ref()[0..len]\n    }\n\n    /// Gets the offset of the unit field.\n    pub fn unit_offset(&self) -> usize {\n        let vect_buf = MaskVectBuffer::new_unchecked(self.inner.as_ref());\n        vect_buf.len()\n    }\n\n    /// Gets the unit part.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn unit(&self) -> &[u8] {\n        let offset = self.unit_offset();\n        &self.inner.as_ref()[offset..]\n    }\n\n    /// Gets the expected number of bytes of this buffer.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn len(&self) -> usize {\n        let unit_offset = self.unit_offset();\n        let unit_buf = MaskUnitBuffer::new_unchecked(&self.inner.as_ref()[unit_offset..]);\n        unit_offset + unit_buf.len()\n    }\n}\n\nimpl<T: AsRef<[u8]> + AsMut<[u8]>> MaskObjectBuffer<T> {\n    /// Gets the vector part.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn vect_mut(&mut self) -> &mut [u8] {\n        self.inner.as_mut()\n    }\n\n    /// Gets the unit part.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn unit_mut(&mut self) -> &mut [u8] {\n        let offset = self.unit_offset();\n        &mut self.inner.as_mut()[offset..]\n    }\n}\n\nimpl ToBytes for MaskObject {\n    fn buffer_length(&self) -> usize {\n        self.vect.buffer_length() + self.unit.buffer_length()\n    }\n\n    fn to_bytes<T: AsMut<[u8]> + AsRef<[u8]>>(&self, buffer: &mut T) {\n        let mut writer = MaskObjectBuffer::new_unchecked(buffer.as_mut());\n        self.vect.to_bytes(&mut writer.vect_mut());\n        self.unit.to_bytes(&mut writer.unit_mut());\n    }\n}\n\nimpl FromBytes for MaskObject {\n    fn from_byte_slice<T: AsRef<[u8]>>(buffer: &T) -> Result<Self, DecodeError> {\n        let reader = MaskObjectBuffer::new(buffer.as_ref())?;\n        let vect = MaskVect::from_byte_slice(&reader.vect()).context(\"invalid vector part\")?;\n        let unit = MaskUnit::from_byte_slice(&reader.unit()).context(\"invalid unit part\")?;\n        Ok(Self { vect, unit })\n    }\n\n    fn from_byte_stream<I: Iterator<Item = u8> + ExactSizeIterator>(\n        iter: &mut I,\n    ) -> Result<Self, DecodeError> {\n        let vect = MaskVect::from_byte_stream(iter).context(\"invalid vector part\")?;\n        let unit = MaskUnit::from_byte_stream(iter).context(\"invalid unit part\")?;\n        Ok(Self { vect, unit })\n    }\n}\n\n#[cfg(test)]\npub(crate) mod tests {\n    use super::*;\n    use crate::mask::{\n        config::{BoundType, DataType, GroupType, MaskConfig, ModelType},\n        object::serialization::{unit::tests::mask_unit, vect::tests::mask_vect},\n        MaskObject,\n    };\n\n    pub fn mask_config() -> (MaskConfig, Vec<u8>) {\n        // config.order() = 20_000_000_000_001 with this config, so the data\n        // should be stored on 6 bytes.\n        let config = MaskConfig {\n            group_type: GroupType::Integer,\n            data_type: DataType::I32,\n            bound_type: BoundType::B0,\n            model_type: ModelType::M3,\n        };\n        let bytes = vec![0x00, 0x02, 0x00, 0x03];\n        (config, bytes)\n    }\n\n    pub fn mask_object() -> (MaskObject, Vec<u8>) {\n        let (mask_vect, mask_vect_bytes) = mask_vect();\n        let (mask_unit, mask_unit_bytes) = mask_unit();\n        let obj = MaskObject::new_unchecked(mask_vect, mask_unit);\n        let bytes = [mask_vect_bytes.as_slice(), mask_unit_bytes.as_slice()].concat();\n\n        (obj, bytes)\n    }\n\n    #[test]\n    fn serialize_mask_object() {\n        let (mask_object, expected) = mask_object();\n        let mut buf = vec![0xff; 42];\n        mask_object.to_bytes(&mut buf);\n        assert_eq!(buf, expected);\n    }\n\n    #[test]\n    fn deserialize_mask_object() {\n        let (expected, bytes) = mask_object();\n        assert_eq!(MaskObject::from_byte_slice(&&bytes[..]).unwrap(), expected);\n    }\n\n    #[test]\n    fn deserialize_mask_object_from_stream() {\n        let (expected, bytes) = mask_object();\n        assert_eq!(\n            MaskObject::from_byte_stream(&mut bytes.into_iter()).unwrap(),\n            expected\n        );\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/mask/object/serialization/unit.rs",
    "content": "//! Serialization of masked units.\n//!\n//! See the [mask module] documentation since this is a private module anyways.\n//!\n//! [mask module]: crate::mask\n\nuse std::ops::Range;\n\nuse anyhow::{anyhow, Context};\nuse num::bigint::BigUint;\n\nuse crate::{\n    mask::{\n        config::{serialization::MASK_CONFIG_BUFFER_LEN, MaskConfig},\n        object::MaskUnit,\n    },\n    message::{\n        traits::{FromBytes, ToBytes},\n        utils::range,\n        DecodeError,\n    },\n};\n\nconst MASK_CONFIG_FIELD: Range<usize> = range(0, MASK_CONFIG_BUFFER_LEN);\n\n/// A buffer for serialized mask units.\npub struct MaskUnitBuffer<T> {\n    inner: T,\n}\n\nimpl<T: AsRef<[u8]>> MaskUnitBuffer<T> {\n    /// Creates a new buffer from `bytes`.\n    ///\n    /// # Errors\n    /// Fails if the `bytes` don't conform to the required buffer length for mask units.\n    pub fn new(bytes: T) -> Result<Self, DecodeError> {\n        let buffer = Self { inner: bytes };\n        buffer\n            .check_buffer_length()\n            .context(\"not a valid mask unit\")?;\n        Ok(buffer)\n    }\n\n    /// Creates a new buffer from `bytes`.\n    pub fn new_unchecked(bytes: T) -> Self {\n        Self { inner: bytes }\n    }\n\n    /// Checks if this buffer conforms to the required buffer length for mask units.\n    ///\n    /// # Errors\n    /// Fails if the buffer is too small.\n    pub fn check_buffer_length(&self) -> Result<(), DecodeError> {\n        let len = self.inner.as_ref().len();\n        if len < MASK_CONFIG_FIELD.end {\n            return Err(anyhow!(\n                \"invalid buffer length: {} < {}\",\n                len,\n                MASK_CONFIG_FIELD.end\n            ));\n        }\n\n        let total_expected_length = self.try_len()?;\n        if len < total_expected_length {\n            return Err(anyhow!(\n                \"invalid buffer length: expected {} bytes but buffer has only {} bytes\",\n                total_expected_length,\n                len\n            ));\n        }\n        Ok(())\n    }\n\n    /// Return the expected length of the underlying byte buffer,\n    /// based on the masking config field of numbers field. This is\n    /// similar to [`len()`] but cannot panic.\n    ///\n    /// [`len()`]: MaskUnitBuffer::len\n    pub fn try_len(&self) -> Result<usize, DecodeError> {\n        let config =\n            MaskConfig::from_byte_slice(&self.config()).context(\"invalid mask unit buffer\")?;\n        let data_length = config.bytes_per_number();\n        Ok(MASK_CONFIG_FIELD.end + data_length)\n    }\n\n    /// Gets the expected number of bytes of this buffer wrt to the masking configuration.\n    ///\n    /// # Panics\n    /// Panics if the serialized masking configuration is invalid.\n    pub fn len(&self) -> usize {\n        let config = MaskConfig::from_byte_slice(&self.config()).unwrap();\n        let data_length = config.bytes_per_number();\n        MASK_CONFIG_FIELD.end + data_length\n    }\n\n    /// Gets the serialized masking configuration.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn config(&self) -> &[u8] {\n        &self.inner.as_ref()[MASK_CONFIG_FIELD]\n    }\n\n    /// Gets the serialized mask unit element.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn data(&self) -> &[u8] {\n        &self.inner.as_ref()[MASK_CONFIG_FIELD.end..self.len()]\n    }\n}\n\nimpl<T: AsRef<[u8]> + AsMut<[u8]>> MaskUnitBuffer<T> {\n    /// Gets the serialized masking configuration.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn config_mut(&mut self) -> &mut [u8] {\n        &mut self.inner.as_mut()[MASK_CONFIG_FIELD]\n    }\n\n    /// Gets the serialized mask unit element.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn data_mut(&mut self) -> &mut [u8] {\n        let end = self.len();\n        &mut self.inner.as_mut()[MASK_CONFIG_FIELD.end..end]\n    }\n}\n\nimpl ToBytes for MaskUnit {\n    fn buffer_length(&self) -> usize {\n        MASK_CONFIG_FIELD.end + self.config.bytes_per_number()\n    }\n\n    fn to_bytes<T: AsMut<[u8]> + AsRef<[u8]>>(&self, buffer: &mut T) {\n        let mut writer = MaskUnitBuffer::new_unchecked(buffer.as_mut());\n        self.config.to_bytes(&mut writer.config_mut());\n\n        let data = writer.data_mut();\n        // FIXME: this allocates a vec which is sub-optimal. See\n        // https://github.com/rust-num/num-bigint/issues/152\n        let bytes = self.data.to_bytes_le();\n        // This may panic if the data is invalid and is an\n        // integer that is bigger than what is expected by the\n        // configuration.\n        data[..bytes.len()].copy_from_slice(&bytes[..]);\n        // padding\n        for b in data\n            .iter_mut()\n            .take(self.config.bytes_per_number())\n            .skip(bytes.len())\n        {\n            *b = 0;\n        }\n    }\n}\n\nimpl FromBytes for MaskUnit {\n    fn from_byte_slice<T: AsRef<[u8]>>(buffer: &T) -> Result<Self, DecodeError> {\n        let reader = MaskUnitBuffer::new(buffer.as_ref())?;\n        let config = MaskConfig::from_byte_slice(&reader.config())?;\n        let data = BigUint::from_bytes_le(reader.data());\n\n        Ok(MaskUnit { data, config })\n    }\n\n    fn from_byte_stream<I: Iterator<Item = u8> + ExactSizeIterator>(\n        iter: &mut I,\n    ) -> Result<Self, DecodeError> {\n        let config = MaskConfig::from_byte_stream(iter)?;\n        if iter.len() < 4 {\n            return Err(anyhow!(\"byte stream exhausted\"));\n        }\n        let data_len = config.bytes_per_number();\n        if iter.len() < data_len {\n            return Err(anyhow!(\n                \"mask unit is {} bytes long but byte stream only has {} bytes\",\n                data_len,\n                iter.len()\n            ));\n        }\n\n        let mut buf = vec![0; data_len];\n        for (i, b) in iter.take(data_len).enumerate() {\n            buf[i] = b;\n        }\n        let data = BigUint::from_bytes_le(buf.as_slice());\n\n        Ok(MaskUnit { data, config })\n    }\n}\n\n#[cfg(test)]\npub(crate) mod tests {\n    use super::*;\n    use crate::mask::object::serialization::tests::mask_config;\n\n    pub fn mask_unit() -> (MaskUnit, Vec<u8>) {\n        let (config, mut bytes) = mask_config();\n        let data = BigUint::from(1_u8);\n        let mask_unit = MaskUnit::new_unchecked(config, data);\n\n        bytes.extend(vec![\n            // data (6 bytes with this config)\n            0x01, 0x00, 0x00, 0x00, 0x00, 0x00, // 1\n        ]);\n        (mask_unit, bytes)\n    }\n\n    #[test]\n    fn serialize_mask_unit() {\n        let (mask_unit, expected) = mask_unit();\n        let mut buf = vec![0xff; expected.len()];\n        mask_unit.to_bytes(&mut buf);\n        assert_eq!(buf, expected);\n    }\n\n    #[test]\n    fn deserialize_mask_unit() {\n        let (expected, bytes) = mask_unit();\n        assert_eq!(MaskUnit::from_byte_slice(&&bytes[..]).unwrap(), expected);\n    }\n\n    #[test]\n    fn deserialize_mask_unit_from_stream() {\n        let (expected, bytes) = mask_unit();\n        assert_eq!(\n            MaskUnit::from_byte_stream(&mut bytes.into_iter()).unwrap(),\n            expected\n        );\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/mask/object/serialization/vect.rs",
    "content": "//! Serialization of masked vectors.\n//!\n//! See the [mask module] documentation since this is a private module anyways.\n//!\n//! [mask module]: crate::mask\n\nuse std::{convert::TryInto, ops::Range};\n\nuse anyhow::{anyhow, Context};\nuse num::bigint::BigUint;\n\nuse crate::{\n    mask::{\n        config::{serialization::MASK_CONFIG_BUFFER_LEN, MaskConfig},\n        object::MaskVect,\n    },\n    message::{\n        traits::{FromBytes, ToBytes},\n        utils::{range, ChunkableIterator},\n        DecodeError,\n    },\n};\n\nconst MASK_CONFIG_FIELD: Range<usize> = range(0, MASK_CONFIG_BUFFER_LEN);\nconst NUMBERS_FIELD: Range<usize> = range(MASK_CONFIG_FIELD.end, 4);\n\n// target dependent maximum number of mask object elements\n#[cfg(target_pointer_width = \"16\")]\nconst MAX_NB: u32 = u16::MAX as u32;\n\n/// A buffer for serialized mask vectors.\npub struct MaskVectBuffer<T> {\n    inner: T,\n}\n\n#[allow(clippy::len_without_is_empty)]\nimpl<T: AsRef<[u8]>> MaskVectBuffer<T> {\n    /// Creates a new buffer from `bytes`.\n    ///\n    /// # Errors\n    /// Fails if the `bytes` don't conform to the required buffer length for mask vectors.\n    pub fn new(bytes: T) -> Result<Self, DecodeError> {\n        let buffer = Self { inner: bytes };\n        buffer\n            .check_buffer_length()\n            .context(\"not a valid mask vector\")?;\n        Ok(buffer)\n    }\n\n    /// Creates a new buffer from `bytes`.\n    pub fn new_unchecked(bytes: T) -> Self {\n        Self { inner: bytes }\n    }\n\n    /// Checks if this buffer conforms to the required buffer length for mask vectors.\n    ///\n    /// # Errors\n    /// Fails if the buffer is too small.\n    pub fn check_buffer_length(&self) -> Result<(), DecodeError> {\n        let len = self.inner.as_ref().len();\n        if len < NUMBERS_FIELD.end {\n            return Err(anyhow!(\n                \"invalid buffer length: {} < {}\",\n                len,\n                NUMBERS_FIELD.end\n            ));\n        }\n\n        let total_expected_length = self.try_len()?;\n        if len < total_expected_length {\n            return Err(anyhow!(\n                \"invalid buffer length: expected {} bytes but buffer has only {} bytes\",\n                total_expected_length,\n                len\n            ));\n        }\n        Ok(())\n    }\n\n    /// Return the expected length of the underlying byte buffer,\n    /// based on the masking config field of numbers field. This is\n    /// similar to [`len()`] but cannot panic.\n    ///\n    /// [`len()`]: MaskVectBuffer::len\n    fn try_len(&self) -> Result<usize, DecodeError> {\n        let config =\n            MaskConfig::from_byte_slice(&self.config()).context(\"invalid mask vector buffer\")?;\n        let bytes_per_number = config.bytes_per_number();\n        let (data_length, overflows) = self.numbers().overflowing_mul(bytes_per_number);\n        if overflows {\n            return Err(anyhow!(\n                \"invalid MaskObject buffer: invalid masking config or numbers field\"\n            ));\n        }\n        Ok(NUMBERS_FIELD.end + data_length)\n    }\n\n    /// Gets the expected number of bytes of this buffer wrt to the masking configuration.\n    ///\n    /// # Panics\n    /// Panics if the serialized masking configuration is invalid.\n    pub fn len(&self) -> usize {\n        let config = MaskConfig::from_byte_slice(&self.config()).unwrap();\n        let bytes_per_number = config.bytes_per_number();\n        let data_length = self.numbers() * bytes_per_number;\n        NUMBERS_FIELD.end + data_length\n    }\n\n    /// Gets the number of serialized mask object elements.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    ///\n    /// Panics if the number can't be represented as usize on targets smaller than 32 bits.\n    pub fn numbers(&self) -> usize {\n        // UNWRAP SAFE: the slice is exactly 4 bytes long\n        let nb = u32::from_be_bytes(self.inner.as_ref()[NUMBERS_FIELD].try_into().unwrap());\n\n        // smaller targets than 32 bits are currently not of interest\n        #[cfg(target_pointer_width = \"16\")]\n        if nb > MAX_NB {\n            panic!(\"16 bit targets or smaller are currently not fully supported\")\n        }\n\n        nb as usize\n    }\n\n    /// Gets the serialized masking configuration.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn config(&self) -> &[u8] {\n        &self.inner.as_ref()[MASK_CONFIG_FIELD]\n    }\n\n    /// Gets the serialized mask vector elements.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn data(&self) -> &[u8] {\n        &self.inner.as_ref()[NUMBERS_FIELD.end..self.len()]\n    }\n}\n\nimpl<T: AsRef<[u8]> + AsMut<[u8]>> MaskVectBuffer<T> {\n    /// Sets the number of serialized mask vector elements.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn set_numbers(&mut self, value: u32) {\n        self.inner.as_mut()[NUMBERS_FIELD].copy_from_slice(&value.to_be_bytes());\n    }\n\n    /// Gets the serialized masking configuration.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn config_mut(&mut self) -> &mut [u8] {\n        &mut self.inner.as_mut()[MASK_CONFIG_FIELD]\n    }\n\n    /// Gets the serialized mask vector elements.\n    ///\n    /// # Panics\n    /// May panic if this buffer is unchecked.\n    pub fn data_mut(&mut self) -> &mut [u8] {\n        let end = self.len();\n        &mut self.inner.as_mut()[NUMBERS_FIELD.end..end]\n    }\n}\n\nimpl ToBytes for MaskVect {\n    fn buffer_length(&self) -> usize {\n        NUMBERS_FIELD.end + self.config.bytes_per_number() * self.data.len()\n    }\n\n    fn to_bytes<T: AsMut<[u8]>>(&self, buffer: &mut T) {\n        let mut writer = MaskVectBuffer::new_unchecked(buffer.as_mut());\n        self.config.to_bytes(&mut writer.config_mut());\n        writer.set_numbers(self.data.len() as u32);\n\n        let mut data = writer.data_mut();\n        let bytes_per_number = self.config.bytes_per_number();\n\n        for int in self.data.iter() {\n            // FIXME: this allocates a vec which is sub-optimal. See\n            // https://github.com/rust-num/num-bigint/issues/152\n            let bytes = int.to_bytes_le();\n            // This may panic if the data is invalid and contains\n            // integers that are bigger than what is expected by the\n            // configuration.\n            data[..bytes.len()].copy_from_slice(&bytes[..]);\n            // padding\n            for b in data.iter_mut().take(bytes_per_number).skip(bytes.len()) {\n                *b = 0;\n            }\n            data = &mut data[bytes_per_number..];\n        }\n    }\n}\n\nimpl FromBytes for MaskVect {\n    fn from_byte_slice<T: AsRef<[u8]>>(buffer: &T) -> Result<Self, DecodeError> {\n        let reader = MaskVectBuffer::new(buffer.as_ref())?;\n\n        let config = MaskConfig::from_byte_slice(&reader.config())?;\n        let mut data = Vec::with_capacity(reader.numbers());\n        let bytes_per_number = config.bytes_per_number();\n        for chunk in reader.data().chunks(bytes_per_number) {\n            data.push(BigUint::from_bytes_le(chunk));\n        }\n\n        Ok(MaskVect { data, config })\n    }\n\n    fn from_byte_stream<I: Iterator<Item = u8> + ExactSizeIterator>(\n        iter: &mut I,\n    ) -> Result<Self, DecodeError> {\n        let config = MaskConfig::from_byte_stream(iter)?;\n        if iter.len() < 4 {\n            return Err(anyhow!(\"byte stream exhausted\"));\n        }\n        let numbers = u32::from_byte_stream(iter)\n            .context(\"failed to parse the number of items in mask vector\")?;\n        let bytes_per_number = config.bytes_per_number();\n\n        let data_len = numbers as usize * bytes_per_number;\n        if iter.len() < data_len {\n            return Err(anyhow!(\n                \"mask vector is {} bytes long but byte stream only has {} bytes\",\n                data_len,\n                iter.len()\n            ));\n        }\n\n        let mut data = Vec::with_capacity(numbers as usize);\n        let mut buf = vec![0; bytes_per_number];\n        for chunk in iter.take(data_len).chunks(bytes_per_number).into_iter() {\n            for (i, b) in chunk.enumerate() {\n                buf[i] = b;\n            }\n            data.push(BigUint::from_bytes_le(buf.as_slice()));\n        }\n\n        Ok(MaskVect { data, config })\n    }\n}\n\n#[cfg(test)]\npub(crate) mod tests {\n    use super::*;\n\n    use crate::mask::object::serialization::tests::mask_config;\n\n    pub fn mask_vect() -> (MaskVect, Vec<u8>) {\n        let (config, mut bytes) = mask_config();\n        let data = vec![\n            BigUint::from(1_u8),\n            BigUint::from(2_u8),\n            BigUint::from(3_u8),\n            BigUint::from(4_u8),\n        ];\n        let mask_vect = MaskVect::new_unchecked(config, data);\n\n        bytes.extend(vec![\n            // number of elements\n            0x00, 0x00, 0x00, 0x04, // data (1 weight => 6 bytes with this config)\n            0x01, 0x00, 0x00, 0x00, 0x00, 0x00, // 1\n            0x02, 0x00, 0x00, 0x00, 0x00, 0x00, // 2\n            0x03, 0x00, 0x00, 0x00, 0x00, 0x00, // 3\n            0x04, 0x00, 0x00, 0x00, 0x00, 0x00, // 4\n        ]);\n\n        (mask_vect, bytes)\n    }\n\n    #[test]\n    fn serialize_mask_vect() {\n        let (mask_vect, expected) = mask_vect();\n        let mut buf = vec![0xff; expected.len()];\n        mask_vect.to_bytes(&mut buf);\n        assert_eq!(buf, expected);\n    }\n\n    #[test]\n    fn deserialize_mask_vect() {\n        let (expected, bytes) = mask_vect();\n        assert_eq!(MaskVect::from_byte_slice(&&bytes[..]).unwrap(), expected);\n    }\n\n    #[test]\n    fn deserialize_mask_vect_from_stream() {\n        let (expected, bytes) = mask_vect();\n        assert_eq!(\n            MaskVect::from_byte_stream(&mut bytes.into_iter()).unwrap(),\n            expected\n        );\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/mask/scalar.rs",
    "content": "//! Scalar representation and conversion.\n//!\n//! See the [mask module] documentation since this is a private module anyways.\n//!\n//! [mask module]: crate::mask\n\nuse crate::mask::{\n    model::{ratio_to_float, PrimitiveType},\n    PrimitiveCastError,\n};\nuse derive_more::{From, Into};\nuse num::{\n    clamp,\n    rational::Ratio,\n    traits::{float::FloatCore, ToPrimitive},\n    BigInt,\n    BigUint,\n    One,\n    Unsigned,\n    Zero,\n};\nuse serde::{Deserialize, Serialize};\nuse std::{\n    convert::{TryFrom, TryInto},\n    fmt::Debug,\n};\nuse thiserror::Error;\n\n#[derive(Debug, Clone, PartialEq, Hash, From, Into, Serialize, Deserialize)]\n/// A numerical representation of a machine learning scalar.\npub struct Scalar(Ratio<BigUint>);\n\nimpl From<Scalar> for Ratio<BigInt> {\n    fn from(scalar: Scalar) -> Self {\n        let (numer, denom) = scalar.0.into();\n        Ratio::new(numer.into(), denom.into())\n    }\n}\n\nimpl TryFrom<Ratio<BigInt>> for Scalar {\n    type Error = <BigUint as TryFrom<BigInt>>::Error;\n\n    fn try_from(ratio: Ratio<BigInt>) -> Result<Self, Self::Error> {\n        let (numer, denom) = ratio.into();\n        Ok(Self(Ratio::new(numer.try_into()?, denom.try_into()?)))\n    }\n}\n\nimpl Scalar {\n    /// Constructs a new `Scalar` from the given numerator and denominator.\n    pub fn new<U>(numer: U, denom: U) -> Self\n    where\n        U: Unsigned + Into<BigUint>,\n    {\n        Self(Ratio::new(numer.into(), denom.into()))\n    }\n\n    /// Constructs a `Scalar` representing the given integer.\n    pub fn from_integer<U>(u: U) -> Self\n    where\n        U: Unsigned + Into<BigUint>,\n    {\n        Self(Ratio::from_integer(u.into()))\n    }\n\n    /// Constructs a `Scalar` of unit value.\n    pub fn unit() -> Self {\n        Self(Ratio::one())\n    }\n\n    /// Convenience method for conversion to a non-negative ratio of `BigInt`.\n    pub(crate) fn to_ratio(&self) -> Ratio<BigInt> {\n        self.clone().into()\n    }\n\n    /// Constructs a `Scalar` from a primitive floating point value, clamped where necessary.\n    ///\n    /// Maps positive infinity to max of the primitive data type, negatives and NaN to zero.\n    pub(crate) fn from_float_bounded<F: FloatCore>(f: F) -> Self {\n        if f.is_nan() {\n            Self(Ratio::zero())\n        } else {\n            let finite_f = clamp(f, F::zero(), F::max_value());\n            // safe unwrap: clamped weight is guaranteed to be finite\n            let r = Ratio::from_float(finite_f).unwrap();\n            // safe unwrap: bounded non-negative ratio r\n            r.try_into().unwrap()\n        }\n    }\n}\n\n#[derive(Error, Debug)]\n#[error(\"Could not convert weight {weight} to primitive type {target}\")]\n/// Errors related to scalar conversion into primitives.\npub struct ScalarCastError {\n    weight: Ratio<BigUint>,\n    target: PrimitiveType,\n}\n\n/// An interface for conversion into a primitive value.\n///\n/// This trait is used to convert a [`Scalar`], which has its own internal\n/// representation, into a primitive type ([`f32`], [`f64`], [`i32`], [`i64`]).\n/// The opposite trait is [`FromPrimitive`].\npub trait IntoPrimitive<P>: Sized {\n    /// Consumes into a converted primitive value.\n    ///\n    /// # Errors\n    /// Returns an error if the conversion fails.\n    fn into_primitive(self) -> Result<P, ScalarCastError>;\n\n    /// Converts to a primitive value.\n    ///\n    /// # Errors\n    /// Returns an error if the conversion fails.\n    fn to_primitive(&self) -> Result<P, ScalarCastError>;\n\n    /// Consumes into a converted primitive value.\n    ///\n    /// # Panics\n    /// Panics if the conversion fails.\n    fn into_primitive_unchecked(self) -> P {\n        self.into_primitive()\n            .expect(\"conversion to primitive type failed\")\n    }\n}\n\n/// An interface for conversion from a primitive value.\n///\n/// This trait is used to obtain a [`Scalar`], which has its own representation,\n/// from a primitive type ([`f32`], [`f64`], [`i32`], [`i64`]). The opposite\n/// trait is [`IntoPrimitive`].\npub trait FromPrimitive<P: Debug>: Sized {\n    /// Converts from a primitive value.\n    ///\n    /// # Errors\n    /// Returns an error if the conversion fails.\n    fn from_primitive(prim: P) -> Result<Self, PrimitiveCastError<P>>;\n\n    /// Converts from a primitive value.\n    ///\n    /// If a direct conversion cannot be obtained from the primitive value, it is clamped.\n    fn from_primitive_bounded(prim: P) -> Self;\n}\n\nimpl IntoPrimitive<i32> for Scalar {\n    fn into_primitive(self) -> Result<i32, ScalarCastError> {\n        let r = self.0;\n        r.to_integer().to_i32().ok_or(ScalarCastError {\n            weight: r,\n            target: PrimitiveType::I32,\n        })\n    }\n\n    fn to_primitive(&self) -> Result<i32, ScalarCastError> {\n        self.clone().into_primitive()\n    }\n}\n\nimpl FromPrimitive<i32> for Scalar {\n    fn from_primitive(prim: i32) -> Result<Self, PrimitiveCastError<i32>> {\n        let i = BigUint::try_from(prim).map_err(|_| PrimitiveCastError(prim))?;\n        Ok(Self(Ratio::from_integer(i)))\n    }\n\n    fn from_primitive_bounded(prim: i32) -> Self {\n        Self::from_primitive(prim).unwrap_or_else(|_| Self(Ratio::zero()))\n    }\n}\n\nimpl IntoPrimitive<i64> for Scalar {\n    fn into_primitive(self) -> Result<i64, ScalarCastError> {\n        let i = self.0;\n        i.to_integer().to_i64().ok_or(ScalarCastError {\n            weight: i,\n            target: PrimitiveType::I64,\n        })\n    }\n\n    fn to_primitive(&self) -> Result<i64, ScalarCastError> {\n        self.clone().into_primitive()\n    }\n}\n\nimpl FromPrimitive<i64> for Scalar {\n    fn from_primitive(prim: i64) -> Result<Self, PrimitiveCastError<i64>> {\n        let i = BigUint::try_from(prim).map_err(|_| PrimitiveCastError(prim))?;\n        Ok(Self(Ratio::from_integer(i)))\n    }\n\n    fn from_primitive_bounded(prim: i64) -> Self {\n        Self::from_primitive(prim).unwrap_or_else(|_| Self(Ratio::zero()))\n    }\n}\n\nimpl IntoPrimitive<f32> for Scalar {\n    fn into_primitive(self) -> Result<f32, ScalarCastError> {\n        let r = self.to_ratio();\n        ratio_to_float(&r).ok_or(ScalarCastError {\n            weight: self.0,\n            target: PrimitiveType::F32,\n        })\n    }\n\n    fn to_primitive(&self) -> Result<f32, ScalarCastError> {\n        self.clone().into_primitive()\n    }\n}\n\nimpl FromPrimitive<f32> for Scalar {\n    fn from_primitive(prim: f32) -> Result<Self, PrimitiveCastError<f32>> {\n        let r = Ratio::from_float(prim).ok_or(PrimitiveCastError(prim))?;\n        r.try_into().map_err(|_| PrimitiveCastError(prim))\n    }\n\n    fn from_primitive_bounded(prim: f32) -> Self {\n        Self::from_float_bounded(prim)\n    }\n}\n\nimpl IntoPrimitive<f64> for Scalar {\n    fn into_primitive(self) -> Result<f64, ScalarCastError> {\n        let r = self.to_ratio();\n        ratio_to_float(&r).ok_or(ScalarCastError {\n            weight: self.0,\n            target: PrimitiveType::F64,\n        })\n    }\n\n    fn to_primitive(&self) -> Result<f64, ScalarCastError> {\n        self.clone().into_primitive()\n    }\n}\n\nimpl FromPrimitive<f64> for Scalar {\n    fn from_primitive(prim: f64) -> Result<Self, PrimitiveCastError<f64>> {\n        let r = Ratio::from_float(prim).ok_or(PrimitiveCastError(prim))?;\n        r.try_into().map_err(|_| PrimitiveCastError(prim))\n    }\n\n    fn from_primitive_bounded(prim: f64) -> Self {\n        Self::from_float_bounded(prim)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn test_ratio_conversion() {\n        let (numer, denom) = (1_u8, 2_u8);\n        let expected_ratio = Ratio::new(BigInt::from(numer), BigInt::from(denom));\n        let actual_ratio = Scalar::new(numer, denom).into();\n        assert_eq!(expected_ratio, actual_ratio);\n    }\n\n    #[test]\n    fn test_ratio_conversion_ok() {\n        let (numer, denom) = (1_u8, 2_u8);\n        let ratio = Ratio::new(BigInt::from(numer), BigInt::from(denom));\n        let sc_res = Scalar::try_from(ratio);\n        assert!(sc_res.is_ok());\n        assert_eq!(sc_res.unwrap(), Scalar::new(numer, denom));\n    }\n\n    #[test]\n    fn test_ratio_conversion_err() {\n        let neg_ratio = Ratio::new(BigInt::from(-1), BigInt::from(2));\n        let sc_res = Scalar::try_from(neg_ratio);\n        assert!(sc_res.is_err());\n    }\n\n    #[test]\n    #[allow(clippy::float_cmp)]\n    fn test_scalar_f32() {\n        let prim_sc_pairs = vec![\n            (0_f32, Scalar::from_integer(0_u8)),\n            (2_f32, Scalar::from_integer(2_u8)),\n            (0.5_f32, Scalar::new(1_u8, 2_u8)),\n        ];\n        for (prim, sc) in prim_sc_pairs {\n            let converted_sc = Scalar::from_primitive(prim);\n            assert!(converted_sc.is_ok());\n            assert_eq!(converted_sc.unwrap(), sc);\n\n            let converted_sc = Scalar::from_primitive_bounded(prim);\n            assert_eq!(converted_sc, sc);\n\n            let converted_prim: f32 = sc.into_primitive_unchecked();\n            assert_eq!(converted_prim, prim);\n        }\n    }\n\n    #[test]\n    fn test_scalar_f32_from_weird_prims() {\n        let prim_pairs = vec![\n            (f32::INFINITY, f32::MAX),\n            (-1_f32, 0_f32),\n            (f32::NAN, 0_f32),\n        ];\n        for (weird, fine) in prim_pairs {\n            let weird_res = Scalar::from_primitive(weird);\n            assert!(weird_res.is_err());\n\n            let bounded = Scalar::from_primitive_bounded(weird);\n            let fine_res = Scalar::try_from(Ratio::from_float(fine).unwrap());\n            assert!(fine_res.is_ok());\n            assert_eq!(bounded, fine_res.unwrap());\n        }\n    }\n\n    #[test]\n    #[allow(clippy::float_cmp)]\n    fn test_scalar_f64() {\n        let prim_sc_pairs = vec![\n            (0_f64, Scalar::from_integer(0_u8)),\n            (2_f64, Scalar::from_integer(2_u8)),\n            (0.5_f64, Scalar::new(1_u8, 2_u8)),\n        ];\n        for (prim, sc) in prim_sc_pairs {\n            let converted_sc = Scalar::from_primitive(prim);\n            assert!(converted_sc.is_ok());\n            assert_eq!(converted_sc.unwrap(), sc);\n\n            let converted_sc = Scalar::from_primitive_bounded(prim);\n            assert_eq!(converted_sc, sc);\n\n            let converted_prim: f64 = sc.into_primitive_unchecked();\n            assert_eq!(converted_prim, prim);\n        }\n    }\n\n    #[test]\n    fn test_scalar_f64_from_weird_prims() {\n        let prim_pairs = vec![\n            (f64::INFINITY, f64::MAX),\n            (-1_f64, 0_f64),\n            (f64::NAN, 0_f64),\n        ];\n        for (weird, fine) in prim_pairs {\n            let weird_res = Scalar::from_primitive(weird);\n            assert!(weird_res.is_err());\n\n            let bounded = Scalar::from_primitive_bounded(weird);\n            let fine_res = Scalar::try_from(Ratio::from_float(fine).unwrap());\n            assert!(fine_res.is_ok());\n            assert_eq!(bounded, fine_res.unwrap());\n        }\n    }\n\n    #[test]\n    fn test_scalar_i32() {\n        let prim_sc_pairs = vec![\n            (0_i32, Scalar::from_integer(0_u8)),\n            (2_i32, Scalar::from_integer(2_u8)),\n        ];\n        for (prim, sc) in prim_sc_pairs {\n            let converted_sc = Scalar::from_primitive(prim);\n            assert!(converted_sc.is_ok());\n            assert_eq!(converted_sc.unwrap(), sc);\n\n            let converted_sc = Scalar::from_primitive_bounded(prim);\n            assert_eq!(converted_sc, sc);\n\n            let converted_prim: i32 = sc.into_primitive_unchecked();\n            assert_eq!(converted_prim, prim);\n        }\n    }\n\n    #[test]\n    fn test_scalar_i64() {\n        let prim_sc_pairs = vec![\n            (0_i64, Scalar::from_integer(0_u8)),\n            (2_i64, Scalar::from_integer(2_u8)),\n        ];\n        for (prim, sc) in prim_sc_pairs {\n            let converted_sc = Scalar::from_primitive(prim);\n            assert!(converted_sc.is_ok());\n            assert_eq!(converted_sc.unwrap(), sc);\n\n            let converted_sc = Scalar::from_primitive_bounded(prim);\n            assert_eq!(converted_sc, sc);\n\n            let converted_prim: i64 = sc.into_primitive_unchecked();\n            assert_eq!(converted_prim, prim);\n        }\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/mask/seed.rs",
    "content": "//! Mask seed and mask generation.\n//!\n//! See the [mask module] documentation since this is a private module anyways.\n//!\n//! [mask module]:  crate::mask\n\nuse std::iter;\n\nuse derive_more::{AsMut, AsRef};\nuse rand::SeedableRng;\nuse rand_chacha::ChaCha20Rng;\nuse serde::{Deserialize, Serialize};\nuse sodiumoxide::crypto::box_;\nuse thiserror::Error;\n\nuse crate::{\n    crypto::{encrypt::SEALBYTES, prng::generate_integer, ByteObject},\n    mask::{\n        object::{MaskObject, MaskUnit, MaskVect},\n        MaskConfigPair,\n    },\n    SumParticipantEphemeralPublicKey,\n    SumParticipantEphemeralSecretKey,\n};\n\n#[derive(AsRef, AsMut, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]\n/// A seed to generate a mask.\n///\n/// When this goes out of scope, its contents will be zeroed out.\npub struct MaskSeed(box_::Seed);\n\nimpl ByteObject for MaskSeed {\n    const LENGTH: usize = box_::SEEDBYTES;\n\n    fn from_slice(bytes: &[u8]) -> Option<Self> {\n        box_::Seed::from_slice(bytes).map(Self)\n    }\n\n    fn zeroed() -> Self {\n        Self(box_::Seed([0_u8; Self::LENGTH]))\n    }\n\n    fn as_slice(&self) -> &[u8] {\n        self.0.as_ref()\n    }\n}\n\nimpl MaskSeed {\n    /// Gets this seed as an array.\n    pub fn as_array(&self) -> [u8; Self::LENGTH] {\n        (self.0).0\n    }\n\n    /// Encrypts this seed with the given public key as an [`EncryptedMaskSeed`].\n    pub fn encrypt(&self, pk: &SumParticipantEphemeralPublicKey) -> EncryptedMaskSeed {\n        // safe unwrap: length of slice is guaranteed by constants\n        EncryptedMaskSeed::from_slice_unchecked(pk.encrypt(self.as_slice()).as_slice())\n    }\n\n    /// Derives a mask of given length from this seed wrt the masking configurations.\n    pub fn derive_mask(&self, len: usize, config: MaskConfigPair) -> MaskObject {\n        let MaskConfigPair {\n            vect: config_n,\n            unit: config_1,\n        } = config;\n        let mut prng = ChaCha20Rng::from_seed(self.as_array());\n\n        let rand_int = generate_integer(&mut prng, &config_1.order());\n        let scalar_mask = MaskUnit::new_unchecked(config_1, rand_int);\n\n        let order_n = config_n.order();\n        let rand_ints = iter::repeat_with(|| generate_integer(&mut prng, &order_n))\n            .take(len)\n            .collect();\n        let model_mask = MaskVect::new_unchecked(config_n, rand_ints);\n\n        MaskObject::new_unchecked(model_mask, scalar_mask)\n    }\n}\n\n#[derive(AsRef, AsMut, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]\n/// An encrypted mask seed.\npub struct EncryptedMaskSeed(Vec<u8>);\n\nimpl From<Vec<u8>> for EncryptedMaskSeed {\n    fn from(value: Vec<u8>) -> Self {\n        Self(value)\n    }\n}\n\nimpl ByteObject for EncryptedMaskSeed {\n    const LENGTH: usize = SEALBYTES + MaskSeed::LENGTH;\n\n    fn from_slice(bytes: &[u8]) -> Option<Self> {\n        if bytes.len() == Self::LENGTH {\n            Some(Self(bytes.to_vec()))\n        } else {\n            None\n        }\n    }\n\n    fn zeroed() -> Self {\n        Self(vec![0_u8; Self::LENGTH])\n    }\n\n    fn as_slice(&self) -> &[u8] {\n        self.0.as_slice()\n    }\n}\n\n#[derive(Debug, Error)]\npub enum InvalidMaskSeed {\n    #[error(\"the encrypted mask seed could not be decrypted\")]\n    DecryptionFailed,\n    #[error(\"the mask seed has an invalid length\")]\n    InvalidLength,\n}\n\nimpl EncryptedMaskSeed {\n    /// Decrypts this seed as a [`MaskSeed`].\n    ///\n    /// # Errors\n    /// Fails if the decryption fails.\n    pub fn decrypt(\n        &self,\n        pk: &SumParticipantEphemeralPublicKey,\n        sk: &SumParticipantEphemeralSecretKey,\n    ) -> Result<MaskSeed, InvalidMaskSeed> {\n        MaskSeed::from_slice(\n            sk.decrypt(self.as_slice(), pk)\n                .or(Err(InvalidMaskSeed::DecryptionFailed))?\n                .as_slice(),\n        )\n        .ok_or(InvalidMaskSeed::InvalidLength)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::{\n        crypto::encrypt::EncryptKeyPair,\n        mask::config::{BoundType, DataType, GroupType, MaskConfig, ModelType},\n    };\n\n    #[test]\n    fn test_constants() {\n        assert_eq!(MaskSeed::LENGTH, 32);\n        assert_eq!(\n            MaskSeed::zeroed().as_slice(),\n            [0_u8; 32].to_vec().as_slice(),\n        );\n        assert_eq!(EncryptedMaskSeed::LENGTH, 80);\n        assert_eq!(\n            EncryptedMaskSeed::zeroed().as_slice(),\n            [0_u8; 80].to_vec().as_slice(),\n        );\n    }\n\n    #[test]\n    fn test_derive_mask() {\n        let config = MaskConfig {\n            group_type: GroupType::Prime,\n            data_type: DataType::F32,\n            bound_type: BoundType::B0,\n            model_type: ModelType::M3,\n        };\n        let seed = MaskSeed::generate();\n        let mask = seed.derive_mask(10, config.into());\n        assert_eq!(mask.vect.data.len(), 10);\n        assert!(mask\n            .vect\n            .data\n            .iter()\n            .all(|integer| integer < &config.order()));\n    }\n\n    #[test]\n    fn test_encryption() {\n        let seed = MaskSeed::generate();\n        assert_eq!(seed.as_slice().len(), 32);\n        assert_ne!(seed, MaskSeed::zeroed());\n        let EncryptKeyPair { public, secret } = EncryptKeyPair::generate();\n        let encr_seed = seed.encrypt(&public);\n        assert_eq!(encr_seed.as_slice().len(), 80);\n        let decr_seed = encr_seed.decrypt(&public, &secret).unwrap();\n        assert_eq!(seed, decr_seed);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/message/message.rs",
    "content": "//! Message buffers.\n//!\n//! See the [message module] documentation since this is a private module anyways.\n//!\n//! [message module]: crate::mask\n\nuse std::convert::{TryFrom, TryInto};\n\nuse anyhow::{anyhow, Context};\nuse serde::{Deserialize, Serialize};\n\nuse crate::{\n    crypto::{ByteObject, PublicEncryptKey, PublicSigningKey, SecretSigningKey, Signature},\n    message::{Chunk, DecodeError, FromBytes, Payload, Sum, Sum2, ToBytes, Update},\n};\n\n/// The minimum number of accepted `sum`/`sum2` messages for the PET protocol to function correctly.\npub const SUM_COUNT_MIN: u64 = 1;\n\n/// The minimum number of accepted `update` messages for the PET protocol to function correctly.\npub const UPDATE_COUNT_MIN: u64 = 3;\n\npub(crate) mod ranges {\n    use std::ops::Range;\n\n    use super::*;\n    use crate::message::utils::range;\n\n    /// Byte range corresponding to the signature in a message in a\n    /// message header\n    pub const SIGNATURE: Range<usize> = range(0, Signature::LENGTH);\n    /// Byte range corresponding to the participant public key in a\n    /// message header\n    pub const PARTICIPANT_PK: Range<usize> = range(SIGNATURE.end, PublicSigningKey::LENGTH);\n    /// Byte range corresponding to the coordinator public key in a\n    /// message header\n    pub const COORDINATOR_PK: Range<usize> = range(PARTICIPANT_PK.end, PublicEncryptKey::LENGTH);\n    /// Byte range corresponding to the length field in a message header\n    pub const LENGTH: Range<usize> = range(COORDINATOR_PK.end, 4);\n    /// Byte range corresponding to the tag in a message header\n    pub const TAG: usize = LENGTH.end;\n    /// Byte range corresponding to the flags in a message header\n    pub const FLAGS: usize = TAG + 1;\n    /// Byte range reserved for future use\n    pub const RESERVED: Range<usize> = range(FLAGS + 1, 2);\n}\n\n/// Length in bytes of a message header\npub const HEADER_LENGTH: usize = ranges::RESERVED.end;\n\n/// A wrapper around a buffer that contains a [`Message`].\n///\n/// It provides getters and setters to access the different fields of\n/// the message safely. A message is made of a header and a payload:\n///\n/// ```no_rust\n///  0                   1                   2                   3\n///  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1\n/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                           signature                           +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                         participant_pk                        +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                         coordinator_pk                        +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +                                                               +\n/// |                                                               |\n/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+\n/// |                             length                            |\n/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+\n/// |      tag      |     flags     |          reserved             |\n/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+\n/// |                                                               |\n/// +                    payload (variable length)                  +\n/// |                                                               |\n/// ```\n///\n/// - `signature` contains the signature of the entire message\n/// - `participant_pk` contains the public key for verifying the\n///   signature\n/// - `coordinator_pk` is the coordinator public encryption key. It is\n///    embedded in the message for security reasons. See [_Donald\n///    T. Davis, \"Defective Sign & Encrypt in S/MIME, PKCS#7, MOSS,\n///    PEM, PGP, and XML.\", Proc. Usenix Tech. Conf. 2001 (Boston,\n///    Mass., June 25-30,\n///    2001)_](http://world.std.com/~dtd/sign_encrypt/sign_encrypt7.html)\n/// - `length` is the length in bytes of the _full_ message, _i.e._\n///   including the header. This is a 32 bits field so in theory,\n///   messages can be as big as 2^32 = 4,294,967,296 bytes.\n/// - `tag` indicates the type of message (sum, update, sum2 or\n///   multipart message)\n/// - the `flags` field currently supports a single flag, that\n///   indicates whether this is a multipart message\n///\n/// # Examples\n/// ## Reading a sum message\n///\n/// ```rust\n/// use std::convert::TryFrom;\n/// use xaynet_core::message::{Flags, MessageBuffer, Tag};\n///\n/// let mut bytes = vec![0x11; 64]; // message signature\n/// bytes.extend(vec![0x22; 32]); // participant public signing key\n/// bytes.extend(vec![0x33; 32]); // coordinator public encrypt key\n/// bytes.extend(&200_u32.to_be_bytes()); // Length field\n/// bytes.push(0x01); // tag (sum message)\n/// bytes.push(0x00); // flags (not a multipart message)\n/// bytes.extend(vec![0x00, 0x00]); // reserved\n///\n/// // Payload: a sum message contains a signature and an ephemeral public key\n/// bytes.extend(vec![0xaa; 32]); // signature\n/// bytes.extend(vec![0xbb; 32]); // public key\n///\n/// let buffer = MessageBuffer::new(&bytes).unwrap();\n/// assert_eq!(buffer.signature(), vec![0x11; 64].as_slice());\n/// assert_eq!(buffer.participant_pk(), vec![0x22; 32].as_slice());\n/// assert_eq!(buffer.coordinator_pk(), vec![0x33; 32].as_slice());\n/// assert_eq!(Tag::try_from(buffer.tag()).unwrap(), Tag::Sum);\n/// assert_eq!(Flags::try_from(buffer.flags()).unwrap(), Flags::empty());\n/// assert_eq!(\n///     buffer.payload(),\n///     [vec![0xaa; 32], vec![0xbb; 32]].concat().as_slice()\n/// );\n/// ```\n///\n/// ## Writing a sum message\n///\n/// ```rust\n/// use std::convert::TryFrom;\n/// use xaynet_core::message::{Flags, MessageBuffer, Tag};\n///\n/// let mut expected = vec![0x11; 64]; // message signature\n/// expected.extend(vec![0x22; 32]); // participant public signing key\n/// expected.extend(vec![0x33; 32]); // coordinator public signing key\n/// expected.extend(&200_u32.to_be_bytes()); // length field\n/// expected.push(0x01); // tag (sum message)\n/// expected.push(0x00); // flags (not a multipart message)\n/// expected.extend(vec![0x00, 0x00]); // reserved\n///\n/// // Payload: a sum message contains a signature and an ephemeral public key\n/// expected.extend(vec![0xaa; 32]); // signature\n/// expected.extend(vec![0xbb; 32]); // public key\n///\n/// let mut bytes = vec![0; expected.len()];\n/// let mut buffer = MessageBuffer::new_unchecked(&mut bytes);\n/// buffer\n///     .signature_mut()\n///     .copy_from_slice(vec![0x11; 64].as_slice());\n/// buffer\n///     .participant_pk_mut()\n///     .copy_from_slice(vec![0x22; 32].as_slice());\n/// buffer\n///     .coordinator_pk_mut()\n///     .copy_from_slice(vec![0x33; 32].as_slice());\n/// buffer.set_length(200 as u32);\n/// buffer.set_tag(Tag::Sum.into());\n/// buffer.set_flags(Flags::empty());\n/// buffer\n///     .payload_mut()\n///     .copy_from_slice([vec![0xaa; 32], vec![0xbb; 32]].concat().as_slice());\n/// assert_eq!(expected, bytes);\n/// ```\npub struct MessageBuffer<T> {\n    inner: T,\n}\n\nimpl<T: AsRef<[u8]>> MessageBuffer<T> {\n    pub fn inner(&self) -> &T {\n        &self.inner\n    }\n\n    pub fn as_ref(&self) -> MessageBuffer<&T> {\n        MessageBuffer::new_unchecked(self.inner())\n    }\n    /// Performs bound checks for the various message fields on `bytes` and returns a new\n    /// [`MessageBuffer`].\n    ///\n    /// # Errors\n    /// Fails if the `bytes` are smaller than a minimal-sized message buffer.\n    pub fn new(bytes: T) -> Result<Self, DecodeError> {\n        let buffer = Self { inner: bytes };\n        buffer\n            .check_buffer_length()\n            .context(\"not a valid MessageBuffer\")?;\n        Ok(buffer)\n    }\n\n    /// Returns a [`MessageBuffer`] without performing any bound checks.\n    ///\n    /// This means accessing the various fields may panic if the data\n    /// is invalid.\n    pub fn new_unchecked(bytes: T) -> Self {\n        Self { inner: bytes }\n    }\n\n    /// Performs bound checks to ensure the fields can be accessed\n    /// without panicking.\n    pub fn check_buffer_length(&self) -> Result<(), DecodeError> {\n        let len = self.inner.as_ref().len();\n        if len < HEADER_LENGTH {\n            return Err(anyhow!(\n                \"invalid buffer length: {} < {}\",\n                len,\n                HEADER_LENGTH\n            ));\n        }\n        let expected_len = self.length() as usize;\n        let actual_len = self.inner.as_ref().len();\n        if actual_len < expected_len {\n            return Err(anyhow!(\n                \"invalid message length: length field says {}, but buffer is {} bytes long\",\n                expected_len,\n                actual_len\n            ));\n        }\n        Ok(())\n    }\n\n    /// Gets the tag field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn tag(&self) -> u8 {\n        self.inner.as_ref()[ranges::TAG]\n    }\n\n    /// Gets the flags field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn flags(&self) -> Flags {\n        Flags::from_bits_truncate(self.inner.as_ref()[ranges::FLAGS])\n    }\n\n    /// Gets the length field\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn length(&self) -> u32 {\n        // Unwrapping is OK, as the slice is guaranteed to be 4 bytes\n        // long\n        u32::from_be_bytes(self.inner.as_ref()[ranges::LENGTH].try_into().unwrap())\n    }\n}\n\nimpl<'a, T: AsRef<[u8]> + ?Sized> MessageBuffer<&'a T> {\n    /// Gets the message signature field\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn signature(&self) -> &'a [u8] {\n        &self.inner.as_ref()[ranges::SIGNATURE]\n    }\n\n    /// Gets the participant public key field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn participant_pk(&self) -> &'a [u8] {\n        &self.inner.as_ref()[ranges::PARTICIPANT_PK]\n    }\n\n    /// Gets the coordinator public key field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn coordinator_pk(&self) -> &'a [u8] {\n        &self.inner.as_ref()[ranges::COORDINATOR_PK]\n    }\n\n    /// Gets the rest of the message.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn payload(&self) -> &'a [u8] {\n        &self.inner.as_ref()[HEADER_LENGTH..]\n    }\n\n    /// Parse the signature and public signing key, and check the\n    /// message signature.\n    pub fn check_signature(&self) -> Result<(), DecodeError> {\n        let signature = Signature::from_byte_slice(&self.signature())\n            .context(\"cannot parse the signature field\")?;\n        let participant_pk = PublicSigningKey::from_byte_slice(&self.participant_pk())\n            .context(\"cannot part the public key field\")?;\n\n        if participant_pk.verify_detached(&signature, self.signed_data()) {\n            Ok(())\n        } else {\n            Err(anyhow!(\"invalid message signature\"))\n        }\n    }\n\n    /// Return the portion of the message used to compute the\n    /// signature, ie the entire message except the signature field\n    /// itself.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn signed_data(&self) -> &'a [u8] {\n        let signed_data_range = ranges::SIGNATURE.end..self.length() as usize;\n        &self.inner.as_ref()[signed_data_range]\n    }\n}\n\nimpl<T: AsMut<[u8]> + AsRef<[u8]>> MessageBuffer<T> {\n    /// Sets the tag field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn set_tag(&mut self, value: u8) {\n        self.inner.as_mut()[ranges::TAG] = value;\n    }\n\n    /// Sets the flags field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn set_flags(&mut self, value: Flags) {\n        self.inner.as_mut()[ranges::FLAGS] = value.bits();\n    }\n\n    /// Sets the length field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn set_length(&mut self, value: u32) {\n        let bytes = value.to_be_bytes();\n        self.inner.as_mut()[ranges::LENGTH].copy_from_slice(&bytes[..]);\n    }\n\n    /// Gets a mutable reference to the message signature field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn signature_mut(&mut self) -> &mut [u8] {\n        &mut self.inner.as_mut()[ranges::SIGNATURE]\n    }\n\n    /// Gets a mutable reference to the participant public key field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn participant_pk_mut(&mut self) -> &mut [u8] {\n        &mut self.inner.as_mut()[ranges::PARTICIPANT_PK]\n    }\n\n    /// Gets a mutable reference to the coordinator public key field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn coordinator_pk_mut(&mut self) -> &mut [u8] {\n        &mut self.inner.as_mut()[ranges::COORDINATOR_PK]\n    }\n\n    /// Gets a mutable reference to the rest of the message.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn payload_mut(&mut self) -> &mut [u8] {\n        &mut self.inner.as_mut()[HEADER_LENGTH..]\n    }\n\n    /// Gets a mutable reference to the portion of the message used to\n    /// compute the signature, ie the entire message except the\n    /// signature field itself.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn signed_data_mut(&mut self) -> &mut [u8] {\n        let signed_data_range = ranges::SIGNATURE.end..self.length() as usize;\n        &mut self.inner.as_mut()[signed_data_range]\n    }\n}\n\nbitflags::bitflags! {\n    /// A bitmask that defines flags for a [`Message`].\n    pub struct Flags: u8 {\n        /// Indicates whether this message is a multipart message\n        const MULTIPART = 1 << 0;\n    }\n}\n\n#[derive(Copy, Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]\n/// A tag that indicates the type of the [`Message`].\npub enum Tag {\n    /// A tag for [`Sum`] messages\n    Sum,\n    /// A tag for [`Update`] messages\n    Update,\n    /// A tag for [`Sum2`] messages\n    Sum2,\n}\n\nimpl TryFrom<u8> for Tag {\n    type Error = DecodeError;\n\n    fn try_from(value: u8) -> Result<Self, Self::Error> {\n        Ok(match value {\n            1 => Tag::Sum,\n            2 => Tag::Update,\n            3 => Tag::Sum2,\n            _ => return Err(anyhow!(\"invalid tag {}\", value)),\n        })\n    }\n}\n\nimpl From<Tag> for u8 {\n    fn from(tag: Tag) -> Self {\n        match tag {\n            Tag::Sum => 1,\n            Tag::Update => 2,\n            Tag::Sum2 => 3,\n        }\n    }\n}\n\n#[derive(Debug, Eq, PartialEq, Clone)]\n/// A header common to all messages.\npub struct Message {\n    /// Message signature. This can be `None` if it hasn't been\n    /// computed yet.\n    pub signature: Option<Signature>,\n    /// The participant public key, used to verify the message\n    /// signature.\n    pub participant_pk: PublicSigningKey,\n    /// The coordinator public key\n    pub coordinator_pk: PublicEncryptKey,\n    /// Wether this is a multipart message\n    pub is_multipart: bool,\n    /// The type of message. This information is partially redundant\n    /// with the `payload` field. So when serializing the message,\n    /// this field is ignored if the payload is a [`Payload::Sum`],\n    /// [`Payload::Update`], or [`Payload::Sum2`]. However, it is\n    /// taken as is for [`Payload::Chunk`].\n    pub tag: Tag,\n    /// Message payload\n    pub payload: Payload,\n}\n\nimpl Message {\n    /// Create a new sum message with the given participant and\n    /// coordinator public keys.\n    pub fn new_sum(\n        participant_pk: PublicSigningKey,\n        coordinator_pk: PublicEncryptKey,\n        message: Sum,\n    ) -> Self {\n        Self {\n            signature: None,\n            participant_pk,\n            coordinator_pk,\n            is_multipart: false,\n            tag: Tag::Sum,\n            payload: message.into(),\n        }\n    }\n\n    /// Create a new sum2 message with the given participant and\n    /// coordinator public keys.\n    pub fn new_sum2(\n        participant_pk: PublicSigningKey,\n        coordinator_pk: PublicEncryptKey,\n        message: Sum2,\n    ) -> Self {\n        Self {\n            signature: None,\n            participant_pk,\n            coordinator_pk,\n            is_multipart: false,\n            tag: Tag::Sum2,\n            payload: message.into(),\n        }\n    }\n\n    /// Create a new update message with the given participant and\n    /// coordinator public keys.\n    pub fn new_update(\n        participant_pk: PublicSigningKey,\n        coordinator_pk: PublicEncryptKey,\n        message: Update,\n    ) -> Self {\n        Self {\n            signature: None,\n            participant_pk,\n            coordinator_pk,\n            is_multipart: false,\n            tag: Tag::Update,\n            payload: message.into(),\n        }\n    }\n\n    /// Create a new multipart message with the given participant and\n    /// coordinator public keys.\n    pub fn new_multipart(\n        participant_pk: PublicSigningKey,\n        coordinator_pk: PublicEncryptKey,\n        message: Chunk,\n        tag: Tag,\n    ) -> Self {\n        Self {\n            signature: None,\n            participant_pk,\n            coordinator_pk,\n            is_multipart: true,\n            tag,\n            payload: message.into(),\n        }\n    }\n\n    /// Parse the given message **without** verifying the\n    /// signature. If you need to check the signature, call\n    /// [`MessageBuffer.verify_signature`] before parsing the message.\n    pub fn from_byte_slice<T: AsRef<[u8]>>(buffer: &T) -> Result<Self, DecodeError> {\n        let reader = MessageBuffer::new(buffer.as_ref())?;\n        let signature =\n            Signature::from_byte_slice(&reader.signature()).context(\"failed to parse signature\")?;\n        let participant_pk = PublicSigningKey::from_byte_slice(&reader.participant_pk())\n            .context(\"failed to parse public key\")?;\n        let coordinator_pk = PublicEncryptKey::from_byte_slice(&reader.coordinator_pk())\n            .context(\"failed to parse public key\")?;\n\n        let tag = reader.tag().try_into()?;\n        let is_multipart = reader.flags().contains(Flags::MULTIPART);\n\n        let payload = if is_multipart {\n            Chunk::from_byte_slice(&reader.payload()).map(Into::into)\n        } else {\n            match tag {\n                Tag::Sum => Sum::from_byte_slice(&reader.payload()).map(Into::into),\n                Tag::Update => Update::from_byte_slice(&reader.payload()).map(Into::into),\n                Tag::Sum2 => Sum2::from_byte_slice(&reader.payload()).map(Into::into),\n            }\n        }\n        .context(\"failed to parse message payload\")?;\n\n        Ok(Self {\n            participant_pk,\n            coordinator_pk,\n            signature: Some(signature),\n            payload,\n            is_multipart,\n            tag,\n        })\n    }\n\n    /// Serialize this message. If the `signature` attribute is\n    /// `Some`, the signature will be directly inserted in the message\n    /// header. Otherwise it will be computed.\n    ///\n    /// # Panic\n    ///\n    /// This method panics if the given buffer is too small for the\n    /// message to fit.\n    pub fn to_bytes<T: AsMut<[u8]> + AsRef<[u8]> + ?Sized>(\n        &self,\n        buffer: &mut T,\n        sk: &SecretSigningKey,\n    ) {\n        let mut writer = MessageBuffer::new(buffer.as_mut()).unwrap();\n\n        self.participant_pk\n            .to_bytes(&mut writer.participant_pk_mut());\n        self.coordinator_pk\n            .to_bytes(&mut writer.coordinator_pk_mut());\n        let flags = if self.is_multipart {\n            Flags::MULTIPART\n        } else {\n            Flags::empty()\n        };\n        writer.set_flags(flags);\n        self.payload.to_bytes(&mut writer.payload_mut());\n        // Determine the tag from the payload type if\n        // possible. Otherwise, use the self.tag field.\n        let tag = match self.payload {\n            Payload::Sum(_) => Tag::Sum,\n            Payload::Update(_) => Tag::Update,\n            Payload::Sum2(_) => Tag::Sum2,\n            Payload::Chunk(_) => self.tag,\n        };\n        writer.set_tag(tag.into());\n        writer.set_length(self.buffer_length() as u32);\n        // insert the signature last. If the message contains one, use\n        // it. Otherwise compute it.\n        let signature = match self.signature {\n            Some(signature) => signature,\n            None => sk.sign_detached(writer.signed_data_mut()),\n        };\n        signature.to_bytes(&mut writer.signature_mut());\n    }\n\n    pub fn buffer_length(&self) -> usize {\n        self.payload.buffer_length() + HEADER_LENGTH\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use std::convert::TryFrom;\n\n    use super::*;\n    use crate::{\n        message::{Message, Tag},\n        testutils::messages as helpers,\n    };\n\n    fn sum_message() -> (Message, Vec<u8>) {\n        helpers::message(helpers::sum::payload)\n    }\n\n    #[test]\n    fn buffer_read() {\n        let bytes = sum_message().1;\n        let buffer = MessageBuffer::new(&bytes).unwrap();\n        assert_eq!(Tag::try_from(buffer.tag()).unwrap(), Tag::Sum);\n        assert_eq!(buffer.signature(), helpers::signature().1.as_slice());\n        assert_eq!(\n            buffer.participant_pk(),\n            helpers::participant_pk().1.as_slice()\n        );\n        assert_eq!(\n            buffer.coordinator_pk(),\n            helpers::coordinator_pk().1.as_slice()\n        );\n        assert_eq!(buffer.length() as usize, bytes.len());\n        assert_eq!(buffer.payload(), helpers::sum::payload().1.as_slice());\n    }\n\n    #[test]\n    fn buffer_write() {\n        let expected = sum_message().1;\n        let mut bytes = vec![0; expected.len()];\n        let mut buffer = MessageBuffer::new_unchecked(&mut bytes);\n\n        buffer\n            .signature_mut()\n            .copy_from_slice(helpers::signature().1.as_slice());\n        buffer\n            .participant_pk_mut()\n            .copy_from_slice(helpers::participant_pk().1.as_slice());\n        buffer\n            .coordinator_pk_mut()\n            .copy_from_slice(helpers::coordinator_pk().1.as_slice());\n        buffer.set_tag(Tag::Sum.into());\n        buffer.set_length(expected.len() as u32);\n        buffer\n            .payload_mut()\n            .copy_from_slice(helpers::sum::payload().1.as_slice());\n        assert_eq!(bytes, expected);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/message/mod.rs",
    "content": "//! The messages of the PET protocol.\n//!\n//! # The sum message\n//! The [`Sum`] message is an abstraction for the values which a sum participant communicates to\n//! XayNet during the sum phase of the PET protocol. It contains the following values:\n//! - The sum signature proves the eligibility of the participant for the sum task.\n//! - The ephemeral public key is used by update participants to encrypt mask seeds in the update\n//!   phase for the process of mask aggregation in the sum2 phase.\n//!\n//! # The update message\n//! The [`Update`] message is an abstraction for the values which an update participant communicates\n//! to XayNet during the update phase of the PET protocol. It contains the following values:\n//! - The sum signature proves the ineligibility of the participant for the sum task.\n//! - The update signature proves the eligibility of the participant for the update task.\n//! - The masked model is the encrypted local update to the global model, which is trained on the\n//!   local data of the update participant.\n//! - The local seed dictionary stores the encrypted mask seed, which generates the local mask for\n//!   the local model, which is encrypted by the ephemeral public keys of the sum participants.\n//!\n//! # The sum2 message\n//! The [`Sum2`] message is an abstraction for the values which a sum participant communicates to\n//! XayNet during the sum2 phase of the PET protocol. It contains the following values:\n//! - The sum signature proves the eligibility of the participant for the sum task.\n//! - The global mask is used by XayNet to unmask the aggregated global model.\n\n#[allow(clippy::module_inception)]\npub(crate) mod message;\npub(crate) mod payload;\npub(crate) mod traits;\npub(crate) mod utils;\n\npub use self::{\n    message::{\n        Flags,\n        Message,\n        MessageBuffer,\n        Tag,\n        HEADER_LENGTH as MESSAGE_HEADER_LENGTH,\n        SUM_COUNT_MIN,\n        UPDATE_COUNT_MIN,\n    },\n    payload::{\n        chunk::{Chunk, ChunkBuffer},\n        sum::{Sum, SumBuffer},\n        sum2::{Sum2, Sum2Buffer},\n        update::{Update, UpdateBuffer},\n        Payload,\n    },\n    traits::{FromBytes, LengthValueBuffer, ToBytes},\n};\n\n/// An error that signals a failure when trying to decrypt and parse a message.\n///\n/// This is kept generic on purpose to not reveal to the sender what specifically failed during\n/// decryption or parsing.\npub type DecodeError = anyhow::Error;\n"
  },
  {
    "path": "rust/xaynet-core/src/message/payload/chunk.rs",
    "content": "use std::convert::TryInto;\n\nuse anyhow::{anyhow, Context};\n\nuse crate::message::{\n    traits::{FromBytes, ToBytes},\n    DecodeError,\n};\n\npub(crate) mod ranges {\n    use crate::message::utils::range;\n    use std::ops::Range;\n\n    /// Byte range corresponding to the chunk ID in a chunk message\n    pub const ID: Range<usize> = range(0, 2);\n    /// Byte range corresponding to the message ID in a chunk message\n    pub const MESSAGE_ID: Range<usize> = range(ID.end, 2);\n    /// Byte range corresponding to the flags in a chunk message\n    pub const FLAGS: usize = MESSAGE_ID.end;\n    /// Byte range reserved for future use\n    pub const RESERVED: Range<usize> = range(FLAGS + 1, 3);\n}\n\n/// Length in bytes of a chunk message header\nconst HEADER_LENGTH: usize = ranges::RESERVED.end;\n\n/// A message chunk.\n#[derive(Eq, PartialEq, Debug, Clone)]\npub struct Chunk {\n    /// Chunk ID\n    pub id: u16,\n    /// ID of the message this chunk belongs to\n    pub message_id: u16,\n    /// `true` if this is the last chunk of the message, `false` otherwise\n    pub last: bool,\n    /// Data contained in this chunk.\n    pub data: Vec<u8>,\n}\n\nbitflags::bitflags! {\n    /// A bitmask that defines flags for a [`Chunk`].\n    pub struct Flags: u8 {\n        /// Indicates whether this message is the last chunk of a\n        /// multipart message\n        const LAST_CHUNK = 1 << 0;\n    }\n}\n\n/// ```no_rust\n///  0                   1                   2                   3\n///  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1\n/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+\n/// |                id             |           message_id          |\n/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+\n/// |     flags     |                    reserved                   |\n/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+\n/// |                                                               |\n/// +                       data (variable length)                  +\n/// |                                                               |\n/// ```\n///\n/// - `id`: ID of the chunk\n/// - `message_id`: ID of the message this chunk belong to\n/// - `flags`: currently the only supported flag indicates whether\n///   this is the last chunk or not\npub struct ChunkBuffer<T> {\n    inner: T,\n}\n\nimpl<T: AsRef<[u8]>> ChunkBuffer<T> {\n    /// Performs bound checks for the various message fields on `bytes` and returns a new\n    /// [`ChunkBuffer`].\n    ///\n    /// # Errors\n    /// Fails if the `bytes` are smaller than a minimal-sized message buffer.\n    pub fn new(bytes: T) -> Result<Self, DecodeError> {\n        let buffer = Self { inner: bytes };\n        buffer\n            .check_buffer_length()\n            .context(\"not a valid ChunkBuffer\")?;\n        Ok(buffer)\n    }\n\n    /// Returns a [`ChunkBuffer`] without performing any bound checks.\n    ///\n    /// This means accessing the various fields may panic if the data\n    /// is invalid.\n    pub fn new_unchecked(bytes: T) -> Self {\n        Self { inner: bytes }\n    }\n\n    /// Performs bound checks to ensure the fields can be accessed\n    /// without panicking.\n    pub fn check_buffer_length(&self) -> Result<(), DecodeError> {\n        let len = self.inner.as_ref().len();\n        if len < HEADER_LENGTH {\n            return Err(anyhow!(\n                \"invalid buffer length: {} < {}\",\n                len,\n                HEADER_LENGTH\n            ));\n        }\n        Ok(())\n    }\n\n    /// Gets the flags field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn flags(&self) -> Flags {\n        Flags::from_bits_truncate(self.inner.as_ref()[ranges::FLAGS])\n    }\n\n    /// Gets the chunk ID field\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn id(&self) -> u16 {\n        // Unwrapping is OK, as the slice is guaranteed to be 4 bytes\n        // long\n        u16::from_be_bytes(self.inner.as_ref()[ranges::ID].try_into().unwrap())\n    }\n\n    /// Gets the message ID field\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn message_id(&self) -> u16 {\n        // Unwrapping is OK, as the slice is guaranteed to be 4 bytes\n        // long\n        u16::from_be_bytes(self.inner.as_ref()[ranges::MESSAGE_ID].try_into().unwrap())\n    }\n}\n\nimpl<'a, T: AsRef<[u8]> + ?Sized> ChunkBuffer<&'a T> {\n    /// Gets the rest of the message.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn payload(&self) -> &'a [u8] {\n        &self.inner.as_ref()[HEADER_LENGTH..]\n    }\n}\n\nimpl<T: AsMut<[u8]> + AsRef<[u8]>> ChunkBuffer<T> {\n    /// Sets the flags field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn set_flags(&mut self, value: Flags) {\n        self.inner.as_mut()[ranges::FLAGS] = value.bits();\n    }\n\n    /// Sets the chunk ID field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn set_id(&mut self, value: u16) {\n        let bytes = value.to_be_bytes();\n        self.inner.as_mut()[ranges::ID].copy_from_slice(&bytes[..]);\n    }\n\n    /// Sets the message ID field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn set_message_id(&mut self, value: u16) {\n        let bytes = value.to_be_bytes();\n        self.inner.as_mut()[ranges::MESSAGE_ID].copy_from_slice(&bytes[..]);\n    }\n\n    /// Gets a mutable reference to the rest of the message.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn payload_mut(&mut self) -> &mut [u8] {\n        &mut self.inner.as_mut()[HEADER_LENGTH..]\n    }\n}\n\nimpl FromBytes for Chunk {\n    fn from_byte_slice<T: AsRef<[u8]>>(buffer: &T) -> Result<Self, DecodeError> {\n        let reader = ChunkBuffer::new(buffer.as_ref()).context(\"Invalid chunk buffer\")?;\n        Ok(Self {\n            last: reader.flags().contains(Flags::LAST_CHUNK),\n            id: reader.id(),\n            message_id: reader.message_id(),\n            data: reader.payload().to_vec(),\n        })\n    }\n\n    fn from_byte_stream<I: Iterator<Item = u8> + ExactSizeIterator>(\n        iter: &mut I,\n    ) -> Result<Self, DecodeError> {\n        if iter.len() < HEADER_LENGTH {\n            return Err(anyhow!(\"byte stream exhausted\"));\n        }\n        let id = u16::from_byte_stream(iter).context(\"cannot parse id\")?;\n        let message_id = u16::from_byte_stream(iter).context(\"cannot parse message id\")?;\n        let flags = Flags::from_bits_truncate(iter.next().unwrap());\n        let data: Vec<u8> = iter.skip(3).collect();\n        Ok(Self {\n            id,\n            message_id,\n            data,\n            last: flags.contains(Flags::LAST_CHUNK),\n        })\n    }\n}\n\nimpl ToBytes for Chunk {\n    fn buffer_length(&self) -> usize {\n        HEADER_LENGTH + self.data.len()\n    }\n\n    fn to_bytes<T: AsMut<[u8]> + AsRef<[u8]>>(&self, buffer: &mut T) {\n        let mut writer = ChunkBuffer::new(buffer.as_mut()).unwrap();\n        let flags = if self.last {\n            Flags::LAST_CHUNK\n        } else {\n            Flags::empty()\n        };\n        writer.set_flags(flags);\n        writer.set_id(self.id);\n        writer.set_message_id(self.message_id);\n        writer.payload_mut()[..self.data.len()].copy_from_slice(self.data.as_slice());\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    fn flags() -> (u8, Flags) {\n        let flags = Flags::LAST_CHUNK;\n        (flags.bits(), flags)\n    }\n\n    fn id() -> (Vec<u8>, u16) {\n        let value = 0xdddd_u16;\n        (value.to_be_bytes().to_vec(), value)\n    }\n\n    fn message_id() -> (Vec<u8>, u16) {\n        let value = 0xeeee_u16;\n        (value.to_be_bytes().to_vec(), value)\n    }\n\n    fn data() -> Vec<u8> {\n        vec![0xff; 10]\n    }\n\n    fn chunk() -> (Vec<u8>, Chunk) {\n        let mut bytes = vec![];\n        bytes.extend(id().0);\n        bytes.extend(message_id().0);\n        bytes.push(flags().0);\n        bytes.extend(vec![0x00, 0x00, 0x00]);\n        bytes.extend(data());\n\n        let message = Chunk {\n            id: id().1,\n            message_id: message_id().1,\n            last: flags().1.contains(Flags::LAST_CHUNK),\n            data: data(),\n        };\n        (bytes, message)\n    }\n\n    #[test]\n    fn buffer_read() {\n        let bytes = chunk().0;\n        let buffer = ChunkBuffer::new(&bytes).unwrap();\n        assert_eq!(buffer.id(), id().1);\n        assert_eq!(buffer.message_id(), message_id().1);\n        assert_eq!(buffer.flags(), flags().1);\n        assert_eq!(buffer.payload(), &data()[..]);\n    }\n\n    #[test]\n    fn stream_parse() {\n        let (bytes, expected) = chunk();\n        let actual = Chunk::from_byte_stream(&mut bytes.into_iter()).unwrap();\n        assert_eq!(actual, expected);\n    }\n\n    #[test]\n    fn buffer_write() {\n        let expected = chunk().0;\n        let mut bytes = vec![0; expected.len()];\n        let mut buffer = ChunkBuffer::new_unchecked(&mut bytes);\n\n        buffer.set_id(id().1);\n        buffer.set_message_id(message_id().1);\n        buffer.set_flags(flags().1);\n        buffer.payload_mut().copy_from_slice(data().as_slice());\n        assert_eq!(bytes, expected);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/message/payload/mod.rs",
    "content": "//! Message payloads.\n//!\n//! See the [message module] documentation since this is a private module anyways.\n//!\n//! [message module]: crate::message\n\npub(crate) mod chunk;\npub(crate) mod sum;\npub(crate) mod sum2;\npub(crate) mod update;\n\nuse derive_more::From;\n\nuse crate::message::{\n    payload::{chunk::Chunk, sum::Sum, sum2::Sum2, update::Update},\n    traits::ToBytes,\n};\n\n/// The payload of a [`Message`].\n///\n/// [`Message`]: crate::message::Message\n#[derive(From, Eq, PartialEq, Debug, Clone)]\npub enum Payload {\n    /// The payload of a [`Sum`] message.\n    Sum(Sum),\n    /// The payload of an [`Update`] message.\n    Update(Update),\n    /// The payload of a [`Sum2`] message.\n    Sum2(Sum2),\n    /// The payload of a [`Chunk`] message.\n    Chunk(Chunk),\n}\n\nimpl Payload {\n    pub fn is_sum(&self) -> bool {\n        matches!(self, Self::Sum(_))\n    }\n\n    pub fn is_update(&self) -> bool {\n        matches!(self, Self::Update(_))\n    }\n\n    pub fn is_sum2(&self) -> bool {\n        matches!(self, Self::Sum2(_))\n    }\n\n    pub fn is_chunk(&self) -> bool {\n        matches!(self, Self::Chunk(_))\n    }\n}\n\nimpl ToBytes for Payload {\n    fn buffer_length(&self) -> usize {\n        match self {\n            Payload::Sum(m) => m.buffer_length(),\n            Payload::Sum2(m) => m.buffer_length(),\n            Payload::Update(m) => m.buffer_length(),\n            Payload::Chunk(m) => m.buffer_length(),\n        }\n    }\n\n    fn to_bytes<T: AsMut<[u8]> + AsRef<[u8]>>(&self, buffer: &mut T) {\n        match self {\n            Payload::Sum(m) => m.to_bytes(buffer),\n            Payload::Sum2(m) => m.to_bytes(buffer),\n            Payload::Update(m) => m.to_bytes(buffer),\n            Payload::Chunk(m) => m.to_bytes(buffer),\n        }\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/message/payload/sum.rs",
    "content": "//! Sum message payloads.\n//!\n//! See the [message module] documentation since this is a private module anyways.\n//!\n//! [message module]: crate::message\n\nuse std::ops::Range;\n\nuse anyhow::{anyhow, Context};\n\nuse crate::{\n    crypto::ByteObject,\n    message::{\n        traits::{FromBytes, ToBytes},\n        utils::range,\n        DecodeError,\n    },\n    ParticipantTaskSignature,\n    SumParticipantEphemeralPublicKey,\n};\n\nconst SUM_SIGNATURE_RANGE: Range<usize> = range(0, ParticipantTaskSignature::LENGTH);\nconst EPHM_PK_RANGE: Range<usize> = range(\n    SUM_SIGNATURE_RANGE.end,\n    SumParticipantEphemeralPublicKey::LENGTH,\n);\n\n#[derive(Clone, Debug, Eq, PartialEq, Hash)]\n/// A wrapper around a buffer that contains a [`Sum`] message.\n///\n/// It provides getters and setters to access the different fields of the message safely.\n///\n/// # Examples\n/// ## Decoding a sum message\n///\n/// ```rust\n/// # use xaynet_core::message::SumBuffer;\n/// let sum_signature = vec![0x11; 64];\n/// let ephm_pk = vec![0x22; 32];\n/// let bytes = [sum_signature.as_slice(), ephm_pk.as_slice()].concat();\n/// let buffer = SumBuffer::new(&bytes).unwrap();\n/// assert_eq!(buffer.sum_signature(), sum_signature.as_slice());\n/// assert_eq!(buffer.ephm_pk(), ephm_pk.as_slice());\n/// ```\n///\n/// ## Encoding a sum message\n///\n/// ```rust\n/// # use xaynet_core::message::SumBuffer;\n/// let sum_signature = vec![0x11; 64];\n/// let ephm_pk = vec![0x22; 32];\n/// let mut storage = vec![0xff; 96];\n/// let mut buffer = SumBuffer::new_unchecked(&mut storage);\n/// buffer\n///     .sum_signature_mut()\n///     .copy_from_slice(&sum_signature[..]);\n/// buffer.ephm_pk_mut().copy_from_slice(&ephm_pk[..]);\n/// assert_eq!(&storage[..64], sum_signature.as_slice());\n/// assert_eq!(&storage[64..], ephm_pk.as_slice());\n/// ```\npub struct SumBuffer<T> {\n    inner: T,\n}\n\nimpl<T: AsRef<[u8]>> SumBuffer<T> {\n    /// Performs bound checks for the various message fields on `bytes` and returns a new\n    /// [`SumBuffer`].\n    ///\n    /// # Errors\n    /// Fails if the `bytes` are smaller than a minimal-sized sum message buffer.\n    pub fn new(bytes: T) -> Result<Self, DecodeError> {\n        let buffer = Self { inner: bytes };\n        buffer\n            .check_buffer_length()\n            .context(\"not a valid SumBuffer\")?;\n        Ok(buffer)\n    }\n\n    /// Returns a [`SumBuffer`] without performing any bound checks.\n    ///\n    /// This means accessing the various fields may panic if the data is invalid.\n    pub fn new_unchecked(bytes: T) -> Self {\n        Self { inner: bytes }\n    }\n\n    /// Performs bound checks to ensure the fields can be accessed without panicking.\n    pub fn check_buffer_length(&self) -> Result<(), DecodeError> {\n        let len = self.inner.as_ref().len();\n        if len < EPHM_PK_RANGE.end {\n            Err(anyhow!(\n                \"invalid buffer length: {} < {}\",\n                len,\n                EPHM_PK_RANGE.end\n            ))\n        } else {\n            Ok(())\n        }\n    }\n}\n\nimpl<T: AsMut<[u8]>> SumBuffer<T> {\n    /// Gets a mutable reference to the sum participant ephemeral public key field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn ephm_pk_mut(&mut self) -> &mut [u8] {\n        &mut self.inner.as_mut()[EPHM_PK_RANGE]\n    }\n\n    /// Gets a mutable reference to the sum signature field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn sum_signature_mut(&mut self) -> &mut [u8] {\n        &mut self.inner.as_mut()[SUM_SIGNATURE_RANGE]\n    }\n}\n\nimpl<'a, T: AsRef<[u8]> + ?Sized> SumBuffer<&'a T> {\n    /// Gets a reference to the sum participant ephemeral public key field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn ephm_pk(&self) -> &'a [u8] {\n        &self.inner.as_ref()[EPHM_PK_RANGE]\n    }\n\n    /// Gets a reference to the sum signature field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn sum_signature(&self) -> &'a [u8] {\n        &self.inner.as_ref()[SUM_SIGNATURE_RANGE]\n    }\n}\n\n#[derive(Debug, Eq, PartialEq, Clone)]\n/// A high level representation of a sum message.\n///\n/// These messages are sent by sum participants during the sum phase.\n///\n/// # Examples\n/// ## Decoding a message\n///\n/// ```rust\n/// # use xaynet_core::{crypto::ByteObject, message::{FromBytes, Sum}, ParticipantTaskSignature, SumParticipantEphemeralPublicKey};\n/// let signature = vec![0x11; 64];\n/// let ephm_pk = vec![0x22; 32];\n/// let bytes = [signature.as_slice(), ephm_pk.as_slice()].concat();\n/// let parsed = Sum::from_byte_slice(&bytes).unwrap();\n/// let expected = Sum{\n///     sum_signature: ParticipantTaskSignature::from_slice(&signature[..]).unwrap(),\n///     ephm_pk: SumParticipantEphemeralPublicKey::from_slice(&ephm_pk[..]).unwrap(),\n/// };\n/// assert_eq!(parsed, expected);\n/// ```\n///\n/// ## Encoding a message\n///\n/// ```rust\n/// # use xaynet_core::{crypto::ByteObject, message::{ToBytes, Sum}, ParticipantTaskSignature, SumParticipantEphemeralPublicKey};\n/// let sum_signature = ParticipantTaskSignature::from_slice(vec![0x11; 64].as_slice()).unwrap();\n/// let ephm_pk = SumParticipantEphemeralPublicKey::from_slice(vec![0x22; 32].as_slice()).unwrap();\n/// let msg = Sum {\n///     sum_signature,\n///     ephm_pk,\n/// };\n/// // we need a 96 bytes long buffer to serialize that message\n/// assert_eq!(msg.buffer_length(), 96);\n/// // create a buffer with enough space and encode the message\n/// let mut buf = vec![0xff; 96];\n/// msg.to_bytes(&mut buf);\n///\n/// assert_eq!(buf, [vec![0x11; 64].as_slice(), vec![0x22; 32].as_slice()].concat());\n/// ```\npub struct Sum {\n    /// The signature of the round seed and the word \"sum\".\n    ///\n    /// This is used to determine whether a participant is selected for the sum task.\n    pub sum_signature: ParticipantTaskSignature,\n    /// An ephemeral public key generated by a sum participant for the current round.\n    pub ephm_pk: SumParticipantEphemeralPublicKey,\n}\n\nimpl ToBytes for Sum {\n    fn buffer_length(&self) -> usize {\n        EPHM_PK_RANGE.end\n    }\n\n    fn to_bytes<T: AsMut<[u8]> + AsRef<[u8]>>(&self, buffer: &mut T) {\n        let mut writer = SumBuffer::new(buffer.as_mut()).unwrap();\n        self.sum_signature.to_bytes(&mut writer.sum_signature_mut());\n        self.ephm_pk.to_bytes(&mut writer.ephm_pk_mut());\n    }\n}\n\nimpl FromBytes for Sum {\n    fn from_byte_slice<T: AsRef<[u8]>>(buffer: &T) -> Result<Self, DecodeError> {\n        let reader = SumBuffer::new(buffer.as_ref())?;\n\n        let sum_signature = ParticipantTaskSignature::from_byte_slice(&reader.sum_signature())\n            .context(\"invalid sum signature\")?;\n\n        let ephm_pk = SumParticipantEphemeralPublicKey::from_byte_slice(&reader.ephm_pk())\n            .context(\"invalid ephemeral public key\")?;\n\n        Ok(Self {\n            sum_signature,\n            ephm_pk,\n        })\n    }\n\n    fn from_byte_stream<I: Iterator<Item = u8> + ExactSizeIterator>(\n        iter: &mut I,\n    ) -> Result<Self, DecodeError> {\n        let sum_signature =\n            ParticipantTaskSignature::from_byte_stream(iter).context(\"invalid sum signature\")?;\n        let ephm_pk = SumParticipantEphemeralPublicKey::from_byte_stream(iter)\n            .context(\"invalid ephemeral public key\")?;\n\n        Ok(Self {\n            sum_signature,\n            ephm_pk,\n        })\n    }\n}\n\n#[cfg(test)]\npub(in crate::message) mod tests {\n    use super::*;\n    use crate::testutils::messages::sum as helpers;\n\n    #[test]\n    fn buffer_read() {\n        let bytes = helpers::payload().1;\n        let buffer = SumBuffer::new(&bytes).unwrap();\n        assert_eq!(buffer.sum_signature(), &helpers::sum_task_signature().1[..]);\n        assert_eq!(buffer.ephm_pk(), &helpers::ephm_pk().1[..]);\n    }\n\n    #[test]\n    fn buffer_read_invalid() {\n        assert!(SumBuffer::new(&helpers::payload().1[1..]).is_err());\n    }\n\n    #[test]\n    fn buffer_write() {\n        let mut buffer = vec![0xff; EPHM_PK_RANGE.end];\n        let mut writer = SumBuffer::new_unchecked(&mut buffer);\n        writer\n            .sum_signature_mut()\n            .copy_from_slice(helpers::sum_task_signature().1.as_slice());\n        writer\n            .ephm_pk_mut()\n            .copy_from_slice(helpers::ephm_pk().1.as_slice());\n    }\n\n    #[test]\n    fn encode() {\n        let (sum, bytes) = helpers::payload();\n        assert_eq!(sum.buffer_length(), bytes.len());\n\n        let mut buf = vec![0xff; sum.buffer_length()];\n        sum.to_bytes(&mut buf);\n        assert_eq!(buf, bytes);\n    }\n\n    #[test]\n    fn decode() {\n        let (expected, bytes) = helpers::payload();\n        let parsed = Sum::from_byte_slice(&bytes).unwrap();\n        assert_eq!(parsed, expected);\n    }\n\n    #[test]\n    fn stream_parse() {\n        let (expected, bytes) = helpers::payload();\n        let parsed = Sum::from_byte_stream(&mut bytes.into_iter()).unwrap();\n        assert_eq!(parsed, expected);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/message/payload/sum2.rs",
    "content": "//! Sum2 message payloads.\n//!\n//! See the [message module] documentation since this is a private module anyways.\n//!\n//! [message module]: crate::message\n\nuse std::ops::Range;\n\nuse anyhow::{anyhow, Context};\n\nuse crate::{\n    crypto::ByteObject,\n    mask::object::{serialization::MaskObjectBuffer, MaskObject},\n    message::{\n        traits::{FromBytes, ToBytes},\n        utils::range,\n        DecodeError,\n    },\n    ParticipantTaskSignature,\n};\n\nconst SUM_SIGNATURE_RANGE: Range<usize> = range(0, ParticipantTaskSignature::LENGTH);\n\n#[derive(Clone, Debug, Eq, PartialEq, Hash)]\n/// A wrapper around a buffer that contains a [`Sum2`] message.\n///\n/// It provides getters and setters to access the different fields of the message safely.\npub struct Sum2Buffer<T> {\n    inner: T,\n}\n\nimpl<T: AsRef<[u8]>> Sum2Buffer<T> {\n    /// Performs bound checks for the various message fields on `bytes` and returns a new\n    /// [`Sum2Buffer`].\n    ///\n    /// # Errors\n    /// Fails if the `bytes` are smaller than a minimal-sized sum2 message buffer.\n    pub fn new(bytes: T) -> Result<Self, DecodeError> {\n        let buffer = Self { inner: bytes };\n        buffer\n            .check_buffer_length()\n            .context(\"not a valid Sum2Buffer\")?;\n        Ok(buffer)\n    }\n\n    /// Returns a `Sum2Buffer` with the given `bytes` without performing bound checks.\n    ///\n    /// This means that accessing the message fields may panic.\n    pub fn new_unchecked(bytes: T) -> Self {\n        Self { inner: bytes }\n    }\n\n    /// Performs bound checks for the various message fields on this buffer.\n    pub fn check_buffer_length(&self) -> Result<(), DecodeError> {\n        let len = self.inner.as_ref().len();\n        if len < SUM_SIGNATURE_RANGE.end {\n            return Err(anyhow!(\n                \"invalid buffer length: {} < {}\",\n                len,\n                SUM_SIGNATURE_RANGE.end\n            ));\n        }\n\n        // check the length of the mask field\n        MaskObjectBuffer::new(&self.inner.as_ref()[self.model_mask_offset()..])\n            .context(\"invalid mask field\")?;\n\n        Ok(())\n    }\n\n    /// Gets the offset of the model mask field.\n    fn model_mask_offset(&self) -> usize {\n        SUM_SIGNATURE_RANGE.end\n    }\n}\n\nimpl<T: AsRef<[u8]> + AsMut<[u8]>> Sum2Buffer<T> {\n    /// Gets a mutable reference to the sum signature field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn sum_signature_mut(&mut self) -> &mut [u8] {\n        &mut self.inner.as_mut()[SUM_SIGNATURE_RANGE]\n    }\n\n    /// Gets a mutable reference to the model mask field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn model_mask_mut(&mut self) -> &mut [u8] {\n        let offset = self.model_mask_offset();\n        &mut self.inner.as_mut()[offset..]\n    }\n}\n\nimpl<'a, T: AsRef<[u8]> + ?Sized> Sum2Buffer<&'a T> {\n    /// Gets a reference to the sum signature field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn sum_signature(&self) -> &'a [u8] {\n        &self.inner.as_ref()[SUM_SIGNATURE_RANGE]\n    }\n\n    /// Gets a reference to the model mask field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn model_mask(&self) -> &'a [u8] {\n        let offset = self.model_mask_offset();\n        &self.inner.as_ref()[offset..]\n    }\n}\n\n#[derive(Eq, PartialEq, Clone, Debug)]\n/// A high level representation of a sum2 message.\n///\n/// These messages are sent by sum participants during the sum2 phase.\npub struct Sum2 {\n    /// The signature of the round seed and the word \"sum\".\n    ///\n    /// This is used to determine whether a participant is selected for the sum task.\n    pub sum_signature: ParticipantTaskSignature,\n\n    /// A model mask computed by the participant.\n    pub model_mask: MaskObject,\n}\n\nimpl ToBytes for Sum2 {\n    fn buffer_length(&self) -> usize {\n        SUM_SIGNATURE_RANGE.end + self.model_mask.buffer_length()\n    }\n\n    fn to_bytes<T: AsMut<[u8]> + AsRef<[u8]>>(&self, buffer: &mut T) {\n        let mut writer = Sum2Buffer::new_unchecked(buffer.as_mut());\n        self.sum_signature.to_bytes(&mut writer.sum_signature_mut());\n        self.model_mask.to_bytes(&mut writer.model_mask_mut());\n    }\n}\n\nimpl FromBytes for Sum2 {\n    fn from_byte_slice<T: AsRef<[u8]>>(buffer: &T) -> Result<Self, DecodeError> {\n        let reader = Sum2Buffer::new(buffer.as_ref())?;\n        Ok(Self {\n            sum_signature: ParticipantTaskSignature::from_byte_slice(&reader.sum_signature())\n                .context(\"invalid sum signature\")?,\n            model_mask: MaskObject::from_byte_slice(&reader.model_mask())\n                .context(\"invalid mask\")?,\n        })\n    }\n\n    fn from_byte_stream<I: Iterator<Item = u8> + ExactSizeIterator>(\n        iter: &mut I,\n    ) -> Result<Self, DecodeError> {\n        Ok(Self {\n            sum_signature: ParticipantTaskSignature::from_byte_stream(iter)\n                .context(\"invalid sum signature\")?,\n            model_mask: MaskObject::from_byte_stream(iter).context(\"invalid mask object\")?,\n        })\n    }\n}\n\n#[cfg(test)]\npub mod tests {\n    use crate::testutils::messages::sum2 as helpers;\n\n    use super::*;\n\n    #[test]\n    fn buffer_read() {\n        let bytes = helpers::payload().1;\n        let buffer = Sum2Buffer::new(&bytes).unwrap();\n        assert_eq!(buffer.sum_signature(), &helpers::sum_task_signature().1[..]);\n\n        let expected_mask = helpers::mask_object().1;\n        let expected_length = expected_mask.len();\n        let actual_mask = &buffer.model_mask()[..expected_length];\n        assert_eq!(actual_mask, expected_mask);\n    }\n\n    #[test]\n    fn buffer_write() {\n        // length = 64 (signature) + 42 (mask) = 106\n        let mut bytes = vec![0xff; 106];\n        {\n            let mut buffer = Sum2Buffer::new_unchecked(&mut bytes);\n            buffer\n                .sum_signature_mut()\n                .copy_from_slice(&helpers::sum_task_signature().1[..]);\n            let mask = helpers::mask_object().1;\n            buffer.model_mask_mut()[..mask.len()].copy_from_slice(&mask[..]);\n        }\n        assert_eq!(&bytes[..], &helpers::payload().1[..]);\n    }\n\n    #[test]\n    fn encode() {\n        let (sum2, bytes) = helpers::payload();\n        assert_eq!(sum2.buffer_length(), bytes.len());\n\n        let mut buf = vec![0xff; sum2.buffer_length()];\n        sum2.to_bytes(&mut buf);\n        assert_eq!(buf, bytes);\n    }\n\n    #[test]\n    fn decode() {\n        let (sum2, bytes) = helpers::payload();\n        let parsed = Sum2::from_byte_slice(&bytes).unwrap();\n        assert_eq!(parsed, sum2);\n    }\n\n    #[test]\n    fn stream_parse() {\n        let (sum2, bytes) = helpers::payload();\n        let parsed = Sum2::from_byte_stream(&mut bytes.into_iter()).unwrap();\n        assert_eq!(parsed, sum2);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/message/payload/update.rs",
    "content": "//! Update message payloads.\n//!\n//! See the [message module] documentation since this is a private module anyways.\n//!\n//! [message module]: crate::message\n\nuse std::ops::Range;\n\nuse anyhow::{anyhow, Context};\n\nuse crate::{\n    crypto::ByteObject,\n    mask::object::{serialization::MaskObjectBuffer, MaskObject},\n    message::{\n        traits::{FromBytes, LengthValueBuffer, ToBytes},\n        utils::range,\n        DecodeError,\n    },\n    LocalSeedDict,\n    ParticipantTaskSignature,\n};\n\nconst SUM_SIGNATURE_RANGE: Range<usize> = range(0, ParticipantTaskSignature::LENGTH);\nconst UPDATE_SIGNATURE_RANGE: Range<usize> =\n    range(SUM_SIGNATURE_RANGE.end, ParticipantTaskSignature::LENGTH);\n\n#[derive(Clone, Debug)]\n/// A wrapper around a buffer that contains an [`Update`] message.\n///\n/// It provides getters and setters to access the different fields of the message safely.\npub struct UpdateBuffer<T> {\n    inner: T,\n}\n\nimpl<T: AsRef<[u8]>> UpdateBuffer<T> {\n    /// Performs bound checks for the various message fields on `bytes` and returns a new\n    /// [`UpdateBuffer`].\n    ///\n    /// # Errors\n    /// Fails if the `bytes` are smaller than a minimal-sized update message buffer.\n    pub fn new(bytes: T) -> Result<Self, DecodeError> {\n        let buffer = Self { inner: bytes };\n        buffer\n            .check_buffer_length()\n            .context(\"invalid UpdateBuffer\")?;\n        Ok(buffer)\n    }\n\n    /// Returns an [`UpdateBuffer`] without performing any bound checks.\n    ///\n    /// This means accessing the various fields may panic if the data is invalid.\n    pub fn new_unchecked(bytes: T) -> Self {\n        Self { inner: bytes }\n    }\n\n    /// Performs bound checks to ensure the fields can be accessed without panicking.\n    pub fn check_buffer_length(&self) -> Result<(), DecodeError> {\n        let len = self.inner.as_ref().len();\n        // First, check the fixed size portion of the\n        // header. UPDATE_SIGNATURE_RANGE is the last field\n        if len < UPDATE_SIGNATURE_RANGE.end {\n            return Err(anyhow!(\n                \"invalid buffer length: {} < {}\",\n                len,\n                UPDATE_SIGNATURE_RANGE.end\n            ));\n        }\n\n        // Check length of the masked object field\n        MaskObjectBuffer::new(&self.inner.as_ref()[self.masked_model_offset()..])\n            .context(\"invalid masked object field\")?;\n\n        // Check the length of the local seed dictionary field\n        let _ = LengthValueBuffer::new(&self.inner.as_ref()[self.local_seed_dict_offset()..])\n            .context(\"invalid local seed dictionary length\")?;\n\n        Ok(())\n    }\n\n    /// Gets the offset of the masked model field.\n    fn masked_model_offset(&self) -> usize {\n        UPDATE_SIGNATURE_RANGE.end\n    }\n\n    /// Gets the offset of the local seed dictionary field.\n    ///\n    /// # Panics\n    /// Computing the offset may panic if the buffer has not been checked before.\n    fn local_seed_dict_offset(&self) -> usize {\n        let masked_model =\n            MaskObjectBuffer::new_unchecked(&self.inner.as_ref()[self.masked_model_offset()..]);\n        self.masked_model_offset() + masked_model.len()\n    }\n}\n\nimpl<'a, T: AsRef<[u8]> + ?Sized> UpdateBuffer<&'a T> {\n    /// Gets the sum signature field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn sum_signature(&self) -> &'a [u8] {\n        &self.inner.as_ref()[SUM_SIGNATURE_RANGE]\n    }\n\n    /// Gets the update signature field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn update_signature(&self) -> &'a [u8] {\n        &self.inner.as_ref()[UPDATE_SIGNATURE_RANGE]\n    }\n\n    /// Gets a slice that starts at the beginning of the masked model field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn masked_model(&self) -> &'a [u8] {\n        let offset = self.masked_model_offset();\n        &self.inner.as_ref()[offset..]\n    }\n\n    /// Gets a slice that starts at the beginning og the local seed dictionary field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn local_seed_dict(&self) -> &'a [u8] {\n        let offset = self.local_seed_dict_offset();\n        &self.inner.as_ref()[offset..]\n    }\n}\n\nimpl<T: AsRef<[u8]> + AsMut<[u8]>> UpdateBuffer<T> {\n    /// Gets a mutable reference to the sum signature field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn sum_signature_mut(&mut self) -> &mut [u8] {\n        &mut self.inner.as_mut()[SUM_SIGNATURE_RANGE]\n    }\n\n    /// Gets a mutable reference to the update signature field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn update_signature_mut(&mut self) -> &mut [u8] {\n        &mut self.inner.as_mut()[UPDATE_SIGNATURE_RANGE]\n    }\n\n    /// Gets a mutable slice that starts at the beginning of the masked model field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn masked_model_mut(&mut self) -> &mut [u8] {\n        let offset = self.masked_model_offset();\n        &mut self.inner.as_mut()[offset..]\n    }\n\n    /// Gets a mutable slice that starts at the beginning of the local seed dictionary field.\n    ///\n    /// # Panics\n    /// Accessing the field may panic if the buffer has not been checked before.\n    pub fn local_seed_dict_mut(&mut self) -> &mut [u8] {\n        let offset = self.local_seed_dict_offset();\n        &mut self.inner.as_mut()[offset..]\n    }\n}\n\n#[derive(Debug, Eq, PartialEq, Clone)]\n/// A high level representation of an update message.\n///\n/// These messages are sent by update participants during the update phase.\npub struct Update {\n    /// The signature of the round seed and the word \"sum\".\n    ///\n    /// This is used to determine whether a participant is selected for the sum task.\n    pub sum_signature: ParticipantTaskSignature,\n    /// Signature of the round seed and the word \"update\".\n    ///\n    /// This is used to determine whether a participant is selected for the update task.\n    pub update_signature: ParticipantTaskSignature,\n    /// A model trained by an update participant.\n    ///\n    /// The model is masked with randomness derived from the participant seed.\n    pub masked_model: MaskObject,\n    /// A dictionary that contains the seed used to mask `masked_model`.\n    ///\n    /// The seed is encrypted with the ephemeral public key of each sum participant.\n    pub local_seed_dict: LocalSeedDict,\n}\n\nimpl ToBytes for Update {\n    fn buffer_length(&self) -> usize {\n        UPDATE_SIGNATURE_RANGE.end\n            + self.masked_model.buffer_length()\n            + self.local_seed_dict.buffer_length()\n    }\n\n    fn to_bytes<T: AsMut<[u8]> + AsRef<[u8]>>(&self, buffer: &mut T) {\n        let mut writer = UpdateBuffer::new_unchecked(buffer.as_mut());\n        self.sum_signature.to_bytes(&mut writer.sum_signature_mut());\n        self.update_signature\n            .to_bytes(&mut writer.update_signature_mut());\n        self.masked_model.to_bytes(&mut writer.masked_model_mut());\n        self.local_seed_dict\n            .to_bytes(&mut writer.local_seed_dict_mut());\n    }\n}\n\nimpl FromBytes for Update {\n    fn from_byte_slice<T: AsRef<[u8]>>(buffer: &T) -> Result<Self, DecodeError> {\n        let reader = UpdateBuffer::new(buffer.as_ref())?;\n        Ok(Self {\n            sum_signature: ParticipantTaskSignature::from_byte_slice(&reader.sum_signature())\n                .context(\"invalid sum signature\")?,\n            update_signature: ParticipantTaskSignature::from_byte_slice(&reader.update_signature())\n                .context(\"invalid update signature\")?,\n            masked_model: MaskObject::from_byte_slice(&reader.masked_model())\n                .context(\"invalid masked model\")?,\n            local_seed_dict: LocalSeedDict::from_byte_slice(&reader.local_seed_dict())\n                .context(\"invalid local seed dictionary\")?,\n        })\n    }\n\n    fn from_byte_stream<I: Iterator<Item = u8> + ExactSizeIterator>(\n        iter: &mut I,\n    ) -> Result<Self, DecodeError> {\n        Ok(Self {\n            sum_signature: ParticipantTaskSignature::from_byte_stream(iter)\n                .context(\"invalid sum signature\")?,\n            update_signature: ParticipantTaskSignature::from_byte_stream(iter)\n                .context(\"invalid update signature\")?,\n            masked_model: MaskObject::from_byte_stream(iter).context(\"invalid masked model\")?,\n            local_seed_dict: LocalSeedDict::from_byte_stream(iter)\n                .context(\"invalid local seed dictionary\")?,\n        })\n    }\n}\n\n#[cfg(test)]\npub mod tests {\n    use super::*;\n    use crate::testutils::messages::update as helpers;\n\n    #[test]\n    fn buffer_read() {\n        let bytes = helpers::payload().1;\n        let buffer = UpdateBuffer::new(&bytes).unwrap();\n        assert_eq!(\n            buffer.sum_signature(),\n            helpers::sum_task_signature().1.as_slice()\n        );\n        assert_eq!(\n            buffer.update_signature(),\n            helpers::update_task_signature().1.as_slice()\n        );\n        let expected = helpers::mask_object().1;\n        assert_eq!(&buffer.masked_model()[..expected.len()], &expected[..]);\n        assert_eq!(buffer.local_seed_dict(), &helpers::local_seed_dict().1[..]);\n    }\n\n    #[test]\n    fn decode_invalid_seed_dict() {\n        let mut invalid = helpers::local_seed_dict().1;\n        // This truncates the last entry of the seed dictionary\n        invalid[3] = 0xe3;\n        let mut bytes = vec![];\n        bytes.extend(helpers::sum_task_signature().1);\n        bytes.extend(helpers::update_task_signature().1);\n        bytes.extend(helpers::mask_object().1);\n        bytes.extend(invalid);\n\n        let e = Update::from_byte_slice(&bytes).unwrap_err();\n        let cause = e.source().unwrap().to_string();\n        assert_eq!(\n            cause,\n            \"invalid local seed dictionary: trailing bytes\".to_string()\n        );\n    }\n\n    #[test]\n    fn decode() {\n        let (update, bytes) = helpers::payload();\n        let parsed = Update::from_byte_slice(&bytes).unwrap();\n        assert_eq!(parsed, update);\n    }\n\n    #[test]\n    fn stream_parse() {\n        let (update, bytes) = helpers::payload();\n        let parsed = Update::from_byte_stream(&mut bytes.into_iter()).unwrap();\n        assert_eq!(parsed, update);\n    }\n\n    #[test]\n    fn encode() {\n        let (update, bytes) = helpers::payload();\n        assert_eq!(update.buffer_length(), bytes.len());\n        let mut buf = vec![0xff; update.buffer_length()];\n        update.to_bytes(&mut buf);\n        // The order in which the hashmap is serialized is not\n        // guaranteed, but we chose our key/values such that they are\n        // sorted.\n        //\n        // First compute the offset at which the local seed dict value\n        // starts: two signature (64 bytes), the masked model (32\n        // bytes), the length field (4 bytes), the masked scalar (10 bytes)\n        let offset = 64 * 2 + 32 + 4 + 10;\n        // Sort the end of the buffer\n        (&mut buf[offset..]).sort_unstable();\n        assert_eq!(buf, bytes);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/message/traits.rs",
    "content": "//! Message traits.\n//!\n//! See the [message module] documentation since this is a private module anyways.\n//!\n//! [message module]: crate::message\n\nuse std::{\n    convert::TryInto,\n    io::{Cursor, Write},\n    iter::{ExactSizeIterator, Iterator},\n    ops::Range,\n};\n\nuse anyhow::{anyhow, Context};\n\nuse crate::{\n    crypto::ByteObject,\n    mask::seed::EncryptedMaskSeed,\n    message::{utils::ChunkableIterator, DecodeError},\n    LocalSeedDict,\n    SumParticipantPublicKey,\n};\n\n/// An interface for serializable message types.\n///\n/// See also [`FromBytes`] for deserialization.\npub trait ToBytes {\n    /// The length of the buffer for encoding the type.\n    fn buffer_length(&self) -> usize;\n\n    /// Serialize the type in the given buffer.\n    ///\n    /// # Panics\n    /// This method may panic if the given buffer is too small. Thus, [`buffer_length()`] must be\n    /// called prior to calling this, and a large enough buffer must be provided.\n    ///\n    /// [`buffer_length()`]: ToBytes::buffer_length\n    fn to_bytes<T: AsMut<[u8]> + AsRef<[u8]>>(&self, buffer: &mut T);\n}\n\n/// An interface for deserializable message types.\n///\n/// See also [`ToBytes`] for serialization.\npub trait FromBytes: Sized {\n    /// Deserialize the type from the given buffer.\n    ///\n    /// # Errors\n    /// May fail if certain parts of the deserialized buffer don't pass message validity checks.\n    fn from_byte_slice<T: AsRef<[u8]>>(buffer: &T) -> Result<Self, DecodeError>;\n\n    fn from_byte_stream<I: Iterator<Item = u8> + ExactSizeIterator>(\n        iter: &mut I,\n    ) -> Result<Self, DecodeError>;\n}\n\nimpl<T> FromBytes for T\nwhere\n    T: ByteObject,\n{\n    fn from_byte_slice<U: AsRef<[u8]>>(buffer: &U) -> Result<Self, DecodeError> {\n        Self::from_slice(buffer.as_ref())\n            .ok_or_else(|| anyhow!(\"failed to deserialize byte object\"))\n    }\n\n    fn from_byte_stream<I: Iterator<Item = u8> + ExactSizeIterator>(\n        iter: &mut I,\n    ) -> Result<Self, DecodeError> {\n        let buf: Vec<u8> = iter.take(Self::LENGTH).collect();\n        Self::from_byte_slice(&buf)\n    }\n}\n\nimpl<T> ToBytes for T\nwhere\n    T: ByteObject,\n{\n    fn buffer_length(&self) -> usize {\n        self.as_slice().len()\n    }\n\n    fn to_bytes<U: AsMut<[u8]> + AsRef<[u8]>>(&self, buffer: &mut U) {\n        buffer.as_mut().copy_from_slice(self.as_slice())\n    }\n}\n\n/// A helper for encoding and decoding Length-Value (LV) fields.\n///\n/// Note that the 4 bytes [`length()`] field gives the length of the *total* Length-Value field,\n/// _i.e._ the length of the value, plus the 4 extra bytes of the length field itself.\n///\n/// # Examples\n/// ## Decoding a LV field\n///\n/// ```rust\n/// # use xaynet_core::message::LengthValueBuffer;\n/// let bytes = vec![\n///     0x00, 0x00, 0x00, 0x05, // Length = 5\n///     0xff, // Value = 0xff\n///     0x11, 0x22, // Extra bytes\n/// ];\n/// let buffer = LengthValueBuffer::new(&bytes).unwrap();\n/// assert_eq!(buffer.length(), 5);\n/// assert_eq!(buffer.value_length(), 1);\n/// assert_eq!(buffer.value(), &[0xff][..]);\n/// ```\n///\n/// ## Encoding a LV field\n///\n/// ```rust\n/// # use xaynet_core::message::LengthValueBuffer;\n/// let mut bytes = vec![0xff; 9];\n/// let mut buffer = LengthValueBuffer::new_unchecked(&mut bytes);\n/// // It is important to set the length field before setting the value, otherwise, `value_mut()` will panic.\n/// buffer.set_length(8);\n/// buffer.value_mut().copy_from_slice(&[0, 1, 2, 3][..]);\n/// let expected = vec![\n///     0x00, 0x00, 0x00, 0x08, // Length = 8\n///     0x00, 0x01, 0x02, 0x03, // Value\n///     0xff, // unchanged\n/// ];\n///\n/// assert_eq!(bytes, expected);\n/// ```\n///\n/// [`length()`]: LengthValueBuffer::length\npub struct LengthValueBuffer<T> {\n    inner: T,\n}\n\n/// The size of the length field for encoding a Length-Value item.\nconst LENGTH_FIELD: Range<usize> = 0..4;\n\nimpl<T: AsRef<[u8]>> LengthValueBuffer<T> {\n    /// Returns a new [`LengthValueBuffer`].\n    ///\n    /// # Errors\n    /// This method performs bound checks and returns an error if the given buffer is not a valid\n    /// Length-Value item.\n    ///\n    /// # Examples\n    ///\n    /// ```rust\n    /// # use xaynet_core::message::LengthValueBuffer;\n    /// // truncated length:\n    /// assert!(LengthValueBuffer::new(&vec![0x00, 0x00, 0x00]).is_err());\n    ///\n    /// // truncated value:\n    /// let bytes = vec![\n    ///     0x00, 0x00, 0x00, 0x08, // length: 8\n    ///     0x11, 0x22, 0x33, // value\n    /// ];\n    /// assert!(LengthValueBuffer::new(&bytes).is_err());\n    ///\n    /// // valid Length-Value item\n    /// let bytes = vec![\n    ///     0x00, 0x00, 0x00, 0x08, // length: 8\n    ///     0x11, 0x22, 0x33, 0x44, // value\n    ///     0xaa, 0xbb, // extra bytes are ignored\n    /// ];\n    /// let buf = LengthValueBuffer::new(&bytes).unwrap();\n    /// assert_eq!(buf.length(), 8);\n    /// assert_eq!(buf.value(), &[0x11, 0x22, 0x33, 0x44][..]);\n    /// ```\n    pub fn new(bytes: T) -> Result<Self, DecodeError> {\n        let buffer = Self { inner: bytes };\n        buffer\n            .check_buffer_length()\n            .context(\"not a valid LengthValueBuffer\")?;\n        Ok(buffer)\n    }\n\n    /// Create a new [`LengthValueBuffer`] without any bound checks.\n    pub fn new_unchecked(bytes: T) -> Self {\n        Self { inner: bytes }\n    }\n\n    /// Check that the buffer is a valid Length-Value item.\n    pub fn check_buffer_length(&self) -> Result<(), DecodeError> {\n        let len = self.inner.as_ref().len();\n        if len < LENGTH_FIELD.end {\n            return Err(anyhow!(\n                \"invalid buffer length: {} < {}\",\n                len,\n                LENGTH_FIELD.end\n            ));\n        }\n\n        if (self.length() as usize) < LENGTH_FIELD.end {\n            return Err(anyhow!(\n                \"invalid length value: {} (should be >= {})\",\n                len,\n                LENGTH_FIELD.end\n            ));\n        }\n\n        if len < self.length() as usize {\n            return Err(anyhow!(\n                \"invalid buffer length: {} < {}\",\n                len,\n                self.length(),\n            ));\n        }\n        Ok(())\n    }\n\n    /// Returns the length field. Note that the value of the length\n    /// field includes the length of the field itself (4 bytes).\n    ///\n    /// # Panics\n    /// This method may panic if buffer is not a valid Length-Value item.\n    pub fn length(&self) -> u32 {\n        // unwrap safe: the slice is exactly 4 bytes long\n        u32::from_be_bytes(self.inner.as_ref()[LENGTH_FIELD].try_into().unwrap())\n    }\n\n    /// Returns the length of the value.\n    pub fn value_length(&self) -> usize {\n        self.length() as usize - LENGTH_FIELD.end\n    }\n\n    /// Returns the range corresponding to the value.\n    fn value_range(&self) -> Range<usize> {\n        let offset = LENGTH_FIELD.end;\n        let value_length = self.value_length();\n        offset..offset + value_length\n    }\n}\n\nimpl<T: AsMut<[u8]>> LengthValueBuffer<T> {\n    /// Sets the length field to the given value.\n    ///\n    /// # Panics\n    /// This method may panic if buffer is not a valid Length-Value item.\n    pub fn set_length(&mut self, value: u32) {\n        self.inner.as_mut()[LENGTH_FIELD].copy_from_slice(&value.to_be_bytes());\n    }\n}\n\nimpl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> LengthValueBuffer<&'a mut T> {\n    /// Gets a mutable reference to the value field.\n    ///\n    /// # Panics\n    /// This method may panic if buffer is not a valid Length-Value item.\n    pub fn value_mut(&mut self) -> &mut [u8] {\n        let range = self.value_range();\n        &mut self.inner.as_mut()[range]\n    }\n\n    /// Gets a mutable reference to the underlying buffer.\n    ///\n    /// # Panics\n    /// This method may panic if buffer is not a valid Length-Value item.\n    pub fn bytes_mut(&mut self) -> &mut [u8] {\n        self.inner.as_mut()\n    }\n}\n\nimpl<'a, T: AsRef<[u8]> + ?Sized> LengthValueBuffer<&'a T> {\n    /// Gets a reference to the value field.\n    ///\n    /// # Panics\n    /// This method may panic if buffer is not a valid Length-Value item.\n    pub fn value(&self) -> &'a [u8] {\n        &self.inner.as_ref()[self.value_range()]\n    }\n\n    /// Gets a reference to the underlying buffer.\n    ///\n    /// # Panics\n    /// This method may panic if buffer is not a valid Length-Value item.\n    pub fn bytes(self) -> &'a [u8] {\n        let range = self.value_range();\n        &self.inner.as_ref()[..range.end]\n    }\n}\n\nconst ENTRY_LENGTH: usize = SumParticipantPublicKey::LENGTH + EncryptedMaskSeed::LENGTH;\n\nimpl ToBytes for LocalSeedDict {\n    fn buffer_length(&self) -> usize {\n        LENGTH_FIELD.end + self.len() * ENTRY_LENGTH\n    }\n\n    fn to_bytes<T: AsMut<[u8]> + AsRef<[u8]>>(&self, buffer: &mut T) {\n        let mut writer = Cursor::new(buffer.as_mut());\n        let length = self.buffer_length() as u32;\n        let _ = writer.write(&length.to_be_bytes()).unwrap();\n        for (key, value) in self {\n            let _ = writer.write(key.as_slice()).unwrap();\n            let _ = writer.write(value.as_ref()).unwrap();\n        }\n    }\n}\n\nimpl FromBytes for LocalSeedDict {\n    fn from_byte_slice<T: AsRef<[u8]>>(buffer: &T) -> Result<Self, DecodeError> {\n        let reader = LengthValueBuffer::new(buffer.as_ref())?;\n        let mut dict = LocalSeedDict::new();\n\n        let key_length = SumParticipantPublicKey::LENGTH;\n        let mut entries = reader.value().chunks_exact(ENTRY_LENGTH);\n        for chunk in &mut entries {\n            // safe unwraps: lengths of slices are guaranteed\n            // by constants.\n            let key = SumParticipantPublicKey::from_slice(&chunk[..key_length]).unwrap();\n            let value = EncryptedMaskSeed::from_slice(&chunk[key_length..]).unwrap();\n            if dict.insert(key, value).is_some() {\n                return Err(anyhow!(\"invalid local seed dictionary: duplicated key\"));\n            }\n        }\n        if !entries.remainder().is_empty() {\n            return Err(anyhow!(\"invalid local seed dictionary: trailing bytes\"));\n        }\n        Ok(dict)\n    }\n\n    fn from_byte_stream<I: Iterator<Item = u8> + ExactSizeIterator>(\n        iter: &mut I,\n    ) -> Result<Self, DecodeError> {\n        let len = u32::from_byte_stream(iter).context(\"cannot parse length field\")? as usize;\n        if len < 4 {\n            return Err(anyhow!(\"invalid length field\"));\n        }\n        if iter.len() < len - 4 {\n            return Err(anyhow!(\n                \"expected {} bytes, but only {} left\",\n                len - 4,\n                iter.len()\n            ));\n        }\n\n        let mut dict = LocalSeedDict::new();\n        let entries = iter.take(len - 4).chunks(ENTRY_LENGTH);\n        for mut chunk in entries.into_iter() {\n            let key = SumParticipantPublicKey::from_byte_stream(&mut chunk)\n                .context(\"invalid entry: cannot parse public key\")?;\n            let value = EncryptedMaskSeed::from_byte_stream(&mut chunk)\n                .context(\"invalid entry: cannot parse encrypted mask seed\")?;\n            // This should really not happen, but it's worth checking\n            // because our chunkable iterator panics if the chunks are\n            // not fully consumed.\n            if chunk.len() > 0 {\n                return Err(anyhow!(\n                    \"unknown error while parsing seed dict entry: entry buffer not fully consumed\"\n                ));\n            }\n            if dict.insert(key, value).is_some() {\n                return Err(anyhow!(\"duplicated key\"));\n            }\n        }\n        Ok(dict)\n    }\n}\n\nimpl FromBytes for u16 {\n    fn from_byte_slice<T: AsRef<[u8]>>(buffer: &T) -> Result<Self, DecodeError> {\n        Ok(u16::from_be_bytes(\n            buffer\n                .as_ref()\n                .try_into()\n                .context(\"failed to parse u16: invalid length\")?,\n        ))\n    }\n\n    fn from_byte_stream<I: Iterator<Item = u8> + ExactSizeIterator>(\n        iter: &mut I,\n    ) -> Result<Self, DecodeError> {\n        fn err() -> DecodeError {\n            anyhow!(\"cannot read u16: byte stream exhausted\")\n        }\n        let b1 = (iter.next().ok_or_else(err)? as u16) << 8;\n        let b2 = iter.next().ok_or_else(err)? as u16;\n        Ok(b1 | b2)\n    }\n}\n\nimpl FromBytes for u32 {\n    fn from_byte_slice<T: AsRef<[u8]>>(buffer: &T) -> Result<Self, DecodeError> {\n        Ok(u32::from_be_bytes(\n            buffer\n                .as_ref()\n                .try_into()\n                .context(\"failed to parse u32: invalid length\")?,\n        ))\n    }\n\n    fn from_byte_stream<I: Iterator<Item = u8> + ExactSizeIterator>(\n        iter: &mut I,\n    ) -> Result<Self, DecodeError> {\n        fn err() -> DecodeError {\n            anyhow!(\"cannot read u32: byte stream exhausted\")\n        }\n        let b1 = (iter.next().ok_or_else(err)? as u32) << 24;\n        let b2 = (iter.next().ok_or_else(err)? as u32) << 16;\n        let b3 = (iter.next().ok_or_else(err)? as u32) << 8;\n        let b4 = iter.next().ok_or_else(err)? as u32;\n        Ok(b1 | b2 | b3 | b4)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn decode_length_value_buffer() {\n        let bytes = vec![\n            0x00, 0x00, 0x00, 0x05, // Length = 1\n            0xff, // Value = 0xff\n            0x11, 0x22, // Extra bytes\n        ];\n        let buffer = LengthValueBuffer::new(&bytes).unwrap();\n        assert_eq!(buffer.length(), 5);\n        assert_eq!(buffer.value_length(), 1);\n        assert_eq!(buffer.value(), &[0xff][..]);\n    }\n\n    #[test]\n    fn decode_empty_value() {\n        let bytes = vec![0x00, 0x00, 0x00, 0x04];\n        let buffer = LengthValueBuffer::new(&bytes).unwrap();\n        assert_eq!(buffer.length(), 4);\n        assert_eq!(buffer.value_length(), 0);\n    }\n\n    #[test]\n    fn decode_length_value_buffer_buffer_exhausted() {\n        let bytes = vec![\n            0x00, 0x00, 0x00, 0x08, // Length = 6\n            0x11, 0x22, // Only 2 bytes\n        ];\n        assert!(LengthValueBuffer::new(bytes).is_err());\n    }\n\n    #[test]\n    fn decode_length_value_buffer_invalid_length() {\n        // Missing bytes\n        let bytes = vec![0x00, 0x00, 0x00];\n        assert!(LengthValueBuffer::new(bytes).is_err());\n        // Length field invalid\n        let bytes = vec![0x00, 0x00, 0x00, 0x03];\n        assert!(LengthValueBuffer::new(bytes).is_err());\n    }\n\n    #[test]\n    fn encode_length_value_buffer() {\n        let mut bytes = vec![0xff; 7];\n        let mut buffer = LengthValueBuffer::new_unchecked(&mut bytes);\n        buffer.set_length(6);\n        buffer.value_mut().copy_from_slice(&[0x11, 0x22][..]);\n        let expected = vec![\n            0x00, 0x00, 0x00, 0x06, // Length = 6\n            0x11, 0x22, // Value\n            0xff, // unchanged\n        ];\n\n        assert_eq!(bytes, expected);\n    }\n\n    #[test]\n    fn encode_length_value_buffer_emty() {\n        let mut bytes = vec![0xff; 5];\n        let mut buffer = LengthValueBuffer::new_unchecked(&mut bytes);\n        buffer.set_length(4);\n        buffer.value_mut().copy_from_slice(&[][..]);\n        let expected = vec![\n            0x00, 0x00, 0x00, 0x04, // Length = 0\n            0xff, // unchanged\n        ];\n\n        assert_eq!(bytes, expected);\n    }\n\n    #[test]\n    fn parse_u16() {\n        let buf = vec![0x12, 0x34];\n        assert_eq!(u16::from_byte_slice(&buf.as_slice()).unwrap(), 0x1234);\n        assert_eq!(u16::from_byte_stream(&mut buf.into_iter()).unwrap(), 0x1234);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/message/utils/chunkable_iterator.rs",
    "content": "//! This module provides an extension to the [`Iterator`] trait that allows iterating by chunks. One\n//! important property of our chunks, is that they implement [`ExactSizeIterator`], which is\n//! required by the [`FromBytes`] trait.\n//!\n//! [`Iterator`]: std::iter::Iterator\n//! [`ExactSizeIterator`]: std::iter::ExactSizeIterator\n//! [`FromBytes`]: crate::message::FromBytes\n\nuse std::{\n    cell::RefCell,\n    cmp,\n    fmt,\n    iter::{ExactSizeIterator, Iterator},\n    ops::Range,\n};\n\npub trait ChunkableIterator: Iterator + Sized {\n    /// Return an _iterable_ that can chunk the iterator.\n    ///\n    /// Yield subiterators (chunks) that each yield a fixed number of\n    /// elements, determined by `size`. The last chunk will be shorter\n    /// if there aren't enough elements.\n    ///\n    /// Note that the chunks *must* be fully consumed in the order\n    /// they are yielded. Otherwise, they will panic.\n    ///\n    /// # Examples\n    ///\n    /// ```compile_fail\n    /// # // private items can't be tested with doc tests\n    /// let chunks = vec![0, 1, 2, 3, 4].into_iter().chunks(2);\n    /// let mut chunks_iter = chunks.into_iter();\n    ///\n    /// let mut chunk_1 = chunks_iter.next().unwrap();\n    /// assert_eq!(chunk_1.next().unwrap(), 0);\n    /// assert_eq!(chunk_1.next().unwrap(), 1);\n    /// assert!(chunk_1.next().is_none());\n    ///\n    /// let mut chunk_2 = chunks_iter.next().unwrap();\n    /// assert_eq!(chunk_2.next().unwrap(), 2);\n    /// assert_eq!(chunk_2.next().unwrap(), 3);\n    /// assert!(chunk_2.next().is_none());\n    ///\n    /// let mut chunk_3 = chunks_iter.next().unwrap();\n    /// assert_eq!(chunk_3.next().unwrap(), 4);\n    /// assert!(chunk_3.next().is_none());\n    ///\n    /// assert!(chunks_iter.next().is_none());\n    /// ```\n    ///\n    /// Attempting to consume chunks out of order fails:\n    ///\n    /// ```compile_fail\n    /// # // private items can't be tested with doc tests\n    /// let chunks = vec![0, 1, 2, 3, 4].into_iter().chunks(2);\n    /// let mut chunks_iter = chunks.into_iter();\n    ///\n    /// let mut chunk_1 = chunks_iter.next().unwrap();\n    /// let mut chunk_2 = chunks_iter.next().unwrap();\n    ///\n    /// chunk_2.next(); // panics because chunk_1 was not consumed\n    /// ```\n    ///\n    /// Similarly, not _fully_ consuming the chunks fails:\n    ///\n    /// ```compile_fail\n    /// # // private items can't be tested with doc tests\n    /// let chunks = vec![0, 1, 2, 3, 4].into_iter().chunks(2);\n    /// let mut chunks_iter = chunks.into_iter();\n    ///\n    /// let mut chunk_1 = chunks_iter.next().unwrap();\n    /// let _ = chunk_1.next().unwrap();\n    /// let mut chunk_2 = chunks_iter.next().unwrap();\n    ///\n    /// chunk_2.next(); // panics because chunk_1 was not fully consumed\n    /// ```\n    ///\n    /// # Panics\n    ///\n    /// Panics if size is 0.\n    fn chunks(self, size: usize) -> IntoChunks<Self>;\n}\n\nimpl<I> ChunkableIterator for I\nwhere\n    I: Iterator,\n{\n    fn chunks(self, size: usize) -> IntoChunks<Self> {\n        IntoChunks::new(self, size)\n    }\n}\n\nstruct Inner<I>\nwhere\n    I: Iterator,\n{\n    /// The iterator we're chunking\n    iter: I,\n    /// Size of each chunk. Note that the last chunk may be smaller\n    chunk_size: usize,\n    /// Number of chunks that have been yielded\n    nb_chunks: usize,\n    /// Next item from `iter`. By buffering it, we can know when `iter`\n    /// is exhausted.\n    next: Option<(usize, I::Item)>,\n}\n\nimpl<I> fmt::Debug for Inner<I>\nwhere\n    I: Iterator + fmt::Debug,\n    I::Item: fmt::Debug,\n{\n    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {\n        f.debug_struct(\"Inner\")\n            .field(\"iter\", &self.iter)\n            .field(\"chunk_size\", &self.chunk_size)\n            .field(\"nb_chunks\", &self.nb_chunks)\n            .field(\"next\", &self.next)\n            .finish()\n    }\n}\n\nimpl<I> Inner<I>\nwhere\n    I: ExactSizeIterator,\n{\n    /// Number of items left in `self.iter`\n    fn remaining(&self) -> usize {\n        self.next.as_ref().map(|_| 1).unwrap_or(0) + self.iter.len()\n    }\n}\n\nimpl<I> Inner<I>\nwhere\n    I: Iterator,\n{\n    /// Return a new `Inner` with the given iterator and chunk size\n    fn new(mut iter: I, chunk_size: usize) -> Self {\n        if chunk_size == 0 {\n            panic!(\"invalid chunk size (must be > 0)\")\n        }\n        let next = iter.next().map(|elt| (0, elt));\n        Self {\n            iter,\n            chunk_size,\n            nb_chunks: 0,\n            next,\n        }\n    }\n\n    /// Get the `index`-th item from the underlying iterator. See\n    /// [`IntoChunks::get`].\n    fn get(&mut self, index: usize) -> Option<I::Item> {\n        self.next.as_ref()?;\n\n        let current_index = self.next.as_ref().unwrap().0;\n        if index < current_index {\n            return None;\n        }\n\n        if index == current_index {\n            let res = Some(self.next.take().unwrap().1);\n            // Buffer the next element\n            self.next = self.iter.next().map(|elt| (index + 1, elt));\n            res\n        } else {\n            panic!(\"previous chunks must be consumed\");\n        }\n    }\n}\n\n/// A type that can be turned into an `Iterator<Item=Chunk<I>>`.\npub struct IntoChunks<I>\nwhere\n    I: Iterator,\n{\n    /// `inner` is just a mutable `Inner<I>`.\n    inner: RefCell<Inner<I>>,\n}\n\nimpl<I> fmt::Debug for IntoChunks<I>\nwhere\n    I: Iterator + fmt::Debug,\n    I::Item: fmt::Debug,\n{\n    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {\n        f.debug_struct(\"IntoChunks\")\n            .field(\"inner\", &self.inner)\n            .finish()\n    }\n}\n\nimpl<I> IntoChunks<I>\nwhere\n    I: Iterator,\n{\n    /// Return a new `Chunk<I>`\n    pub fn new(iter: I, chunk_size: usize) -> Self {\n        Self {\n            inner: RefCell::new(Inner::new(iter, chunk_size)),\n        }\n    }\n\n    /// Get the range of the next chunk\n    fn next_chunk_range(&self) -> Range<usize> {\n        let start = self.inner.borrow().nb_chunks * self.inner.borrow().chunk_size;\n        let end = start + self.inner.borrow().chunk_size;\n        start..end\n    }\n\n    /// Return `true` if the iterator we're chunking is exhausted\n    fn exhausted(&self) -> bool {\n        self.inner.borrow().next.is_none()\n    }\n\n    /// Get the `index`-th item from the underlying iterator. If the\n    /// iterator already advanced beyond `index`, `None` is\n    /// returned. If the requested `index` hasn't been reached yet,\n    /// this method panics. This is to enforce the invariant that all\n    /// chunks must be consumed in order.\n    ///\n    /// # Examples\n    ///\n    /// ```compile_fail\n    /// # // private items can't be tested with doc tests\n    /// let iter = vec![0, 1, 2, 3, 4, 5].into_iter();\n    /// let chunk_size = 2;\n    /// let chunks = IntoChunks::new(iter, chunk_size);\n    /// assert_eq!(chunks.get(0), Some(0));\n    /// assert_eq!(chunks.get(1), Some(1));\n    /// // calling `get` for an index that have been consumed already\n    /// assert_eq!(chunks.get(1), None);\n    /// // this panics, because the expected index is `2`\n    /// chunks.get(3);\n    /// ```\n    pub fn get(&self, index: usize) -> Option<I::Item> {\n        self.inner.borrow_mut().get(index)\n    }\n}\n\nimpl<I> IntoChunks<I>\nwhere\n    I: ExactSizeIterator,\n{\n    /// Number of items left in the iterator we're chunking\n    fn remaining(&self) -> usize {\n        self.inner.borrow().remaining()\n    }\n}\n\nimpl<'a, I> IntoIterator for &'a IntoChunks<I>\nwhere\n    I: Iterator,\n{\n    type Item = Chunk<'a, I>;\n    type IntoIter = Chunks<'a, I>;\n\n    fn into_iter(self) -> Self::IntoIter {\n        Chunks { parent: self }\n    }\n}\n\n/// An iterator that yields chunks\npub struct Chunks<'a, I>\nwhere\n    I: Iterator,\n{\n    parent: &'a IntoChunks<I>,\n}\n\nimpl<'a, I> fmt::Debug for Chunks<'a, I>\nwhere\n    I: Iterator + fmt::Debug,\n    I::Item: fmt::Debug,\n{\n    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {\n        f.debug_struct(\"Chunks\")\n            .field(\"parent\", &self.parent)\n            .finish()\n    }\n}\n\nimpl<'a, I> Iterator for Chunks<'a, I>\nwhere\n    I: Iterator,\n{\n    type Item = Chunk<'a, I>;\n\n    fn next(&mut self) -> Option<Chunk<'a, I>> {\n        if self.parent.exhausted() {\n            return None;\n        }\n\n        let chunk = Chunk {\n            range: self.parent.next_chunk_range(),\n            chunks: self.parent,\n        };\n        self.parent.inner.borrow_mut().nb_chunks += 1;\n        Some(chunk)\n    }\n}\n\n/// A chunk\npub struct Chunk<'a, I>\nwhere\n    I: Iterator,\n{\n    range: Range<usize>,\n    chunks: &'a IntoChunks<I>,\n}\n\nimpl<'a, I> fmt::Debug for Chunk<'a, I>\nwhere\n    I: Iterator + fmt::Debug,\n    I::Item: fmt::Debug,\n{\n    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {\n        f.debug_struct(\"Chunk\")\n            .field(\"range\", &self.range)\n            .field(\"chunks\", &self.chunks)\n            .finish()\n    }\n}\n\nimpl<'a, I> Iterator for Chunk<'a, I>\nwhere\n    I: Iterator,\n{\n    type Item = I::Item;\n\n    fn next(&mut self) -> Option<Self::Item> {\n        if self.range.start >= self.range.end {\n            return None;\n        }\n        match self.chunks.get(self.range.start) {\n            Some(elt) => {\n                self.range.start += 1;\n                Some(elt)\n            }\n            None => {\n                self.range.start = self.range.end;\n                None\n            }\n        }\n    }\n}\n\nimpl<'a, I> ExactSizeIterator for Chunk<'a, I>\nwhere\n    I: Iterator + ExactSizeIterator,\n{\n    fn len(&self) -> usize {\n        cmp::min(self.chunks.remaining(), self.range.end - self.range.start)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn full_chunks_1() {\n        let iter = vec![0, 1, 2].into_iter();\n        let chunks = IntoChunks::new(iter, 1);\n        let mut chunks_iter = chunks.into_iter();\n        let mut c = chunks_iter.next().unwrap();\n        assert_eq!(c.len(), 1);\n        assert_eq!(c.next().unwrap(), 0);\n        assert_eq!(c.len(), 0);\n        assert!(c.next().is_none());\n\n        let mut c = chunks_iter.next().unwrap();\n        assert_eq!(c.len(), 1);\n        assert_eq!(c.next().unwrap(), 1);\n        assert_eq!(c.len(), 0);\n        assert!(c.next().is_none());\n\n        let mut c = chunks_iter.next().unwrap();\n        assert_eq!(c.len(), 1);\n        assert_eq!(c.next().unwrap(), 2);\n        assert_eq!(c.len(), 0);\n        assert!(c.next().is_none());\n\n        assert!(chunks_iter.next().is_none());\n    }\n\n    #[test]\n    fn full_chunks_2() {\n        let iter = vec![0, 1, 2, 3, 4, 5].into_iter();\n        let chunks = IntoChunks::new(iter, 2);\n        let mut chunks_iter = chunks.into_iter();\n\n        let mut c = chunks_iter.next().unwrap();\n        assert_eq!(c.len(), 2);\n        assert_eq!(c.next().unwrap(), 0);\n        assert_eq!(c.len(), 1);\n        assert_eq!(c.next().unwrap(), 1);\n        assert_eq!(c.len(), 0);\n        assert!(c.next().is_none());\n\n        let mut c = chunks_iter.next().unwrap();\n        assert_eq!(c.len(), 2);\n        assert_eq!(c.next().unwrap(), 2);\n        assert_eq!(c.len(), 1);\n        assert_eq!(c.next().unwrap(), 3);\n        assert_eq!(c.len(), 0);\n        assert!(c.next().is_none());\n\n        let mut c = chunks_iter.next().unwrap();\n        assert_eq!(c.len(), 2);\n        assert_eq!(c.next().unwrap(), 4);\n        assert_eq!(c.len(), 1);\n        assert_eq!(c.next().unwrap(), 5);\n        assert_eq!(c.len(), 0);\n        assert!(c.next().is_none());\n\n        assert!(chunks_iter.next().is_none());\n    }\n\n    #[test]\n    fn partial_chunk() {\n        let iter = vec![0, 1, 2].into_iter();\n        let chunks = IntoChunks::new(iter, 2);\n        let mut chunks_iter = chunks.into_iter();\n\n        let mut c = chunks_iter.next().unwrap();\n        assert_eq!(c.len(), 2);\n        assert_eq!(c.next().unwrap(), 0);\n        assert_eq!(c.len(), 1);\n        assert_eq!(c.next().unwrap(), 1);\n        assert_eq!(c.len(), 0);\n        assert!(c.next().is_none());\n\n        let mut c = chunks_iter.next().unwrap();\n        assert_eq!(c.len(), 1);\n        assert_eq!(c.next().unwrap(), 2);\n        assert_eq!(c.len(), 0);\n        assert!(c.next().is_none());\n    }\n\n    #[test]\n    #[should_panic(expected = \"previous chunks must be consumed\")]\n    fn chunks_consumed_out_of_order() {\n        let iter = vec![0, 1, 2, 3, 4, 5].into_iter();\n        let chunks = IntoChunks::new(iter, 2);\n        let mut chunks_iter = chunks.into_iter();\n\n        let mut c1 = chunks_iter.next().unwrap();\n        assert_eq!(c1.next().unwrap(), 0);\n        assert_eq!(c1.next().unwrap(), 1);\n        assert!(c1.next().is_none());\n\n        let _c2 = chunks_iter.next().unwrap();\n        let mut c3 = chunks_iter.next().unwrap();\n\n        assert_eq!(c3.next().unwrap(), 4);\n    }\n\n    // This test case illustrates a weird behavior of our iterator:\n    // everything being lazy, we can create chunks that start *beyond*\n    // what our main iterator can provide in theory. Attempting to\n    // consume such iterators should panic\n    #[test]\n    #[should_panic(expected = \"previous chunks must be consumed\")]\n    fn weird() {\n        let iter = vec![0, 1, 2].into_iter();\n        let chunks = IntoChunks::new(iter, 1);\n        let mut chunks_iter = chunks.into_iter();\n\n        let mut c1 = chunks_iter.next().unwrap();\n        let mut c2 = chunks_iter.next().unwrap();\n        let mut c3 = chunks_iter.next().unwrap();\n        // This chunks starts at index 3, which we don't even have\n        let mut c4 = chunks_iter.next().unwrap();\n        assert!(c4.next().is_none());\n        assert!(c1.next().is_none());\n        assert!(c2.next().is_none());\n        assert!(c3.next().is_none());\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/message/utils/mod.rs",
    "content": "//! Message utilities.\n//!\n//! See the [message module] documentation since this is a private module anyways.\n//!\n//! [message module]: crate::message\n\nmod chunkable_iterator;\npub use chunkable_iterator::{Chunk, ChunkableIterator, Chunks, IntoChunks};\n\nuse std::ops::Range;\n\n/// Creates a range from `start` to `start + length`.\npub(crate) const fn range(start: usize, length: usize) -> Range<usize> {\n    start..(start + length)\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/testutils/messages.rs",
    "content": "//! This module provides helpers for generating messages or messages\n//! parts such as signatures, cryptographic keys, or mask objects.\n\nuse std::convert::TryFrom;\n\nuse num::BigUint;\n\nuse crate::{\n    crypto::{ByteObject, PublicEncryptKey, PublicSigningKey, Signature},\n    mask::EncryptedMaskSeed,\n    message::{Message, Payload, Sum, Sum2, Tag, Update},\n    LocalSeedDict,\n};\n\n// A message adds 136 bytes of overhead:\n//\n// - a signature (64 bytes)\n// - the participant pk (32 bytes)\n// - the coordinator pk (32 bytes)\n// - a length field (4 bytes)\n// - a tag (1 byte)\n// - flags (1 byte)\n// - a reserved field (2 bytes)\npub const HEADER_LENGTH: usize = 136;\n\npub fn signature() -> (Signature, Vec<u8>) {\n    let bytes = vec![0x1a; 64];\n    let signature = Signature::from_slice(bytes.as_slice()).unwrap();\n    (signature, bytes)\n}\n\npub fn participant_pk() -> (PublicSigningKey, Vec<u8>) {\n    let bytes = vec![0xbb; 32];\n    let pk = PublicSigningKey::from_slice(&bytes).unwrap();\n    (pk, bytes)\n}\n\npub fn coordinator_pk() -> (PublicEncryptKey, Vec<u8>) {\n    let bytes = vec![0xcc; 32];\n    let pk = PublicEncryptKey::from_slice(&bytes).unwrap();\n    (pk, bytes)\n}\n\npub fn message<F, P>(f: F) -> (Message, Vec<u8>)\nwhere\n    F: Fn() -> (P, Vec<u8>),\n    P: Into<Payload>,\n{\n    let (payload, payload_bytes) = f();\n    let payload: Payload = payload.into();\n    let tag = match payload {\n        Payload::Sum(_) => Tag::Sum,\n        Payload::Update(_) => Tag::Update,\n        Payload::Sum2(_) => Tag::Sum2,\n        _ => panic!(\"chunks not supported\"),\n    };\n    let message = Message {\n        signature: Some(signature().0),\n        participant_pk: participant_pk().0,\n        coordinator_pk: coordinator_pk().0,\n        payload,\n        is_multipart: false,\n        tag,\n    };\n\n    let mut buf = signature().1;\n    buf.extend(participant_pk().1);\n    buf.extend(coordinator_pk().1);\n    let length = payload_bytes.len() + HEADER_LENGTH;\n    buf.extend(&(length as u32).to_be_bytes());\n    buf.push(tag.into());\n    buf.extend(vec![0, 0, 0]);\n    buf.extend(payload_bytes);\n\n    (message, buf)\n}\n\npub mod sum {\n    //! This module provides helpers for generating sum payloads\n\n    use super::*;\n\n    /// Return a fake sum task signature and its serialized version\n    pub fn sum_task_signature() -> (Signature, Vec<u8>) {\n        let bytes = vec![0x11; 64];\n        let signature = Signature::from_slice(&bytes[..]).unwrap();\n        (signature, bytes)\n    }\n\n    /// Return a fake ephemeral public key and its serialized version\n    pub fn ephm_pk() -> (PublicEncryptKey, Vec<u8>) {\n        let bytes = vec![0x22; 32];\n        let pk = PublicEncryptKey::from_slice(&bytes[..]).unwrap();\n        (pk, bytes)\n    }\n\n    /// Return an sum payload with its serialized version\n    pub fn payload() -> (Sum, Vec<u8>) {\n        let mut bytes = sum_task_signature().1;\n        bytes.extend(ephm_pk().1);\n\n        let sum = Sum {\n            sum_signature: sum_task_signature().0,\n            ephm_pk: ephm_pk().0,\n        };\n        (sum, bytes)\n    }\n}\n\npub mod update {\n    //! This module provides helpers for generating update payloads\n    pub use mask::{mask_object, mask_unit, mask_vect};\n    pub use sum::sum_task_signature;\n\n    use super::*;\n\n    /// Return a fake update task signature and its serialized version\n    pub fn update_task_signature() -> (Signature, Vec<u8>) {\n        let bytes = vec![0x14; 64];\n        let signature = Signature::from_slice(&bytes[..]).unwrap();\n        (signature, bytes)\n    }\n\n    /// Return a local seed dictionary with two entries with its\n    /// expected serialized version\n    pub fn local_seed_dict() -> (LocalSeedDict, Vec<u8>) {\n        let mut local_seed_dict = LocalSeedDict::new();\n        let mut bytes = vec![];\n\n        // Length (32+80) * 2 + 4 = 228\n        bytes.extend(vec![0x00, 0x00, 0x00, 0xe4]);\n\n        bytes.extend(vec![0x55; PublicSigningKey::LENGTH]);\n        bytes.extend(vec![0x66; EncryptedMaskSeed::LENGTH]);\n        local_seed_dict.insert(\n            PublicSigningKey::from_slice(vec![0x55; 32].as_slice()).unwrap(),\n            EncryptedMaskSeed::try_from(vec![0x66; EncryptedMaskSeed::LENGTH]).unwrap(),\n        );\n\n        // Second entry\n        bytes.extend(vec![0x77; PublicSigningKey::LENGTH]);\n        bytes.extend(vec![0x88; EncryptedMaskSeed::LENGTH]);\n        local_seed_dict.insert(\n            PublicSigningKey::from_slice(vec![0x77; 32].as_slice()).unwrap(),\n            EncryptedMaskSeed::try_from(vec![0x88; EncryptedMaskSeed::LENGTH]).unwrap(),\n        );\n\n        (local_seed_dict, bytes)\n    }\n\n    /// Return an update payload with its serialized version\n    pub fn payload() -> (Update, Vec<u8>) {\n        let mut bytes = sum_task_signature().1;\n        bytes.extend(update_task_signature().1);\n        bytes.extend(mask_object().1);\n        bytes.extend(local_seed_dict().1);\n\n        let update = Update {\n            sum_signature: sum_task_signature().0,\n            update_signature: update_task_signature().0,\n            masked_model: mask_object().0,\n            local_seed_dict: local_seed_dict().0,\n        };\n        (update, bytes)\n    }\n}\n\npub mod sum2 {\n    //! This module provides helpers for generating update payloads\n    pub use mask::{mask_object, mask_unit, mask_vect};\n    pub use sum::sum_task_signature;\n\n    use super::*;\n\n    /// Return a sum2 message and its serialized version\n    pub fn payload() -> (Sum2, Vec<u8>) {\n        let (sum_signature, sum_signature_bytes) = sum_task_signature();\n        let (model_mask, model_mask_bytes) = mask_object();\n        let bytes = [sum_signature_bytes.as_slice(), model_mask_bytes.as_slice()].concat();\n\n        let sum2 = Sum2 {\n            sum_signature,\n            model_mask,\n        };\n        (sum2, bytes)\n    }\n}\n\npub mod mask {\n    //! This module provides helpers for generating mask objects\n    use crate::mask::{\n        BoundType,\n        DataType,\n        GroupType,\n        MaskConfig,\n        MaskObject,\n        MaskUnit,\n        MaskVect,\n        ModelType,\n    };\n\n    use super::*;\n\n    /// Return a mask config and its serialized version\n    pub fn mask_config() -> (MaskConfig, Vec<u8>) {\n        // config.order() = 20_000_000_000_001 with this config, so the data\n        // should be stored on 6 bytes.\n        let config = MaskConfig {\n            group_type: GroupType::Integer,\n            data_type: DataType::I32,\n            bound_type: BoundType::B0,\n            model_type: ModelType::M3,\n        };\n        let bytes = vec![0x00, 0x02, 0x00, 0x03];\n        (config, bytes)\n    }\n\n    /// Return a masked vector and its serialized version\n    pub fn mask_vect() -> (MaskVect, Vec<u8>) {\n        let (config, mut bytes) = mask_config();\n        let data = vec![\n            BigUint::from(1_u8),\n            BigUint::from(2_u8),\n            BigUint::from(3_u8),\n            BigUint::from(4_u8),\n        ];\n        let mask_vect = MaskVect::new(config, data).unwrap();\n\n        bytes.extend(vec![\n            // number of elements\n            0x00, 0x00, 0x00, 0x04, // data (1 weight => 6 bytes with this config)\n            0x01, 0x00, 0x00, 0x00, 0x00, 0x00, // 1\n            0x02, 0x00, 0x00, 0x00, 0x00, 0x00, // 2\n            0x03, 0x00, 0x00, 0x00, 0x00, 0x00, // 3\n            0x04, 0x00, 0x00, 0x00, 0x00, 0x00, // 4\n        ]);\n\n        (mask_vect, bytes)\n    }\n\n    /// Return a masked scalar and its serialized version\n    pub fn mask_unit() -> (MaskUnit, Vec<u8>) {\n        let (config, mut bytes) = mask_config();\n        let data = BigUint::from(1_u8);\n        let mask_unit = MaskUnit::new(config, data).unwrap();\n\n        bytes.extend(vec![\n            0x01, 0x00, 0x00, 0x00, 0x00, 0x00, // data: 1\n        ]);\n        (mask_unit, bytes)\n    }\n\n    /// Return a mask object, containing a masked vector and a masked\n    /// scalar, and its serialized version\n    pub fn mask_object() -> (MaskObject, Vec<u8>) {\n        let (mask_vect, mask_vect_bytes) = mask_vect();\n        let (mask_unit, mask_unit_bytes) = mask_unit();\n        let obj = MaskObject::new_unchecked(mask_vect, mask_unit);\n        let bytes = [mask_vect_bytes.as_slice(), mask_unit_bytes.as_slice()].concat();\n\n        (obj, bytes)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    // This tests is just so that if something changes, we catch it\n    // and can update the helpers accordingly\n    #[test]\n    fn check_object_lengths() {\n        assert_eq!(Signature::LENGTH, 64);\n        assert_eq!(PublicEncryptKey::LENGTH, 32);\n        assert_eq!(PublicSigningKey::LENGTH, 32);\n        assert_eq!(EncryptedMaskSeed::LENGTH, 80);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-core/src/testutils/mod.rs",
    "content": "pub mod messages;\npub mod multipart;\n"
  },
  {
    "path": "rust/xaynet-core/src/testutils/multipart.rs",
    "content": "use num::BigUint;\n\nuse crate::{\n    crypto::{ByteObject, PublicSigningKey, Signature},\n    mask::{\n        BoundType,\n        DataType,\n        EncryptedMaskSeed,\n        GroupType,\n        MaskConfig,\n        MaskObject,\n        MaskUnit,\n        MaskVect,\n        ModelType,\n    },\n    message::{Message, ToBytes, Update},\n    testutils::messages,\n    LocalSeedDict,\n};\n\n/// Return a seed dict that has the given length `len` once\n/// serialized. `len - 4` must be multiple of 112.\npub fn local_seed_dict(len: usize) -> LocalSeedDict {\n    // a public key is 32 bytes and an encrypted mask seed 80.\n    let entry_len = 32 + 80;\n    if ((len - 4) % entry_len) != 0 {\n        panic!(\"invalid length for seed dict\");\n    }\n\n    let nb_entries = (len - 4) / entry_len;\n    let mut dict = LocalSeedDict::new();\n    for i in 0..nb_entries {\n        let bytes = (i as u64).to_be_bytes();\n        let pk_bytes = bytes.iter().cycle().take(32).copied().collect::<Vec<_>>();\n        let seed_bytes = bytes.iter().cycle().take(80).copied().collect::<Vec<_>>();\n        let pk = PublicSigningKey::from_slice(pk_bytes.as_slice()).unwrap();\n        let mask_seed = EncryptedMaskSeed::from_slice(seed_bytes.as_slice()).unwrap();\n        dict.insert(pk, mask_seed);\n    }\n\n    // Check that our calculations are correct\n    assert_eq!(dict.buffer_length(), len);\n    dict\n}\n\npub fn mask_object(len: usize) -> MaskObject {\n    // The model contains 2 sub mask objects:\n    //    - the masked model, which has:\n    //         - 4 bytes for the config\n    //         - 4 bytes for the number of weights\n    //         - 6 bytes (with our config) for each weight\n    //    - the masked scalar:\n    //         - 4 bytes for the config\n    //         - 6 bytes (with our config) for the scalar\n    //\n    // The only parameter we control to make the length vary is\n    // the number of weights. The lengths is then:\n    //\n    // len = (4 + 4 + n_weights * 6) + (4 + 6) = 18 + 6 * n_weights\n    //\n    // So we must have: (len - 18) % 6 = 0\n    if (len - 18) % 6 != 0 {\n        panic!(\"invalid masked model length\")\n    }\n    let n_weights = (len - 18) / 6;\n    // Let's not be too crazy, it makes no sense to test with too\n    // many weights\n    assert!(n_weights < u32::MAX as usize);\n\n    let mut weights = vec![];\n    for i in 0..n_weights {\n        weights.push(BigUint::from(i));\n    }\n\n    let masked_model = MaskVect::new(mask_config(), weights).unwrap();\n    let masked_scalar = MaskUnit::new(mask_config(), BigUint::from(0_u32)).unwrap();\n    let obj = MaskObject::new_unchecked(masked_model, masked_scalar);\n\n    // Check that our calculations are correct\n    assert_eq!(obj.buffer_length(), len);\n    obj\n}\n\npub fn mask_config() -> MaskConfig {\n    // config.order() = 20_000_000_000_001 with this config, so the data\n    // should be stored on 6 bytes.\n    MaskConfig {\n        group_type: GroupType::Integer,\n        data_type: DataType::I32,\n        bound_type: BoundType::B0,\n        model_type: ModelType::M3,\n    }\n}\n\npub fn task_signatures() -> (Signature, Signature) {\n    (\n        messages::sum::sum_task_signature().0,\n        messages::update::update_task_signature().0,\n    )\n}\n\n/// Create an update payload with a seed dictionary of length\n/// `dict_len` and a mask object of length `mask_len`. For a payload\n/// of size `S`, the following must hold true:\n///\n/// ```no_rust\n/// (mask_len - 22) % 6 = 0\n/// (dict_len - 4) % 112 = 0\n/// S = dict_len + mask_len + 64*2\n/// ```\npub fn update(dict_len: usize, mask_obj_len: usize) -> Update {\n    // An update message is made of:\n    // - 2 signatures of 64 bytes each\n    // - a mask object of variable length\n    // - a seed dictionary of variable length\n    //\n    // The `Message` overhead is 136 bytes (see\n    // crate::messages::HEADER_LEN). So a message with\n    // `dict_len` = 100 and `mask_obj_len` = 100 will be:\n    //\n    //    100 + 100 + 64*2 + 136 = 464 bytes\n    let (sum_signature, update_signature) = task_signatures();\n\n    let payload = Update {\n        sum_signature,\n        update_signature,\n        masked_model: mask_object(mask_obj_len),\n        local_seed_dict: local_seed_dict(dict_len),\n    };\n\n    assert_eq!(payload.buffer_length(), mask_obj_len + dict_len + 64 * 2);\n    payload\n}\n\n/// Create an update message with a seed dictionary of length\n/// `dict_len` and a mask object of length `mask_len`. For a message\n/// of size `S`, the following must hold true:\n///\n/// ```no_rust\n/// (mask_len - 22) % 6 = 0\n/// (dict_len - 4) % 112 = 0\n/// S = dict_len + mask_len + 64*2 + 136\n/// ```\npub fn message(dict_len: usize, mask_obj_len: usize) -> Message {\n    let (message, _) = messages::message(|| {\n        let payload = update(dict_len, mask_obj_len);\n        let dummy_buf = vec![];\n        (payload, dummy_buf)\n    });\n    message\n}\n"
  },
  {
    "path": "rust/xaynet-mobile/.cargo/config.toml",
    "content": "# These reduces the size of the libraries a lot!\n# See: https://github.com/johnthagen/min-sized-rust\n[profile.release]\nlto = true\ncodegen-units = 1\nopt-level = 'z'\n"
  },
  {
    "path": "rust/xaynet-mobile/.gitignore",
    "content": "ffi_test.o.dSYM\nffi_test.o\ntest_participant_save_and_restore.txt\n"
  },
  {
    "path": "rust/xaynet-mobile/Cargo.toml",
    "content": "[package]\nname = \"xaynet-mobile\"\nversion = \"0.1.0\"\nauthors = [\"Xayn Engineering <engineering@xaynet.dev>\"]\nedition = \"2018\"\ndescription = \"The Xayn Network project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant.\"\nreadme = \"README.md\"\nhomepage = \"https://xaynet.dev/\"\nrepository = \"https://github.com/xaynetwork/xaynet/\"\nlicense-file = \"../../LICENSE\"\nkeywords = [\"federated-learning\", \"fl\", \"ai\", \"machine-learning\"]\ncategories = [\"science\", \"cryptography\"]\n\n[dependencies]\nasync-trait = \"0.1.57\"\nbincode = \"1.3.3\"\nffi-support = \"0.4.4\"\nfutures = \"0.3.24\"\nreqwest = { version = \"0.11.10\", default-features = false, features = [\"rustls-tls\"]}\nsodiumoxide = \"0.2.7\"\nthiserror = \"1.0.32\"\ntracing = \"0.1.36\"\ntokio = { version = \"1.20.1\", default-features = false, features = [\"rt\"] }\nxaynet-core = { path = \"../xaynet-core\", version = \"0.2.0\" }\nxaynet-sdk = { path = \"../xaynet-sdk\", default-features = false, version = \"0.1.0\", features = [\"reqwest-client\"]}\nzeroize = \"1.5.7\"\n\n[build-dependencies]\ncbindgen = \"=0.17.0\"\n\n[lib]\nname = \"xaynet_mobile\"\ncrate-type = [\"staticlib\", \"cdylib\", \"rlib\"]\n\n[features]\ndefault = []\n"
  },
  {
    "path": "rust/xaynet-mobile/README.md",
    "content": "# Xaynet FFI\n\n## Generate C-Header File\n\nTo generate the header files, run `cargo build`.\n\n\n## Run tests\n\n### macOS\n\n```\ncc -o tests/ffi_test.o -Wl,-dead_strip -I. tests/ffi_test.c ../target/debug/libxaynet_mobile.a -framework Security -framework Foundation\n./tests/ffi_test.o\n```\n\n### Linux\n\n```\ngcc \\\n    tests/ffi_test.c\n    -Wall \\\n    -I. \\\n    -pthread -Wl,--no-as-needed -lm -ldl \\\n    ../target/debug/libxaynet_mobile.a \\\n    -o tests/ffi_test.o\n./tests/ffi_test.o\n```\n\nTo check for memory leaks, you can use Valgrind:\n\n```\nvalgrind --tool=memcheck  --leak-check=full --show-leak-kinds=all -s ./tests/ffi_test.o\n```\n"
  },
  {
    "path": "rust/xaynet-mobile/build.rs",
    "content": "use std::{\n    env,\n    fs::read_dir,\n    path::{Path, PathBuf},\n};\n\nuse cbindgen::{generate_with_config, Config};\n\n// cargo doesn't check directories recursively so we have to do it by hand, also emitting a\n// rerun-if line cancels the default rerun for changes in the crate directory\nfn cargo_rerun_if_changed(entry: impl AsRef<Path>) {\n    let entry = entry.as_ref();\n    if entry.is_dir() {\n        for entry in read_dir(entry).expect(\"Failed to read dir.\") {\n            cargo_rerun_if_changed(entry.expect(\"Failed to read entry.\").path());\n        }\n    } else {\n        println!(\"cargo:rerun-if-changed={}\", entry.display());\n    }\n}\n\nfn main() {\n    let crate_dir = PathBuf::from(\n        env::var(\"CARGO_MANIFEST_DIR\").expect(\"Failed to read CARGO_MANIFEST_DIR env.\"),\n    );\n    let bind_config = crate_dir.join(\"cbindgen.toml\");\n    let bind_file = crate_dir.join(\"xaynet_ffi.h\");\n\n    cargo_rerun_if_changed(crate_dir.join(\"src\"));\n    cargo_rerun_if_changed(crate_dir.join(\"Cargo.toml\"));\n    cargo_rerun_if_changed(bind_config.as_path());\n\n    let config = Config::from_file(bind_config).expect(\"Failed to read config.\");\n    generate_with_config(crate_dir, config)\n        .expect(\"Failed to generate bindings.\")\n        .write_to_file(bind_file);\n}\n"
  },
  {
    "path": "rust/xaynet-mobile/cbindgen.toml",
    "content": "language = \"C\"\nautogen_warning = \"/* Warning, this file is autogenerated by cbindgen. Don't modify this manually. */\"\ninclude_version = true\n\n[export]\nexclude = [\"_xaynet_ffi_settings_destroy\", \"_xaynet_ffi_participant_destroy\", \"_xaynet_ffi_local_model_config_destroy\"]\n\n[parse]\nparse_deps = true\ninclude = [\"ffi-support\"]\n\n[enum]\nrename_variants = \"ScreamingSnakeCase\"\nprefix_with_name = true\n"
  },
  {
    "path": "rust/xaynet-mobile/src/ffi/config.rs",
    "content": "use crate::ffi::{ERR_NULLPTR, OK};\nuse std::os::raw::c_int;\nuse xaynet_core::mask::DataType;\n\nmod pv {\n    use super::LocalModelConfig;\n    ffi_support::define_box_destructor!(LocalModelConfig, _xaynet_ffi_local_model_config_destroy);\n}\n\n/// Destroy the model configuration created by [`xaynet_ffi_participant_local_model_config()`].\n///\n/// # Return value\n///\n/// - [`OK`] on success\n/// - [`ERR_NULLPTR`] if `local_model_config` is NULL\n///\n/// # Safety\n///\n/// 1. When calling this method, you have to ensure that *either* the pointer is NULL\n///    *or* all of the following is true:\n///    - The pointer must be properly [aligned].\n///    - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n///      documentation.\n/// 2. After destroying the `LocalModelConfig`, the pointer becomes invalid and must not be\n///    used.\n/// 3. This function should only be called on a pointer that has been created by\n///    [`xaynet_ffi_participant_local_model_config()`].\n///\n/// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n/// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n/// [`xaynet_ffi_participant_local_model_config()`]: crate::ffi::xaynet_ffi_participant_local_model_config\n#[no_mangle]\npub unsafe extern \"C\" fn xaynet_ffi_local_model_config_destroy(\n    local_model_config: *mut LocalModelConfig,\n) -> c_int {\n    if local_model_config.is_null() {\n        return ERR_NULLPTR;\n    }\n    pv::_xaynet_ffi_local_model_config_destroy(local_model_config);\n    OK\n}\n\n#[repr(C)]\n/// The model configuration of the model that is expected in [`xaynet_ffi_participant_set_model()`].\n///\n/// [`xaynet_ffi_participant_set_model()`]: crate::ffi::xaynet_ffi_participant_set_model\npub struct LocalModelConfig {\n    /// The expected data type of the model.\n    pub data_type: ModelDataType,\n    /// the expected length of the model.\n    pub len: u64,\n}\n\nimpl From<xaynet_sdk::LocalModelConfig> for LocalModelConfig {\n    fn from(lmc: xaynet_sdk::LocalModelConfig) -> Self {\n        LocalModelConfig {\n            data_type: lmc.data_type.into(),\n            len: lmc.len as u64,\n        }\n    }\n}\n\n#[repr(u8)]\n/// The original primitive data type of the numerical values to be masked.\npub enum ModelDataType {\n    /// Numbers of type f32.\n    F32 = 0,\n    /// Numbers of type f64.\n    F64 = 1,\n    /// Numbers of type i32.\n    I32 = 2,\n    /// Numbers of type i64.\n    I64 = 3,\n}\n\nimpl From<DataType> for ModelDataType {\n    fn from(dt: DataType) -> Self {\n        match dt {\n            DataType::F32 => ModelDataType::F32,\n            DataType::F64 => ModelDataType::F64,\n            DataType::I32 => ModelDataType::I32,\n            DataType::I64 => ModelDataType::I64,\n        }\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-mobile/src/ffi/mod.rs",
    "content": "#![allow(unused_unsafe)]\n\nmod participant;\npub use participant::*;\n\nmod settings;\npub use settings::*;\n\nmod config;\npub use config::*;\n\npub use ffi_support::{ByteBuffer, FfiStr};\nuse std::os::raw::c_int;\n\n/// Destroy the given `ByteBuffer` and free its memory. This function must only be\n/// called on `ByteBuffer`s that have been created on the Rust side of the FFI. If you\n/// have created a `ByteBuffer` on the other side of the FFI, do not use this function,\n/// use `free()` instead.\n///\n/// # Return value\n///\n/// - [`OK`] on success\n/// - [`ERR_NULLPTR`] if `buf` is NULL\n///\n/// # Safety\n///\n/// 1. When calling this method, you have to ensure that *either* the pointer is NULL\n/// *or* all of the following is true:\n///  - The pointer must be properly [aligned].\n///  - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n///    documentation.\n/// 2. After destroying the `ByteBuffer` the pointer becomes invalid and must not be\n///    used.\n/// 3. Calling this function on a `ByteBuffer` that has not been created on the Rust\n///    side of the FFI is UB.\n///\n/// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n/// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n#[no_mangle]\npub unsafe extern \"C\" fn xaynet_ffi_byte_buffer_destroy(\n    // Note that we use a *const instead of a *mut here. The reason is\n    // that the functions that create byte buffers return *const\n    // pointers. Taking a *mut here would trigger a\n    // -Wdiscarded-qualifiers warning from C. Forcing users to use\n    // *const pointers brings some safety, and casting back to *mut\n    // here is no big deal since the pointer becomes invalid afterward\n    // anyway.\n    buf: *const ByteBuffer,\n) -> c_int {\n    if buf.is_null() {\n        return ERR_NULLPTR;\n    }\n    Box::from_raw(buf as *mut ByteBuffer).destroy();\n    OK\n}\n\n/// Initialize the crypto library. This method must be called before instantiating a\n/// participant with [`xaynet_ffi_participant_new()`] or before generating new keys with\n/// [`xaynet_ffi_generate_key_pair()`].\n///\n/// # Return value\n///\n/// - [`OK`] if the initialization succeeded\n/// - -[`ERR_CRYPTO_INIT`] if the initialization failed\n///\n/// # Safety\n///\n/// This function is safe to call\n#[no_mangle]\npub unsafe extern \"C\" fn xaynet_ffi_crypto_init() -> c_int {\n    if sodiumoxide::init().is_err() {\n        ERR_CRYPTO_INIT\n    } else {\n        OK\n    }\n}\n\n/// Return value upon success\npub const OK: c_int = 0;\n/// NULL pointer argument\npub const ERR_NULLPTR: c_int = 1;\n/// Invalid coordinator URL\npub const ERR_INVALID_URL: c_int = 2;\n/// Invalid settings: coordinator URL is not set\npub const ERR_SETTINGS_URL: c_int = 3;\n/// Invalid settings: signing keys are not set\npub const ERR_SETTINGS_KEYS: c_int = 4;\n/// Invalid settings: scalar is out of bounds\npub const ERR_SETTINGS_SCALAR: c_int = 5;\n/// Failed to set the local model: invalid model\npub const ERR_SETMODEL_MODEL: c_int = 6;\n/// Failed to set the local model: invalid data type\npub const ERR_SETMODEL_DATATYPE: c_int = 7;\n/// Failed to initialized the crypto library\npub const ERR_CRYPTO_INIT: c_int = 8;\n/// Invalid secret signing key\npub const ERR_CRYPTO_SECRET_KEY: c_int = 9;\n/// Invalid public signing key\npub const ERR_CRYPTO_PUBLIC_KEY: c_int = 10;\n/// No global model is currently available\npub const GLOBALMODEL_NONE: c_int = 11;\n/// Failed to get the global model: communication with the coordinator failed\npub const ERR_GLOBALMODEL_IO: c_int = 12;\n/// Failed to get the global model: invalid data type\npub const ERR_GLOBALMODEL_DATATYPE: c_int = 13;\n/// Failed to get the global model: invalid buffer length\npub const ERR_GLOBALMODEL_LEN: c_int = 14;\n/// Failed to get the global model: invalid model\npub const ERR_GLOBALMODEL_CONVERT: c_int = 15;\n"
  },
  {
    "path": "rust/xaynet-mobile/src/ffi/participant.rs",
    "content": "use std::{\n    convert::TryFrom,\n    os::raw::{c_int, c_uchar, c_uint, c_void},\n    ptr,\n    slice,\n};\n\nuse ffi_support::{ByteBuffer, FfiStr};\nuse xaynet_core::mask::{DataType, FromPrimitives, IntoPrimitives, Model};\n\nuse super::{\n    LocalModelConfig,\n    ERR_GLOBALMODEL_CONVERT,\n    ERR_GLOBALMODEL_DATATYPE,\n    ERR_GLOBALMODEL_IO,\n    ERR_GLOBALMODEL_LEN,\n    ERR_NULLPTR,\n    ERR_SETMODEL_DATATYPE,\n    ERR_SETMODEL_MODEL,\n    GLOBALMODEL_NONE,\n    OK,\n};\nuse crate::{into_primitives, Participant, Settings, Task};\n\nmod pv {\n    use super::Participant;\n    ffi_support::define_box_destructor!(Participant, _xaynet_ffi_participant_destroy);\n}\n\n/// Destroy the participant created by [`xaynet_ffi_participant_new()`] or\n/// [`xaynet_ffi_participant_restore()`].\n///\n/// # Return value\n///\n/// - [`OK`] on success\n/// - [`ERR_NULLPTR`] if `participant` is NULL\n///\n/// # Safety\n///\n/// 1. When calling this method, you have to ensure that *either* the pointer is NULL\n///    *or* all of the following is true:\n///    - The pointer must be properly [aligned].\n///    - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n///      documentation.\n/// 2. After destroying the `Participant`, the pointer becomes invalid and must not be\n///    used.\n/// 3. This function should only be called on a pointer that has been created by\n///    [`xaynet_ffi_participant_new()`] or [`xaynet_ffi_participant_restore()`]\n///\n/// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n/// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n#[no_mangle]\npub unsafe extern \"C\" fn xaynet_ffi_participant_destroy(participant: *mut Participant) -> c_int {\n    if participant.is_null() {\n        return ERR_NULLPTR;\n    }\n    pv::_xaynet_ffi_participant_destroy(participant);\n    OK\n}\n\n/// The participant is not taking part in the sum or update task\npub const PARTICIPANT_TASK_NONE: c_int = 1;\n/// The participant is not taking part in the sum task\npub const PARTICIPANT_TASK_SUM: c_int = 1 << 1;\n/// The participant is not taking part in the update task\npub const PARTICIPANT_TASK_UPDATE: c_int = 1 << 2;\n/// The participant is expected to set the model it trained\npub const PARTICIPANT_SHOULD_SET_MODEL: c_int = 1 << 3;\n/// The participant is expected to set the model it trained\npub const PARTICIPANT_MADE_PROGRESS: c_int = 1 << 4;\n/// A new global model is available\npub const PARTICIPANT_NEW_GLOBALMODEL: c_int = 1 << 5;\n\n/// Instantiate a new participant with the given settings. The participant must be\n/// destroyed with [`xaynet_ffi_participant_destroy`].\n///\n/// # Return value\n///\n/// - a NULL pointer if `settings` is NULL or if the participant creation failed\n/// - a valid pointer to a [`Participant`] otherwise\n///\n/// # Safety\n///\n/// When calling this method, you have to ensure that *either* the pointer is NULL *or*\n/// all of the following is true:\n///\n/// - The pointer must be properly [aligned].\n/// - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n///   documentation.\n///\n/// After destroying the participant with [`xaynet_ffi_participant_destroy`] becomes\n/// invalid and must not be used.\n///\n/// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n/// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n#[no_mangle]\npub unsafe extern \"C\" fn xaynet_ffi_participant_new(settings: *const Settings) -> *mut Participant {\n    let settings = match unsafe { settings.as_ref() } {\n        Some(settings) => settings.clone(),\n        None => return std::ptr::null_mut(),\n    };\n\n    match Participant::new(settings) {\n        Ok(participant) => Box::into_raw(Box::new(participant)),\n        Err(_) => std::ptr::null_mut(),\n    }\n}\n\n/// Drive the participant internal state machine. Every tick, the state machine\n/// attempts to perform a small work unit.\n///\n/// # Return value\n///\n/// - [`ERR_NULLPTR`] is `participant` is NULL\n/// - a bitflag otherwise, with the following flags:\n///   - [`PARTICIPANT_MADE_PROGRESS`]: if set, this flag indicates that the participant\n///     internal state machine was able to make some progress, and that the participant\n///     state changed. This information can be used as an indication for saving the\n///     participant state for instance. If the flag is not set, the state machine was\n///     not able to make progress. There are many potential causes for this, including:\n///       - the participant is not taking part to the current training round and is just\n///         waiting for a new one to start\n///       - the Xaynet coordinator is not reachable or has not published some\n///         information the participant is waiting for\n///       - the state machine is waiting for the model to be set (see\n///         [`xaynet_ffi_participant_set_model()`])\n///   - [`PARTICIPANT_TASK_NONE`], [`PARTICIPANT_TASK_SUM`] and\n///     [`PARTICIPANT_TASK_UPDATE`]: these flags are mutually exclusive, and indicate\n///     which task the participant has been selected for, for the current round. If\n///     [`PARTICIPANT_TASK_NONE`] is set, then the participant will just wait for a new\n///     round to start. If [`PARTICIPANT_TASK_UPDATE`] is set, then the participant has\n///     been selected to update the global model, and should prepare to provide a new\n///     model once the [`PARTICIPANT_SHOULD_SET_MODEL`] flag is set.\n///   - [`PARTICIPANT_SHOULD_SET_MODEL`]: if set, then the participant should set its\n///     model, by calling [`xaynet_ffi_participant_set_model()`]\n///   - [`PARTICIPANT_NEW_GLOBALMODEL`]: if set, the participant can fetch the new global\n///     model, by calling [`xaynet_ffi_participant_global_model()`]\n///\n/// # Safety\n///\n/// When calling this method, you have to ensure that *either* the pointer is NULL *or*\n/// all of the following is true:\n///\n/// - The pointer must be properly [aligned].\n/// - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n///   documentation.\n///\n/// After destroying the participant with [`xaynet_ffi_participant_destroy`] becomes\n/// invalid and must not be used.\n///\n/// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n/// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n#[no_mangle]\npub unsafe extern \"C\" fn xaynet_ffi_participant_tick(participant: *mut Participant) -> c_int {\n    let participant = match unsafe { participant.as_mut() } {\n        Some(participant) => participant,\n        None => return ERR_NULLPTR,\n    };\n\n    participant.tick();\n\n    let mut flags: c_int = 0;\n    match participant.task() {\n        Task::None => flags |= PARTICIPANT_TASK_NONE,\n        Task::Sum => flags |= PARTICIPANT_TASK_SUM,\n        Task::Update => flags |= PARTICIPANT_TASK_UPDATE,\n    };\n    if participant.should_set_model() {\n        flags |= PARTICIPANT_SHOULD_SET_MODEL;\n    }\n    if participant.made_progress() {\n        flags |= PARTICIPANT_MADE_PROGRESS;\n    }\n    if participant.new_global_model() {\n        flags |= PARTICIPANT_NEW_GLOBALMODEL;\n    }\n    flags\n}\n\n/// Serialize the participant state and return a buffer that contains the serialized\n/// participant.\n///\n/// # Safety\n///\n/// 1. When calling this method, you have to ensure that *either* the pointer is NULL\n///    *or* all of the following is true:\n///    - The pointer must be properly [aligned].\n///    - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n///      documentation.\n/// 2. the `ByteBuffer` created by this function must be destroyed with\n///    [`xaynet_ffi_participant_destroy`]. Attempting to free the memory from the other\n///    side of the FFI is UB.\n/// 3. This function destroys the participant. Therefore, **the pointer becomes invalid\n///    and must not be used anymore**. Instead, a new participant should be created,\n///    either with [`xaynet_ffi_participant_new()`] or\n///    [`xaynet_ffi_participant_restore()`]\n///\n/// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n/// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n///\n/// # Example\n///\n/// To save the participant into a file:\n///\n/// ```c\n///  const ByteBuffer *save_buf = xaynet_ffi_participant_save(participant);\n///  assert(save_buf);\n///\n///  char *path = \"./participant.bin\";\n///  FILE *f = fopen(path, \"w\");\n///  fwrite(save_buf->data, 1, save_buf->len, f);\n///  fclose(f);\n/// ```\n#[no_mangle]\npub unsafe extern \"C\" fn xaynet_ffi_participant_save(\n    participant: *mut Participant,\n) -> *const ByteBuffer {\n    let participant: Participant = match unsafe { participant.as_mut() } {\n        Some(ptr) => unsafe { *Box::from_raw(ptr) },\n        None => return std::ptr::null(),\n    };\n\n    Box::into_raw(Box::new(ByteBuffer::from_vec(participant.save())))\n}\n\n/// Restore the participant from a buffer that contained its serialized state.\n///\n/// # Return value\n///\n/// - a NULL pointer on failure\n/// - a pointer to the restored participant on success\n///\n/// # Safety\n///\n/// When calling this method, you have to ensure that *either* the pointers are NULL\n/// *or* all of the following is true:\n/// - The pointers must be properly [aligned].\n/// - They must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n///   documentation.\n///\n/// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n/// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n///\n/// # Example\n///\n/// To restore a participant from a file:\n///\n/// ```c\n/// f = fopen(\"./participant.bin\", \"r\");\n/// fseek(f, 0L, SEEK_END);\n/// int fsize = ftell(f);\n/// fseek(f, 0L, SEEK_SET);\n/// ByteBuffer buf = {\n///     .len = fsize,\n///     .data = (uint8_t *)malloc(fsize),\n/// };\n/// int n_read = fread(buf.data, 1, fsize, f);\n/// assert(n_read == fsize);\n/// fclose(f);\n/// Participant *restored =\n///     xaynet_ffi_participant_restore(\"http://localhost:8081\", &buf);\n/// free(buf.data);\n/// ```\n#[no_mangle]\npub unsafe extern \"C\" fn xaynet_ffi_participant_restore(\n    url: FfiStr,\n    buffer: *const ByteBuffer,\n) -> *mut Participant {\n    let url = match url.as_opt_str() {\n        Some(url) => url,\n        None => return ptr::null_mut(),\n    };\n\n    let buffer: &ByteBuffer = match unsafe { buffer.as_ref() } {\n        Some(ptr) => ptr,\n        None => return ptr::null_mut(),\n    };\n\n    if let Ok(participant) = Participant::restore(buffer.as_slice(), url) {\n        Box::into_raw(Box::new(participant))\n    } else {\n        ptr::null_mut()\n    }\n}\n\n/// Set the participant's model. Usually this should be called when the value returned\n/// by [`xaynet_ffi_participant_tick()`] contains the [`PARTICIPANT_SHOULD_SET_MODEL`]\n/// flag, but it can be called anytime. The model just won't be sent to the coordinator\n/// until it's time.\n///\n/// - `buffer` should be a pointer to a buffer that contains the model\n/// - `data_type` specifies the type of the model weights (see [`DataType`]). The C header\n///   file generated by this crate provides an enum corresponding to the parameters: `DataType`.\n/// - `len` is the number of weights the model has\n///\n/// # Return value\n///\n/// - [`OK`] if the model is set successfully\n/// - [`ERR_NULLPTR`] if `participant` is NULL\n/// - [`ERR_SETMODEL_DATATYPE`] if the datatype is invalid\n/// - [`ERR_SETMODEL_MODEL`] if the model is invalid\n///\n/// # Safety\n///\n/// 1. When calling this method, you have to ensure that *either* the pointer is NULL\n///    *or* all of the following is true:\n///    - The pointer must be properly [aligned].\n///    - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n///      documentation.\n/// 2. If `len` or `data_type` do not match the model in `buffer`, this method will\n///    result in a buffer over-read.\n///\n/// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n/// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n#[no_mangle]\npub unsafe extern \"C\" fn xaynet_ffi_participant_set_model(\n    participant: *mut Participant,\n    buffer: *const c_void,\n    data_type: c_uchar,\n    len: c_uint,\n) -> c_int {\n    let participant = match unsafe { participant.as_mut() } {\n        Some(participant) => participant,\n        None => return ERR_NULLPTR,\n    };\n\n    if buffer.is_null() {\n        return ERR_NULLPTR;\n    }\n\n    let data_type = match DataType::try_from(data_type) {\n        Ok(data_type) => data_type,\n        Err(_) => return ERR_SETMODEL_DATATYPE,\n    };\n\n    let len = len as usize;\n    let model = match data_type {\n        DataType::F32 => {\n            let buffer = unsafe { slice::from_raw_parts(buffer as *const f32, len) };\n            // we map the error so that we get an uniform error type\n            Model::from_primitives(buffer.iter().copied()).map_err(|_| ())\n        }\n        DataType::F64 => {\n            let buffer = unsafe { slice::from_raw_parts(buffer as *const f64, len) };\n            Model::from_primitives(buffer.iter().copied()).map_err(|_| ())\n        }\n        DataType::I32 => {\n            let buffer = unsafe { slice::from_raw_parts(buffer as *const i32, len) };\n            Model::from_primitives(buffer.iter().copied()).map_err(|_| ())\n        }\n        DataType::I64 => {\n            let buffer = unsafe { slice::from_raw_parts(buffer as *const i64, len) };\n            Model::from_primitives(buffer.iter().copied()).map_err(|_| ())\n        }\n    };\n\n    if let Ok(m) = model {\n        participant.set_model(m);\n        OK\n    } else {\n        ERR_SETMODEL_MODEL\n    }\n}\n\n/// Return the latest global model from the coordinator.\n///\n/// - `buffer` is the array in which the global model should be copied.\n/// - `data_type` specifies the type of the model weights (see [`DataType`]). The C header\n///   file generated by this crate provides an enum corresponding to the parameters: `DataType`.\n/// - `len` is the number of weights the model has\n///\n/// # Return Value\n///\n/// - [`OK`] if the model is set successfully\n/// - [`ERR_NULLPTR`] if `participant` or the `buffer` is NULL\n/// - [`GLOBALMODEL_NONE`] if no model exists\n/// - [`ERR_GLOBALMODEL_IO`] if the communication with the coordinator failed\n/// - [`ERR_GLOBALMODEL_DATATYPE`] if the datatype is invalid\n/// - [`ERR_GLOBALMODEL_LEN`] if the length of the buffer does not match the length of the model\n/// - [`ERR_GLOBALMODEL_CONVERT`] if the conversion of the model failed\n///\n/// # Note\n///\n///   It is **not** guaranteed, that the model configuration returned by\n///   [`xaynet_ffi_participant_local_model_config`] corresponds to the configuration of\n///   the global model. This means that the global model can have a different length / data type\n///   than it is defined in model configuration. That both model configurations are the same is\n///   only guaranteed if the model config **never** changes on the coordinator side.\n///\n/// # Safety\n///\n/// 1. When calling this method, you have to ensure that *either* the pointer is NULL\n///    *or* all of the following is true:\n///    - The pointer must be properly [aligned].\n///    - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n///      documentation.\n/// 2. If `len` or `data_type` do not match the model in `buffer`, this method will\n///    result in a buffer over-read.\n///\n/// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n#[no_mangle]\npub unsafe extern \"C\" fn xaynet_ffi_participant_global_model(\n    participant: *mut Participant,\n    buffer: *mut c_void,\n    data_type: c_uchar,\n    len: c_uint,\n) -> c_int {\n    let participant = match unsafe { participant.as_mut() } {\n        Some(participant) => participant,\n        None => return ERR_NULLPTR,\n    };\n\n    if buffer.is_null() {\n        return ERR_NULLPTR;\n    }\n\n    let global_model = match participant.global_model() {\n        Ok(Some(model)) => model,\n        Ok(None) => return GLOBALMODEL_NONE,\n        Err(_) => return ERR_GLOBALMODEL_IO,\n    };\n\n    let data_type = match DataType::try_from(data_type) {\n        Ok(data_type) => data_type,\n        Err(_) => return ERR_GLOBALMODEL_DATATYPE,\n    };\n\n    let len = len as usize;\n    if len != global_model.len() {\n        return ERR_GLOBALMODEL_LEN;\n    }\n\n    match data_type {\n        DataType::F32 => into_primitives!(global_model, buffer, f32, len),\n        DataType::F64 => into_primitives!(global_model, buffer, f64, len),\n        DataType::I32 => into_primitives!(global_model, buffer, i32, len),\n        DataType::I64 => into_primitives!(global_model, buffer, i64, len),\n    }\n}\n\n#[macro_export]\nmacro_rules! into_primitives {\n    ($global_model:expr, $buffer:expr, $data_type:ty, $len:expr) => {{\n        if let Ok(global_model) = $global_model\n            .into_primitives()\n            .collect::<Result<Vec<$data_type>, _>>()\n        {\n            let buffer = unsafe { slice::from_raw_parts_mut($buffer as *mut $data_type, $len) };\n            buffer.copy_from_slice(global_model.as_slice());\n            OK\n        } else {\n            ERR_GLOBALMODEL_CONVERT\n        }\n    }};\n}\n\n/// Return the local model configuration of the model that is expected in the\n/// [`xaynet_ffi_participant_set_model()`] function.\n///\n/// # Safety\n///\n/// 1. When calling this method, you have to ensure that *either* the pointer is NULL\n///    *or* all of the following is true:\n///    - The pointer must be properly [aligned].\n///    - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n///      documentation.\n///\n/// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n#[no_mangle]\npub unsafe extern \"C\" fn xaynet_ffi_participant_local_model_config(\n    participant: *const Participant,\n) -> *mut LocalModelConfig {\n    let participant = match unsafe { participant.as_ref() } {\n        Some(ptr) => ptr,\n        None => return std::ptr::null_mut(),\n    };\n\n    Box::into_raw(Box::new(participant.local_model_config().into()))\n}\n"
  },
  {
    "path": "rust/xaynet-mobile/src/ffi/settings.rs",
    "content": "use std::os::raw::{c_double, c_int};\n\nuse ffi_support::{ByteBuffer, FfiStr};\nuse xaynet_core::crypto::{ByteObject, PublicSigningKey, SecretSigningKey, SigningKeyPair};\nuse zeroize::Zeroize;\n\nuse super::{\n    ERR_CRYPTO_PUBLIC_KEY,\n    ERR_CRYPTO_SECRET_KEY,\n    ERR_INVALID_URL,\n    ERR_NULLPTR,\n    ERR_SETTINGS_KEYS,\n    ERR_SETTINGS_SCALAR,\n    ERR_SETTINGS_URL,\n    OK,\n};\nuse crate::{Settings, SettingsError};\n\nmod pv {\n    use super::Settings;\n    ffi_support::define_box_destructor!(Settings, _xaynet_ffi_settings_destroy);\n}\n\n/// Destroy the settings created by [`xaynet_ffi_settings_new()`].\n///\n/// # Return value\n///\n/// - [`OK`] on success\n/// - [`ERR_NULLPTR`] if `buf` is NULL\n///\n/// # Safety\n///\n/// 1. When calling this method, you have to ensure that *either* the pointer is NULL\n///    *or* all of the following is true:\n///    - The pointer must be properly [aligned].\n///    - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n///      documentation.\n/// 2. After destroying the `Settings`, the pointer becomes invalid and must not be\n///    used.\n/// 3. This function should only be called on a pointer that has been created by\n///    [`xaynet_ffi_settings_new`].\n///\n/// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n/// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n#[no_mangle]\npub unsafe extern \"C\" fn xaynet_ffi_settings_destroy(settings: *mut Settings) -> c_int {\n    if settings.is_null() {\n        return ERR_NULLPTR;\n    }\n    pv::_xaynet_ffi_settings_destroy(settings);\n    OK\n}\n\n/// Create new [`Settings`] and return a pointer to it.\n///\n/// # Safety\n///\n/// The `Settings` created by this function must be destroyed with\n/// [`xaynet_ffi_settings_destroy()`]. Attempting to free the memory from the other side\n/// of the FFI is UB.\n#[no_mangle]\npub unsafe extern \"C\" fn xaynet_ffi_settings_new() -> *mut Settings {\n    Box::into_raw(Box::new(Settings::new()))\n}\n\n/// Set scalar setting.\n///\n/// # Return value\n///\n/// - [`OK`] if successful\n/// - [`ERR_NULLPTR`] if `settings` is `NULL`\n///\n/// # Safety\n///\n/// When calling this method, you have to ensure that *either* the pointer is NULL *or*\n/// all of the following is true:\n/// - The pointer must be properly [aligned].\n/// - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n///   documentation.\n///\n/// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n/// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n#[no_mangle]\npub unsafe extern \"C\" fn xaynet_ffi_settings_set_scalar(\n    settings: *mut Settings,\n    scalar: c_double,\n) -> c_int {\n    match unsafe { settings.as_mut() } {\n        Some(settings) => {\n            settings.set_scalar(scalar);\n            OK\n        }\n        None => ERR_NULLPTR,\n    }\n}\n\n/// Set coordinator URL.\n///\n/// # Return value\n///\n/// - [`OK`] if successful\n/// - [`ERR_INVALID_URL`] if `url` is not a valid string\n/// - [`ERR_NULLPTR`] if `settings` is `NULL`\n///\n/// # Safety\n///\n/// When calling this method, you have to ensure that *either* the pointers are NULL\n/// *or* all of the following is true:\n/// - The pointers must be properly [aligned].\n/// - They must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n///   documentation.\n///\n/// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n/// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n#[no_mangle]\npub unsafe extern \"C\" fn xaynet_ffi_settings_set_url(\n    settings: *mut Settings,\n    url: FfiStr,\n) -> c_int {\n    let url = match url.as_opt_str() {\n        Some(url) => url,\n        None => return ERR_INVALID_URL,\n    };\n    match unsafe { settings.as_mut() } {\n        Some(settings) => {\n            settings.set_url(url.to_string());\n            OK\n        }\n        None => ERR_NULLPTR,\n    }\n}\n\n// TODO: add a way to save the key pair\n/// A signing key pair\npub struct KeyPair {\n    public: ByteBuffer,\n    secret: ByteBuffer,\n}\n\n// TODO: document that crypto must be initialized.\n/// Generate a new signing key pair that can be used in the [`Settings`]. **Before\n/// calling this function you must initialize the crypto library with\n/// [`xaynet_ffi_crypto_init()`]**.\n///\n/// The returned value contains a pointer to the secret key. For security reasons, you\n/// must make sure that this buffer life is a short as possible, and call\n/// [`xaynet_ffi_forget_key_pair`] to destroy it.\n///\n/// [`xaynet_ffi_crypto_init()`]: crate::ffi::xaynet_ffi_crypto_init\n///\n/// # Safety\n///\n/// This function is safe to call\n#[no_mangle]\npub unsafe extern \"C\" fn xaynet_ffi_generate_key_pair() -> *const KeyPair {\n    let SigningKeyPair { public, secret } = SigningKeyPair::generate();\n    let public_vec = public.as_slice().to_vec();\n    let secret_vec = secret.as_slice().to_vec();\n    let keys = KeyPair {\n        public: ByteBuffer::from_vec(public_vec),\n        // under the hood, ByteBuffer takes ownership of the memory\n        // without copying/leaking anything. There's no need to zero\n        // out anything yet\n        secret: ByteBuffer::from_vec(secret_vec),\n    };\n    Box::into_raw(Box::new(keys))\n}\n\n/// De-allocate the buffers that contain the signing keys, and zero out the content of\n/// the buffer that contains the secret key.\n///\n/// # Return value\n///\n/// - [`ERR_NULLPTR`] is `key_pair` is NULL\n/// - [`OK`] otherwise\n///\n/// # Safety\n///\n/// When calling this method, you have to ensure that *either* the pointer is NULL *or*\n/// all of the following is true:\n/// - The pointer must be properly [aligned].\n/// - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n///   documentation.\n///\n/// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n/// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n#[no_mangle]\npub unsafe extern \"C\" fn xaynet_ffi_forget_key_pair(key_pair: *const KeyPair) -> c_int {\n    if key_pair.is_null() {\n        return ERR_NULLPTR;\n    }\n    let key_pair = unsafe { Box::from_raw(key_pair as *mut KeyPair) };\n    // IMPORTANT: we need to free the ByteBuffer memory, since it does\n    // not implement drop. We also take care of zero-ing the memory\n    // for the secret key.\n    key_pair.secret.destroy_into_vec().zeroize();\n    key_pair.public.destroy_into_vec();\n    OK\n}\n\n/// Set participant signing keys.\n///\n/// # Return value\n///\n/// - [`OK`] if successful\n/// - [`ERR_NULLPTR`] if `settings` or `key_pair` is `NULL`\n/// - [`ERR_CRYPTO_PUBLIC_KEY`] if the given `key_pair` contains an invalid public key\n/// - [`ERR_CRYPTO_SECRET_KEY`] if the given `key_pair` contains an invalid secret key\n///\n/// # Safety\n///\n/// When calling this method, you have to ensure that *either* the pointers are NULL\n/// *or* all of the following is true:\n/// - The pointers must be properly [aligned].\n/// - They must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n///   documentation.\n///\n/// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n/// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n#[no_mangle]\npub unsafe extern \"C\" fn xaynet_ffi_settings_set_keys(\n    settings: *mut Settings,\n    key_pair: *const KeyPair,\n) -> c_int {\n    let key_pair = match unsafe { key_pair.as_ref() } {\n        Some(key_pair) => key_pair,\n        None => return ERR_NULLPTR,\n    };\n\n    let secret_slice = key_pair.secret.as_slice();\n    if secret_slice.len() != SecretSigningKey::LENGTH {\n        return ERR_CRYPTO_SECRET_KEY;\n    }\n    let secret = SecretSigningKey::from_slice_unchecked(secret_slice);\n\n    let public_slice = key_pair.public.as_slice();\n    if public_slice.len() != PublicSigningKey::LENGTH {\n        return ERR_CRYPTO_PUBLIC_KEY;\n    }\n    let public = PublicSigningKey::from_slice_unchecked(public_slice);\n\n    match unsafe { settings.as_mut() } {\n        Some(settings) => {\n            settings.set_keys(SigningKeyPair { public, secret });\n            OK\n        }\n        None => ERR_NULLPTR,\n    }\n}\n\n/// Check whether the given settings are valid and can be used to instantiate a\n/// participant (see [`xaynet_ffi_participant_new()`]).\n///\n/// # Return value\n///\n/// - [`OK`] on success\n/// - [`ERR_SETTINGS_URL`] if the URL has not been set\n/// - [`ERR_SETTINGS_KEYS`] if the signing keys have not been set\n/// - [`ERR_SETTINGS_SCALAR`] if the scalar is out of bounds\n///\n/// # Safety\n///\n/// When calling this method, you have to ensure that *either* the pointer is NULL *or*\n/// all of the following is true:\n///\n/// - The pointer must be properly [aligned].\n/// - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n///   documentation.\n///\n/// [`xaynet_ffi_participant_new()`]: crate::ffi::xaynet_ffi_participant_new\n/// [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n/// [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n#[no_mangle]\npub unsafe extern \"C\" fn xaynet_ffi_check_settings(settings: *const Settings) -> c_int {\n    match unsafe { settings.as_ref() } {\n        Some(settings) => match settings.check() {\n            Ok(()) => OK,\n            Err(SettingsError::MissingUrl) => ERR_SETTINGS_URL,\n            Err(SettingsError::MissingKeys) => ERR_SETTINGS_KEYS,\n            Err(SettingsError::OutOfScalarRange(_)) => ERR_SETTINGS_SCALAR,\n        },\n        None => ERR_NULLPTR,\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-mobile/src/lib.rs",
    "content": "#![cfg_attr(\n    doc,\n    forbid(rustdoc::broken_intra_doc_links, rustdoc::private_intra_doc_links)\n)]\n#![doc(\n    html_logo_url = \"https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/xaynet_banner.png\",\n    html_favicon_url = \"https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/favicon.png\",\n    issue_tracker_base_url = \"https://github.com/xaynetwork/xaynet/issues\"\n)]\n//! This crates provides a mobile friendly implementation of a Xaynet Federated Learning\n//! participant, along with FFI C bindings for building applications in languages that\n//! can use C bindings.\n//!\n//! The [`Participant`] provided by this crate is mobile friendly because the caller has\n//! a lot of control on how to drive the participant execution. You can regularly pause\n//! the execution of the participant, save it, and later restore it and continue the\n//! execution. When running on a device that is low on battery or does not have access\n//! to Wi-Fi for instance, it can be useful to be able to pause the participant.\n//!\n//! This control comes at a complexity cost though. Usually, a participant is split two:\n//! - a task that executes a state machine that implements the PET protocol and emit\n//!   notifications.\n//! - a task that react to these events, for instance by downloading the latest global\n//!   model at the end of a round, or trains a new model when the participant has been\n//!   selected for the update task.\n//!\n//! The task that executes the PET protocol usually runs in background and we have\n//! little control over it. This is a problem on mobile environment:\n//! - first, the app may be killed at any moment and we'd lose the participant state\n//! - second we don't really want a background task to potentially perform CPU heavy or\n//!   network heavy operations without having a say since it may drain the battery or\n//!   consume too much data.\n//!\n//! To solve this problem, the [`Participant`] provided in this crate embeds the PET\n//! state machine, and it's the caller responsibility to drive its execution (see\n//! [`Participant::tick()`])\n#[macro_use]\nextern crate ffi_support;\n#[macro_use]\nextern crate async_trait;\n\n#[macro_use]\nextern crate tracing;\n\nmod participant;\nmod settings;\npub use self::{\n    participant::{Event, Events, InitError, Notifier, Participant, Task},\n    settings::{Settings, SettingsError},\n};\npub mod ffi;\n\nmod reqwest_client;\npub(crate) use reqwest_client::new_client;\npub use reqwest_client::ClientError;\n"
  },
  {
    "path": "rust/xaynet-mobile/src/participant.rs",
    "content": "//! Participant implementation\nuse std::{convert::TryInto, sync::Arc};\n\nuse futures::future::FutureExt;\nuse thiserror::Error;\nuse tokio::{\n    runtime::Runtime,\n    sync::{mpsc, Mutex},\n};\nuse xaynet_core::mask::Model;\nuse xaynet_sdk::{\n    client::Client,\n    LocalModelConfig,\n    ModelStore,\n    Notify,\n    SerializableState,\n    StateMachine,\n    TransitionOutcome,\n    XaynetClient,\n};\n\nuse crate::{\n    new_client,\n    settings::{Settings, SettingsError},\n    ClientError,\n};\n\n/// Event emitted by the participant internal state machine as it advances through the\n/// PET protocol\npub enum Event {\n    /// Event emitted when the participant is selected for the update task\n    Update,\n    /// Event emitted when the participant is selected for the sum task\n    Sum,\n    /// Event emitted when the participant is done with its task\n    Idle,\n    /// Event emitted when a new round starts\n    NewRound,\n    /// Event emitted when the participant should load its model. This only happens if\n    /// the participant has been selected for the update task\n    LoadModel,\n}\n\n/// Event sender that is passed to the participant internal state machine for emitting\n/// notification\npub struct Notifier(mpsc::Sender<Event>);\nimpl Notifier {\n    fn notify(&mut self, event: Event) {\n        if let Err(e) = self.0.try_send(event) {\n            warn!(\"failed to notify participant: {}\", e);\n        }\n    }\n}\n\n/// A receiver for events emitted by the participant internal state machine\npub struct Events(mpsc::Receiver<Event>);\n\nimpl Events {\n    /// Create a new event sender and receiver.\n    fn new() -> (Self, Notifier) {\n        let (tx, rx) = mpsc::channel(10);\n        (Self(rx), Notifier(tx))\n    }\n\n    /// Pop the next event. If no event has been received, return `None`.\n    fn next(&mut self) -> Option<Event> {\n        // Note `try_recv` (tokio 0.2.x) or `recv().now_or_never()` (tokio 1.x)\n        // has an implementation bug where previously sent messages may not be\n        // available immediately.\n        // Related issue: https://github.com/tokio-rs/tokio/issues/3350\n        // However, that should not be an issue for us.\n        let next = self.0.recv().now_or_never()?;\n        if next.is_none() {\n            // if next is `none`, the channel is closed\n            // This can happen if:\n            //  1. the state machine crashed. In that case it's OK to crash.\n            //  2. `next` was called whereas the state machine was\n            //     dropped, which is an error. So crashing is OK as\n            //     well.\n            panic!(\"notifier dropped\")\n        }\n        next\n    }\n}\n\nimpl Notify for Notifier {\n    fn new_round(&mut self) {\n        self.notify(Event::NewRound)\n    }\n    fn sum(&mut self) {\n        self.notify(Event::Sum)\n    }\n    fn update(&mut self) {\n        self.notify(Event::Update)\n    }\n    fn load_model(&mut self) {\n        self.notify(Event::LoadModel)\n    }\n    fn idle(&mut self) {\n        self.notify(Event::Idle)\n    }\n}\n\n/// A store shared between by the participant and its internal state machine. When the\n/// state machine emits a [`Event::LoadModel`] event, the participant is expected to\n/// load its model into the store. See [`Participant::set_model()`].\n#[derive(Clone)]\nstruct Store(Arc<Mutex<Option<Model>>>);\n\nimpl Store {\n    /// Create a new model store.\n    fn new() -> Self {\n        Self(Arc::new(Mutex::new(None)))\n    }\n}\n\n#[async_trait]\nimpl ModelStore for Store {\n    type Model = Model;\n    type Error = std::convert::Infallible;\n\n    async fn load_model(&mut self) -> Result<Option<Self::Model>, Self::Error> {\n        Ok(self.0.lock().await.take())\n    }\n}\n\n/// Represent the participant current task\n#[derive(Clone, Debug, Copy)]\npub enum Task {\n    /// The participant is taking part in the sum task\n    Sum,\n    /// The participant is taking part in the update task\n    Update,\n    /// The participant is not taking part in any task\n    None,\n}\n\n/// A participant. It embeds an internal state machine that executes the PET\n/// protocol. However, it is the caller's responsibility to drive this state machine by\n/// calling [`Participant::tick()`], and to take action when the participant state\n/// changes.\npub struct Participant {\n    /// Internal state machine\n    state_machine: Option<StateMachine>,\n    /// Receiver for the events emitted by the state machine\n    events: Events,\n    /// Model store where the participant should load its model, when\n    /// `self.should_set_model` is `true`.\n    store: Store,\n    /// Async runtime to execute the state machine\n    runtime: Runtime,\n    /// Xaynet client\n    client: Client<reqwest::Client>,\n    /// Whether the participant state changed after the last call to\n    /// [`Participant::tick()`]\n    made_progress: bool,\n    /// Whether the participant should load its model into the store.\n    should_set_model: bool,\n    /// Whether a new global model is available.\n    new_global_model: bool,\n    /// The participant current task\n    task: Task,\n}\n\n/// Error that can occur when instantiating a new [`Participant`], either with\n/// [`Participant::new()`] or [`Participant::restore()`]\n#[derive(Error, Debug)]\npub enum InitError {\n    #[error(\"failed to deserialize the participant state {:?}\", _0)]\n    Deserialization(#[from] Box<bincode::ErrorKind>),\n    #[error(\"failed to initialize the participant runtime {:?}\", _0)]\n    Runtime(std::io::Error),\n    #[error(\"failed to initialize HTTP client {:?}\", _0)]\n    Client(#[from] ClientError),\n    #[error(\"invalid participant settings {:?}\", _0)]\n    InvalidSettings(#[from] SettingsError),\n}\n\n#[derive(Error, Debug)]\n#[error(\"failed to fetch global model: {}\", self.0)]\npub struct GetGlobalModelError(xaynet_sdk::client::ClientError);\n\nimpl Participant {\n    /// Create a new participant with the given settings\n    pub fn new(settings: Settings) -> Result<Self, InitError> {\n        let (url, pet_settings) = settings.try_into()?;\n        let client = new_client(url.as_str(), None, None)?;\n        let (events, notifier) = Events::new();\n        let store = Store::new();\n        let state_machine =\n            StateMachine::new(pet_settings, client.clone(), store.clone(), notifier);\n        Self::init(state_machine, client, events, store)\n    }\n\n    /// Restore a participant from it's serialized state. The coordinator client that\n    /// the participant uses internally is not part of the participant state, so the\n    /// `url` is used to instantiate a new one.\n    pub fn restore(state: &[u8], url: &str) -> Result<Self, InitError> {\n        let state: SerializableState = bincode::deserialize(state)?;\n        let (events, notifier) = Events::new();\n        let store = Store::new();\n        let client = new_client(url, None, None)?;\n        let state_machine = StateMachine::restore(state, client.clone(), store.clone(), notifier);\n        Self::init(state_machine, client, events, store)\n    }\n\n    fn init(\n        state_machine: StateMachine,\n        client: Client<reqwest::Client>,\n        events: Events,\n        store: Store,\n    ) -> Result<Self, InitError> {\n        let mut participant = Self {\n            runtime: Self::runtime()?,\n            state_machine: Some(state_machine),\n            events,\n            store,\n            client,\n            task: Task::None,\n            made_progress: true,\n            should_set_model: false,\n            new_global_model: false,\n        };\n        participant.process_events();\n        Ok(participant)\n    }\n\n    fn runtime() -> Result<Runtime, InitError> {\n        tokio::runtime::Builder::new_current_thread()\n            .enable_all()\n            .build()\n            .map_err(InitError::Runtime)\n    }\n\n    /// Serialize the participant state and return the corresponding buffer.\n    pub fn save(self) -> Vec<u8> {\n        // UNWRAP_SAFE: the state machine is always set.\n        let state_machine = self.state_machine.unwrap().save();\n        bincode::serialize(&state_machine).unwrap()\n    }\n\n    /// Drive the participant internal state machine.\n    ///\n    /// After calling this method, the caller should check whether the participant state\n    /// changed, by calling [`Participant::made_progress()`].  If the state changed, the\n    /// caller should perform the following checks and react appropriately:\n    ///\n    /// - whether the participant is taking part to any task by calling\n    ///   [`Participant::task()`]\n    /// - whether the participant should load its model into the store by calling\n    ///   [`Participant::should_set_model()`]\n    pub fn tick(&mut self) {\n        // UNWRAP_SAFE: the state machine is always set.\n        let state_machine = self.state_machine.take().unwrap();\n        let outcome = self\n            .runtime\n            .block_on(async { state_machine.transition().await });\n        match outcome {\n            TransitionOutcome::Pending(new_state_machine) => {\n                self.made_progress = false;\n                self.state_machine = Some(new_state_machine);\n            }\n            TransitionOutcome::Complete(new_state_machine) => {\n                self.made_progress = true;\n                self.state_machine = Some(new_state_machine)\n            }\n        };\n        self.process_events();\n    }\n\n    fn process_events(&mut self) {\n        loop {\n            match self.events.next() {\n                Some(Event::Idle) => {\n                    self.task = Task::None;\n                }\n                Some(Event::Update) => {\n                    self.task = Task::Update;\n                }\n                Some(Event::Sum) => {\n                    self.task = Task::Sum;\n                }\n                Some(Event::NewRound) => {\n                    self.should_set_model = false;\n                    self.new_global_model = true;\n                }\n                Some(Event::LoadModel) => {\n                    self.should_set_model = true;\n                }\n                None => break,\n            }\n        }\n    }\n\n    /// Check whether the participant internal state machine made progress while\n    /// executing the PET protocol. If so, the participant state likely changed.\n    pub fn made_progress(&self) -> bool {\n        self.made_progress\n    }\n\n    /// Check whether the participant internal state machine is waiting for the\n    /// participant to load its model into the store. If this method returns `true`, the\n    /// caller should make sure to call [`Participant::set_model()`] at some point.\n    pub fn should_set_model(&self) -> bool {\n        self.should_set_model\n    }\n\n    /// Check whether a new global model is available. If this method returns `true`, the\n    /// caller can call [`Participant::global_model()`] to fetch the new global model.\n    pub fn new_global_model(&self) -> bool {\n        self.new_global_model\n    }\n\n    /// Return the participant current task\n    pub fn task(&self) -> Task {\n        self.task\n    }\n\n    /// Load the given model into the store, so that the participant internal state\n    /// machine can process it.\n    pub fn set_model(&mut self, model: Model) {\n        let Self {\n            ref mut runtime,\n            ref store,\n            ..\n        } = self;\n\n        runtime.block_on(async {\n            let mut stored_model = store.0.lock().await;\n            *stored_model = Some(model)\n        });\n        self.should_set_model = false;\n    }\n\n    /// Retrieve the current global model, if available.\n    pub fn global_model(&mut self) -> Result<Option<Model>, GetGlobalModelError> {\n        let Self {\n            ref mut runtime,\n            ref mut client,\n            ..\n        } = self;\n\n        let global_model =\n            runtime.block_on(async { client.get_model().await.map_err(GetGlobalModelError) });\n        if global_model.is_ok() {\n            self.new_global_model = false;\n        }\n        global_model\n    }\n\n    /// Return the local model configuration of the model that is expected in the\n    /// [`Participant::set_model`] method.\n    pub fn local_model_config(&self) -> LocalModelConfig {\n        // UNWRAP_SAFE: the state machine is always set.\n        let state_machine = self.state_machine.as_ref().unwrap();\n        state_machine.local_model_config()\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-mobile/src/reqwest_client.rs",
    "content": "use std::{fs::File, io::Read};\n\nuse thiserror::Error;\n\nuse xaynet_sdk::client::Client;\n\n/// Error returned upon failing to instantiate a new [`xaynet_sdk::client::Client`]\n#[derive(Debug, Error)]\npub enum ClientError {\n    #[error(\"invalid URL: {0}\")]\n    InvalidUrl(String),\n    #[error(\"failed to read trust anchor {0}: {1}\")]\n    TrustAnchor(String, String),\n    #[error(\"failed to read client certificate {0}: {1}\")]\n    ClientCert(String, String),\n    #[error(\"{0}\")]\n    Other(String),\n}\n\nimpl ClientError {\n    fn trust_anchor<E: std::error::Error>(path: String, e: E) -> Self {\n        Self::TrustAnchor(path, format!(\"{}\", e))\n    }\n\n    fn client_cert<E: std::error::Error>(path: String, e: E) -> Self {\n        Self::ClientCert(path, format!(\"{}\", e))\n    }\n\n    fn other<E: std::error::Error>(e: E) -> Self {\n        Self::Other(format!(\"{}\", e))\n    }\n}\n\n/// Build a new [`xaynet_sdk::client::Client`]\n///\n/// # Args\n///\n/// - `address`: URL of the Xaynet coordinator to connect to\n/// - `trust_anchor_path`: path the to root certificate for TLS server authentication. The\n///   certificate must be PEM encoded.\n/// - `client_cert_path`: path to the client certificate to use for TLS client authentication. The\n///   certificate must be PEM encoded.\npub fn new_client(\n    address: &str,\n    trust_anchor_path: Option<String>,\n    client_cert_path: Option<String>,\n) -> Result<Client<reqwest::Client>, ClientError> {\n    let builder = reqwest::ClientBuilder::new();\n\n    let builder = if let Some(path) = trust_anchor_path {\n        let mut buf = Vec::new();\n        File::open(&path)\n            .map_err(|e| ClientError::trust_anchor(path.clone(), e))?\n            .read_to_end(&mut buf)\n            .map_err(|e| ClientError::trust_anchor(path.clone(), e))?;\n        let root_cert =\n            reqwest::Certificate::from_pem(&buf).map_err(|e| ClientError::trust_anchor(path, e))?;\n        builder.use_rustls_tls().add_root_certificate(root_cert)\n    } else {\n        builder\n    };\n\n    let builder = if let Some(path) = client_cert_path {\n        let mut buf = Vec::new();\n        File::open(&path)\n            .map_err(|e| ClientError::client_cert(path.clone(), e))?\n            .read_to_end(&mut buf)\n            .map_err(|e| ClientError::client_cert(path.clone(), e))?;\n        let identity =\n            reqwest::Identity::from_pem(&buf).map_err(|e| ClientError::client_cert(path, e))?;\n        builder.use_rustls_tls().identity(identity)\n    } else {\n        builder\n    };\n\n    let reqwest_client = builder.build().map_err(ClientError::other)?;\n\n    let xaynet_client = Client::new(reqwest_client, address)\n        .map_err(|_| ClientError::InvalidUrl(address.to_string()))?;\n    Ok(xaynet_client)\n}\n"
  },
  {
    "path": "rust/xaynet-mobile/src/settings.rs",
    "content": "//! This module provides utilities to configure a [`Participant`].\n//!\n//! [`Participant`]: crate::Participant\n\nuse std::convert::TryInto;\nuse thiserror::Error;\nuse xaynet_core::{\n    crypto::SigningKeyPair,\n    mask::{FromPrimitive, PrimitiveCastError, Scalar},\n};\nuse xaynet_sdk::settings::{MaxMessageSize, PetSettings};\n\n/// A participant settings\n#[derive(Clone, Debug)]\npub struct Settings {\n    /// The Xaynet coordinator URL.\n    url: Option<String>,\n    /// The participant signing keys.\n    keys: Option<SigningKeyPair>,\n    /// The scalar used for masking.\n    scalar: Result<Scalar, PrimitiveCastError<f64>>,\n    /// The maximum possible size of a message.\n    max_message_size: MaxMessageSize,\n}\n\nimpl Default for Settings {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl Settings {\n    /// Create new empty settings.\n    pub fn new() -> Self {\n        Self {\n            url: None,\n            keys: None,\n            scalar: Ok(Scalar::unit()),\n            max_message_size: MaxMessageSize::default(),\n        }\n    }\n\n    /// Set the participant signing keys\n    pub fn set_keys(&mut self, keys: SigningKeyPair) {\n        self.keys = Some(keys);\n    }\n\n    /// Set the scalar to use for masking\n    pub fn set_scalar(&mut self, scalar: f64) {\n        self.scalar = Scalar::from_primitive(scalar)\n    }\n\n    /// Set the Xaynet coordinator address\n    pub fn set_url(&mut self, url: String) {\n        self.url = Some(url);\n    }\n\n    /// Sets the maximum possible size of a message.\n    pub fn set_max_message_size(&mut self, size: MaxMessageSize) {\n        self.max_message_size = size;\n    }\n\n    /// Check whether the settings are complete and valid\n    pub fn check(&self) -> Result<(), SettingsError> {\n        if self.url.is_none() {\n            Err(SettingsError::MissingUrl)\n        } else if self.keys.is_none() {\n            Err(SettingsError::MissingKeys)\n        } else if let Err(e) = &self.scalar {\n            Err(e.clone().into())\n        } else {\n            Ok(())\n        }\n    }\n}\n\n/// Error returned when the settings are invalid\n#[derive(Debug, Error)]\npub enum SettingsError {\n    #[error(\"the Xaynet coordinator URL must be specified\")]\n    MissingUrl,\n    #[error(\"the participant signing key pair must be specified\")]\n    MissingKeys,\n    #[error(\"float not within range of scalar: {0}\")]\n    OutOfScalarRange(#[from] PrimitiveCastError<f64>),\n}\n\nimpl TryInto<(String, PetSettings)> for Settings {\n    type Error = SettingsError;\n\n    fn try_into(self) -> Result<(String, PetSettings), Self::Error> {\n        let Settings {\n            keys,\n            url,\n            scalar,\n            max_message_size,\n        } = self;\n\n        let url = url.ok_or(SettingsError::MissingUrl)?;\n        let keys = keys.ok_or(SettingsError::MissingKeys)?;\n        let scalar = scalar.map_err(SettingsError::OutOfScalarRange)?;\n\n        let pet_settings = PetSettings {\n            keys,\n            scalar,\n            max_message_size,\n        };\n\n        Ok((url, pet_settings))\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-mobile/tests/ffi_test.c",
    "content": "#include <assert.h>\n#include <stdint.h>\n#include <stdio.h>\n#include <stdlib.h>\n#include <string.h>\n\n#include \"minunit.h\"\n#include \"xaynet_ffi.h\"\n\nstatic char *test_settings_new() {\n  Settings *settings = xaynet_ffi_settings_new();\n  xaynet_ffi_settings_destroy(settings);\n  return 0;\n}\n\nstatic char *test_settings_set_keys() {\n  mu_assert(\"failed to init crypto\", xaynet_ffi_crypto_init() == OK);\n  Settings *settings = xaynet_ffi_settings_new();\n  const KeyPair *keys = xaynet_ffi_generate_key_pair();\n  int err = xaynet_ffi_settings_set_keys(settings, keys);\n  mu_assert(\"failed to set keys\", !err);\n  xaynet_ffi_forget_key_pair(keys);\n\n  xaynet_ffi_settings_destroy(settings);\n  return 0;\n}\n\nstatic char *test_settings_set_url() {\n  Settings *settings = xaynet_ffi_settings_new();\n\n  int err = xaynet_ffi_settings_set_url(settings, NULL);\n  mu_assert(\"settings invalid URL should fail\", err == ERR_INVALID_URL);\n\n  char *url = \"http://localhost:1234\";\n  err = xaynet_ffi_settings_set_url(settings, url);\n  mu_assert(\"failed to set url\", !err);\n\n  char *url2 = strdup(url);\n  err = xaynet_ffi_settings_set_url(settings, url2);\n  mu_assert(\"failed to set url from allocated string\", !err);\n\n  // cleanup\n  free(url2);\n  xaynet_ffi_settings_destroy(settings);\n\n  return 0;\n}\n\nvoid with_keys(Settings *settings) {\n  const KeyPair *keys = xaynet_ffi_generate_key_pair();\n  int err = xaynet_ffi_settings_set_keys(settings, keys);\n  assert(!err);\n  xaynet_ffi_forget_key_pair(keys);\n}\n\nvoid with_url(Settings *settings) {\n  int err = xaynet_ffi_settings_set_url(settings, \"http://localhost:1234\");\n  assert(!err);\n}\n\nstatic char *test_settings() {\n  Settings *settings = xaynet_ffi_settings_new();\n  with_keys(settings);\n  int err = xaynet_ffi_check_settings(settings);\n  mu_assert(\"expected missing url error\", err == ERR_SETTINGS_URL);\n  xaynet_ffi_settings_destroy(settings);\n\n  settings = xaynet_ffi_settings_new();\n  with_url(settings);\n  err = xaynet_ffi_check_settings(settings);\n  mu_assert(\"expected missing keys error\", err == ERR_SETTINGS_KEYS);\n  xaynet_ffi_settings_destroy(settings);\n\n  return 0;\n}\n\nstatic char *test_global_model() {\n  Settings *settings = xaynet_ffi_settings_new();\n  with_keys(settings);\n  with_url(settings);\n  xaynet_ffi_settings_set_url(settings, \"http://localhost:8081\");\n\n  Participant *participant = xaynet_ffi_participant_new(settings);\n  mu_assert(\"failed to create participant\", participant != NULL);\n  LocalModelConfig *local_model_config = xaynet_ffi_participant_local_model_config(participant);\n  float* buffer = (float *)malloc(sizeof(float) * local_model_config->len);\n\n  int err = xaynet_ffi_participant_global_model(NULL, buffer, local_model_config->data_type, local_model_config->len);\n  mu_assert(\"expected participant is null error\", err == ERR_NULLPTR);\n\n  err = xaynet_ffi_participant_global_model(participant, NULL, local_model_config->data_type, local_model_config->len);\n  mu_assert(\"expected buffer is null error\", err == ERR_NULLPTR);\n\n  err = xaynet_ffi_participant_global_model(participant, buffer, local_model_config->data_type, local_model_config->len);\n  mu_assert(\"expected io error (cannot connect to coordinator)\", err == ERR_GLOBALMODEL_IO);\n\n  free(buffer);\n  xaynet_ffi_local_model_config_destroy(local_model_config);\n  xaynet_ffi_participant_destroy(participant);\n  xaynet_ffi_settings_destroy(settings);\n\n  return 0;\n}\n\nstatic char *test_participant_save_and_restore() {\n  Settings *settings = xaynet_ffi_settings_new();\n  with_keys(settings);\n  with_url(settings);\n\n  Participant *participant = xaynet_ffi_participant_new(settings);\n  mu_assert(\"failed to create participant\", participant != NULL);\n  xaynet_ffi_settings_destroy(settings);\n\n  // save the participant\n  const ByteBuffer *save_buf = xaynet_ffi_participant_save(participant);\n  mu_assert(\"failed to save participant\", save_buf != NULL);\n\n  // write the serialized participant to a file\n  char *path = \"./test_participant_save_and_restore.txt\";\n  FILE *f = fopen(path, \"w\");\n  fwrite(save_buf->data, 1, save_buf->len, f);\n  fclose(f);\n  int err = xaynet_ffi_byte_buffer_destroy(save_buf);\n  assert(!err);\n\n  // read the serialized participant from the file\n  f = fopen(path, \"r\");\n  fseek(f, 0L, SEEK_END);\n  int fsize = ftell(f);\n  fseek(f, 0L, SEEK_SET);\n  ByteBuffer restore_buf = {\n      .len = fsize,\n      .data = (uint8_t *)malloc(fsize),\n  };\n  int n_read = fread(restore_buf.data, 1, fsize, f);\n  mu_assert(\"failed to read serialized participant\", n_read == fsize);\n  fclose(f);\n\n  // restore the participant\n  Participant *restored =\n      xaynet_ffi_participant_restore(\"http://localhost:8081\", &restore_buf);\n  mu_assert(\"failed to restore participant\", restored != NULL);\n\n  // free memory\n  free(restore_buf.data);\n  xaynet_ffi_participant_destroy(restored);\n\n  return 0;\n}\n\nstatic char *test_participant_tick() {\n  Settings *settings = xaynet_ffi_settings_new();\n  with_keys(settings);\n  with_url(settings);\n\n  Participant *participant = xaynet_ffi_participant_new(settings);\n  mu_assert(\"failed to create participant\", participant != NULL);\n\n  int status = xaynet_ffi_participant_tick(participant);\n  mu_assert(\"missing no task flag\", (status & PARTICIPANT_TASK_NONE));\n  mu_assert(\"unexpected sum task flag\", !(status & PARTICIPANT_TASK_SUM));\n  mu_assert(\"unexpected update task flag\", !(status & PARTICIPANT_TASK_UPDATE));\n  mu_assert(\"unexpected set model flag\",\n            !(status & PARTICIPANT_SHOULD_SET_MODEL));\n  mu_assert(\"unexpected made progress flag\",\n            !(status & PARTICIPANT_MADE_PROGRESS));\n  // free memory\n  xaynet_ffi_settings_destroy(settings);\n  xaynet_ffi_participant_destroy(participant);\n\n  return 0;\n}\n\nstatic char *all_tests() {\n  mu_run_test(test_settings_new);\n  mu_run_test(test_settings_set_keys);\n  mu_run_test(test_settings_set_url);\n  mu_run_test(test_settings);\n  mu_run_test(test_global_model);\n  mu_run_test(test_participant_save_and_restore);\n  mu_run_test(test_participant_tick);\n  return 0;\n}\n\nint tests_run = 0;\n\nint main(int argc, char **argv) {\n  assert(xaynet_ffi_crypto_init() == OK);\n\n  char *result = all_tests();\n  if (result != 0) {\n    fprintf(stderr, RED \"ERROR: %s\\n\" RESET, result);\n  } else {\n    printf(GREEN \"ALL TESTS PASSED\\n\" RESET);\n  }\n  printf(\"Tests run: %d\\n\", tests_run);\n\n  return result != 0;\n}\n"
  },
  {
    "path": "rust/xaynet-mobile/tests/minunit.h",
    "content": "#define RESET   \"\\033[0m\"\n#define BLACK   \"\\033[30m\"      /* Black */\n#define RED     \"\\033[31m\"      /* Red */\n#define GREEN   \"\\033[32m\"      /* Green */\n#define mu_assert(message, test) \\\n    do                           \\\n    {                            \\\n        if (!(test))             \\\n            return message;      \\\n    } while (0)\n#define mu_run_test(test)       \\\n    do                          \\\n    {                           \\\n        char *message = test(); \\\n        tests_run++;            \\\n        if (message)            \\\n            return message;     \\\n    } while (0)\nextern int tests_run;\n"
  },
  {
    "path": "rust/xaynet-mobile/xaynet_ffi.h",
    "content": "/* Generated with cbindgen:0.17.0 */\n\n/* Warning, this file is autogenerated by cbindgen. Don't modify this manually. */\n\n#include <stdarg.h>\n#include <stdbool.h>\n#include <stdint.h>\n#include <stdlib.h>\n\n/**\n * Return value upon success\n */\n#define OK 0\n\n/**\n * NULL pointer argument\n */\n#define ERR_NULLPTR 1\n\n/**\n * Invalid coordinator URL\n */\n#define ERR_INVALID_URL 2\n\n/**\n * Invalid settings: coordinator URL is not set\n */\n#define ERR_SETTINGS_URL 3\n\n/**\n * Invalid settings: signing keys are not set\n */\n#define ERR_SETTINGS_KEYS 4\n\n/**\n * Invalid settings: scalar is out of bounds\n */\n#define ERR_SETTINGS_SCALAR 5\n\n/**\n * Failed to set the local model: invalid model\n */\n#define ERR_SETMODEL_MODEL 6\n\n/**\n * Failed to set the local model: invalid data type\n */\n#define ERR_SETMODEL_DATATYPE 7\n\n/**\n * Failed to initialized the crypto library\n */\n#define ERR_CRYPTO_INIT 8\n\n/**\n * Invalid secret signing key\n */\n#define ERR_CRYPTO_SECRET_KEY 9\n\n/**\n * Invalid public signing key\n */\n#define ERR_CRYPTO_PUBLIC_KEY 10\n\n/**\n * No global model is currently available\n */\n#define GLOBALMODEL_NONE 11\n\n/**\n * Failed to get the global model: communication with the coordinator failed\n */\n#define ERR_GLOBALMODEL_IO 12\n\n/**\n * Failed to get the global model: invalid data type\n */\n#define ERR_GLOBALMODEL_DATATYPE 13\n\n/**\n * Failed to get the global model: invalid buffer length\n */\n#define ERR_GLOBALMODEL_LEN 14\n\n/**\n * Failed to get the global model: invalid model\n */\n#define ERR_GLOBALMODEL_CONVERT 15\n\n/**\n * The participant is not taking part in the sum or update task\n */\n#define PARTICIPANT_TASK_NONE 1\n\n/**\n * The participant is not taking part in the sum task\n */\n#define PARTICIPANT_TASK_SUM (1 << 1)\n\n/**\n * The participant is not taking part in the update task\n */\n#define PARTICIPANT_TASK_UPDATE (1 << 2)\n\n/**\n * The participant is expected to set the model it trained\n */\n#define PARTICIPANT_SHOULD_SET_MODEL (1 << 3)\n\n/**\n * The participant is expected to set the model it trained\n */\n#define PARTICIPANT_MADE_PROGRESS (1 << 4)\n\n/**\n * A new global model is available\n */\n#define PARTICIPANT_NEW_GLOBALMODEL (1 << 5)\n\n/**\n * The original primitive data type of the numerical values to be masked.\n */\nenum ModelDataType {\n  /**\n   * Numbers of type f32.\n   */\n  MODEL_DATA_TYPE_F32 = 0,\n  /**\n   * Numbers of type f64.\n   */\n  MODEL_DATA_TYPE_F64 = 1,\n  /**\n   * Numbers of type i32.\n   */\n  MODEL_DATA_TYPE_I32 = 2,\n  /**\n   * Numbers of type i64.\n   */\n  MODEL_DATA_TYPE_I64 = 3,\n};\ntypedef uint8_t ModelDataType;\n\n/**\n * A signing key pair\n */\ntypedef struct KeyPair KeyPair;\n\n/**\n * A participant. It embeds an internal state machine that executes the PET\n * protocol. However, it is the caller's responsibility to drive this state machine by\n * calling [`Participant::tick()`], and to take action when the participant state\n * changes.\n */\ntypedef struct Participant Participant;\n\n/**\n * A participant settings\n */\ntypedef struct Settings Settings;\n\n/**\n * ByteBuffer is a struct that represents an array of bytes to be sent over the FFI boundaries.\n * There are several cases when you might want to use this, but the primary one for us\n * is for returning protobuf-encoded data to Swift and Java. The type is currently rather\n * limited (implementing almost no functionality), however in the future it may be\n * more expanded.\n *\n * ## Caveats\n *\n * Note that the order of the fields is `len` (an i64) then `data` (a `*mut u8`), getting\n * this wrong on the other side of the FFI will cause memory corruption and crashes.\n * `i64` is used for the length instead of `u64` and `usize` because JNA has interop\n * issues with both these types.\n *\n * ### `Drop` is not implemented\n *\n * ByteBuffer does not implement Drop. This is intentional. Memory passed into it will\n * be leaked if it is not explicitly destroyed by calling [`ByteBuffer::destroy`], or\n * [`ByteBuffer::destroy_into_vec`]. This is for two reasons:\n *\n * 1. In the future, we may allow it to be used for data that is not managed by\n *    the Rust allocator\\*, and `ByteBuffer` assuming it's okay to automatically\n *    deallocate this data with the Rust allocator.\n *\n * 2. Automatically running destructors in unsafe code is a\n *    [frequent footgun](https://without.boats/blog/two-memory-bugs-from-ringbahn/)\n *    (among many similar issues across many crates).\n *\n * Note that calling `destroy` manually is often not needed, as usually you should\n * be passing these to the function defined by [`define_bytebuffer_destructor!`] from\n * the other side of the FFI.\n *\n * Because this type is essentially *only* useful in unsafe or FFI code (and because\n * the most common usage pattern does not require manually managing the memory), it\n * does not implement `Drop`.\n *\n * \\* Note: in the case of multiple Rust shared libraries loaded at the same time,\n * there may be multiple instances of \"the Rust allocator\" (one per shared library),\n * in which case we're referring to whichever instance is active for the code using\n * the `ByteBuffer`. Note that this doesn't occur on all platforms or build\n * configurations, but treating allocators in different shared libraries as fully\n * independent is always safe.\n *\n * ## Layout/fields\n *\n * This struct's field are not `pub` (mostly so that we can soundly implement `Send`, but also so\n * that we can verify rust users are constructing them appropriately), the fields, their types, and\n * their order are *very much* a part of the public API of this type. Consumers on the other side\n * of the FFI will need to know its layout.\n *\n * If this were a C struct, it would look like\n *\n * ```c,no_run\n * struct ByteBuffer {\n *     // Note: This should never be negative, but values above\n *     // INT64_MAX / i64::MAX are not allowed.\n *     int64_t len;\n *     // Note: nullable!\n *     uint8_t *data;\n * };\n * ```\n *\n * In rust, there are two fields, in this order: `len: i64`, and `data: *mut u8`.\n *\n * For clarity, the fact that the data pointer is nullable means that `Option<ByteBuffer>` is not\n * the same size as ByteBuffer, and additionally is not FFI-safe (the latter point is not\n * currently guaranteed anyway as of the time of writing this comment).\n *\n * ### Description of fields\n *\n * `data` is a pointer to an array of `len` bytes. Note that data can be a null pointer and therefore\n * should be checked.\n *\n * The bytes array is allocated on the heap and must be freed on it as well. Critically, if there\n * are multiple rust shared libraries using being used in the same application, it *must be freed\n * on the same heap that allocated it*, or you will corrupt both heaps.\n *\n * Typically, this object is managed on the other side of the FFI (on the \"FFI consumer\"), which\n * means you must expose a function to release the resources of `data` which can be done easily\n * using the [`define_bytebuffer_destructor!`] macro provided by this crate.\n */\ntypedef struct ByteBuffer {\n  int64_t len;\n  uint8_t *data;\n} ByteBuffer;\n\n/**\n * `FfiStr<'a>` is a safe (`#[repr(transparent)]`) wrapper around a\n * nul-terminated `*const c_char` (e.g. a C string). Conceptually, it is\n * similar to [`std::ffi::CStr`], except that it may be used in the signatures\n * of extern \"C\" functions.\n *\n * Functions accepting strings should use this instead of accepting a C string\n * directly. This allows us to write those functions using safe code without\n * allowing safe Rust to cause memory unsafety.\n *\n * A single function for constructing these from Rust ([`FfiStr::from_raw`])\n * has been provided. Most of the time, this should not be necessary, and users\n * should accept `FfiStr` in the parameter list directly.\n *\n * ## Caveats\n *\n * An effort has been made to make this struct hard to misuse, however it is\n * still possible, if the `'static` lifetime is manually specified in the\n * struct. E.g.\n *\n * ```rust,no_run\n * # use ffi_support::FfiStr;\n * // NEVER DO THIS\n * #[no_mangle]\n * extern \"C\" fn never_do_this(s: FfiStr<'static>) {\n *     // save `s` somewhere, and access it after this\n *     // function returns.\n * }\n * ```\n *\n * Instead, one of the following patterns should be used:\n *\n * ```\n * # use ffi_support::FfiStr;\n * #[no_mangle]\n * extern \"C\" fn valid_use_1(s: FfiStr<'_>) {\n *     // Use of `s` after this function returns is impossible\n * }\n * // Alternative:\n * #[no_mangle]\n * extern \"C\" fn valid_use_2(s: FfiStr) {\n *     // Use of `s` after this function returns is impossible\n * }\n * ```\n */\ntypedef const char *FfiStr;\n\n/**\n * The model configuration of the model that is expected in [`xaynet_ffi_participant_set_model()`].\n *\n * [`xaynet_ffi_participant_set_model()`]: crate::ffi::xaynet_ffi_participant_set_model\n */\ntypedef struct LocalModelConfig {\n  /**\n   * The expected data type of the model.\n   */\n  ModelDataType data_type;\n  /**\n   * the expected length of the model.\n   */\n  uint64_t len;\n} LocalModelConfig;\n\n/**\n * Destroy the given `ByteBuffer` and free its memory. This function must only be\n * called on `ByteBuffer`s that have been created on the Rust side of the FFI. If you\n * have created a `ByteBuffer` on the other side of the FFI, do not use this function,\n * use `free()` instead.\n *\n * # Return value\n *\n * - [`OK`] on success\n * - [`ERR_NULLPTR`] if `buf` is NULL\n *\n * # Safety\n *\n * 1. When calling this method, you have to ensure that *either* the pointer is NULL\n * *or* all of the following is true:\n *  - The pointer must be properly [aligned].\n *  - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n *    documentation.\n * 2. After destroying the `ByteBuffer` the pointer becomes invalid and must not be\n *    used.\n * 3. Calling this function on a `ByteBuffer` that has not been created on the Rust\n *    side of the FFI is UB.\n *\n * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n */\nint xaynet_ffi_byte_buffer_destroy(const struct ByteBuffer *buf);\n\n/**\n * Initialize the crypto library. This method must be called before instantiating a\n * participant with [`xaynet_ffi_participant_new()`] or before generating new keys with\n * [`xaynet_ffi_generate_key_pair()`].\n *\n * # Return value\n *\n * - [`OK`] if the initialization succeeded\n * - -[`ERR_CRYPTO_INIT`] if the initialization failed\n *\n * # Safety\n *\n * This function is safe to call\n */\nint xaynet_ffi_crypto_init(void);\n\n/**\n * Destroy the participant created by [`xaynet_ffi_participant_new()`] or\n * [`xaynet_ffi_participant_restore()`].\n *\n * # Return value\n *\n * - [`OK`] on success\n * - [`ERR_NULLPTR`] if `participant` is NULL\n *\n * # Safety\n *\n * 1. When calling this method, you have to ensure that *either* the pointer is NULL\n *    *or* all of the following is true:\n *    - The pointer must be properly [aligned].\n *    - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n *      documentation.\n * 2. After destroying the `Participant`, the pointer becomes invalid and must not be\n *    used.\n * 3. This function should only be called on a pointer that has been created by\n *    [`xaynet_ffi_participant_new()`] or [`xaynet_ffi_participant_restore()`]\n *\n * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n */\nint xaynet_ffi_participant_destroy(struct Participant *participant);\n\n/**\n * Instantiate a new participant with the given settings. The participant must be\n * destroyed with [`xaynet_ffi_participant_destroy`].\n *\n * # Return value\n *\n * - a NULL pointer if `settings` is NULL or if the participant creation failed\n * - a valid pointer to a [`Participant`] otherwise\n *\n * # Safety\n *\n * When calling this method, you have to ensure that *either* the pointer is NULL *or*\n * all of the following is true:\n *\n * - The pointer must be properly [aligned].\n * - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n *   documentation.\n *\n * After destroying the participant with [`xaynet_ffi_participant_destroy`] becomes\n * invalid and must not be used.\n *\n * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n */\nstruct Participant *xaynet_ffi_participant_new(const struct Settings *settings);\n\n/**\n * Drive the participant internal state machine. Every tick, the state machine\n * attempts to perform a small work unit.\n *\n * # Return value\n *\n * - [`ERR_NULLPTR`] is `participant` is NULL\n * - a bitflag otherwise, with the following flags:\n *   - [`PARTICIPANT_MADE_PROGRESS`]: if set, this flag indicates that the participant\n *     internal state machine was able to make some progress, and that the participant\n *     state changed. This information can be used as an indication for saving the\n *     participant state for instance. If the flag is not set, the state machine was\n *     not able to make progress. There are many potential causes for this, including:\n *       - the participant is not taking part to the current training round and is just\n *         waiting for a new one to start\n *       - the Xaynet coordinator is not reachable or has not published some\n *         information the participant is waiting for\n *       - the state machine is waiting for the model to be set (see\n *         [`xaynet_ffi_participant_set_model()`])\n *   - [`PARTICIPANT_TASK_NONE`], [`PARTICIPANT_TASK_SUM`] and\n *     [`PARTICIPANT_TASK_UPDATE`]: these flags are mutually exclusive, and indicate\n *     which task the participant has been selected for, for the current round. If\n *     [`PARTICIPANT_TASK_NONE`] is set, then the participant will just wait for a new\n *     round to start. If [`PARTICIPANT_TASK_UPDATE`] is set, then the participant has\n *     been selected to update the global model, and should prepare to provide a new\n *     model once the [`PARTICIPANT_SHOULD_SET_MODEL`] flag is set.\n *   - [`PARTICIPANT_SHOULD_SET_MODEL`]: if set, then the participant should set its\n *     model, by calling [`xaynet_ffi_participant_set_model()`]\n *   - [`PARTICIPANT_NEW_GLOBALMODEL`]: if set, the participant can fetch the new global\n *     model, by calling [`xaynet_ffi_participant_global_model()`]\n *\n * # Safety\n *\n * When calling this method, you have to ensure that *either* the pointer is NULL *or*\n * all of the following is true:\n *\n * - The pointer must be properly [aligned].\n * - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n *   documentation.\n *\n * After destroying the participant with [`xaynet_ffi_participant_destroy`] becomes\n * invalid and must not be used.\n *\n * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n */\nint xaynet_ffi_participant_tick(struct Participant *participant);\n\n/**\n * Serialize the participant state and return a buffer that contains the serialized\n * participant.\n *\n * # Safety\n *\n * 1. When calling this method, you have to ensure that *either* the pointer is NULL\n *    *or* all of the following is true:\n *    - The pointer must be properly [aligned].\n *    - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n *      documentation.\n * 2. the `ByteBuffer` created by this function must be destroyed with\n *    [`xaynet_ffi_participant_destroy`]. Attempting to free the memory from the other\n *    side of the FFI is UB.\n * 3. This function destroys the participant. Therefore, **the pointer becomes invalid\n *    and must not be used anymore**. Instead, a new participant should be created,\n *    either with [`xaynet_ffi_participant_new()`] or\n *    [`xaynet_ffi_participant_restore()`]\n *\n * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n *\n * # Example\n *\n * To save the participant into a file:\n *\n * ```c\n *  const ByteBuffer *save_buf = xaynet_ffi_participant_save(participant);\n *  assert(save_buf);\n *\n *  char *path = \"./participant.bin\";\n *  FILE *f = fopen(path, \"w\");\n *  fwrite(save_buf->data, 1, save_buf->len, f);\n *  fclose(f);\n * ```\n */\nconst struct ByteBuffer *xaynet_ffi_participant_save(struct Participant *participant);\n\n/**\n * Restore the participant from a buffer that contained its serialized state.\n *\n * # Return value\n *\n * - a NULL pointer on failure\n * - a pointer to the restored participant on success\n *\n * # Safety\n *\n * When calling this method, you have to ensure that *either* the pointers are NULL\n * *or* all of the following is true:\n * - The pointers must be properly [aligned].\n * - They must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n *   documentation.\n *\n * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n *\n * # Example\n *\n * To restore a participant from a file:\n *\n * ```c\n * f = fopen(\"./participant.bin\", \"r\");\n * fseek(f, 0L, SEEK_END);\n * int fsize = ftell(f);\n * fseek(f, 0L, SEEK_SET);\n * ByteBuffer buf = {\n *     .len = fsize,\n *     .data = (uint8_t *)malloc(fsize),\n * };\n * int n_read = fread(buf.data, 1, fsize, f);\n * assert(n_read == fsize);\n * fclose(f);\n * Participant *restored =\n *     xaynet_ffi_participant_restore(\"http://localhost:8081\", &buf);\n * free(buf.data);\n * ```\n */\nstruct Participant *xaynet_ffi_participant_restore(FfiStr url, const struct ByteBuffer *buffer);\n\n/**\n * Set the participant's model. Usually this should be called when the value returned\n * by [`xaynet_ffi_participant_tick()`] contains the [`PARTICIPANT_SHOULD_SET_MODEL`]\n * flag, but it can be called anytime. The model just won't be sent to the coordinator\n * until it's time.\n *\n * - `buffer` should be a pointer to a buffer that contains the model\n * - `data_type` specifies the type of the model weights (see [`DataType`]). The C header\n *   file generated by this crate provides an enum corresponding to the parameters: `DataType`.\n * - `len` is the number of weights the model has\n *\n * # Return value\n *\n * - [`OK`] if the model is set successfully\n * - [`ERR_NULLPTR`] if `participant` is NULL\n * - [`ERR_SETMODEL_DATATYPE`] if the datatype is invalid\n * - [`ERR_SETMODEL_MODEL`] if the model is invalid\n *\n * # Safety\n *\n * 1. When calling this method, you have to ensure that *either* the pointer is NULL\n *    *or* all of the following is true:\n *    - The pointer must be properly [aligned].\n *    - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n *      documentation.\n * 2. If `len` or `data_type` do not match the model in `buffer`, this method will\n *    result in a buffer over-read.\n *\n * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n */\nint xaynet_ffi_participant_set_model(struct Participant *participant,\n                                     const void *buffer,\n                                     unsigned char data_type,\n                                     unsigned int len);\n\n/**\n * Return the latest global model from the coordinator.\n *\n * - `buffer` is the array in which the global model should be copied.\n * - `data_type` specifies the type of the model weights (see [`DataType`]). The C header\n *   file generated by this crate provides an enum corresponding to the parameters: `DataType`.\n * - `len` is the number of weights the model has\n *\n * # Return Value\n *\n * - [`OK`] if the model is set successfully\n * - [`ERR_NULLPTR`] if `participant` or the `buffer` is NULL\n * - [`GLOBALMODEL_NONE`] if no model exists\n * - [`ERR_GLOBALMODEL_IO`] if the communication with the coordinator failed\n * - [`ERR_GLOBALMODEL_DATATYPE`] if the datatype is invalid\n * - [`ERR_GLOBALMODEL_LEN`] if the length of the buffer does not match the length of the model\n * - [`ERR_GLOBALMODEL_CONVERT`] if the conversion of the model failed\n *\n * # Note\n *\n *   It is **not** guaranteed, that the model configuration returned by\n *   [`xaynet_ffi_participant_local_model_config`] corresponds to the configuration of\n *   the global model. This means that the global model can have a different length / data type\n *   than it is defined in model configuration. That both model configurations are the same is\n *   only guaranteed if the model config **never** changes on the coordinator side.\n *\n * # Safety\n *\n * 1. When calling this method, you have to ensure that *either* the pointer is NULL\n *    *or* all of the following is true:\n *    - The pointer must be properly [aligned].\n *    - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n *      documentation.\n * 2. If `len` or `data_type` do not match the model in `buffer`, this method will\n *    result in a buffer over-read.\n *\n * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n */\nint xaynet_ffi_participant_global_model(struct Participant *participant,\n                                        void *buffer,\n                                        unsigned char data_type,\n                                        unsigned int len);\n\n/**\n * Return the local model configuration of the model that is expected in the\n * [`xaynet_ffi_participant_set_model()`] function.\n *\n * # Safety\n *\n * 1. When calling this method, you have to ensure that *either* the pointer is NULL\n *    *or* all of the following is true:\n *    - The pointer must be properly [aligned].\n *    - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n *      documentation.\n *\n * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n */\nstruct LocalModelConfig *xaynet_ffi_participant_local_model_config(const struct Participant *participant);\n\n/**\n * Destroy the settings created by [`xaynet_ffi_settings_new()`].\n *\n * # Return value\n *\n * - [`OK`] on success\n * - [`ERR_NULLPTR`] if `buf` is NULL\n *\n * # Safety\n *\n * 1. When calling this method, you have to ensure that *either* the pointer is NULL\n *    *or* all of the following is true:\n *    - The pointer must be properly [aligned].\n *    - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n *      documentation.\n * 2. After destroying the `Settings`, the pointer becomes invalid and must not be\n *    used.\n * 3. This function should only be called on a pointer that has been created by\n *    [`xaynet_ffi_settings_new`].\n *\n * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n */\nint xaynet_ffi_settings_destroy(struct Settings *settings);\n\n/**\n * Create new [`Settings`] and return a pointer to it.\n *\n * # Safety\n *\n * The `Settings` created by this function must be destroyed with\n * [`xaynet_ffi_settings_destroy()`]. Attempting to free the memory from the other side\n * of the FFI is UB.\n */\nstruct Settings *xaynet_ffi_settings_new(void);\n\n/**\n * Set scalar setting.\n *\n * # Return value\n *\n * - [`OK`] if successful\n * - [`ERR_NULLPTR`] if `settings` is `NULL`\n *\n * # Safety\n *\n * When calling this method, you have to ensure that *either* the pointer is NULL *or*\n * all of the following is true:\n * - The pointer must be properly [aligned].\n * - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n *   documentation.\n *\n * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n */\nint xaynet_ffi_settings_set_scalar(struct Settings *settings, double scalar);\n\n/**\n * Set coordinator URL.\n *\n * # Return value\n *\n * - [`OK`] if successful\n * - [`ERR_INVALID_URL`] if `url` is not a valid string\n * - [`ERR_NULLPTR`] if `settings` is `NULL`\n *\n * # Safety\n *\n * When calling this method, you have to ensure that *either* the pointers are NULL\n * *or* all of the following is true:\n * - The pointers must be properly [aligned].\n * - They must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n *   documentation.\n *\n * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n */\nint xaynet_ffi_settings_set_url(struct Settings *settings, FfiStr url);\n\n/**\n * Generate a new signing key pair that can be used in the [`Settings`]. **Before\n * calling this function you must initialize the crypto library with\n * [`xaynet_ffi_crypto_init()`]**.\n *\n * The returned value contains a pointer to the secret key. For security reasons, you\n * must make sure that this buffer life is a short as possible, and call\n * [`xaynet_ffi_forget_key_pair`] to destroy it.\n *\n * [`xaynet_ffi_crypto_init()`]: crate::ffi::xaynet_ffi_crypto_init\n *\n * # Safety\n *\n * This function is safe to call\n */\nconst struct KeyPair *xaynet_ffi_generate_key_pair(void);\n\n/**\n * De-allocate the buffers that contain the signing keys, and zero out the content of\n * the buffer that contains the secret key.\n *\n * # Return value\n *\n * - [`ERR_NULLPTR`] is `key_pair` is NULL\n * - [`OK`] otherwise\n *\n * # Safety\n *\n * When calling this method, you have to ensure that *either* the pointer is NULL *or*\n * all of the following is true:\n * - The pointer must be properly [aligned].\n * - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n *   documentation.\n *\n * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n */\nint xaynet_ffi_forget_key_pair(const struct KeyPair *key_pair);\n\n/**\n * Set participant signing keys.\n *\n * # Return value\n *\n * - [`OK`] if successful\n * - [`ERR_NULLPTR`] if `settings` or `key_pair` is `NULL`\n * - [`ERR_CRYPTO_PUBLIC_KEY`] if the given `key_pair` contains an invalid public key\n * - [`ERR_CRYPTO_SECRET_KEY`] if the given `key_pair` contains an invalid secret key\n *\n * # Safety\n *\n * When calling this method, you have to ensure that *either* the pointers are NULL\n * *or* all of the following is true:\n * - The pointers must be properly [aligned].\n * - They must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n *   documentation.\n *\n * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n */\nint xaynet_ffi_settings_set_keys(struct Settings *settings, const struct KeyPair *key_pair);\n\n/**\n * Check whether the given settings are valid and can be used to instantiate a\n * participant (see [`xaynet_ffi_participant_new()`]).\n *\n * # Return value\n *\n * - [`OK`] on success\n * - [`ERR_SETTINGS_URL`] if the URL has not been set\n * - [`ERR_SETTINGS_KEYS`] if the signing keys have not been set\n * - [`ERR_SETTINGS_SCALAR`] if the scalar is out of bounds\n *\n * # Safety\n *\n * When calling this method, you have to ensure that *either* the pointer is NULL *or*\n * all of the following is true:\n *\n * - The pointer must be properly [aligned].\n * - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n *   documentation.\n *\n * [`xaynet_ffi_participant_new()`]: crate::ffi::xaynet_ffi_participant_new\n * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n */\nint xaynet_ffi_check_settings(const struct Settings *settings);\n\n/**\n * Destroy the model configuration created by [`xaynet_ffi_participant_local_model_config()`].\n *\n * # Return value\n *\n * - [`OK`] on success\n * - [`ERR_NULLPTR`] if `local_model_config` is NULL\n *\n * # Safety\n *\n * 1. When calling this method, you have to ensure that *either* the pointer is NULL\n *    *or* all of the following is true:\n *    - The pointer must be properly [aligned].\n *    - It must be \"dereferencable\" in the sense defined in the [`std::ptr`] module\n *      documentation.\n * 2. After destroying the `LocalModelConfig`, the pointer becomes invalid and must not be\n *    used.\n * 3. This function should only be called on a pointer that has been created by\n *    [`xaynet_ffi_participant_local_model_config()`].\n *\n * [`std::ptr`]: https://doc.rust-lang.org/std/ptr/index.html#safety\n * [aligned]: https://doc.rust-lang.org/std/ptr/index.html#alignment\n * [`xaynet_ffi_participant_local_model_config()`]: crate::ffi::xaynet_ffi_participant_local_model_config\n */\nint xaynet_ffi_local_model_config_destroy(struct LocalModelConfig *local_model_config);\n"
  },
  {
    "path": "rust/xaynet-sdk/Cargo.toml",
    "content": "[package]\nname = \"xaynet-sdk\"\nversion = \"0.1.0\"\nauthors = [\"Xayn Engineering <engineering@xaynet.dev>\"]\nedition = \"2018\"\ndescription = \"The Xayn Network project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant.\"\nreadme = \"../../README.md\"\nhomepage = \"https://xaynet.dev/\"\nrepository = \"https://github.com/xaynetwork/xaynet/\"\nlicense-file = \"../../LICENSE\"\nkeywords = [\"federated-learning\", \"fl\", \"ai\", \"machine-learning\"]\ncategories = [\"science\", \"cryptography\"]\n\n[package.metadata.docs.rs]\nall-features = true\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n\n[dependencies]\nasync-trait = \"0.1.57\"\nbase64 = \"0.13.0\"\nbincode = \"1.3.3\"\nderive_more = { version = \"0.99.17\", default-features = false, features = [\"from\"] }\n# TODO: remove once concurrent_futures.rs was moved to the e2e package\nfutures = \"0.3.24\"\npaste = \"1.0.8\"\nserde = { version = \"1.0.144\", features = [\"derive\"] }\nsodiumoxide = \"0.2.7\"\nthiserror = \"1.0.32\"\n# TODO: move to dev-dependencies once concurrent_futures.rs was moved to the e2e package\ntokio = { version = \"1.20.1\", features = [\"rt\", \"macros\"] }\ntracing = \"0.1.36\"\nurl = \"2.2.2\"\nxaynet-core = { path = \"../xaynet-core\", version = \"0.2.0\" }\n\n# feature: reqwest client\nreqwest = { version = \"0.11.10\", default-features = false, optional = true }\n# This has to match the version used by reqwest. It would be nice if\n# reqwest just re-exported it\nbytes = { version = \"1.0.1\", optional = true }\nrand = \"0.8.5\"\n\n[dev-dependencies]\nmockall = \"0.11.2\"\nnum = { version = \"0.4.0\", features = [\"serde\"] }\nserde_json = \"1.0.85\"\ntokio-test = \"0.4.1\"\nxaynet-core = { path = \"../xaynet-core\", features = [\"testutils\"] }\n\n[features]\ndefault = []\nreqwest-client = [\"reqwest\", \"bytes\"]\n"
  },
  {
    "path": "rust/xaynet-sdk/src/client.rs",
    "content": "use async_trait::async_trait;\nuse thiserror::Error;\nuse url::Url;\n\nuse crate::XaynetClient;\nuse xaynet_core::{\n    common::RoundParameters,\n    crypto::{ByteObject, PublicSigningKey},\n    mask::Model,\n    SumDict,\n    UpdateSeedDict,\n};\n\n/// Error returned upon failing to build a new [`Client`]\n#[derive(Debug, Error)]\npub enum ClientError {\n    #[error(\"failed to deserialize data: {0}\")]\n    Deserialize(String),\n\n    #[error(\"HTTP request failed: {0}\")]\n    Http(String),\n\n    #[error(\"{0}\")]\n    Other(String),\n\n    #[error(\"Reading from file failed: {0}\")]\n    Io(#[from] std::io::Error),\n\n    #[error(\"Unexpected response\")]\n    UnexpectedResponse(u16),\n\n    #[error(\"Unexpected certificate extension\")]\n    UnexpectedCertificate,\n\n    #[error(\"No certificate found\")]\n    NoCertificate,\n}\n\n#[cfg_attr(not(feature = \"reqwest-client\"), allow(dead_code))]\nimpl ClientError {\n    fn http_error<E: std::error::Error>(e: E) -> Self {\n        Self::Http(format!(\"{}\", e))\n    }\n}\n\nimpl From<bincode::Error> for ClientError {\n    fn from(e: bincode::Error) -> Self {\n        Self::Deserialize(format!(\"{}\", e))\n    }\n}\n\nimpl From<std::num::ParseIntError> for ClientError {\n    fn from(e: std::num::ParseIntError) -> Self {\n        Self::Deserialize(format!(\"{}\", e))\n    }\n}\n\n/// A basic HTTP interface that [`Client`] HTTP backends must implement.\n#[async_trait]\npub trait XaynetHttpClient {\n    /// Error type for all the trait's methods\n    type Error: std::error::Error;\n    /// Reponse type for `GET` requests\n    type GetResponse: AsRef<[u8]>;\n\n    /// Perform an HTTP `GET` on the given URL.\n    ///\n    /// If the response is `NO_CONTENT`, the implementor must return `Ok(None)`. Otherwise, the\n    /// response body must be returned\n    async fn get(&mut self, url: &str) -> Result<Option<Self::GetResponse>, ClientError>;\n\n    /// Perform an HTTP `POST` on the given URL, with the given body.\n    async fn post(&mut self, url: &str, body: Vec<u8>) -> Result<(), ClientError>;\n}\n\n#[derive(Debug, Clone)]\n/// A client that communicates with the coordinator's API via HTTP(S).\npub struct Client<C> {\n    /// HTTP(S) client\n    client: C,\n    /// Coordinator URL\n    base_url: Url,\n}\n\n/// Error returned when trying to client a [`Client`] with an invalid\n/// address for the Xaynet coordinator.\n#[derive(Debug, Error)]\n#[error(\"Invalid base URL: {}\", .0)]\npub struct InvalidBaseUrl(String);\n\nimpl<C> Client<C>\nwhere\n    C: XaynetHttpClient,\n{\n    /// Create a new client.\n    ///\n    /// # Args\n    ///\n    /// - `client` is the HTTP client that will be used to perform the HTTP requests. Any HTTP\n    ///   client can be used, as long as it implements the [`XaynetHttpClient`] trait.\n    /// - `base_url` is the URL to the Xaynet coordinator\n    ///\n    /// # Errors\n    ///\n    /// An error is returned if `base_url` is not a valid URL\n    pub fn new(http_client: C, base_url: &str) -> Result<Self, InvalidBaseUrl> {\n        let base_url = Url::parse(base_url).map_err(|e| InvalidBaseUrl(format!(\"{}\", e)))?;\n        if base_url.cannot_be_a_base() {\n            return Err(InvalidBaseUrl(String::from(\"cannot be a base URL\")));\n        }\n        Ok(Self {\n            client: http_client,\n            base_url,\n        })\n    }\n\n    /// Append the given segment to the client base URL\n    fn url(&self, segment: &str) -> Url {\n        let mut url = self.base_url.clone();\n        url.path_segments_mut().unwrap().push(segment);\n        url\n    }\n\n    async fn get<T>(&mut self, url: &Url) -> Result<Option<T>, ClientError>\n    where\n        T: for<'a> serde::Deserialize<'a>,\n    {\n        Ok(match self.client.get(url.as_str()).await? {\n            Some(data) => Some(bincode::deserialize::<T>(data.as_ref())?),\n            None => None,\n        })\n    }\n\n    async fn post(&mut self, url: &Url, data: Vec<u8>) -> Result<(), ClientError> {\n        self.client.post(url.as_str(), data).await\n    }\n}\n\n#[async_trait]\nimpl<C> XaynetClient for Client<C>\nwhere\n    C: XaynetHttpClient + Send,\n{\n    type Error = ClientError;\n\n    async fn get_round_params(&mut self) -> Result<RoundParameters, Self::Error> {\n        let url = self.url(\"params\");\n        let round_params: Option<RoundParameters> = self.get(&url).await?;\n        round_params.ok_or_else(|| {\n            ClientError::Other(\"failed to fetch round parameters: empty response\".to_string())\n        })\n    }\n\n    async fn get_sums(&mut self) -> Result<Option<SumDict>, Self::Error> {\n        let url = self.url(\"sums\");\n        Ok(self.get(&url).await?)\n    }\n\n    async fn get_seeds(\n        &mut self,\n        pk: PublicSigningKey,\n    ) -> Result<Option<UpdateSeedDict>, Self::Error> {\n        let mut url = self.url(\"seeds\");\n        url.query_pairs_mut()\n            .append_pair(\"pk\", &base64::encode(pk.as_slice()));\n        self.get(&url).await\n    }\n\n    async fn get_model(&mut self) -> Result<Option<Model>, Self::Error> {\n        let url = self.url(\"model\");\n        Ok(self.get(&url).await?)\n    }\n\n    async fn send_message(&mut self, msg: Vec<u8>) -> Result<(), Self::Error> {\n        let url = self.url(\"message\");\n        self.post(&url, msg).await\n    }\n}\n\n#[cfg(feature = \"reqwest-client\")]\n#[cfg_attr(docsrs, doc(cfg(feature = \"reqwest-client\")))]\n#[async_trait]\nimpl XaynetHttpClient for reqwest::Client {\n    type Error = reqwest::Error;\n    type GetResponse = bytes::Bytes;\n\n    async fn get(&mut self, url: &str) -> Result<Option<Self::GetResponse>, ClientError> {\n        let resp = reqwest::Client::get(self, url)\n            .send()\n            .await\n            .map_err(ClientError::http_error)?\n            .error_for_status()\n            .map_err(ClientError::http_error)?;\n        match resp.status() {\n            reqwest::StatusCode::OK => {\n                Ok(Some(resp.bytes().await.map_err(ClientError::http_error)?))\n            }\n            reqwest::StatusCode::NO_CONTENT => Ok(None),\n            status => Err(ClientError::UnexpectedResponse(status.as_u16())),\n        }\n    }\n\n    async fn post(&mut self, url: &str, body: Vec<u8>) -> Result<(), ClientError> {\n        let _resp = reqwest::Client::post(self, url)\n            .body(body)\n            .send()\n            .await\n            .map_err(ClientError::http_error)?\n            .error_for_status()\n            .map_err(ClientError::http_error)?;\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-sdk/src/lib.rs",
    "content": "#![cfg_attr(docsrs, feature(doc_cfg))]\n#![cfg_attr(\n    doc,\n    forbid(rustdoc::broken_intra_doc_links, rustdoc::private_intra_doc_links)\n)]\n#![doc(\n    html_logo_url = \"https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/xaynet_banner.png\",\n    html_favicon_url = \"https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/favicon.png\",\n    issue_tracker_base_url = \"https://github.com/xaynetwork/xaynet/issues\"\n)]\n//! This crate provides building blocks for implementing participants for\n//! the [Xaynet Federated Learning platform](https://www.xaynet.dev/).\n//!\n//! The PET protocol states that in any given round of federated learning,\n//! each participant of the protocol may be selected to carry out one of\n//! two tasks:\n//!\n//! - **update**: participants selected for the update task\n//!   (a.k.a. _update participants_) are responsible for sending a machine\n//!   learning model they trained\n//! - **sum**: participants selected for the sum task (a.k.a. _sum\n//!   participants_) are responsible for computing a global mask from local mask seeds sent by\n//!   the update participants\n//!\n//! Participants may also not be selected for any of these tasks, in which\n//! case they simply wait for the next round.\n//!\n//! # Running a participant\n//!\n//! The communication with the Xaynet coordinator is managed by a\n//! background task that runs the PET protocol. We call it the PET\n//! agent. In practice, the agent is a simple wrapper around the\n//! [`StateMachine`].\n//!\n//! To run a participant, you need to start an agent, and\n//! interact with it. There are two types of interactions:\n//!\n//! - reacting to notifications for the agents, which include:\n//!   - start of a new round of training\n//!   - selection for the sum task\n//!   - selection for the update task\n//!   - end of a task\n//! - providing the agent with a Machine Learning model and a corresponding\n//!   scalar for aggregation when the participant takes part the update task\n//!\n//! ## Implementing an agent\n//!\n//! A simple agent can be implemented as a function.\n//!\n//! ```\n//! use std::time::Duration;\n//!\n//! use tokio::time::sleep;\n//! use xaynet_sdk::{StateMachine, TransitionOutcome};\n//!\n//! async fn run_agent(mut state_machine: StateMachine, tick: Duration) {\n//!     loop {\n//!         state_machine = match state_machine.transition().await {\n//!             // The state machine is stuck waiting for some data,\n//!             // either from the coordinator or from the\n//!             // participant. Let's wait a little and try again\n//!             TransitionOutcome::Pending(state_machine) => {\n//!                 sleep(tick).await;\n//!                 state_machine\n//!             }\n//!             // The state machine moved forward in the PET protocol.\n//!             // We simply continue looping, trying to make more progress.\n//!             TransitionOutcome::Complete(state_machine) => state_machine,\n//!         };\n//!     }\n//! }\n//! ```\n//!\n//! This agent needs to be fed a [`StateMachine`] in order to run. A\n//! state machine requires found components:\n//!\n//! - a cryptographic key identifying the participant, see [`PetSettings`]\n//! - a store from which it can load a model when the participant is\n//!   selected for the update task. This can be any type that\n//!   implements the [`ModelStore`] trait. In our case, we'll use a\n//!   dummy in-memory store that always returns the same model.\n//! - a client to talk with the Xaynet coordinator. This can be any\n//!   type that implements the [`XaynetClient`] trait, like the [`Client`].\n//!   For this we're going to use the trait implementations on the `reqwest`\n//!   client that is available when compiling with `--features reqwest-client`.\n//! - a notifier that the state machine can use to send\n//!   notifications. This can be any type that implements the\n//!   [`Notify`] trait. We'll use channels for this.\n//!\n//! [`PetSettings`]: crate::settings::PetSettings\n//! [`Client`]: crate::client::Client\n//!\n//! Finally we can start our agent and log the events it emits. Here\n//! is the full code:\n//!\n//! ```no_run\n//! # #[cfg(all(feature = \"reqwest-client\", feature = \"tokio/rt-muli-thread\"))]\n//! # mod feature_reqwest_client {\n//! use std::{\n//!     sync::{mpsc, Arc},\n//!     time::Duration,\n//! };\n//!\n//! use async_trait::async_trait;\n//! use reqwest::Client as ReqwestClient;\n//! use tokio::time::sleep;\n//!\n//! use xaynet_core::{\n//!     crypto::SigningKeyPair,\n//!     mask::{BoundType, DataType, FromPrimitives, GroupType, MaskConfig, Model, ModelType},\n//! };\n//! use xaynet_sdk::{\n//!     client::Client,\n//!     settings::PetSettings,\n//!     ModelStore,\n//!     Notify,\n//!     StateMachine,\n//!     TransitionOutcome,\n//! };\n//!\n//! async fn run_agent(mut state_machine: StateMachine, tick: Duration) {\n//!     loop {\n//!         state_machine = match state_machine.transition().await {\n//!             TransitionOutcome::Pending(state_machine) => {\n//!                 sleep(tick.clone()).await;\n//!                 state_machine\n//!             }\n//!             TransitionOutcome::Complete(state_machine) => state_machine,\n//!         };\n//!     }\n//! }\n//!\n//! #[derive(Debug)]\n//! enum Event {\n//!     // event sent by the state machine when the participant is\n//!     // selected for the update task\n//!     Update,\n//!     // event sent by the state machine when the participant is\n//!     // selected for the sum task\n//!     Sum,\n//!     // event sent by the state machine when a new round starts\n//!     NewRound,\n//!     // event sent by the state machine when the participant\n//!     // becomes inactive (after finishing a task for instance)\n//!     Idle,\n//!     // event sent by the state machine when the participant\n//!     // is supposed to populate the model store\n//!     LoadModel,\n//! }\n//!\n//! // Our notifier is a simple wrapper around a channel.\n//! struct Notifier(mpsc::Sender<Event>);\n//!\n//! impl Notify for Notifier {\n//!     fn new_round(&mut self) {\n//!         self.0.send(Event::NewRound).unwrap();\n//!     }\n//!     fn sum(&mut self) {\n//!         self.0.send(Event::Sum).unwrap();\n//!     }\n//!     fn update(&mut self) {\n//!         self.0.send(Event::Update).unwrap();\n//!     }\n//!     fn idle(&mut self) {\n//!         self.0.send(Event::Idle).unwrap();\n//!     }\n//!     fn load_model(&mut self) {\n//!         self.0.send(Event::LoadModel).unwrap();\n//!     }\n//! }\n//!\n//! // Our store will always load the same model.\n//! // In practice the model should be updated with\n//! // the model the participant trains when it is selected\n//! // for the update task.\n//! struct LocalModel(Arc<Model>);\n//!\n//! #[async_trait]\n//! impl ModelStore for LocalModel {\n//!     type Model = Arc<Model>;\n//!     type Error = std::convert::Infallible;\n//!\n//!     async fn load_model(&mut self) -> Result<Option<Self::Model>, Self::Error> {\n//!         Ok(Some(self.0.clone()))\n//!     }\n//! }\n//!\n//! #[tokio::main]\n//! async fn main() -> Result<(), std::convert::Infallible> {\n//!     let keys = SigningKeyPair::generate();\n//!     let settings = PetSettings::new(keys);\n//!     let xaynet_client = Client::new(ReqwestClient::new(), \"http://localhost:8081\").unwrap();\n//!     let (tx, rx) = mpsc::channel::<Event>();\n//!     let notifier = Notifier(tx);\n//!     let model = Model::from_primitives(vec![0; 100].into_iter()).unwrap();\n//!     let model_store = LocalModel(Arc::new(model));\n//!\n//!     let mut state_machine = StateMachine::new(settings, xaynet_client, model_store, notifier);\n//!     // Start the agent\n//!     tokio::spawn(async move {\n//!         run_agent(state_machine, Duration::from_secs(1)).await;\n//!     });\n//!\n//!     loop {\n//!         println!(\"{:?}\", rx.recv().unwrap());\n//!     }\n//! }\n//! # }\n//! # fn main() {} // don't actually run anything, because the client never terminates\n//! ```\n\npub mod client;\nmod message_encoder;\npub mod settings;\nmod state_machine;\nmod traits;\npub(crate) mod utils;\n\npub(crate) use self::message_encoder::MessageEncoder;\npub use self::traits::{ModelStore, Notify, XaynetClient};\npub use state_machine::{LocalModelConfig, SerializableState, StateMachine, TransitionOutcome};\n"
  },
  {
    "path": "rust/xaynet-sdk/src/message_encoder/chunker.rs",
    "content": "#![allow(dead_code)]\n\nuse std::cmp;\n\n/// Default chunk size, for [`Chunker`]\npub const DEFAULT_CHUNK_SIZE: usize = 4096;\n\n/// A struct that yields chunks of the given data.\npub struct Chunker<'a, T: AsRef<[u8]>> {\n    data: &'a T,\n    max_chunk_size: usize,\n}\n\nimpl<'a, T> Chunker<'a, T>\nwhere\n    T: AsRef<[u8]>,\n{\n    /// Create a new [`Chunker`] that yields chunks of `T` of size\n    /// `max_chunk_size`. If `max_chunk_size` is `0`, then the max\n    /// chunk size will be set to [`DEFAULT_CHUNK_SIZE`].\n    pub fn new(data: &'a T, max_chunk_size: usize) -> Self {\n        let max_chunk_size = if max_chunk_size == 0 {\n            DEFAULT_CHUNK_SIZE\n        } else {\n            max_chunk_size\n        };\n        Self {\n            data,\n            max_chunk_size,\n        }\n    }\n\n    /// Get the total number of chunks\n    pub fn nb_chunks(&self) -> usize {\n        let data_len = self.data.as_ref().len();\n        ceiling_div(data_len, self.max_chunk_size)\n    }\n\n    /// Get the chunk with the given ID.\n    ///\n    /// # Panics\n    ///\n    /// This method panics if the given `id` is bigger than `self.nb_chunks()`.\n    pub fn get_chunk(&self, id: usize) -> &'a [u8] {\n        if id >= self.nb_chunks() {\n            panic!(\"no chunk with ID {}\", id);\n        }\n        let start = id * self.max_chunk_size;\n        let end = cmp::min(start + self.max_chunk_size, self.data.as_ref().len());\n        let range = start..end;\n        &self.data.as_ref()[range]\n    }\n}\n\n/// A helper that performs division with ceil.\n///\n/// # Panic\n///\n/// This function panic if `d` is 0.\nfn ceiling_div(n: usize, d: usize) -> usize {\n    (n + d - 1) / d\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    #[should_panic(expected = \"no chunk with ID 0\")]\n    fn test_0() {\n        let data = vec![];\n        let chunker = Chunker::new(&data, 0);\n        assert_eq!(chunker.nb_chunks(), 0);\n        chunker.get_chunk(0);\n    }\n\n    #[test]\n    #[should_panic(expected = \"no chunk with ID 5\")]\n    fn test_1() {\n        let data = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];\n        let chunker = Chunker::new(&data, 2);\n        assert_eq!(chunker.nb_chunks(), 5);\n        assert_eq!(chunker.get_chunk(0), &[0, 1]);\n        assert_eq!(chunker.get_chunk(1), &[2, 3]);\n        assert_eq!(chunker.get_chunk(2), &[4, 5]);\n        assert_eq!(chunker.get_chunk(3), &[6, 7]);\n        assert_eq!(chunker.get_chunk(4), &[8, 9]);\n        chunker.get_chunk(5);\n    }\n\n    #[test]\n    #[should_panic(expected = \"no chunk with ID 5\")]\n    fn test_2() {\n        let data = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];\n        let chunker = Chunker::new(&data, 2);\n        assert_eq!(chunker.nb_chunks(), 5);\n        assert_eq!(chunker.get_chunk(0), &[0, 1]);\n        assert_eq!(chunker.get_chunk(1), &[2, 3]);\n        assert_eq!(chunker.get_chunk(2), &[4, 5]);\n        assert_eq!(chunker.get_chunk(3), &[6, 7]);\n        assert_eq!(chunker.get_chunk(4), &[8]);\n        chunker.get_chunk(5);\n    }\n\n    #[test]\n    #[should_panic(expected = \"no chunk with ID 4\")]\n    fn test_3() {\n        let data = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];\n        let chunker = Chunker::new(&data, 3);\n        assert_eq!(chunker.nb_chunks(), 4);\n        assert_eq!(chunker.get_chunk(0), &[0, 1, 2]);\n        assert_eq!(chunker.get_chunk(1), &[3, 4, 5]);\n        assert_eq!(chunker.get_chunk(2), &[6, 7, 8]);\n        assert_eq!(chunker.get_chunk(3), &[9]);\n        chunker.get_chunk(4);\n    }\n\n    #[test]\n    #[should_panic(expected = \"no chunk with ID 1\")]\n    fn test_4() {\n        let data = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];\n        let chunker = Chunker::new(&data, 10);\n        assert_eq!(chunker.nb_chunks(), 1);\n        assert_eq!(chunker.get_chunk(0), data.as_slice());\n        chunker.get_chunk(1);\n    }\n\n    #[test]\n    #[should_panic(expected = \"no chunk with ID 1\")]\n    fn test_5() {\n        let data = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];\n        let chunker = Chunker::new(&data, 0);\n        assert_eq!(chunker.max_chunk_size, DEFAULT_CHUNK_SIZE);\n        assert_eq!(chunker.nb_chunks(), 1);\n        assert_eq!(chunker.get_chunk(0), data.as_slice());\n        chunker.get_chunk(1);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-sdk/src/message_encoder/encoder.rs",
    "content": "use serde::{Deserialize, Serialize};\nuse thiserror::Error;\n\nuse super::Chunker;\nuse xaynet_core::{\n    crypto::{PublicEncryptKey, SecretSigningKey, SigningKeyPair},\n    message::{Chunk, Message, Payload, Tag, ToBytes},\n};\n\n/// An encoder for multipart messages. It implements\n/// `Iterator<Item=Vec<u8>>`, which yields message parts ready to be\n/// sent over the wire.\n#[derive(Serialize, Deserialize, Debug)]\npub struct MultipartEncoder {\n    keys: SigningKeyPair,\n    /// The coordinator public key. It should be the key used to\n    /// encrypt the message.\n    coordinator_pk: PublicEncryptKey,\n    /// Serialized message payload.\n    data: Vec<u8>,\n    /// Next chunk ID to be produced by the iterator\n    id: u16,\n    /// Message tag\n    tag: Tag,\n    /// The maximum size allowed for the payload. `self.data` is split\n    /// in chunks of this size.\n    payload_size: usize,\n    /// A random ID common to all the message chunks.\n    message_id: u16,\n}\n\n/// Overhead induced by wrapping the data in [`Payload::Chunk`]\npub const CHUNK_OVERHEAD: usize = 8;\npub const MIN_PAYLOAD_SIZE: usize = CHUNK_OVERHEAD + 1;\n\nimpl Iterator for MultipartEncoder {\n    type Item = Vec<u8>;\n\n    fn next(&mut self) -> Option<Self::Item> {\n        let chunker = Chunker::new(&self.data, self.payload_size - CHUNK_OVERHEAD);\n\n        if self.id as usize >= chunker.nb_chunks() {\n            return None;\n        }\n\n        let chunk = Chunk {\n            id: self.id,\n            message_id: self.message_id,\n            last: self.id as usize == chunker.nb_chunks() - 1,\n            data: chunker.get_chunk(self.id as usize).to_vec(),\n        };\n        self.id += 1;\n\n        let message = Message {\n            // The signature is computed when serializing the message\n            signature: None,\n            participant_pk: self.keys.public,\n            is_multipart: true,\n            tag: self.tag,\n            payload: Payload::Chunk(chunk),\n            coordinator_pk: self.coordinator_pk,\n        };\n        let data = serialize_message(&message, &self.keys.secret);\n        Some(data)\n    }\n}\n\n/// An encoder for a [`Payload`] representing a sum, update or sum2\n/// message. If the [`Payload`] is small enough, a [`Message`] header\n/// is added, and the message is serialized and signed. If\n/// the [`Payload`] is too large to fit in a single message, it is\n/// split in chunks which are also serialized and signed.\n#[derive(Serialize, Deserialize, Debug)]\npub enum MessageEncoder {\n    /// Encoder for a payload that fits in a single message.\n    Simple(Option<Vec<u8>>),\n    /// Encoder for a large payload that needs to be split in several\n    /// parts.\n    Multipart(MultipartEncoder),\n}\n\nimpl Iterator for MessageEncoder {\n    type Item = Vec<u8>;\n\n    fn next(&mut self) -> Option<Self::Item> {\n        match self {\n            MessageEncoder::Simple(ref mut data) => data.take(),\n            MessageEncoder::Multipart(ref mut multipart_encoder) => multipart_encoder.next(),\n        }\n    }\n}\n\n#[derive(Error, Debug)]\npub enum InvalidEncodingInput {\n    #[error(\"only sum, update, and sum2 messages can be encoded\")]\n    Payload,\n    #[error(\"the max payload size is too small\")]\n    PayloadSize,\n}\n\nimpl MessageEncoder {\n    // NOTE: the only reason we need to consume the payload is because creating the Message\n    // consumes it.\n    /// Create a new encoder for the given payload. The `participant`\n    /// is used to sign the message(s). If the serialized payload is\n    /// larger than `max_payload_size`, the message will we split in\n    /// multiple chunks. If `max_payload_size` is `0`, the message\n    /// will not be split.\n    ///\n    /// # Errors\n    ///\n    /// An [`InvalidEncodingInput`] error is returned when `payload` is of\n    /// type [`Payload::Chunk`]. Only [`Payload::Sum`],\n    /// [`Payload::Update`], [`Payload::Sum2`] are accepted.\n    pub fn new(\n        keys: SigningKeyPair,\n        payload: Payload,\n        coordinator_pk: PublicEncryptKey,\n        max_payload_size: usize,\n    ) -> Result<Self, InvalidEncodingInput> {\n        // Reject payloads of type Payload::Chunk. It is the job of the encoder to produce those if\n        // the payload is deemed to big to be sent in a single message\n        if payload.is_chunk() {\n            return Err(InvalidEncodingInput::Payload);\n        }\n\n        if max_payload_size != 0 && max_payload_size <= MIN_PAYLOAD_SIZE {\n            return Err(InvalidEncodingInput::PayloadSize);\n        }\n\n        if max_payload_size != 0 && payload.buffer_length() > max_payload_size {\n            Ok(Self::new_multipart(\n                keys,\n                coordinator_pk,\n                payload,\n                max_payload_size,\n            ))\n        } else {\n            Ok(Self::new_simple(keys, coordinator_pk, payload))\n        }\n    }\n\n    fn new_simple(\n        keys: SigningKeyPair,\n        coordinator_pk: PublicEncryptKey,\n        payload: Payload,\n    ) -> Self {\n        let message = Message {\n            // The signature is computed when serializing the message\n            signature: None,\n            participant_pk: keys.public,\n            is_multipart: false,\n            coordinator_pk,\n            tag: Self::get_tag_from_payload(&payload),\n            payload,\n        };\n        let data = serialize_message(&message, &keys.secret);\n        Self::Simple(Some(data))\n    }\n\n    fn new_multipart(\n        keys: SigningKeyPair,\n        coordinator_pk: PublicEncryptKey,\n        payload: Payload,\n        payload_size: usize,\n    ) -> Self {\n        let tag = Self::get_tag_from_payload(&payload);\n        let mut data = vec![0; payload.buffer_length()];\n        payload.to_bytes(&mut data);\n        Self::Multipart(MultipartEncoder {\n            keys,\n            data,\n            id: 0,\n            tag,\n            coordinator_pk,\n            payload_size,\n            message_id: rand::random::<u16>(),\n        })\n    }\n\n    fn get_tag_from_payload(payload: &Payload) -> Tag {\n        match payload {\n            Payload::Sum(_) => Tag::Sum,\n            Payload::Update(_) => Tag::Update,\n            Payload::Sum2(_) => Tag::Sum2,\n            Payload::Chunk(_) => panic!(\"no tag associated to Payload::Chunk\"),\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use xaynet_core::{\n        crypto::{ByteObject, EncryptKeyPair, EncryptKeySeed, SigningKeyPair, SigningKeySeed},\n        message::{FromBytes, Update},\n        testutils::multipart as helpers,\n    };\n\n    use super::*;\n\n    fn participant_keys() -> SigningKeyPair {\n        let seed = SigningKeySeed::from_slice(vec![0; 32].as_slice()).unwrap();\n        SigningKeyPair::derive_from_seed(&seed)\n    }\n\n    fn coordinator_keys() -> EncryptKeyPair {\n        let seed = EncryptKeySeed::from_slice(vec![0; 32].as_slice()).unwrap();\n        EncryptKeyPair::derive_from_seed(&seed)\n    }\n\n    fn message(dict_len: usize, mask_obj_len: usize) -> Message {\n        let payload = helpers::update(dict_len, mask_obj_len).into();\n        Message {\n            signature: None,\n            participant_pk: participant_keys().public,\n            is_multipart: false,\n            tag: Tag::Update,\n            payload,\n            coordinator_pk: coordinator_keys().public,\n        }\n    }\n\n    fn small_message() -> Message {\n        let dict_len = 80 + 32 + 4; // 116 => dict with a single entry\n        let model_len = 6 + 18; // 24 => masked model with single weight\n        let message = message(dict_len, model_len);\n        let payload_len = dict_len + model_len + 64 * 2; // 268\n        let message_len = payload_len + 136; // 404\n        assert_eq!(message.payload.buffer_length(), payload_len);\n        assert_eq!(message.buffer_length(), message_len);\n        message\n    }\n\n    #[test]\n    fn no_chunk() {\n        let msg = small_message();\n\n        let mut enc = MessageEncoder::new(\n            participant_keys(),\n            msg.clone().payload,\n            msg.coordinator_pk,\n            272,\n        )\n        .unwrap();\n\n        let data = enc.next().unwrap();\n        let parsed = Message::from_byte_slice(&data.as_slice()).unwrap();\n        assert!(!parsed.is_multipart);\n        assert_eq!(parsed.payload, msg.payload);\n        assert!(enc.next().is_none());\n    }\n\n    #[test]\n    fn two_chunks() {\n        let msg = small_message();\n\n        let mut enc = MessageEncoder::new(\n            participant_keys(),\n            msg.clone().payload,\n            msg.coordinator_pk,\n            200,\n        )\n        .unwrap();\n\n        let data = enc.next().unwrap();\n        // The payload should be 200 bytes + 136 bytes for the\n        // message header.\n        //\n        // 8 of these 200 payload bytes are for the Chunk payload\n        // header. So this chunk actually only contains 192 bytes (out\n        // of 268) from the Update payload. So 76 bytes remain.\n        assert_eq!(data.len(), 200 + 136);\n        let parsed = Message::from_byte_slice(&data.as_slice()).unwrap();\n        assert!(parsed.is_multipart);\n        let chunk1 = extract_chunk(parsed);\n        assert!(!chunk1.last);\n        assert_eq!(chunk1.id, 0);\n        assert_eq!(chunk1.data.len(), 192);\n\n        let data = enc.next().unwrap();\n        // The payload should be 76 bytes + 8 bytes of CHUNK_OVERHEAD,\n        // plus 136 byte for the message header\n        assert_eq!(data.len(), 84 + 136);\n        let parsed = Message::from_byte_slice(&data.as_slice()).unwrap();\n        assert!(parsed.is_multipart);\n        let chunk2 = extract_chunk(parsed);\n        assert!(chunk2.last);\n        assert_eq!(chunk2.id, 1);\n        assert_eq!(chunk2.data.len(), 76);\n\n        let payload_data: Vec<u8> = [chunk1.data, chunk2.data].concat();\n        let update = Update::from_byte_slice(&payload_data).unwrap();\n        assert_eq!(update, extract_update(msg));\n    }\n\n    fn extract_chunk(message: Message) -> Chunk {\n        if let Payload::Chunk(c) = message.payload {\n            c\n        } else {\n            panic!(\"not a chunk message\");\n        }\n    }\n\n    fn extract_update(message: Message) -> Update {\n        if let Payload::Update(u) = message.payload {\n            u\n        } else {\n            panic!(\"not an update message\");\n        }\n    }\n}\n\nfn serialize_message(message: &Message, sk: &SecretSigningKey) -> Vec<u8> {\n    let mut buf = vec![0; message.buffer_length()];\n    message.to_bytes(&mut buf, sk);\n    buf\n}\n"
  },
  {
    "path": "rust/xaynet-sdk/src/message_encoder/mod.rs",
    "content": "mod chunker;\nmod encoder;\n\nuse chunker::Chunker;\npub use encoder::MessageEncoder;\n"
  },
  {
    "path": "rust/xaynet-sdk/src/settings/max_message_size.rs",
    "content": "use serde::{de::Error as SerdeError, Deserialize, Deserializer, Serialize};\nuse thiserror::Error;\n\npub use xaynet_core::message::MESSAGE_HEADER_LENGTH;\n\n/// The minimum message payload size\npub const MINIMUM_PAYLOAD_SIZE: usize = 1;\n/// Length of the encryption header in encrypted messages\npub const ENCRYPTION_HEADER_LENGTH: usize = xaynet_core::crypto::SEALBYTES;\n/// The minimum size a message can have\npub const MIN_MESSAGE_SIZE: usize =\n    MESSAGE_HEADER_LENGTH + ENCRYPTION_HEADER_LENGTH + MINIMUM_PAYLOAD_SIZE;\n\n/// Invalid [`MaxMessageSize`] value\n#[derive(Debug, Error)]\n#[error(\"max message size must be at least {}\", MIN_MESSAGE_SIZE)]\npub struct InvalidMaxMessageSize;\n\n/// Represent the maximum size messages sent by a participant can\n/// have. If a larger message needs to be sent, it will be chunked and\n/// sent in several parts. Note that messages have a minimal size of\n/// [`MIN_MESSAGE_SIZE`].\n#[derive(Serialize, Deserialize, Clone, Copy, Debug)]\npub struct MaxMessageSize(#[serde(deserialize_with = \"deserialize\")] Option<usize>);\n\nimpl Default for MaxMessageSize {\n    fn default() -> Self {\n        MaxMessageSize(Some(\n            4096 - MESSAGE_HEADER_LENGTH - ENCRYPTION_HEADER_LENGTH,\n        ))\n    }\n}\n\nimpl MaxMessageSize {\n    /// An arbitrary large maximum message size. With this setting,\n    /// messages will never be split.\n    pub fn unlimited() -> Self {\n        MaxMessageSize(None)\n    }\n\n    /// Create a max message size of `size`.\n    ///\n    /// # Errors\n    ///\n    /// This method returns an [`InvalidMaxMessageSize`] error if\n    /// `size` is smaller than [`MIN_MESSAGE_SIZE`];\n    pub fn capped(size: usize) -> Result<Self, InvalidMaxMessageSize> {\n        if size >= MIN_MESSAGE_SIZE {\n            Ok(MaxMessageSize(Some(size)))\n        } else {\n            Err(InvalidMaxMessageSize)\n        }\n    }\n\n    /// Get the maximum payload size corresponding to the maximum\n    /// message size. `None` means that the payload size is unlimited.\n    pub fn max_payload_size(&self) -> Option<usize> {\n        self.0\n            .map(|size| size - MESSAGE_HEADER_LENGTH - ENCRYPTION_HEADER_LENGTH)\n    }\n}\n\nfn deserialize<'de, D>(deserializer: D) -> Result<Option<usize>, D::Error>\nwhere\n    D: Deserializer<'de>,\n{\n    let value: Option<usize> = Option::deserialize(deserializer)?;\n    match value {\n        Some(size) => {\n            if size >= MIN_MESSAGE_SIZE {\n                Ok(Some(size))\n            } else {\n                Err(SerdeError::custom(format!(\n                    \"max_message_size must be at least {} (got {})\",\n                    MIN_MESSAGE_SIZE, size\n                )))\n            }\n        }\n        None => Ok(None),\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use serde_json::json;\n\n    use super::*;\n\n    #[test]\n    fn max_message_size_deserialization_capped() {\n        let input = r#\"{\"some\":1000}\"#;\n        let expected = json!({\"some\": MaxMessageSize::capped(1000).unwrap()});\n        let actual: serde_json::Value = serde_json::from_str(input).unwrap();\n        assert_eq!(expected, actual);\n    }\n\n    #[test]\n    fn max_message_size_deserialization_unlimited() {\n        let input = r#\"{\"none\":null}\"#;\n        let expected = json!({ \"none\": MaxMessageSize::unlimited() });\n        let actual: serde_json::Value = serde_json::from_str(input).unwrap();\n        assert_eq!(expected, actual);\n    }\n\n    #[test]\n    fn max_message_size_deserialization_err() {\n        // Use a dummy struct, otherwise, serde deserializes the value\n        // as an integer.\n        #[derive(Deserialize, Serialize, Debug)]\n        struct Dummy {\n            mms: MaxMessageSize,\n        }\n        let input = r#\"{\"mms\":123}\"#;\n        let expected =\n            \"max_message_size must be at least 185 (got 123) at line 1 column 11\".to_string();\n        let actual = serde_json::from_str::<Dummy>(input).unwrap_err();\n        assert_eq!(expected, format!(\"{}\", actual));\n    }\n\n    #[test]\n    fn max_message_size_serialization_capped() {\n        let input = json!({\"some\": MaxMessageSize::capped(1000).unwrap()});\n        let expected = r#\"{\"some\":1000}\"#;\n        let actual = serde_json::to_string(&input).unwrap();\n        assert_eq!(expected, actual);\n    }\n\n    #[test]\n    fn max_message_size_serialization_unlimited() {\n        let input = json!({ \"none\": MaxMessageSize::unlimited() });\n        let expected = r#\"{\"none\":null}\"#;\n        let actual = serde_json::to_string(&input).unwrap();\n        assert_eq!(expected, actual);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-sdk/src/settings/mod.rs",
    "content": "mod max_message_size;\n\nuse serde::{Deserialize, Serialize};\n\npub use max_message_size::{InvalidMaxMessageSize, MaxMessageSize, MIN_MESSAGE_SIZE};\nuse xaynet_core::{crypto::SigningKeyPair, mask::Scalar};\n\n#[derive(Serialize, Deserialize, Debug)]\npub struct PetSettings {\n    pub keys: SigningKeyPair,\n    pub scalar: Scalar,\n    pub max_message_size: MaxMessageSize,\n}\n\nimpl PetSettings {\n    pub fn new(keys: SigningKeyPair) -> Self {\n        PetSettings {\n            keys,\n            scalar: Scalar::unit(),\n            max_message_size: MaxMessageSize::default(),\n        }\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-sdk/src/state_machine/io.rs",
    "content": "use std::error::Error;\n\nuse async_trait::async_trait;\n\nuse xaynet_core::{\n    common::RoundParameters,\n    mask::Model,\n    SumDict,\n    SumParticipantPublicKey,\n    UpdateSeedDict,\n};\n\nuse crate::{ModelStore, Notify, XaynetClient};\n\n/// Returned a dynamically dispatched [`IO`] object\npub(crate) fn boxed_io<X, M, N>(\n    xaynet_client: X,\n    model_store: M,\n    notifier: N,\n) -> Box<dyn IO<Model = Box<dyn AsRef<Model> + Send>>>\nwhere\n    X: XaynetClient + Send + 'static,\n    M: ModelStore + Send + 'static,\n    N: Notify + Send + 'static,\n{\n    Box::new(StateMachineIO::new(xaynet_client, model_store, notifier))\n}\n\n#[cfg(test)]\ntype DynModel = Box<(dyn std::convert::AsRef<xaynet_core::mask::Model> + Send)>;\n/// A trait that gathers all the [`Notify`], [`XaynetClient`] and [`ModelStore`]\n/// methods.\n///\n/// This trait is intended not to be exposed. It is a convenience for avoiding the\n/// proliferation of generic parameters in the state machine: instead of three traits,\n/// we now have only one.\n///\n/// Note that by having only one trait, we can also use dynamic dispatch and actually\n/// get rid of all the generic parameters in the state machine.\n///\n/// ```compile_fail\n/// Box<dyn IO> // allowed\n/// Box<dyn ModelStore + Notify + XaynetClient> // not allowed\n/// ```\n#[cfg_attr(test, mockall::automock(type Model=DynModel;))]\n#[async_trait]\npub(crate) trait IO: Send + 'static {\n    type Model;\n\n    /// Attempt to load the model from the store.\n    async fn load_model(&mut self) -> Result<Option<Self::Model>, Box<dyn Error>>;\n\n    /// Fetch the round parameters from the coordinator\n    async fn get_round_params(&mut self) -> Result<RoundParameters, Box<dyn Error>>;\n    /// Fetch the sum dictionary from the coordinator\n    async fn get_sums(&mut self) -> Result<Option<SumDict>, Box<dyn Error>>;\n    /// Fetch the seed dictionary for the given sum participant from the coordinator\n    async fn get_seeds(\n        &mut self,\n        pk: SumParticipantPublicKey,\n    ) -> Result<Option<UpdateSeedDict>, Box<dyn Error>>;\n    /// Fetch the latest global model from the coordinator\n    async fn get_model(&mut self) -> Result<Option<Model>, Box<dyn Error>>;\n    /// Send the given signed and encrypted PET message to the coordinator\n    async fn send_message(&mut self, msg: Vec<u8>) -> Result<(), Box<dyn Error>>;\n\n    /// Notify the participant that a new round started\n    fn notify_new_round(&mut self);\n    /// Notify the participant that they have been selected for the sum task for the current\n    /// round\n    fn notify_sum(&mut self);\n    /// Notify the participant that it is selected for the update task for the current\n    /// round\n    fn notify_update(&mut self);\n    /// Notify the participant that is done with its current task and it waiting for\n    /// being selected for a task\n    fn notify_idle(&mut self);\n    /// Notify the participant that is is expected to provide a model to the state\n    /// machine by loading it into the store\n    fn notify_load_model(&mut self);\n}\n\n/// Internal struct that implements the [`IO`] trait. It is not used as is in the state\n/// machine. Instead, we box it and use it as a `dyn IO` object.\nstruct StateMachineIO<X, M, N> {\n    xaynet_client: X,\n    model_store: M,\n    notifier: N,\n}\n\nimpl<X, M, N> StateMachineIO<X, M, N> {\n    /// Create a new `StateMachineIO`\n    pub fn new(xaynet_client: X, model_store: M, notifier: N) -> Self {\n        Self {\n            xaynet_client,\n            model_store,\n            notifier,\n        }\n    }\n}\n\n#[async_trait]\nimpl<X, M, N> IO for StateMachineIO<X, M, N>\nwhere\n    X: XaynetClient + Send + 'static,\n    M: ModelStore + Send + 'static,\n    N: Notify + Send + 'static,\n{\n    type Model = Box<dyn AsRef<Model> + Send>;\n\n    async fn load_model(&mut self) -> Result<Option<Self::Model>, Box<dyn Error>> {\n        self.model_store\n            .load_model()\n            .await\n            .map_err(|e| Box::new(e) as Box<dyn Error>)\n            .map(|opt| opt.map(|model| Box::new(model) as Box<dyn AsRef<Model> + Send>))\n    }\n\n    async fn get_round_params(&mut self) -> Result<RoundParameters, Box<dyn Error>> {\n        self.xaynet_client\n            .get_round_params()\n            .await\n            .map_err(|e| Box::new(e) as Box<dyn Error>)\n    }\n\n    async fn get_sums(&mut self) -> Result<Option<SumDict>, Box<dyn Error>> {\n        self.xaynet_client\n            .get_sums()\n            .await\n            .map_err(|e| Box::new(e) as Box<dyn Error>)\n    }\n\n    async fn get_seeds(\n        &mut self,\n        pk: SumParticipantPublicKey,\n    ) -> Result<Option<UpdateSeedDict>, Box<dyn Error>> {\n        self.xaynet_client\n            .get_seeds(pk)\n            .await\n            .map_err(|e| Box::new(e) as Box<dyn Error>)\n    }\n\n    async fn get_model(&mut self) -> Result<Option<Model>, Box<dyn Error>> {\n        self.xaynet_client\n            .get_model()\n            .await\n            .map_err(|e| Box::new(e) as Box<dyn Error>)\n    }\n\n    async fn send_message(&mut self, msg: Vec<u8>) -> Result<(), Box<dyn Error>> {\n        self.xaynet_client\n            .send_message(msg)\n            .await\n            .map_err(|e| Box::new(e) as Box<dyn Error>)\n    }\n\n    fn notify_new_round(&mut self) {\n        self.notifier.new_round()\n    }\n\n    fn notify_sum(&mut self) {\n        self.notifier.sum()\n    }\n\n    fn notify_update(&mut self) {\n        self.notifier.update()\n    }\n\n    fn notify_idle(&mut self) {\n        self.notifier.idle()\n    }\n\n    fn notify_load_model(&mut self) {\n        self.notifier.load_model()\n    }\n}\n\n#[async_trait]\nimpl IO for Box<dyn IO<Model = Box<dyn AsRef<Model> + Send>>> {\n    type Model = Box<dyn AsRef<Model> + Send>;\n\n    async fn load_model(&mut self) -> Result<Option<Self::Model>, Box<dyn Error>> {\n        self.as_mut().load_model().await\n    }\n\n    async fn get_round_params(&mut self) -> Result<RoundParameters, Box<dyn Error>> {\n        self.as_mut().get_round_params().await\n    }\n\n    async fn get_sums(&mut self) -> Result<Option<SumDict>, Box<dyn Error>> {\n        self.as_mut().get_sums().await\n    }\n\n    async fn get_seeds(\n        &mut self,\n        pk: SumParticipantPublicKey,\n    ) -> Result<Option<UpdateSeedDict>, Box<dyn Error>> {\n        self.as_mut().get_seeds(pk).await\n    }\n\n    async fn get_model(&mut self) -> Result<Option<Model>, Box<dyn Error>> {\n        self.as_mut().get_model().await\n    }\n\n    async fn send_message(&mut self, msg: Vec<u8>) -> Result<(), Box<dyn Error>> {\n        self.as_mut().send_message(msg).await\n    }\n\n    fn notify_new_round(&mut self) {\n        self.as_mut().notify_new_round()\n    }\n\n    fn notify_sum(&mut self) {\n        self.as_mut().notify_sum()\n    }\n\n    fn notify_update(&mut self) {\n        self.as_mut().notify_update()\n    }\n\n    fn notify_idle(&mut self) {\n        self.as_mut().notify_idle()\n    }\n\n    fn notify_load_model(&mut self) {\n        self.as_mut().notify_load_model()\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-sdk/src/state_machine/mod.rs",
    "content": "// Important the macro_use modules must be declared first for the\n// macro to be used in the other modules (until declarative macros are stable)\n#[macro_use]\nmod phase;\nmod io;\nmod phases;\n#[allow(clippy::module_inception)]\nmod state_machine;\n\n// It is useful to re-export everything within this module because\n// there are lot of interdependencies between all the sub-modules\n#[cfg(test)]\nuse self::io::MockIO;\nuse self::{\n    io::{boxed_io, IO},\n    phase::{IntoPhase, Phase, PhaseIo, Progress, SharedState, State, Step},\n    phases::{Awaiting, NewRound, SendingSum, SendingSum2, SendingUpdate, Sum, Sum2, Update},\n};\n\npub use self::{\n    phase::{LocalModelConfig, SerializableState},\n    state_machine::{StateMachine, TransitionOutcome},\n};\n\n#[cfg(test)]\npub mod tests;\n"
  },
  {
    "path": "rust/xaynet-sdk/src/state_machine/phase.rs",
    "content": "use async_trait::async_trait;\nuse derive_more::From;\nuse serde::{Deserialize, Serialize};\nuse thiserror::Error;\nuse tracing::{debug, error, info, warn};\n\nuse super::{Awaiting, NewRound, SendingSum, SendingSum2, SendingUpdate, Sum, Sum2, Update, IO};\nuse crate::{\n    settings::{MaxMessageSize, PetSettings},\n    state_machine::{StateMachine, TransitionOutcome},\n    MessageEncoder,\n};\nuse xaynet_core::{\n    common::{RoundParameters, RoundSeed},\n    crypto::{ByteObject, PublicEncryptKey, SigningKeyPair},\n    mask::{self, DataType, MaskConfig, Model, Scalar},\n    message::Payload,\n};\n\n/// State of the state machine\n#[derive(Debug, Serialize, Deserialize)]\npub struct State<P> {\n    /// data specific to the current phase\n    pub private: Box<P>,\n    /// data common to most of the phases\n    pub shared: Box<SharedState>,\n}\n\nimpl<P> State<P> {\n    /// Create a new state\n    pub fn new(shared: Box<SharedState>, private: Box<P>) -> Self {\n        Self { private, shared }\n    }\n}\n\n/// A dynamically dispatched [`IO`] object.\npub(crate) type PhaseIo = Box<dyn IO<Model = Box<dyn AsRef<Model> + Send>>>;\n\n/// Represent the state machine in a specific phase\npub struct Phase<P> {\n    /// State of the phase.\n    pub(super) state: State<P>,\n    /// Opaque client for performing IO tasks: talking with the\n    /// coordinator API, loading models, etc.\n    pub(super) io: PhaseIo,\n}\n\nimpl<P> std::fmt::Debug for Phase<P>\nwhere\n    P: std::fmt::Debug,\n{\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        f.debug_struct(\"Phase\")\n            .field(\"state\", &self.state)\n            .field(\"io\", &\"PhaseIo\")\n            .finish()\n    }\n}\n\n/// Store for all the data that are common to all the phases\n#[derive(Serialize, Deserialize, Debug)]\npub struct SharedState {\n    /// Keys that identify the participant. They are used to sign the\n    /// PET message sent by the participant.\n    pub keys: SigningKeyPair,\n    /// Scalar used for masking\n    pub scalar: Scalar,\n    /// Maximum message size the participant can send. Messages larger\n    /// than `message_size` are split in several parts.\n    pub message_size: MaxMessageSize,\n    /// Current round parameters\n    pub round_params: RoundParameters,\n}\n\n/// Get arbitrary round parameters. These round parameters are never used, we just\n/// temporarily use them in the [`SharedState`] when creating a new state machine. The\n/// first thing the state machine does when it runs, is to fetch the real round\n/// parameters from the coordinator.\nfn dummy_round_parameters() -> RoundParameters {\n    RoundParameters {\n        pk: PublicEncryptKey::zeroed(),\n        sum: 0.0,\n        update: 0.0,\n        seed: RoundSeed::zeroed(),\n        mask_config: MaskConfig {\n            group_type: mask::GroupType::Integer,\n            data_type: mask::DataType::F32,\n            bound_type: mask::BoundType::B0,\n            model_type: mask::ModelType::M3,\n        }\n        .into(),\n        model_length: 0,\n    }\n}\n\nimpl SharedState {\n    pub fn new(settings: PetSettings) -> Self {\n        Self {\n            keys: settings.keys,\n            scalar: settings.scalar,\n            message_size: settings.max_message_size,\n            round_params: dummy_round_parameters(),\n        }\n    }\n}\n\n/// A trait that each `Phase<P>` implements. When `Step::step` is called, the phase\n/// tries to do a small piece of work.\n#[async_trait]\npub trait Step {\n    /// Represent an attempt to make progress within a phase. If the step results in a\n    /// change in the phase state, the updated state machine is returned as\n    /// `TransitionOutcome::Complete`. If no progress can be made, the state machine is\n    /// returned unchanged as `TransitionOutcome::Pending`.\n    async fn step(mut self) -> TransitionOutcome;\n}\n\n#[macro_export]\nmacro_rules! try_progress {\n    ($progress:expr) => {{\n        use $crate::state_machine::{Progress, TransitionOutcome};\n        match $progress {\n            // No progress can be made. Return the state machine as is\n            Progress::Stuck(phase) => return TransitionOutcome::Pending(phase.into()),\n            // Further progress can be made but require more work, so don't return\n            Progress::Continue(phase) => phase,\n            // Progress has been made, return the updated state machine\n            Progress::Updated(state_machine) => return TransitionOutcome::Complete(state_machine),\n        }\n    }};\n}\n\n/// Represent the presence or absence of progress being made during a phase.\n#[derive(Debug)]\npub enum Progress<P> {\n    /// No progress can be made currently.\n    Stuck(Phase<P>),\n    /// More work needs to be done for progress to be made.\n    Continue(Phase<P>),\n    /// Progress has been made and resulted in this new state machine.\n    Updated(StateMachine),\n}\n\nimpl<P> Phase<P>\nwhere\n    Phase<P>: Step + Into<StateMachine>,\n{\n    /// Try to make some progress in the execution of the PET protocol. There are three\n    /// possible outcomes:\n    ///\n    /// 1. no progress can currently be made and the phase state is unchanged\n    /// 2. progress is made but the state machine does not transition to a new\n    ///    phase. Internally, the phase state is changed though.\n    /// 3. progress is made and the state machine transitions to a new phase.\n    ///\n    /// In case `1.`, the state machine is returned unchanged, wrapped in\n    /// [`TransitionOutcome::Pending`] to indicate to the caller that the state machine\n    /// wasn't updated. In case `2.` and `3.` the updated state machine is returned\n    /// wrapped in [`TransitionOutcome::Complete`].\n    pub async fn step(mut self) -> TransitionOutcome {\n        match self.check_round_freshness().await {\n            RoundFreshness::Unknown => TransitionOutcome::Pending(self.into()),\n            RoundFreshness::Outdated => {\n                info!(\"a new round started: updating the round parameters and resetting the state machine\");\n                self.io.notify_new_round();\n                TransitionOutcome::Complete(\n                    Phase::<NewRound>::new(\n                        State::new(self.state.shared, Box::new(NewRound)),\n                        self.io,\n                    )\n                    .into(),\n                )\n            }\n            RoundFreshness::Fresh => {\n                debug!(\"round is still fresh, continuing from where we left off\");\n                <Self as Step>::step(self).await\n            }\n        }\n    }\n\n    /// Check whether the coordinator has published new round parameters. In other\n    /// words, this checks whether a new round has started.\n    async fn check_round_freshness(&mut self) -> RoundFreshness {\n        match self.io.get_round_params().await {\n            Err(e) => {\n                warn!(\"failed to fetch round parameters {:?}\", e);\n                RoundFreshness::Unknown\n            }\n            Ok(params) => {\n                if params == self.state.shared.round_params {\n                    debug!(\"round parameters didn't change\");\n                    RoundFreshness::Fresh\n                } else {\n                    info!(\"fetched fresh round parameters\");\n                    self.state.shared.round_params = params;\n                    RoundFreshness::Outdated\n                }\n            }\n        }\n    }\n}\n\n/// Trait for building [`Phase<P>`] from a [`State<P>`].\n///\n/// Note that we could just use [`Phase::new`] for this. However we want to be able to\n/// customize the conversion for each phase. For instance, when building a\n/// `Phase<Update>` from an `Update`, we want to emit some events with the `io`\n/// object. It is cleaner to wrap this custom logic in a trait impl.\npub(crate) trait IntoPhase<P> {\n    /// Build the phase with the given `io` object\n    fn into_phase(self, io: PhaseIo) -> Phase<P>;\n}\n\nimpl<P> Phase<P> {\n    /// Build a new phase with the given state and io object. This should not be called\n    /// directly. Instead, use the [`IntoPhase`] trait to construct a phase.\n    pub(crate) fn new(state: State<P>, io: PhaseIo) -> Self {\n        Phase { state, io }\n    }\n\n    /// Instantiate a message encoder for the given payload.\n    ///\n    /// The encoder takes care of converting the given `payload` into one or several\n    /// signed and encrypted PET messages.\n    pub fn message_encoder(&self, payload: Payload) -> MessageEncoder {\n        MessageEncoder::new(\n            self.state.shared.keys.clone(),\n            payload,\n            self.state.shared.round_params.pk,\n            self.state\n                .shared\n                .message_size\n                .max_payload_size()\n                .unwrap_or(0),\n        )\n        // the encoder rejects Chunk payload, but in the state\n        // machine, we never manually create such payloads so\n        // unwrapping is fine\n        .unwrap()\n    }\n\n    /// Return the local model configuration of the model that is expected in the update phase.\n    pub fn local_model_config(&self) -> LocalModelConfig {\n        LocalModelConfig {\n            data_type: self.state.shared.round_params.mask_config.vect.data_type,\n            len: self.state.shared.round_params.model_length,\n        }\n    }\n\n    #[cfg(test)]\n    pub(crate) fn with_io_mock<F>(&mut self, f: F)\n    where\n        F: FnOnce(&mut super::MockIO),\n    {\n        let mut mock = super::MockIO::new();\n        f(&mut mock);\n        self.io = Box::new(mock);\n    }\n\n    #[cfg(test)]\n    pub(crate) fn check_io_mock(&mut self) {\n        // dropping the mock forces the checks to run. We replace it\n        // by an empty one, so that we detect if a method is called\n        // un-expectedly afterwards\n        let _ = std::mem::replace(&mut self.io, Box::new(super::MockIO::new()));\n    }\n}\n\n#[derive(Debug)]\n/// The local model configuration of the model that is expected in the update phase.\npub struct LocalModelConfig {\n    /// The expected data type of the local model.\n    // In the current state it is not possible to configure a coordinator in which\n    // the scalar data type and the model data type are different. Therefore, we assume here\n    // that the scalar data type is the same as the model data type.\n    pub data_type: DataType,\n    /// The expected length of the local model.\n    pub len: usize,\n}\n\n#[derive(Error, Debug)]\n#[error(\"failed to send a PET message\")]\npub struct SendMessageError;\n\n/// Round freshness indicator\npub enum RoundFreshness {\n    /// A new round started. The current round is outdated\n    Outdated,\n    /// We were not able to check whether a new round started\n    Unknown,\n    /// The current round is still going\n    Fresh,\n}\n\n/// A serializable representation of a phase state.\n///\n/// We cannot serialize the state directly, even though it implements `Serialize`, because deserializing it would require knowing its type in advance:\n///\n/// ```compile_fail\n/// // `buf` is a Vec<u8> that contains a serialized state that we want to deserialize\n/// let state: State<???> = State::deserialize(&buf[..]).unwrap();\n/// ```\n#[derive(Serialize, Deserialize, From, Debug)]\npub enum SerializableState {\n    NewRound(State<NewRound>),\n    Awaiting(State<Awaiting>),\n    Sum(State<Sum>),\n    Update(State<Update>),\n    Sum2(State<Sum2>),\n    SendingSum(State<SendingSum>),\n    SendingUpdate(State<SendingUpdate>),\n    SendingSum2(State<SendingSum2>),\n}\n\nimpl<P> From<Phase<P>> for SerializableState\nwhere\n    State<P>: Into<SerializableState>,\n{\n    fn from(phase: Phase<P>) -> Self {\n        phase.state.into()\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-sdk/src/state_machine/phases/awaiting.rs",
    "content": "use async_trait::async_trait;\nuse serde::{Deserialize, Serialize};\nuse tracing::info;\n\nuse crate::state_machine::{IntoPhase, Phase, PhaseIo, State, Step, TransitionOutcome};\n\n#[derive(Serialize, Deserialize, Debug)]\npub struct Awaiting;\n\n#[async_trait]\nimpl Step for Phase<Awaiting> {\n    async fn step(mut self) -> TransitionOutcome {\n        info!(\"awaiting task\");\n        TransitionOutcome::Pending(self.into())\n    }\n}\n\nimpl IntoPhase<Awaiting> for State<Awaiting> {\n    fn into_phase(self, mut io: PhaseIo) -> Phase<Awaiting> {\n        io.notify_idle();\n        Phase::<_>::new(self, io)\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-sdk/src/state_machine/phases/mod.rs",
    "content": "mod awaiting;\nmod new_round;\nmod sending;\nmod sum;\nmod sum2;\nmod update;\n\npub use self::{\n    awaiting::Awaiting,\n    new_round::NewRound,\n    sending::{SendingSum, SendingSum2, SendingUpdate},\n    sum::Sum,\n    sum2::Sum2,\n    update::Update,\n};\n"
  },
  {
    "path": "rust/xaynet-sdk/src/state_machine/phases/new_round.rs",
    "content": "use async_trait::async_trait;\nuse serde::{Deserialize, Serialize};\nuse tracing::info;\nuse xaynet_core::crypto::{ByteObject, Signature};\n\nuse crate::state_machine::{\n    Awaiting,\n    IntoPhase,\n    Phase,\n    PhaseIo,\n    State,\n    Step,\n    Sum,\n    TransitionOutcome,\n    Update,\n};\n\n#[derive(Serialize, Deserialize, Debug)]\npub struct NewRound;\n\nimpl IntoPhase<NewRound> for State<NewRound> {\n    fn into_phase(self, mut io: PhaseIo) -> Phase<NewRound> {\n        io.notify_new_round();\n        Phase::<_>::new(self, io)\n    }\n}\n\n#[async_trait]\nimpl Step for Phase<NewRound> {\n    async fn step(mut self) -> TransitionOutcome {\n        info!(\"new_round task\");\n\n        info!(\"checking eligibility for sum task\");\n        let sum_signature = self.sign(b\"sum\");\n        if sum_signature.is_eligible(self.state.shared.round_params.sum) {\n            info!(\"eligible for sum task\");\n            return TransitionOutcome::Complete(self.into_sum(sum_signature).into());\n        }\n\n        info!(\"not eligible for sum task, checking eligibility for update task\");\n        let update_signature = self.sign(b\"update\");\n        if update_signature.is_eligible(self.state.shared.round_params.update) {\n            info!(\"eligible for update task\");\n            return TransitionOutcome::Complete(\n                self.into_update(sum_signature, update_signature).into(),\n            );\n        }\n\n        info!(\"not eligible for update task, going to sleep until next round\");\n        let awaiting: Phase<Awaiting> = self.into();\n        TransitionOutcome::Complete(awaiting.into())\n    }\n}\n\nimpl From<Phase<NewRound>> for Phase<Awaiting> {\n    fn from(new_round: Phase<NewRound>) -> Self {\n        State::new(new_round.state.shared, Box::new(Awaiting)).into_phase(new_round.io)\n    }\n}\n\nimpl Phase<NewRound> {\n    fn sign(&self, data: &[u8]) -> Signature {\n        let sk = &self.state.shared.keys.secret;\n        let seed = self.state.shared.round_params.seed.as_slice();\n        sk.sign_detached(&[seed, data].concat())\n    }\n\n    fn into_sum(self, sum_signature: Signature) -> Phase<Sum> {\n        let sum = Box::new(Sum::new(sum_signature));\n        let state = State::new(self.state.shared, sum);\n        state.into_phase(self.io)\n    }\n\n    fn into_update(self, sum_signature: Signature, update_signature: Signature) -> Phase<Update> {\n        let update = Box::new(Update::new(sum_signature, update_signature));\n        let state = State::new(self.state.shared, update);\n        state.into_phase(self.io)\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-sdk/src/state_machine/phases/sending.rs",
    "content": "use async_trait::async_trait;\nuse paste::paste;\nuse serde::{Deserialize, Serialize};\nuse tracing::{debug, error, info};\n\nuse crate::{\n    state_machine::{\n        phases::Sum2,\n        Awaiting,\n        IntoPhase,\n        Phase,\n        PhaseIo,\n        Progress,\n        State,\n        Step,\n        TransitionOutcome,\n        IO,\n    },\n    MessageEncoder,\n};\n\n/// Implements the `SendingSum`, `SendingUpdate` and `SendingSum2` phases and transitions.\nmacro_rules! impl_sending {\n    ($Phase: ty, $Next: ty, $phase: expr, $next: expr) => {\n        paste! {\n            #[doc = \"The state of the \" $phase \" sending phase.\"]\n            #[derive(Serialize, Deserialize, Debug)]\n            pub struct [<Sending $Phase>] {\n                /// The message to send.\n                message: MessageEncoder,\n\n                /// Chunk that couldn't be sent and should be tried again.\n                failed: Option<Vec<u8>>,\n\n                /// State of the phase to transition to, after this one completes.\n                next: $Next,\n            }\n\n            impl [<Sending $Phase>] {\n                #[doc = \"Creates a new \" $phase \" sending state.\"]\n                pub fn new(message: MessageEncoder, next: $Next) -> Self {\n                    Self {\n                        message,\n                        failed: None,\n                        next,\n                    }\n                }\n            }\n\n            impl IntoPhase<[<Sending $Phase>]> for State<[<Sending $Phase>]> {\n                fn into_phase(self, io: PhaseIo) -> Phase<[<Sending $Phase>]> {\n                    Phase::<_>::new(self, io)\n                }\n            }\n\n            #[async_trait]\n            impl Step for Phase<[<Sending $Phase>]> {\n                async fn step(mut self) -> TransitionOutcome {\n                    info!(\"sending {} message\", $phase);\n                    self = try_progress!(self.send_next().await);\n\n                    info!(\"done sending {} message, going to {} phase\", $phase, $next);\n                    let phase: Phase<$Next> = self.into();\n                    TransitionOutcome::Complete(phase.into())\n                }\n            }\n\n            impl From<Phase<[<Sending $Phase>]>> for Phase<$Next> {\n                fn from(sending: Phase<[<Sending $Phase>]>) -> Self {\n                    State::new(sending.state.shared, Box::new(sending.state.private.next))\n                        .into_phase(sending.io)\n                }\n            }\n\n            impl Phase<[<Sending $Phase>]> {\n                #[doc = \"Tries to send a \" $phase \" message and reports back on the progress made.\"]\n                async fn try_send(mut self, data: Vec<u8>) -> Progress<[<Sending $Phase>]> {\n                    info!(\"sending {} message (size = {})\", $phase, data.len());\n                    if let Err(e) = self.io.send_message(data.clone()).await {\n                        error!(\"failed to send {} message: {:?}\", $phase, e);\n                        self.state.private.failed = Some(data);\n                        Progress::Stuck(self)\n                    } else {\n                        Progress::Updated(self.into())\n                    }\n                }\n\n                #[doc =\n                    \"Sends the next \" $phase \" message and reports back on the progress made.\\n\"\n                    \"\\n\"\n                    \"Retries to send a previously failed message. Otherwise, tries to send the \"\n                    \"next message.\"\n                ]\n                async fn send_next(mut self) -> Progress<[<Sending $Phase>]> {\n                    if let Some(data) = self.state.private.failed.take() {\n                        debug!(\n                            \"retrying to send {} message that couldn't be sent previously\",\n                            $phase\n                        );\n                        self.try_send(data).await\n                    } else {\n                        match self.state.private.message.next() {\n                            Some(data) => {\n                                let data = self.state.shared.round_params.pk.encrypt(data.as_slice());\n                                self.try_send(data).await\n                            }\n                            None => {\n                                debug!(\"nothing left to send\");\n                                Progress::Continue(self)\n                            }\n                        }\n                    }\n                }\n            }\n        }\n    }\n}\n\nimpl_sending!(Sum, Sum2, \"sum\", \"sum2\");\nimpl_sending!(Update, Awaiting, \"update\", \"awaiting\");\nimpl_sending!(Sum2, Awaiting, \"sum2\", \"awaiting\");\n"
  },
  {
    "path": "rust/xaynet-sdk/src/state_machine/phases/sum.rs",
    "content": "use async_trait::async_trait;\nuse serde::{Deserialize, Serialize};\nuse tracing::{debug, info};\n\nuse crate::{\n    state_machine::{IntoPhase, Phase, PhaseIo, SendingSum, State, Step, Sum2, TransitionOutcome},\n    MessageEncoder,\n};\nuse xaynet_core::{\n    crypto::{EncryptKeyPair, Signature},\n    message::Sum as SumMessage,\n};\n\nuse super::Awaiting;\n\n/// The state of the sum phase.\n#[derive(Serialize, Deserialize, Debug)]\npub struct Sum {\n    /// The sum participant ephemeral keys. They are used to decrypt\n    /// the encrypted mask seeds.\n    pub ephm_keys: EncryptKeyPair,\n    /// Signature that proves that the participant has been selected\n    /// for the sum task.\n    pub sum_signature: Signature,\n}\n\nimpl Sum {\n    /// Creates a new sum state.\n    pub fn new(sum_signature: Signature) -> Self {\n        Sum {\n            ephm_keys: EncryptKeyPair::generate(),\n            sum_signature,\n        }\n    }\n}\n\nimpl IntoPhase<Sum> for State<Sum> {\n    fn into_phase(self, mut io: PhaseIo) -> Phase<Sum> {\n        io.notify_sum();\n        Phase::<_>::new(self, io)\n    }\n}\n\n#[async_trait]\nimpl Step for Phase<Sum> {\n    async fn step(mut self) -> TransitionOutcome {\n        info!(\"sum task\");\n        let sending: Phase<SendingSum> = self.into();\n        TransitionOutcome::Complete(sending.into())\n    }\n}\n\nimpl From<Phase<Sum>> for Phase<SendingSum> {\n    fn from(sum: Phase<Sum>) -> Self {\n        debug!(\"composing sum message\");\n        let message = sum.compose_message();\n\n        debug!(\"going to sending phase\");\n        let sum2 = Sum2::new(sum.state.private.ephm_keys, sum.state.private.sum_signature);\n        let sending = Box::new(SendingSum::new(message, sum2));\n        let state = State::new(sum.state.shared, sending);\n        state.into_phase(sum.io)\n    }\n}\n\nimpl From<Phase<Sum>> for Phase<Awaiting> {\n    fn from(sum: Phase<Sum>) -> Self {\n        State::new(sum.state.shared, Box::new(Awaiting)).into_phase(sum.io)\n    }\n}\n\nimpl Phase<Sum> {\n    /// Creates and encodes the sum message from the sum state.\n    pub fn compose_message(&self) -> MessageEncoder {\n        let sum = SumMessage {\n            sum_signature: self.state.private.sum_signature,\n            ephm_pk: self.state.private.ephm_keys.public,\n        };\n        self.message_encoder(sum.into())\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-sdk/src/state_machine/phases/sum2.rs",
    "content": "use async_trait::async_trait;\nuse serde::{Deserialize, Serialize};\nuse tracing::{debug, error, info, warn};\nuse xaynet_core::{\n    crypto::{EncryptKeyPair, Signature},\n    mask::{Aggregation, MaskObject, MaskSeed},\n    message::Sum2 as Sum2Message,\n    UpdateSeedDict,\n};\n\nuse crate::{\n    state_machine::{\n        IntoPhase,\n        Phase,\n        PhaseIo,\n        Progress,\n        SendingSum2,\n        State,\n        Step,\n        TransitionOutcome,\n        IO,\n    },\n    MessageEncoder,\n};\n\nuse super::Awaiting;\n\n/// The state of the sum2 phase.\n#[derive(Serialize, Deserialize, Debug)]\npub struct Sum2 {\n    /// The sum participant ephemeral keys. They are used to decrypt\n    /// the encrypted mask seeds.\n    pub ephm_keys: EncryptKeyPair,\n    /// Signature that proves that the participant has been selected\n    /// for the sum task.\n    pub sum_signature: Signature,\n    /// Dictionary containing the encrypted mask seed of every update\n    /// participants.\n    pub seed_dict: Option<UpdateSeedDict>,\n    /// The decrypted mask seeds\n    pub seeds: Option<Vec<MaskSeed>>,\n    /// The global mask, obtained by aggregating the masks derived\n    /// from the mask seeds.\n    pub mask: Option<MaskObject>,\n}\n\nimpl Sum2 {\n    /// Creates a new sum2 state.\n    pub fn new(ephm_keys: EncryptKeyPair, sum_signature: Signature) -> Self {\n        Self {\n            ephm_keys,\n            sum_signature,\n            seed_dict: None,\n            seeds: None,\n            mask: None,\n        }\n    }\n\n    /// Checks if the seed dict has already been fetched.\n    fn has_fetched_seed_dict(&self) -> bool {\n        self.seed_dict.is_some() || self.has_decrypted_seeds()\n    }\n\n    /// Checks if the seeds have already been decrypted.\n    fn has_decrypted_seeds(&self) -> bool {\n        self.seeds.is_some() || self.has_aggregated_masks()\n    }\n\n    /// Checks if the masks have already been aggregated.\n    fn has_aggregated_masks(&self) -> bool {\n        self.mask.is_some()\n    }\n}\n\nimpl IntoPhase<Sum2> for State<Sum2> {\n    fn into_phase(self, io: PhaseIo) -> Phase<Sum2> {\n        Phase::<_>::new(self, io)\n    }\n}\n\n#[async_trait]\nimpl Step for Phase<Sum2> {\n    async fn step(mut self) -> TransitionOutcome {\n        info!(\"sum2 task\");\n        self = try_progress!(self.fetch_seed_dict().await);\n        self = try_progress!(self.decrypt_seeds());\n        self = try_progress!(self.aggregate_masks());\n        let sending: Phase<SendingSum2> = self.into();\n        TransitionOutcome::Complete(sending.into())\n    }\n}\n\nimpl From<Phase<Sum2>> for Phase<SendingSum2> {\n    fn from(mut sum2: Phase<Sum2>) -> Self {\n        debug!(\"composing sum2 message\");\n        let message = sum2.compose_message();\n\n        debug!(\"going to sending phase\");\n        let sending = Box::new(SendingSum2::new(message, Awaiting));\n        let state = State::new(sum2.state.shared, sending);\n        state.into_phase(sum2.io)\n    }\n}\n\nimpl From<Phase<Sum2>> for Phase<Awaiting> {\n    fn from(sum2: Phase<Sum2>) -> Self {\n        State::new(sum2.state.shared, Box::new(Awaiting)).into_phase(sum2.io)\n    }\n}\n\nimpl Phase<Sum2> {\n    /// Retrieve the encrypted mask seeds.\n    pub(crate) async fn fetch_seed_dict(mut self) -> Progress<Sum2> {\n        if self.state.private.has_fetched_seed_dict() {\n            return Progress::Continue(self);\n        }\n        debug!(\"polling for update seeds\");\n        match self.io.get_seeds(self.state.shared.keys.public).await {\n            Err(e) => {\n                warn!(\"failed to fetch seeds: {}\", e);\n                Progress::Stuck(self)\n            }\n            Ok(None) => {\n                debug!(\"seeds not available yet\");\n                Progress::Stuck(self)\n            }\n            Ok(Some(seeds)) => {\n                self.state.private.seed_dict = Some(seeds);\n                Progress::Updated(self.into())\n            }\n        }\n    }\n\n    /// Decrypt the mask seeds that the update participants generated.\n    pub(crate) fn decrypt_seeds(mut self) -> Progress<Sum2> {\n        if self.state.private.has_decrypted_seeds() {\n            return Progress::Continue(self);\n        }\n\n        let keys = &self.state.private.ephm_keys;\n        // UNWRAP_SAFE: the seed dict is set in\n        // `self.fetch_seed_dict()` which is called before this method\n        let seeds: Result<Vec<MaskSeed>, ()> = self\n            .state\n            .private\n            .seed_dict\n            .take()\n            .unwrap()\n            .into_iter()\n            .map(|(_, seed)| seed.decrypt(&keys.public, &keys.secret).map_err(|_| ()))\n            .collect();\n\n        match seeds {\n            Ok(seeds) => {\n                self.state.private.seeds = Some(seeds);\n                Progress::Updated(self.into())\n            }\n            Err(_) => {\n                warn!(\"failed to decrypt mask seeds, going back to waiting phase\");\n                self.io.notify_idle();\n                let awaiting: Phase<Awaiting> = self.into();\n                Progress::Updated(awaiting.into())\n            }\n        }\n    }\n\n    /// Derive the masks from the decrypted mask seeds, and aggregate\n    /// them. The resulting mask will later be added to the sum2\n    /// message to be sent to the coordinator.\n    pub(crate) fn aggregate_masks(mut self) -> Progress<Sum2> {\n        if self.state.private.has_aggregated_masks() {\n            return Progress::Continue(self);\n        }\n\n        info!(\"aggregating masks\");\n        let config = self.state.shared.round_params.mask_config;\n        let mask_len = self.state.shared.round_params.model_length;\n        let mut mask_agg = Aggregation::new(config, mask_len as usize);\n        // UNWRAP_SAFE: the seeds are set in `decrypt_seeds()` which is called before this method\n        for seed in self.state.private.seeds.take().unwrap().into_iter() {\n            let mask = seed.derive_mask(mask_len as usize, config);\n            if let Err(e) = mask_agg.validate_aggregation(&mask) {\n                error!(\"sum2 phase failed: cannot aggregate masks: {}\", e);\n                error!(\"going to awaiting phase\");\n                let awaiting: Phase<Awaiting> = self.into();\n                return Progress::Updated(awaiting.into());\n            } else {\n                mask_agg.aggregate(mask);\n            }\n        }\n        self.state.private.mask = Some(mask_agg.into());\n        Progress::Updated(self.into())\n    }\n\n    /// Creates and encodes the sum2 message from the sum2 state.\n    pub fn compose_message(&mut self) -> MessageEncoder {\n        let sum2 = Sum2Message {\n            sum_signature: self.state.private.sum_signature,\n            // UNWRAP_SAFE: the mask set in `aggregate_masks()` which is called before this method\n            model_mask: self.state.private.mask.take().unwrap(),\n        };\n        self.message_encoder(sum2.into())\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-sdk/src/state_machine/phases/update.rs",
    "content": "use std::ops::Deref;\n\nuse async_trait::async_trait;\nuse derive_more::From;\nuse serde::{Deserialize, Serialize};\nuse tracing::{debug, info, warn};\n\nuse xaynet_core::{\n    crypto::Signature,\n    mask::{MaskObject, MaskSeed, Masker, Model},\n    message::Update as UpdateMessage,\n    LocalSeedDict,\n    ParticipantTaskSignature,\n    SumDict,\n};\n\nuse crate::{\n    state_machine::{\n        Awaiting,\n        IntoPhase,\n        Phase,\n        PhaseIo,\n        Progress,\n        SendingUpdate,\n        State,\n        Step,\n        TransitionOutcome,\n        IO,\n    },\n    MessageEncoder,\n};\n\n#[derive(From)]\npub enum LocalModel {\n    Dyn(Box<dyn AsRef<Model> + Send>),\n    Owned(Model),\n}\n\nimpl std::fmt::Debug for LocalModel {\n    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {\n        match self {\n            LocalModel::Dyn(_) => fmt.debug_tuple(\"LocalModel::Dyn\"),\n            LocalModel::Owned(_) => fmt.debug_tuple(\"LocalModel::Owned\"),\n        }\n        .field(&\"...\")\n        .finish()\n    }\n}\n\nimpl AsRef<Model> for LocalModel {\n    fn as_ref(&self) -> &Model {\n        match self {\n            LocalModel::Dyn(model) => model.deref().as_ref(),\n            LocalModel::Owned(model) => model,\n        }\n    }\n}\n\nimpl serde::ser::Serialize for LocalModel {\n    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>\n    where\n        S: serde::ser::Serializer,\n    {\n        match self {\n            LocalModel::Dyn(model) => model.as_ref().as_ref().serialize(serializer),\n            LocalModel::Owned(model) => model.serialize(serializer),\n        }\n    }\n}\n\nimpl<'de> serde::de::Deserialize<'de> for LocalModel {\n    fn deserialize<D>(deserializer: D) -> Result<LocalModel, D::Error>\n    where\n        D: serde::de::Deserializer<'de>,\n    {\n        let model = <Model as serde::de::Deserialize>::deserialize(deserializer)?;\n        Ok(LocalModel::Owned(model))\n    }\n}\n\n/// The state of the update phase.\n#[derive(Serialize, Deserialize, Debug)]\npub struct Update {\n    pub sum_signature: ParticipantTaskSignature,\n    pub update_signature: ParticipantTaskSignature,\n    pub sum_dict: Option<SumDict>,\n    pub seed_dict: Option<LocalSeedDict>,\n    pub model: Option<LocalModel>,\n    pub mask: Option<(MaskSeed, MaskObject)>,\n}\n\nimpl Update {\n    /// Creates a new update state.\n    pub fn new(sum_signature: Signature, update_signature: Signature) -> Self {\n        Update {\n            sum_signature,\n            update_signature,\n            sum_dict: None,\n            seed_dict: None,\n            model: None,\n            mask: None,\n        }\n    }\n\n    fn has_fetched_sum_dict(&self) -> bool {\n        self.sum_dict.is_some() || self.has_loaded_model()\n    }\n\n    fn has_loaded_model(&self) -> bool {\n        self.model.is_some() || self.has_masked_model()\n    }\n\n    fn has_masked_model(&self) -> bool {\n        self.mask.is_some() || self.has_built_seed_dict()\n    }\n\n    fn has_built_seed_dict(&self) -> bool {\n        self.seed_dict.is_some()\n    }\n}\n\nimpl IntoPhase<Update> for State<Update> {\n    fn into_phase(self, mut io: PhaseIo) -> Phase<Update> {\n        io.notify_update();\n        if !self.private.has_loaded_model() {\n            io.notify_load_model();\n        }\n        Phase::<_>::new(self, io)\n    }\n}\n\n#[async_trait]\nimpl Step for Phase<Update> {\n    async fn step(mut self) -> TransitionOutcome {\n        self = try_progress!(self.fetch_sum_dict().await);\n        self = try_progress!(self.load_model().await);\n        self = try_progress!(self.mask_model());\n        self = try_progress!(self.build_seed_dict());\n        let sending: Phase<SendingUpdate> = self.into();\n        TransitionOutcome::Complete(sending.into())\n    }\n}\n\nimpl From<Phase<Update>> for Phase<SendingUpdate> {\n    fn from(mut update: Phase<Update>) -> Self {\n        debug!(\"composing update message\");\n        let message = update.compose_message();\n\n        debug!(\"going to sending phase\");\n        let sending = Box::new(SendingUpdate::new(message, Awaiting));\n        let state = State::new(update.state.shared, sending);\n        state.into_phase(update.io)\n    }\n}\n\nimpl From<Phase<Update>> for Phase<Awaiting> {\n    fn from(update: Phase<Update>) -> Self {\n        State::new(update.state.shared, Box::new(Awaiting)).into_phase(update.io)\n    }\n}\n\nimpl Phase<Update> {\n    pub(crate) async fn fetch_sum_dict(mut self) -> Progress<Update> {\n        if self.state.private.has_fetched_sum_dict() {\n            debug!(\"already fetched the sum dictionary, continuing\");\n            return Progress::Continue(self);\n        }\n        debug!(\"fetching sum dictionary\");\n        match self.io.get_sums().await {\n            Ok(Some(dict)) => {\n                self.state.private.sum_dict = Some(dict);\n                Progress::Updated(self.into())\n            }\n            Ok(None) => {\n                debug!(\"sum dictionary is not available yet\");\n                Progress::Stuck(self)\n            }\n            Err(e) => {\n                warn!(\"failed to fetch sum dictionary: {:?}\", e);\n                Progress::Stuck(self)\n            }\n        }\n    }\n\n    pub(crate) async fn load_model(mut self) -> Progress<Update> {\n        if self.state.private.has_loaded_model() {\n            debug!(\"already loaded the model, continuing\");\n            return Progress::Continue(self);\n        }\n\n        debug!(\"loading local model\");\n        match self.io.load_model().await {\n            Ok(Some(model)) => {\n                self.state.private.model = Some(model.into());\n                Progress::Updated(self.into())\n            }\n            Ok(None) => {\n                debug!(\"model is not ready\");\n                Progress::Stuck(self)\n            }\n            Err(e) => {\n                warn!(\"failed to load model: {:?}\", e);\n                Progress::Stuck(self)\n            }\n        }\n    }\n\n    /// Generate a mask seed and mask a local model.\n    pub(crate) fn mask_model(mut self) -> Progress<Update> {\n        if self.state.private.has_masked_model() {\n            debug!(\"already computed the masked model, continuing\");\n            return Progress::Continue(self);\n        }\n        info!(\"computing masked model\");\n        let config = self.state.shared.round_params.mask_config;\n        let masker = Masker::new(config);\n        // UNWRAP_SAFE: the model is set, per the `has_masked_model()` check above\n        let model = self.state.private.model.take().unwrap();\n        let scalar = self.state.shared.scalar.clone();\n        self.state.private.mask = Some(masker.mask(scalar, model.as_ref()));\n        Progress::Updated(self.into())\n    }\n\n    // Create a local seed dictionary from a sum dictionary.\n    pub(crate) fn build_seed_dict(mut self) -> Progress<Update> {\n        if self.state.private.has_built_seed_dict() {\n            debug!(\"already built the seed dictionary, continuing\");\n            return Progress::Continue(self);\n        }\n        // UNWRAP_SAFE: the mask is set in `mask_model()` which is called before this method\n        let mask_seed = &self.state.private.mask.as_ref().unwrap().0;\n        info!(\"building local seed dictionary\");\n        let seeds = self\n            .state\n            .private\n            .sum_dict\n            .take()\n            .unwrap()\n            .into_iter()\n            .map(|(pk, ephm_pk)| (pk, mask_seed.encrypt(&ephm_pk)))\n            .collect();\n        self.state.private.seed_dict = Some(seeds);\n        Progress::Updated(self.into())\n    }\n\n    /// Creates and encodes the update message from the update state.\n    pub fn compose_message(&mut self) -> MessageEncoder {\n        let update = UpdateMessage {\n            sum_signature: self.state.private.sum_signature,\n            update_signature: self.state.private.update_signature,\n            // UNWRAP_SAFE: the mask is set in `mask_model()` which is called before this method\n            masked_model: self.state.private.mask.take().unwrap().1,\n            // UNWRAP_SAFE: the dict is set in `build_seed_dict()` which is called before this method\n            local_seed_dict: self.state.private.seed_dict.take().unwrap(),\n        };\n        self.message_encoder(update.into())\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-sdk/src/state_machine/state_machine.rs",
    "content": "use derive_more::From;\n\nuse super::{\n    boxed_io,\n    Awaiting,\n    IntoPhase,\n    LocalModelConfig,\n    NewRound,\n    Phase,\n    SendingSum,\n    SendingSum2,\n    SendingUpdate,\n    SerializableState,\n    SharedState,\n    State,\n    Sum,\n    Sum2,\n    Update,\n};\nuse crate::{settings::PetSettings, ModelStore, Notify, XaynetClient};\n\n/// Outcome of a state machine transition attempt.\n#[derive(Debug)]\npub enum TransitionOutcome {\n    /// Outcome when the state machine cannot make immediate progress. The state machine\n    /// is returned unchanged.\n    Pending(StateMachine),\n    /// Outcome when a transition occured and the state machine was updated.\n    Complete(StateMachine),\n}\n\n/// PET state machine.\n#[derive(From, Debug)]\npub enum StateMachine {\n    /// PET state machine in the \"new round\" phase\n    NewRound(Phase<NewRound>),\n    /// PET state machine in the \"awaiting\" phase\n    Awaiting(Phase<Awaiting>),\n    /// PET state machine in the \"sum\" phase\n    Sum(Phase<Sum>),\n    /// PET state machine in the \"update\" phase\n    Update(Phase<Update>),\n    /// PET state machine in the \"sum2\" phase\n    Sum2(Phase<Sum2>),\n    /// PET state machine in the \"sending sum message\" phase\n    SendingSum(Phase<SendingSum>),\n    /// PET state machine in the \"sending update message\" phase\n    SendingUpdate(Phase<SendingUpdate>),\n    /// PET state machine in the \"sending sum2 message\" phase\n    SendingSum2(Phase<SendingSum2>),\n}\n\nimpl StateMachine {\n    /// Try to make progress in the PET protocol\n    pub async fn transition(self) -> TransitionOutcome {\n        match self {\n            StateMachine::NewRound(phase) => phase.step().await,\n            StateMachine::Awaiting(phase) => phase.step().await,\n            StateMachine::Sum(phase) => phase.step().await,\n            StateMachine::Update(phase) => phase.step().await,\n            StateMachine::Sum2(phase) => phase.step().await,\n            StateMachine::SendingSum(phase) => phase.step().await,\n            StateMachine::SendingUpdate(phase) => phase.step().await,\n            StateMachine::SendingSum2(phase) => phase.step().await,\n        }\n    }\n\n    /// Convert the state machine into a serializable data structure so\n    /// that it can be saved.\n    pub fn save(self) -> SerializableState {\n        match self {\n            StateMachine::NewRound(phase) => phase.state.into(),\n            StateMachine::Awaiting(phase) => phase.state.into(),\n            StateMachine::Sum(phase) => phase.state.into(),\n            StateMachine::Update(phase) => phase.state.into(),\n            StateMachine::Sum2(phase) => phase.state.into(),\n            StateMachine::SendingSum(phase) => phase.state.into(),\n            StateMachine::SendingUpdate(phase) => phase.state.into(),\n            StateMachine::SendingSum2(phase) => phase.state.into(),\n        }\n    }\n\n    /// Return the local model configuration of the model that is expected in the update phase.\n    pub fn local_model_config(&self) -> LocalModelConfig {\n        match self {\n            StateMachine::NewRound(ref phase) => phase.local_model_config(),\n            StateMachine::Awaiting(ref phase) => phase.local_model_config(),\n            StateMachine::Sum(ref phase) => phase.local_model_config(),\n            StateMachine::Update(ref phase) => phase.local_model_config(),\n            StateMachine::Sum2(ref phase) => phase.local_model_config(),\n            StateMachine::SendingSum(ref phase) => phase.local_model_config(),\n            StateMachine::SendingUpdate(ref phase) => phase.local_model_config(),\n            StateMachine::SendingSum2(ref phase) => phase.local_model_config(),\n        }\n    }\n}\n\nimpl StateMachine {\n    /// Instantiate a new PET state machine.\n    ///\n    /// # Args\n    ///\n    /// - `settings`: PET settings\n    /// - `xaynet_client`: a client for communicating with the Xaynet coordinator\n    /// - `model_store`: a store from which the trained model can be\n    ///   loaded, when the participant is selected for the update task\n    /// - `notifier`: a type that the state machine can use to emit notifications\n    pub fn new<X, M, N>(\n        settings: PetSettings,\n        xaynet_client: X,\n        model_store: M,\n        notifier: N,\n    ) -> Self\n    where\n        X: XaynetClient + Send + 'static,\n        M: ModelStore + Send + 'static,\n        N: Notify + Send + 'static,\n    {\n        let io = boxed_io(xaynet_client, model_store, notifier);\n        let state = State::new(Box::new(SharedState::new(settings)), Box::new(Awaiting));\n        state.into_phase(io).into()\n    }\n\n    /// Restore the PET state machine from the given `state`.\n    pub fn restore<X, M, N>(\n        state: SerializableState,\n        xaynet_client: X,\n        model_store: M,\n        notifier: N,\n    ) -> Self\n    where\n        X: XaynetClient + Send + 'static,\n        M: ModelStore + Send + 'static,\n        N: Notify + Send + 'static,\n    {\n        let io = boxed_io(xaynet_client, model_store, notifier);\n        match state {\n            SerializableState::NewRound(state) => state.into_phase(io).into(),\n            SerializableState::Awaiting(state) => state.into_phase(io).into(),\n            SerializableState::Sum(state) => state.into_phase(io).into(),\n            SerializableState::Sum2(state) => state.into_phase(io).into(),\n            SerializableState::Update(state) => state.into_phase(io).into(),\n            SerializableState::SendingSum(state) => state.into_phase(io).into(),\n            SerializableState::SendingUpdate(state) => state.into_phase(io).into(),\n            SerializableState::SendingSum2(state) => state.into_phase(io).into(),\n        }\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-sdk/src/state_machine/tests/mod.rs",
    "content": "mod phases;\npub mod utils;\n"
  },
  {
    "path": "rust/xaynet-sdk/src/state_machine/tests/phases/mod.rs",
    "content": "mod new_round;\nmod sum;\nmod sum2;\nmod update;\n"
  },
  {
    "path": "rust/xaynet-sdk/src/state_machine/tests/phases/new_round.rs",
    "content": "use crate::{\n    state_machine::{\n        tests::utils::{shared_state, SelectFor},\n        IntoPhase,\n        MockIO,\n        NewRound,\n        Phase,\n        State,\n    },\n    unwrap_step,\n};\n\n#[tokio::test]\nasync fn test_selected_for_sum() {\n    let mut io = MockIO::new();\n    io.expect_notify_sum().return_const(());\n    let phase = make_phase(SelectFor::Sum, io);\n    unwrap_step!(phase, complete, sum);\n}\n\n#[tokio::test]\nasync fn test_selected_for_update() {\n    let mut io = MockIO::new();\n    io.expect_notify_update().times(1).return_const(());\n    io.expect_notify_load_model().times(1).return_const(());\n    let phase = make_phase(SelectFor::Update, io);\n    unwrap_step!(phase, complete, update);\n}\n\n#[tokio::test]\nasync fn test_not_selected() {\n    let mut io = MockIO::new();\n    io.expect_notify_idle().times(1).return_const(());\n    let phase = make_phase(SelectFor::None, io);\n    unwrap_step!(phase, complete, awaiting);\n}\n\n/// Instantiate a new round phase.\n///\n/// - `task` is the task we want the simulated participant to be selected for. If you want a\n///   sum participant, pass `SelectedFor::Sum` for example.\n/// - `io` is the mock the test wants to use. It should contains all the test expectations. The\n///   reason for settings the mocked IO object in this helper is that once the phase is\n///   created, `phase.io` is a `Box<dyn IO>`, not a `MockIO`. Therefore, it doesn't have any of\n///   the mock methods (`expect_xxx()`, `checkpoint()`, etc.) so we cannot set any expectation\n///   a posteriori\nfn make_phase(task: SelectFor, io: MockIO) -> Phase<NewRound> {\n    let shared = shared_state(task);\n\n    // Check IntoPhase<NewRound> implementation\n    let mut mock = MockIO::new();\n    mock.expect_notify_new_round().times(1).return_const(());\n    let mut phase: Phase<NewRound> =\n        State::new(shared, Box::new(NewRound)).into_phase(Box::new(mock));\n\n    // Set `phase.io` to the mock the test wants to use. Note that this drops the `mock` we\n    // created above, so the expectations we set on `mock` run now.\n    let _ = std::mem::replace(&mut phase.io, Box::new(io));\n    phase\n}\n"
  },
  {
    "path": "rust/xaynet-sdk/src/state_machine/tests/phases/sum.rs",
    "content": "use thiserror::Error;\nuse xaynet_core::crypto::{ByteObject, EncryptKeyPair, EncryptKeySeed};\n\nuse crate::{\n    state_machine::{\n        tests::utils::{shared_state, SelectFor},\n        IntoPhase,\n        MockIO,\n        Phase,\n        SharedState,\n        State,\n        Sum,\n    },\n    unwrap_step,\n};\n\n/// Instantiate a sum phase.\nfn make_phase(io: MockIO) -> Phase<Sum> {\n    let shared = shared_state(SelectFor::Sum);\n    let sum = make_sum(&shared);\n\n    // Check IntoPhase<Sum> implementation\n    let mut mock = MockIO::new();\n    mock.expect_notify_sum().times(1).return_const(());\n    let mut phase: Phase<Sum> = State::new(shared, sum).into_phase(Box::new(mock));\n\n    // Set `phase.io` to the mock the test wants to use. Note that this drops the `mock` we\n    // created above, so the expectations we set on `mock` run now.\n    let _ = std::mem::replace(&mut phase.io, Box::new(io));\n    phase\n}\n\nfn make_sum(shared: &SharedState) -> Box<Sum> {\n    let ephm_keys = EncryptKeyPair::derive_from_seed(&EncryptKeySeed::zeroed());\n    let sk = &shared.keys.secret;\n    let seed = shared.round_params.seed.as_slice();\n    let signature = sk.sign_detached(&[seed, b\"sum\"].concat());\n    Box::new(Sum {\n        ephm_keys,\n        sum_signature: signature,\n    })\n}\n\n#[tokio::test]\nasync fn test_phase() {\n    let io = MockIO::new();\n    let phase = make_phase(io);\n    let _phase = unwrap_step!(phase, complete, sending_sum);\n}\n\n#[derive(Error, Debug)]\n#[error(\"error\")]\nstruct DummyErr;\n"
  },
  {
    "path": "rust/xaynet-sdk/src/state_machine/tests/phases/sum2.rs",
    "content": "use mockall::Sequence;\nuse xaynet_core::{\n    crypto::{ByteObject, EncryptKeyPair, EncryptKeySeed, PublicEncryptKey},\n    mask::{FromPrimitives, MaskConfigPair, MaskObject, MaskSeed, Masker, Model, Scalar},\n    UpdateSeedDict,\n};\n\nuse crate::{\n    state_machine::{\n        tests::utils::{shared_state, SelectFor, SigningKeyGenerator},\n        IntoPhase,\n        MockIO,\n        Phase,\n        SendingSum2,\n        SharedState,\n        State,\n        Sum2,\n    },\n    unwrap_progress_continue,\n    unwrap_step,\n};\n\n/// Instantiate a sum phase.\nfn make_phase() -> Phase<Sum2> {\n    let shared = shared_state(SelectFor::Sum);\n    let sum2 = make_sum2(&shared);\n\n    // Check IntoPhase<Sum2> implementation\n    let mock = MockIO::new();\n    let mut phase: Phase<Sum2> = State::new(shared, sum2).into_phase(Box::new(mock));\n\n    phase.check_io_mock();\n    phase\n}\n\nfn make_sum2(shared: &SharedState) -> Box<Sum2> {\n    let ephm_keys = EncryptKeyPair::derive_from_seed(&EncryptKeySeed::zeroed());\n    let sk = &shared.keys.secret;\n    let seed = shared.round_params.seed.as_slice();\n    let signature = sk.sign_detached(&[seed, b\"sum\"].concat());\n    Box::new(Sum2 {\n        ephm_keys,\n        sum_signature: signature,\n        seed_dict: None,\n        seeds: None,\n        mask: None,\n    })\n}\n\nfn make_seed_dict(mask_config: MaskConfigPair, ephm_pk: PublicEncryptKey) -> UpdateSeedDict {\n    let (seed, _mask) = make_masked_model(mask_config);\n    let mut key_gen = SigningKeyGenerator::new();\n    let mut dict = UpdateSeedDict::new();\n    for _ in 0..4 {\n        let pk = key_gen.next().public;\n        dict.insert(pk, seed.encrypt(&ephm_pk));\n    }\n    dict\n}\n\nfn make_model() -> Model {\n    Model::from_primitives(vec![1.0, 2.0, 3.0, 4.0].into_iter()).unwrap()\n}\n\nfn make_masked_model(mask_config: MaskConfigPair) -> (MaskSeed, MaskObject) {\n    let masker = Masker::new(mask_config);\n    let scalar = Scalar::unit();\n    let model = make_model();\n    masker.mask(scalar, &model)\n}\n\nasync fn step1_fetch_seed_dict(mut phase: Phase<Sum2>) -> Phase<Sum2> {\n    let mask_config = phase.state.shared.round_params.mask_config;\n    let ephm_pk = phase.state.private.ephm_keys.public;\n    phase.with_io_mock(move |mock| {\n        let mut seq = Sequence::new();\n        // The first time the state machine fetches the seed dict,\n        // pretend it's not published yet\n        mock.expect_get_seeds()\n            .times(1)\n            .in_sequence(&mut seq)\n            .returning(|_| Ok(None));\n        // The second time, return it\n        mock.expect_get_seeds()\n            .times(1)\n            .in_sequence(&mut seq)\n            .returning(move |_| Ok(Some(make_seed_dict(mask_config, ephm_pk))));\n    });\n\n    // First time: no progress should be made, since we didn't\n    // fetch any seed dict yet\n    let phase = unwrap_step!(phase, pending, sum2);\n\n    // Second time: now the state machine should have made progress\n    let phase = unwrap_step!(phase, complete, sum2);\n\n    // Calling `fetch_seed_dict` again should return Progress::Continue\n    let mut phase = unwrap_progress_continue!(phase, fetch_seed_dict, async);\n    phase.check_io_mock();\n    phase\n}\n\nasync fn step2_decrypt_seeds(phase: Phase<Sum2>) -> Phase<Sum2> {\n    let phase = unwrap_step!(phase, complete, sum2);\n    assert!(phase.state.private.seeds.is_some());\n    // Make sure this steps consumes the seed dict.\n    assert!(phase.state.private.seed_dict.is_none());\n    phase\n}\n\nasync fn step3_aggregate_masks(phase: Phase<Sum2>) -> Phase<Sum2> {\n    let phase = unwrap_step!(phase, complete, sum2);\n    assert!(phase.state.private.mask.is_some());\n    // Make sure this steps consumes the seeds.\n    assert!(phase.state.private.seeds.is_none());\n    phase\n}\n\nasync fn step4_into_sending_phase(phase: Phase<Sum2>) -> Phase<SendingSum2> {\n    let phase = unwrap_step!(phase, complete, sending_sum2);\n    phase\n}\n\n#[tokio::test]\nasync fn test_phase() {\n    let phase = make_phase();\n    let phase = step1_fetch_seed_dict(phase).await;\n    let phase = step2_decrypt_seeds(phase).await;\n    let phase = step3_aggregate_masks(phase).await;\n    let _phase = step4_into_sending_phase(phase).await;\n}\n"
  },
  {
    "path": "rust/xaynet-sdk/src/state_machine/tests/phases/update.rs",
    "content": "use mockall::Sequence;\nuse xaynet_core::{\n    crypto::ByteObject,\n    mask::{FromPrimitives, Model},\n    SumDict,\n};\n\nuse crate::{\n    save_and_restore,\n    state_machine::{\n        tests::utils::{shared_state, EncryptKeyGenerator, SelectFor, SigningKeyGenerator},\n        IntoPhase,\n        MockIO,\n        Phase,\n        SendingUpdate,\n        SharedState,\n        State,\n        Update,\n    },\n    unwrap_progress_continue,\n    unwrap_step,\n};\n\n/// Instantiate a sum phase.\nfn make_phase() -> Phase<Update> {\n    let shared = shared_state(SelectFor::Update);\n    let update = make_update(&shared);\n\n    // Check IntoPhase<Update> implementation\n    let mut mock = MockIO::new();\n    mock.expect_notify_update().times(1).return_const(());\n    mock.expect_notify_load_model().times(1).return_const(());\n    let mut phase: Phase<Update> = State::new(shared, update).into_phase(Box::new(mock));\n\n    phase.check_io_mock();\n    phase\n}\n\nfn make_update(shared: &SharedState) -> Box<Update> {\n    let sk = &shared.keys.secret;\n    let seed = shared.round_params.seed.as_slice();\n    let sum_signature = sk.sign_detached(&[seed, b\"sum\"].concat());\n    let update_signature = sk.sign_detached(&[seed, b\"update\"].concat());\n    Box::new(Update {\n        sum_signature,\n        update_signature,\n        sum_dict: None,\n        seed_dict: None,\n        model: None,\n        mask: None,\n    })\n}\n\nfn make_model() -> Model {\n    let weights: Vec<f32> = vec![1.1, 2.2, 3.3, 4.4];\n    Model::from_primitives(weights.into_iter()).unwrap()\n}\n\nfn make_sum_dict() -> SumDict {\n    let mut dict = SumDict::new();\n\n    let mut signing_keys = SigningKeyGenerator::new();\n    let mut encrypt_keys = EncryptKeyGenerator::new();\n\n    dict.insert(signing_keys.next().public, encrypt_keys.next().public);\n    dict.insert(signing_keys.next().public, encrypt_keys.next().public);\n\n    dict\n}\n\nasync fn step1_fetch_sum_dict(mut phase: Phase<Update>) -> Phase<Update> {\n    phase.with_io_mock(|mock| {\n        let mut seq = Sequence::new();\n        // The first time the state machine fetches the sum dict,\n        // pretend it's not published yet\n        mock.expect_get_sums()\n            .times(1)\n            .in_sequence(&mut seq)\n            .returning(|| Ok(None));\n        // The second time, return a sum dictionary.\n        mock.expect_get_sums()\n            .times(1)\n            .in_sequence(&mut seq)\n            .returning(|| Ok(Some(make_sum_dict())));\n    });\n\n    // First time: no progress should be made, since we didn't\n    // fetch any sum dict yet\n    let phase = unwrap_step!(phase, pending, update);\n\n    // Second time: now the state machine should have made progress\n    let phase = unwrap_step!(phase, complete, update);\n\n    // Calling `fetch_sum_dict` again should return Progress::Continue\n    let mut phase = unwrap_progress_continue!(phase, fetch_sum_dict, async);\n    phase.check_io_mock();\n    phase\n}\n\nasync fn step2_load_model(mut phase: Phase<Update>) -> Phase<Update> {\n    phase.with_io_mock(|mock| {\n        let mut seq = Sequence::new();\n        // The first time the state machine fetches the sum dict,\n        // pretend it's not published yet\n        mock.expect_load_model()\n            .times(1)\n            .in_sequence(&mut seq)\n            .returning(|| Ok(None));\n        // The second time, return a sum dictionary.\n        mock.expect_load_model()\n            .times(1)\n            .in_sequence(&mut seq)\n            .returning(|| Ok(Some(Box::new(make_model()))));\n    });\n\n    // First time: no progress should be made, since we didn't\n    // load any model\n    let phase = unwrap_step!(phase, pending, update);\n\n    // Second time: now the state machine should have made progress\n    let phase = unwrap_step!(phase, complete, update);\n\n    // Calling `load_model` again should return Progress::Continue\n    let mut phase = unwrap_progress_continue!(phase, load_model, async);\n    phase.check_io_mock();\n    phase\n}\n\nasync fn step3_mask_model(phase: Phase<Update>) -> Phase<Update> {\n    let phase = unwrap_step!(phase, complete, update);\n    let mut phase = unwrap_progress_continue!(phase, mask_model);\n    phase.check_io_mock();\n    phase\n}\n\nasync fn step4_build_seed_dict(phase: Phase<Update>) -> Phase<Update> {\n    let phase = unwrap_step!(phase, complete, update);\n    let mut phase = unwrap_progress_continue!(phase, build_seed_dict);\n    phase.check_io_mock();\n    phase\n}\n\nasync fn step5_into_sending_phase(phase: Phase<Update>) -> Phase<SendingUpdate> {\n    let phase = unwrap_step!(phase, complete, sending_update);\n    phase\n}\n\n#[tokio::test]\nasync fn test_update_phase() {\n    let phase = make_phase();\n    let phase = step1_fetch_sum_dict(phase).await;\n    let phase = step2_load_model(phase).await;\n    let phase = step3_mask_model(phase).await;\n    let phase = step4_build_seed_dict(phase).await;\n    let _phase = step5_into_sending_phase(phase).await;\n}\n\n#[tokio::test]\nasync fn test_save_and_restore() {\n    let phase = make_phase();\n    let mut phase = step1_fetch_sum_dict(phase).await;\n\n    phase.with_io_mock(|mock| {\n        let mut seq = Sequence::new();\n        mock.expect_notify_update()\n            .times(1)\n            .in_sequence(&mut seq)\n            .return_const(());\n        mock.expect_notify_load_model()\n            .times(1)\n            .in_sequence(&mut seq)\n            .return_const(());\n    });\n    let phase = save_and_restore!(phase, Update);\n\n    let mut phase = step2_load_model(phase).await;\n    phase.with_io_mock(|mock| {\n        mock.expect_notify_update().times(1).return_const(());\n    });\n    let phase = save_and_restore!(phase, Update);\n\n    let mut phase = step3_mask_model(phase).await;\n    phase.with_io_mock(|mock| {\n        mock.expect_notify_update().times(1).return_const(());\n    });\n    let phase = save_and_restore!(phase, Update);\n\n    let mut phase = step4_build_seed_dict(phase).await;\n    phase.with_io_mock(|mock| {\n        mock.expect_notify_update().times(1).return_const(());\n    });\n    let _phase = save_and_restore!(phase, Update);\n}\n"
  },
  {
    "path": "rust/xaynet-sdk/src/state_machine/tests/utils.rs",
    "content": "use xaynet_core::{\n    common::{RoundParameters, RoundSeed},\n    crypto::{ByteObject, EncryptKeyPair, EncryptKeySeed, SigningKeyPair, SigningKeySeed},\n    mask::{self, MaskConfig, Scalar},\n};\n\nuse crate::{settings::MaxMessageSize, state_machine::SharedState};\n\n#[macro_export]\nmacro_rules! unwrap_as {\n    ($e:expr, $p:path) => {\n        match $e {\n            $p(s) => s,\n            x => panic!(\"Not a {}: {:?}\", stringify!($p), x),\n        }\n    };\n}\n\n#[macro_export]\nmacro_rules! unwrap_step {\n    ($phase:expr, complete, $state_machine:tt) => {\n        unwrap_step!(\n            $phase,\n            $crate::state_machine::TransitionOutcome::Complete,\n            $state_machine\n        )\n    };\n    ($phase:expr, pending, $state_machine:tt) => {\n        unwrap_step!(\n            $phase,\n            $crate::state_machine::TransitionOutcome::Pending,\n            $state_machine\n        )\n    };\n    ($phase:expr, $transition_outcome:path, awaiting) => {\n        unwrap_step!(\n            $phase,\n            $transition_outcome,\n            $crate::state_machine::StateMachine::Awaiting\n        )\n    };\n    ($phase:expr, $transition_outcome:path, sum) => {\n        unwrap_step!(\n            $phase,\n            $transition_outcome,\n            $crate::state_machine::StateMachine::Sum\n        )\n    };\n    ($phase:expr, $transition_outcome:path, sum2) => {\n        unwrap_step!(\n            $phase,\n            $transition_outcome,\n            $crate::state_machine::StateMachine::Sum2\n        )\n    };\n    ($phase:expr, $transition_outcome:path, update) => {\n        unwrap_step!(\n            $phase,\n            $transition_outcome,\n            $crate::state_machine::StateMachine::Update\n        )\n    };\n    ($phase:expr, $transition_outcome:path, sending_sum) => {\n        unwrap_step!(\n            $phase,\n            $transition_outcome,\n            $crate::state_machine::StateMachine::SendingSum\n        )\n    };\n    ($phase:expr, $transition_outcome:path, sending_update) => {\n        unwrap_step!(\n            $phase,\n            $transition_outcome,\n            $crate::state_machine::StateMachine::SendingUpdate\n        )\n    };\n    ($phase:expr, $transition_outcome:path, sending_sum2) => {\n        unwrap_step!(\n            $phase,\n            $transition_outcome,\n            $crate::state_machine::StateMachine::SendingSum2\n        )\n    };\n    ($phase:expr, $transition_outcome:path, $state_machine:path) => {{\n        let x = $crate::unwrap_as!(\n            $crate::state_machine::Step::step($phase).await,\n            $transition_outcome\n        );\n        $crate::unwrap_as!(x, $state_machine)\n    }};\n}\n\n#[macro_export]\nmacro_rules! unwrap_progress_continue {\n    ($expr:expr) => {\n        $crate::unwrap_as!($expr, $crate::state_machine::Progress::Continue)\n    };\n    ($phase:expr, $method:tt) => {\n        unwrap_progress_continue!($phase.$method())\n    };\n    ($phase:expr, $method:tt, async) => {\n        unwrap_progress_continue!($phase.$method().await)\n    };\n}\n\n#[macro_export]\nmacro_rules! save_and_restore {\n    ($phase:expr, $state:tt) => {{\n        let mut phase = $phase;\n        let io_mock = std::mem::replace(&mut phase.io, Box::new(MockIO::new()));\n        let serializable_state = Into::<$crate::state_machine::SerializableState>::into(phase);\n        // TODO: actually serialize the state here\n        let state = $crate::unwrap_as!(\n            serializable_state,\n            $crate::state_machine::SerializableState::$state\n        );\n        let mut phase = $crate::state_machine::IntoPhase::<$state>::into_phase(state, io_mock);\n        phase.check_io_mock();\n        phase\n    }};\n}\n\n/// Task for which the round parameters should be generated.\n#[derive(Debug, PartialEq, Eq)]\npub enum SelectFor {\n    /// Create round parameters that always select participants for the sum task\n    Sum,\n    /// Create round parameters that always select participants for the update task\n    Update,\n    /// Create round parameters that never select participants\n    None,\n}\n\npub fn mask_config() -> MaskConfig {\n    MaskConfig {\n        group_type: mask::GroupType::Prime,\n        data_type: mask::DataType::F32,\n        bound_type: mask::BoundType::B0,\n        model_type: mask::ModelType::M3,\n    }\n}\n\npub fn round_params(task: SelectFor) -> RoundParameters {\n    RoundParameters {\n        pk: EncryptKeySeed::zeroed().derive_encrypt_key_pair().0,\n        sum: if task == SelectFor::Sum { 1.0 } else { 0.0 },\n        update: if task == SelectFor::Update { 1.0 } else { 0.0 },\n        seed: RoundSeed::zeroed(),\n        mask_config: mask_config().into(),\n        model_length: 0,\n    }\n}\n\npub fn shared_state(task: SelectFor) -> Box<SharedState> {\n    Box::new(SharedState {\n        keys: SigningKeyPair::derive_from_seed(&SigningKeySeed::zeroed()),\n        scalar: Scalar::unit(),\n        message_size: MaxMessageSize::unlimited(),\n        round_params: round_params(task),\n    })\n}\n\npub struct EncryptKeyGenerator(EncryptKeySeed);\n\nimpl EncryptKeyGenerator {\n    pub fn new() -> Self {\n        Self(EncryptKeySeed::zeroed())\n    }\n\n    fn incr_seed(&mut self) {\n        let mut raw = self.0.as_slice().to_vec();\n        for b in &mut raw {\n            if *b < 0xff {\n                *b += 1;\n                return self.0 = EncryptKeySeed::from_slice(raw.as_slice()).unwrap();\n            }\n        }\n        panic!(\"max seed\");\n    }\n\n    pub fn next(&mut self) -> EncryptKeyPair {\n        let keys = EncryptKeyPair::derive_from_seed(&self.0);\n        self.incr_seed();\n        keys\n    }\n}\n\npub struct SigningKeyGenerator(SigningKeySeed);\n\nimpl SigningKeyGenerator {\n    pub fn new() -> Self {\n        Self(SigningKeySeed::zeroed())\n    }\n\n    fn incr_seed(&mut self) {\n        let mut raw = self.0.as_slice().to_vec();\n        for b in &mut raw {\n            if *b < 0xff {\n                *b += 1;\n                return self.0 = SigningKeySeed::from_slice(raw.as_slice()).unwrap();\n            }\n        }\n        panic!(\"max seed\");\n    }\n\n    pub fn next(&mut self) -> SigningKeyPair {\n        let item = SigningKeyPair::derive_from_seed(&self.0);\n        self.incr_seed();\n        item\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-sdk/src/traits.rs",
    "content": "use async_trait::async_trait;\n\nuse xaynet_core::{\n    common::RoundParameters,\n    mask::Model,\n    SumDict,\n    SumParticipantPublicKey,\n    UpdateSeedDict,\n};\n\n/// A trait used by the [`StateMachine`] to emit notifications upon\n/// certain events.\n///\n/// [`StateMachine`]: crate::StateMachine\npub trait Notify {\n    /// Emit a notification when a new round of federated learning\n    /// starts\n    fn new_round(&mut self) {}\n    /// Emit a notification when the participant has been selected for\n    /// the sum task\n    fn sum(&mut self) {}\n    /// Emit a notification when the participant has been selected for\n    /// the update task\n    fn update(&mut self) {}\n    /// Emit a notification when the participant is not selected for\n    /// any task and is waiting for another round to start\n    fn idle(&mut self) {}\n    /// Emit a notification when the participant should populate the\n    /// model store (see [`ModelStore`]).\n    fn load_model(&mut self) {}\n}\n\n/// A trait used by the [`StateMachine`] to load the model trained by\n/// the participant, when it has been selected for the update task.\n///\n/// [`StateMachine`]: crate::StateMachine\n#[async_trait]\npub trait ModelStore {\n    type Error: std::error::Error;\n    type Model: AsRef<Model> + Send;\n\n    /// Attempt to load the model. If the model is not yet available,\n    /// `Ok(None)` should be returned.\n    async fn load_model(&mut self) -> Result<Option<Self::Model>, Self::Error>;\n}\n\n/// A trait used by the [`StateMachine`] to communicate with the\n/// Xaynet coordinator.\n///\n/// [`StateMachine`]: crate::StateMachine\n#[async_trait]\npub trait XaynetClient {\n    type Error: std::error::Error;\n\n    /// Retrieve the current round parameters\n    async fn get_round_params(&mut self) -> Result<RoundParameters, Self::Error>;\n\n    /// Retrieve the current sum dictionary, if available.\n    async fn get_sums(&mut self) -> Result<Option<SumDict>, Self::Error>;\n\n    /// Retrieve the current seed dictionary for the given sum\n    /// participant, if available.\n    async fn get_seeds(\n        &mut self,\n        pk: SumParticipantPublicKey,\n    ) -> Result<Option<UpdateSeedDict>, Self::Error>;\n\n    /// Retrieve the current global model, if available.\n    async fn get_model(&mut self) -> Result<Option<Model>, Self::Error>;\n\n    /// Send an encrypted and signed PET message to the coordinator.\n    async fn send_message(&mut self, msg: Vec<u8>) -> Result<(), Self::Error>;\n}\n"
  },
  {
    "path": "rust/xaynet-sdk/src/utils/concurrent_futures.rs",
    "content": "#![allow(dead_code)]\n\nuse std::{\n    collections::VecDeque,\n    pin::Pin,\n    task::{Context, Poll},\n};\n\nuse futures::{\n    stream::{FuturesUnordered, Stream},\n    Future,\n};\nuse tokio::task::{JoinError, JoinHandle};\n\n/// `ConcurrentFutures` can keep a capped number of futures running concurrently, and yield their\n/// result as they finish. When the max number of concurrent futures is reached, new tasks are\n/// queued until some in-flight futures finish.\npub struct ConcurrentFutures<T>\nwhere\n    T: Future + Send + 'static,\n    T::Output: Send + 'static,\n{\n    /// In-flight futures.\n    running: FuturesUnordered<JoinHandle<T::Output>>,\n    /// Buffered tasks.\n    queued: VecDeque<T>,\n    /// Max number of concurrent futures.\n    max_in_flight: usize,\n}\n\nimpl<T> ConcurrentFutures<T>\nwhere\n    T: Future + Send + 'static,\n    T::Output: Send + 'static,\n{\n    pub fn new(max_in_flight: usize) -> Self {\n        Self {\n            running: FuturesUnordered::new(),\n            queued: VecDeque::new(),\n            max_in_flight,\n        }\n    }\n\n    pub fn push(&mut self, task: T) {\n        self.queued.push_back(task)\n    }\n}\n\nimpl<T> Unpin for ConcurrentFutures<T>\nwhere\n    T: Future + Send + 'static,\n    T::Output: Send + 'static,\n{\n}\n\nimpl<T> Stream for ConcurrentFutures<T>\nwhere\n    T: Future + Send + 'static,\n    T::Output: Send + 'static,\n{\n    type Item = Result<T::Output, JoinError>;\n    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {\n        let this = self.get_mut();\n        while this.running.len() < this.max_in_flight {\n            if let Some(queued) = this.queued.pop_front() {\n                let handle = tokio::spawn(queued);\n                this.running.push(handle);\n            } else {\n                break;\n            }\n        }\n        Pin::new(&mut this.running).poll_next(cx)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use std::time::Duration;\n\n    use futures::stream::StreamExt;\n    use tokio::time::sleep;\n\n    use super::*;\n\n    // this can fail in rare occasions because of polling delays\n    #[tokio::test]\n    async fn test() {\n        let mut stream =\n            ConcurrentFutures::<Pin<Box<dyn Future<Output = u8> + Send + 'static>>>::new(2);\n\n        stream.push(Box::pin(async {\n            sleep(Duration::from_millis(10_u64)).await;\n            1_u8\n        }));\n\n        stream.push(Box::pin(async {\n            sleep(Duration::from_millis(28_u64)).await;\n            2_u8\n        }));\n\n        stream.push(Box::pin(async {\n            sleep(Duration::from_millis(8_u64)).await;\n            3_u8\n        }));\n\n        stream.push(Box::pin(async {\n            sleep(Duration::from_millis(2_u64)).await;\n            4_u8\n        }));\n\n        // poll_next() hasn't been called yet so all futures are queued\n        assert_eq!(stream.running.len(), 0);\n        assert_eq!(stream.queued.len(), 4);\n\n        // future 1 and 2 are spawned, then future 1 is ready\n        assert_eq!(stream.next().await.unwrap().unwrap(), 1);\n\n        // future 2 is pending, futures 3 and 4 are queued\n        assert_eq!(stream.running.len(), 1);\n        assert_eq!(stream.queued.len(), 2);\n\n        // future 3 is spawned, then future 3 is ready\n        assert_eq!(stream.next().await.unwrap().unwrap(), 3);\n\n        // future 2 is pending, future 4 is queued\n        assert_eq!(stream.running.len(), 1);\n        assert_eq!(stream.queued.len(), 1);\n\n        // future 4 is spawned, then future 4 is ready\n        assert_eq!(stream.next().await.unwrap().unwrap(), 4);\n\n        // future 2 is pending, then future 2 is ready\n        assert_eq!(stream.next().await.unwrap().unwrap(), 2);\n\n        // all futures have been resolved\n        assert_eq!(stream.running.len(), 0);\n        assert_eq!(stream.queued.len(), 0);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-sdk/src/utils/mod.rs",
    "content": "// TODO: move to the e2e package\npub mod concurrent_futures;\n"
  },
  {
    "path": "rust/xaynet-server/Cargo.toml",
    "content": "[package]\nname = \"xaynet-server\"\nversion = \"0.2.0\"\nauthors = [\"Xayn Engineering <engineering@xaynet.dev>\"]\nedition = \"2018\"\ndescription = \"The Xayn Network project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant.\"\nreadme = \"../../README.md\"\nhomepage = \"https://xaynet.dev/\"\nrepository = \"https://github.com/xaynetwork/xaynet/\"\nlicense-file = \"../../LICENSE\"\nkeywords = [\"federated-learning\", \"fl\", \"ai\", \"machine-learning\"]\ncategories = [\"science\", \"cryptography\"]\n\n[package.metadata.docs.rs]\nall-features = true\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n\n[dependencies]\nanyhow = \"1.0.62\"\nasync-trait = \"0.1.57\"\nbase64 = \"0.13.0\"\nbincode = \"1.3.3\"\nbitflags = \"1.3.2\"\nbytes = \"1.0.1\"\nconfig = \"0.12.0\"\nchrono = \"0.4.22\"\nderive_more = { version = \"0.99.17\", default-features = false, features = [\n    \"as_mut\",\n    \"as_ref\",\n    \"deref\",\n    \"display\",\n    \"from\",\n    \"index\",\n    \"index_mut\",\n    \"into\",\n] }\ndisplaydoc = \"0.2.3\"\nfutures = \"0.3.24\"\nhex = \"0.4.3\"\nhttp = \"0.2.8\"\ninfluxdb = \"0.5.2\"\nnum = { version = \"0.4.0\", features = [\"serde\"] }\nnum_enum = \"0.5.7\"\nonce_cell = \"1.13.1\"\npaste = \"1.0.8\"\nrand = \"0.8.5\"\nrand_chacha = \"0.3.1\"\nserde = { version = \"1.0.144\", features = [\"derive\"] }\nrayon = \"1.5.3\"\nredis = { version = \"0.21.6\", default-features = false, features = [\n    \"aio\",\n    \"connection-manager\",\n    \"script\",\n    \"tokio-comp\",\n] }\nsodiumoxide = \"0.2.7\"\nstructopt = \"0.3.26\"\nthiserror = \"1.0.32\"\ntokio = { version = \"1.20.1\", features = [\n    \"macros\",\n    \"rt-multi-thread\",\n    \"signal\",\n    \"sync\",\n    \"net\",\n    \"time\",\n] }\ntower = { version = \"0.4.6\", default-features = false, features = [\n    \"buffer\",\n    \"load-shed\",\n    \"limit\"\n] }\ntracing = \"0.1.36\"\ntracing-futures = \"0.2.5\"\ntracing-subscriber = { version = \"0.3.15\", features = [\"env-filter\"] }\nvalidator = { version = \"0.16.0\", features = [\"derive\"] }\nwarp = \"0.3.1\"\nxaynet-core = { path = \"../xaynet-core\", version = \"0.2.0\" }\n\n# feature: model-persistence\nfancy-regex = { version = \"0.10.0\", optional = true }\nrusoto_core = { version = \"0.46.0\", optional = true }\nrusoto_s3 = { version = \"0.46.0\", optional = true }\n\n[dev-dependencies]\n# We can't run tarpaulin with the flag `--test-threads=1` because it can trigger a segfault:\n# https://github.com/xd009642/tarpaulin/issues/317. A workaround is to use `serial_test`.\nmockall = \"0.11.2\"\nserial_test = \"0.8.0\"\ntokio-test = \"0.4.1\"\ntower-test = \"0.4.0\"\n\n[[bin]]\nname = \"coordinator\"\npath = \"src/bin/main.rs\"\n\n[features]\ndefault = []\nfull = [\"metrics\", \"model-persistence\", \"tls\"]\nmetrics = []\nmodel-persistence = [\"fancy-regex\", \"rusoto_core\", \"rusoto_s3\"]\ntls = [\"warp/tls\"]\n"
  },
  {
    "path": "rust/xaynet-server/src/bin/main.rs",
    "content": "use std::{path::PathBuf, process};\n\nuse structopt::StructOpt;\nuse tokio::signal;\nuse tracing::warn;\nuse tracing_subscriber::*;\n\n#[cfg(feature = \"metrics\")]\nuse xaynet_server::{metrics, settings::InfluxSettings};\n\nuse xaynet_server::{\n    rest::{serve, RestError},\n    services,\n    settings::{LoggingSettings, RedisSettings, Settings},\n    state_machine::initializer::StateMachineInitializer,\n    storage::{coordinator_storage::redis, Storage, Store},\n};\n#[cfg(feature = \"model-persistence\")]\nuse xaynet_server::{settings::S3Settings, storage::model_storage::s3};\n\n#[derive(Debug, StructOpt)]\n#[structopt(name = \"Coordinator\")]\nstruct Opt {\n    /// Path of the configuration file\n    #[structopt(short, parse(from_os_str))]\n    config_path: PathBuf,\n}\n\n#[tokio::main]\nasync fn main() {\n    let opt = Opt::from_args();\n\n    let settings = Settings::new(opt.config_path).unwrap_or_else(|err| {\n        eprintln!(\"{}\", err);\n        process::exit(1);\n    });\n    let Settings {\n        pet: pet_settings,\n        mask: mask_settings,\n        api: api_settings,\n        log: log_settings,\n        model: model_settings,\n        redis: redis_settings,\n        ..\n    } = settings;\n\n    init_tracing(log_settings);\n\n    // This should already called internally when instantiating the\n    // state machine but it doesn't hurt making sure the crypto layer\n    // is correctly initialized\n    sodiumoxide::init().unwrap();\n\n    #[cfg(feature = \"metrics\")]\n    init_metrics(settings.metrics.influxdb);\n\n    let store = init_store(\n        redis_settings,\n        #[cfg(feature = \"model-persistence\")]\n        settings.s3,\n    )\n    .await;\n\n    let (state_machine, requests_tx, event_subscriber) = StateMachineInitializer::new(\n        pet_settings,\n        mask_settings,\n        model_settings,\n        #[cfg(feature = \"model-persistence\")]\n        settings.restore,\n        store,\n    )\n    .init()\n    .await\n    .expect(\"failed to initialize state machine\");\n\n    let fetcher = services::fetchers::fetcher(&event_subscriber);\n    let message_handler =\n        services::messages::PetMessageHandler::new(&event_subscriber, requests_tx);\n\n    tokio::select! {\n        biased;\n\n        _ =  signal::ctrl_c() => {}\n        _ = state_machine.run() => {\n            warn!(\"shutting down: Service terminated\");\n        }\n        result = serve(api_settings, fetcher, message_handler) => {\n            match result {\n                Ok(()) => warn!(\"shutting down: REST server terminated\"),\n                Err(RestError::InvalidTlsConfig) => {\n                    warn!(\"shutting down: invalid TLS settings for REST server\");\n                },\n            }\n        }\n    }\n}\n\nfn init_tracing(settings: LoggingSettings) {\n    let _fmt_subscriber = FmtSubscriber::builder()\n        .with_env_filter(settings.filter)\n        .with_ansi(true)\n        .init();\n}\n\n#[cfg(feature = \"metrics\")]\nfn init_metrics(settings: InfluxSettings) {\n    let recorder = metrics::Recorder::new(settings);\n    if metrics::GlobalRecorder::install(recorder).is_err() {\n        warn!(\"failed to install metrics recorder\");\n    };\n}\n\nasync fn init_store(\n    redis_settings: RedisSettings,\n    #[cfg(feature = \"model-persistence\")] s3_settings: S3Settings,\n) -> impl Storage {\n    let coordinator_store = redis::Client::new(redis_settings.url)\n        .await\n        .expect(\"failed to establish a connection to Redis\");\n\n    let model_store = {\n        #[cfg(not(feature = \"model-persistence\"))]\n        {\n            xaynet_server::storage::model_storage::noop::NoOp\n        }\n\n        #[cfg(feature = \"model-persistence\")]\n        {\n            let s3 = s3::Client::new(s3_settings).expect(\"failed to create S3 client\");\n            s3.create_global_models_bucket()\n                .await\n                .expect(\"failed to create bucket for global models\");\n            s3\n        }\n    };\n\n    Store::new(coordinator_store, model_store)\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/examples.rs",
    "content": "/*!\nA guide to getting started with the XayNet examples.\n\n# Examples\n\nThe XayNet examples code can be found under the `rust/examples` directory of the\n[`xaynet`](https://github.com/xaynetwork/xaynet/) repository.\n\nThis Getting Started guide will cover only the general ideas around usage of the\nexamples. Also see the source code of the individual examples themselves, which\nhave plenty of comments.\n\nRunning an example typically requires having a *coordinator* already running,\nwhich is the core component of XayNet.\n\n# Federated Learning\n\nA federated learning session over XayNet consists of two kinds of parties - a\n*coordinator* and (multiple) *participants*. The two parties engage in a\nprotocol (called PET) over a series of rounds. The over-simplified idea is that\nin each round:\n\n1. The coordinator makes available a *global* model, from which selected\nparticipants will train model updates (or, *local* models) to be sent back to\nthe coordinator.\n\n2. As a round progresses, the coordinator aggregates these updates\ninto a new global model.\n\nFrom this description, it might appear that individual local models are plainly\nvisible to the coordinator. What if sensitive data could be extracted from them?\nWould this not be a violation of participants' data privacy?\n\nIn fact, a key point about this process is that the updates are **not** sent in\nthe plain! Rather, they are sent encrypted (or *masked*) so that the coordinator\n(and by extension, XayNet) learns almost nothing about the individual updates.\nYet, it is nevertheless able to aggregate them in such a way that the resulting\nglobal model is unmasked.\n\nThis is essentially what is meant by federated learning that is\n*privacy-preserving*, and is a key feature enabled by the PET protocol.\n\n## PET Protocol\n\nIt is worth describing the protocol very briefly here, if only to better\nunderstand some of the configuration settings we will meet later. It is helpful\nto think of each round being divided up into several contiguous phases:\n\n**Start.** At the start of a round, the coordinator generates a collection of random *round\nparameters* for all participants. From these parameters, each participant is\nable to determine whether it is selected for the round and if so, which of the\ntwo roles it is:\n\n- *update* participants.\n\n- *sum* participants.\n\n**Sum.** In the Sum phase, sum participants send `sum` messages to the\ncoordinator (the details of which are not so important here, but vital for\ncomputing `sum2` messages later).\n\n**Update.** In the Update phase, each update participant obtains the global\nmodel from the coordinator, trains a local model from it, masks it, and sends it\nto the coordinator in the form of `update` messages. The coordinator will\ninternally aggregate these (masked) local models.\n\n**Sum2.** In the Sum2 phase, sum participants compute the sum of masks over all\nthe local models, and sends it to the coordinator in the form of `sum2` messages.\n\nEquipped with the sum of masks, the coordinator is able to *unmask* the\naggregated global model, for the next round.\n\nThis short description of the protocol skips over many details, but is\nsufficient for the purposes of this guide. For a much more complete\nspecification, see the [white paper](https://uploads-ssl.webflow.com/5f0c5c0bb18a279f0a62919e/5f157004da6585f299fa542b_XayNet%20Whitepaper%202.1.pdf).\n\n# Coordinator\n\nThe coordinator is configurable via various settings. The project contains\nvarious ready-made configuration files that can be used, found under the\n`configs` directory of the repository. Typically they look something like\nthe following (in TOML format):\n\n```toml\n[api]\nbind_address = \"127.0.0.1:8081\"\n\n[pet.sum]\nprob = 0.1\ncount = { min = 1, max = 100 }\ntime = { min = 5, max = 3600 }\n\n[pet.update]\nprob = 0.9\ncount = { min = 3, max = 10000 }\ntime = { min = 10, max = 3600 }\n\n[pet.sum2]\ncount = { min = 1, max = 100 }\ntime = { min = 5, max = 3600 }\n\n[mask]\ngroup_type = \"Prime\"\ndata_type = \"F32\"\nbound_type = \"B0\"\nmodel_type = \"M3\"\n\n[model]\nlength = 4\n```\n\nThe actual files contain more settings than this, but we mention just the\nselection above because they will be the most relevant for this guide.\n\n## Settings\n\nGoing from the top, the [`ApiSettings`] include the\naddress the coordinator should listen on for requests from participants. This\naddress should be known to all participants. Optionally, it also contains configurations for TLS\nserver and client authentication.\n\nThe [`PetSettings`] specify various parameters of the PET protocol:\n\n- The most important are [`sum.prob`] and [`update.prob`], which are the probabilities assigned to\nthe selection of sum and update participants, respectively (note that if a participant is selected for\nboth roles, the *sum* role takes precedence).\n\n- The settings [`sum.count.min`], [`update.count.min`] and [`sum2.count.min`] specify, respectively,\nthe minimum number of `sum`, `update` and `sum2` messages the coordinator should accept. Similarly,\nthe [`sum.count.max`], [`update.count.max`] and [`sum2.count.max`] specify the maximum number of\n`sum`, `update` and `sum2` messages the coordinator should accept.\n\n- To complement, the settings [`sum.time.min`], [`update.time.min`] and [`sum2.time.min`] specify,\nrespectively, the minimum amount of time (in seconds) the coordinator should wait for `sum`,\n`update` and `sum2` messages. To allow for more messages to be processed, increase these times.\nSimilarly, the [`sum.time.max`], [`update.time.max`] and [`sum2.time.max`] specify the maximum\namount of time (in seconds) the coordinator should wait for `sum`, `update` and `sum2` messages.\n\nThe [`MaskSettings`] determines the masking configuration, consisting of the\ngroup type, data type, bound type and model type. The [`ModelSettings`] specify\nthe length of the model used. Both of these settings should be decided in advance\nwith participants, and agreed upon by both.\n\n## Running\n\nThe coordinator can be run as follows:\n\n```text\n$ git clone git://github.com/xaynetwork/xaynet\n$ cd xaynet/rust\n$ cargo run --bin coordinator -- -c ../configs/config.toml\n```\n\n\n## Running participants\n\nYou can run the example from the xaynet repository:\n\n```text\n$ git clone https://github.com/xaynetwork/xaynet\n$ cf xaynet/rust/examples\n$ RUST_LOG=info cargo run --example test-drive -- -n 10\n```\n\n[`ApiSettings`]: crate::settings::ApiSettings\n[`PetSettings`]: crate::settings::PetSettings\n[`sum.prob`]: crate::settings::PetSettingsSum::prob\n[`update.prob`]: crate::settings::PetSettingsUpdate::prob\n[`sum.count.min`]: crate::settings::PetSettingsSum::count\n[`update.count.min`]: crate::settings::PetSettingsUpdate::count\n[`sum2.count.min`]: crate::settings::PetSettingsSum2::count\n[`sum.count.max`]: crate::settings::PetSettingsSum::count\n[`update.count.max`]: crate::settings::PetSettingsUpdate::count\n[`sum2.count.max`]: crate::settings::PetSettingsSum2::count\n[`sum.time.min`]: crate::settings::PetSettingsSum::time\n[`update.time.min`]: crate::settings::PetSettingsUpdate::time\n[`sum2.time.min`]: crate::settings::PetSettingsSum2::time\n[`sum.time.max`]: crate::settings::PetSettingsSum::time\n[`update.time.max`]: crate::settings::PetSettingsUpdate::time\n[`sum2.time.max`]: crate::settings::PetSettingsSum2::time\n[`MaskSettings`]: crate::settings::MaskSettings\n[`ModelSettings`]: crate::settings::ModelSettings\n*/\n"
  },
  {
    "path": "rust/xaynet-server/src/lib.rs",
    "content": "#![cfg_attr(docsrs, feature(doc_cfg))]\n#![cfg_attr(\n    doc,\n    forbid(rustdoc::broken_intra_doc_links, rustdoc::private_intra_doc_links)\n)]\n#![doc(\n    html_logo_url = \"https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/xaynet_banner.png\",\n    html_favicon_url = \"https://raw.githubusercontent.com/xaynetwork/xaynet/master/assets/favicon.png\",\n    issue_tracker_base_url = \"https://github.com/xaynetwork/xaynet/issues\"\n)]\n//! `xaynet_server` is a backend for federated machine learning. It\n//! ensures the users privacy using the _Privacy-Enhancing Technology_\n//! (PET). Download the [whitepaper] for an introduction to the\n//! protocol.\n//!\n//! [whitepaper]: https://uploads-ssl.webflow.com/5f0c5c0bb18a279f0a62919e/5f157004da6585f299fa542b_XayNet%20Whitepaper%202.1.pdf\n\npub mod examples;\n\npub mod metrics;\npub mod rest;\npub mod services;\npub mod settings;\npub mod state_machine;\npub mod storage;\n"
  },
  {
    "path": "rust/xaynet-server/src/metrics/mod.rs",
    "content": "//! Utils to record metrics.\n\npub mod recorders;\n\nuse once_cell::sync::OnceCell;\n\npub use self::recorders::influxdb::{Measurement, Recorder, Tags};\n\nstatic RECORDER: OnceCell<Recorder> = OnceCell::new();\n\n/// A wrapper around a static global metrics/events recorder.\npub struct GlobalRecorder;\n\nimpl GlobalRecorder {\n    /// Gets the reference to the global recorder.\n    ///\n    /// Returns `None` if no recorder is set or is currently being initialized.\n    /// This method never blocks.\n    pub fn global() -> Option<&'static Recorder> {\n        RECORDER.get()\n    }\n\n    /// Installs a new global recorder.\n    ///\n    /// Returns Err(Recorder) if a recorder has already been set.\n    pub fn install(recorder: Recorder) -> Result<(), Recorder> {\n        RECORDER.set(recorder)\n    }\n}\n\n/// Records an event.\n///\n/// # Example\n///\n/// ```compile_fail\n/// // An event with just a title:\n/// event!(\"Error\");\n///\n/// // An event with a title and a description:\n/// event!(\"Error\", \"something went wrong\");\n///\n/// // An event with a title, a description and tags:\n/// event!(\n///     \"Error\",\n///     \"something went wrong\",\n///     [\"phase error\", \"coordinator\"],\n/// );\n/// ```\n#[macro_export]\nmacro_rules! event {\n    ($title: expr $(,)?) => {\n        if let Some(recorder) = crate::metrics::GlobalRecorder::global() {\n            recorder.event::<_, _, &str, _, &[_], &str>($title, None, None);\n        }\n    };\n    ($title: expr, $description: expr $(,)?) => {\n        if let Some(recorder) = crate::metrics::GlobalRecorder::global() {\n            recorder.event::<_, _, _, _, &[_], &str>($title, $description, None);\n        }\n    };\n    ($title: expr, $description: expr, [$($tags: expr),+] $(,)?) => {\n        if let Some(recorder) = crate::metrics::GlobalRecorder::global() {\n            recorder.event($title, $description, [$($tags),+])\n        }\n    };\n}\n\n/// Records a metric.\n///\n/// # Example\n///\n/// ```compile_fail\n/// // A basic metric:\n/// metric!(Measurement::RoundTotalNumber, 1);\n///\n/// // A metric with one tag:\n/// metric!(Measurement::RoundParamSum, 0.7, (\"round_id\", 1));\n///\n/// // A metric with multiple tags:\n/// metric!(\n///     Measurement::RoundParamSum,\n///     0.7,\n///     (\"round_id\", 1),\n///     (\"phase\", 2),\n/// );\n/// ```\n#[macro_export]\nmacro_rules! metric {\n    ($measurement: expr, $value: expr $(,)?) => {\n        if let Some(recorder) = crate::metrics::GlobalRecorder::global() {\n            recorder.metric::<_, _, crate::metrics::Tags>($measurement, $value, None);\n        }\n    };\n    ($measurement: expr, $value: expr, $(($tag: expr, $val: expr)),+ $(,)?) => {\n        if let Some(recorder) = crate::metrics::GlobalRecorder::global() {\n            let mut tags = crate::metrics::Tags::new();\n            $(\n                tags.add($tag, $val);\n            )+\n            recorder.metric($measurement, $value, tags);\n        }\n    };\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/metrics/recorders/influxdb/dispatcher.rs",
    "content": "use super::models::{Event, Metric};\nuse derive_more::From;\nuse futures::future::BoxFuture;\nuse influxdb::{Client as InfluxClient, WriteQuery};\nuse std::task::{Context, Poll};\nuse tower::Service;\nuse tracing::debug;\n\n#[derive(From)]\npub(in crate::metrics) enum Request {\n    Metric(Metric),\n    Event(Event),\n}\n\nimpl From<Request> for WriteQuery {\n    fn from(req: Request) -> Self {\n        match req {\n            Request::Metric(metric) => metric.into(),\n            Request::Event(event) => event.into(),\n        }\n    }\n}\n\n#[derive(Clone)]\npub(in crate::metrics) struct Dispatcher {\n    client: InfluxClient,\n}\n\nimpl Dispatcher {\n    pub fn new(url: impl Into<String>, database: impl Into<String>) -> Self {\n        let client = InfluxClient::new(url, database);\n        Self { client }\n    }\n}\n\nimpl Service<Request> for Dispatcher {\n    type Response = ();\n    type Error = anyhow::Error;\n    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;\n\n    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {\n        Poll::Ready(Ok(()))\n    }\n\n    fn call(&mut self, req: Request) -> Self::Future {\n        let client = self.client.clone();\n        let fut = async move {\n            debug!(\"dispatch metric\");\n            client\n                .query(&WriteQuery::from(req))\n                .await\n                .map_err(|err| anyhow::anyhow!(\"failed to dispatch metric {}\", err))?;\n            Ok(())\n        };\n\n        Box::pin(fut)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use tokio_test::assert_ready;\n    use tower_test::mock::Spawn;\n\n    use super::*;\n    use crate::{\n        metrics::{\n            recorders::influxdb::models::{Event, Metric},\n            Measurement,\n        },\n        settings::InfluxSettings,\n    };\n\n    fn influx_settings() -> InfluxSettings {\n        InfluxSettings {\n            url: \"http://127.0.0.1:8086\".to_string(),\n            db: \"metrics\".to_string(),\n        }\n    }\n\n    #[tokio::test]\n    #[ignore]\n    async fn integration_dispatch_metric() {\n        let settings = influx_settings();\n        let mut task = Spawn::new(Dispatcher::new(settings.url, settings.db));\n\n        let metric = Metric::new(Measurement::Phase, 1);\n        assert_ready!(task.poll_ready()).unwrap();\n        let resp = task.call(metric.into()).await;\n        assert!(resp.is_ok());\n    }\n\n    #[tokio::test]\n    #[ignore]\n    async fn integration_dispatch_event() {\n        let settings = influx_settings();\n        let mut task = Spawn::new(Dispatcher::new(settings.url, settings.db));\n\n        let event = Event::new(\"event\");\n        assert_ready!(task.poll_ready()).unwrap();\n        let resp = task.call(event.into()).await;\n        assert!(resp.is_ok());\n    }\n\n    #[tokio::test]\n    #[ignore]\n    async fn integration_wrong_url() {\n        let settings = influx_settings();\n        let mut task = Spawn::new(Dispatcher::new(\"http://127.0.0.1:9998\", settings.db));\n\n        let event = Event::new(\"event\");\n        assert_ready!(task.poll_ready()).unwrap();\n        let resp = task.call(event.into()).await;\n        assert!(resp.is_err());\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/metrics/recorders/influxdb/mod.rs",
    "content": "mod dispatcher;\nmod models;\nmod recorder;\nmod service;\n\npub(in crate::metrics) use self::{\n    dispatcher::{Dispatcher, Request},\n    models::{Event, Metric},\n    service::InfluxDbService,\n};\npub use self::{\n    models::{Measurement, Tags},\n    recorder::Recorder,\n};\n"
  },
  {
    "path": "rust/xaynet-server/src/metrics/recorders/influxdb/models.rs",
    "content": "use std::{borrow::Borrow, iter::IntoIterator};\n\nuse chrono::{DateTime, Utc};\nuse influxdb::{InfluxDbWriteable, Timestamp, Type, WriteQuery};\n\n/// An enum that contains all supported measurements.\npub enum Measurement {\n    RoundParamSum,\n    RoundParamUpdate,\n    Phase,\n    MasksTotalNumber,\n    RoundTotalNumber,\n    MessageAccepted,\n    MessageDiscarded,\n    MessageRejected,\n}\n\nimpl From<Measurement> for &'static str {\n    fn from(measurement: Measurement) -> &'static str {\n        match measurement {\n            Measurement::RoundParamSum => \"round_param_sum\",\n            Measurement::RoundParamUpdate => \"round_param_update\",\n            Measurement::Phase => \"phase\",\n            Measurement::MasksTotalNumber => \"masks_total_number\",\n            Measurement::RoundTotalNumber => \"round_total_number\",\n            Measurement::MessageAccepted => \"message_accepted\",\n            Measurement::MessageDiscarded => \"message_discarded\",\n            Measurement::MessageRejected => \"message_rejected\",\n        }\n    }\n}\n\nimpl From<Measurement> for String {\n    fn from(measurement: Measurement) -> Self {\n        <&str>::from(measurement).into()\n    }\n}\n\n/// A container that contains the tags of a metric.\npub struct Tags(Vec<(String, Type)>);\n\nimpl Tags {\n    /// Creates a new empty container for tags.\n    pub fn new() -> Self {\n        Self(Vec::new())\n    }\n\n    /// Adds a tag to the metric.\n    pub fn add(&mut self, tag: impl Into<String>, value: impl Into<Type>) {\n        self.0.push((tag.into(), value.into()))\n    }\n}\n\nimpl Default for Tags {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl IntoIterator for Tags {\n    type Item = <Vec<(String, Type)> as IntoIterator>::Item;\n    type IntoIter = <Vec<(String, Type)> as IntoIterator>::IntoIter;\n\n    fn into_iter(self) -> Self::IntoIter {\n        self.0.into_iter()\n    }\n}\n\n/// A metrics data point.\npub(in crate::metrics) struct Metric {\n    name: Measurement,\n    time: DateTime<Utc>,\n    value: Type,\n    tags: Option<Tags>,\n}\n\nimpl Metric {\n    pub(in crate::metrics) fn new(measurement: Measurement, value: impl Into<Type>) -> Self {\n        Self {\n            name: measurement,\n            time: Utc::now(),\n            value: value.into(),\n            tags: None,\n        }\n    }\n\n    pub(in crate::metrics) fn with_tags<T, I>(mut self, tags: T) -> Self\n    where\n        T: Into<Option<I>>,\n        I: Into<Tags>,\n    {\n        // It is by design that this function should only be called once.\n        // see `Recorder::metric`\n        // Therefore, we don't cover the case where we would extend `self.tags`\n        // when `self.tags` already contains tags.\n        self.tags = tags.into().map(Into::into);\n        self\n    }\n}\n\nimpl From<Metric> for WriteQuery {\n    fn from(metric: Metric) -> Self {\n        let mut query = Timestamp::from(metric.time).into_query(metric.name);\n        query = query.add_field(\"value\", metric.value);\n\n        if let Some(tags) = metric.tags {\n            for (tag, value) in tags {\n                query = query.add_tag(tag, value);\n            }\n        }\n\n        query\n    }\n}\n\n/// An event data point.\npub(in crate::metrics) struct Event {\n    name: &'static str,\n    time: DateTime<Utc>,\n    title: String,\n    description: Option<String>,\n    tags: Option<String>,\n}\n\nimpl Event {\n    pub(in crate::metrics) fn new(title: impl Into<String>) -> Self {\n        Self {\n            name: \"event\",\n            time: Utc::now(),\n            title: title.into(),\n            description: None,\n            tags: None,\n        }\n    }\n\n    pub(in crate::metrics) fn with_description<D, S>(mut self, description: D) -> Self\n    where\n        D: Into<Option<S>>,\n        S: Into<String>,\n    {\n        self.description = description.into().map(Into::into);\n        self\n    }\n\n    pub(in crate::metrics) fn with_tags<T, A, B>(mut self, tags: T) -> Self\n    where\n        T: Into<Option<A>>,\n        A: AsRef<[B]>,\n        B: Borrow<str>,\n    {\n        // It is by design that this function should only be called once.\n        // see `Recorder::metric`\n        // Therefore, we don't cover the case where we would extend `self.tags`\n        // when `self.tags` already contains tags.\n        self.tags = tags.into().map(|tags| tags.as_ref().join(\",\"));\n        self\n    }\n}\n\nimpl From<Event> for WriteQuery {\n    fn from(event: Event) -> Self {\n        let mut query = Timestamp::from(event.time).into_query(event.name);\n        query = query.add_field(\"title\", event.title);\n\n        if let Some(description) = event.description {\n            query = query.add_field(\"description\", description);\n        }\n\n        if let Some(tags) = event.tags {\n            query = query.add_field(\"tags\", tags);\n        }\n\n        query\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use influxdb::Query;\n\n    use super::*;\n\n    /// Creates key-value tags for metrics.\n    macro_rules! tags {\n        ($(($tag: expr, $val: expr)),+ $(,)?) => {\n            {\n                let mut tags = crate::metrics::Tags::new();\n                $(\n                    tags.add($tag, $val);\n                )+\n                tags\n            }\n        };\n    }\n\n    #[test]\n    fn test_basic_metric() {\n        let metric = Metric::new(Measurement::Phase, 1);\n        assert!(WriteQuery::from(metric)\n            .build()\n            .unwrap()\n            .get()\n            .starts_with(\"phase value=1i \"))\n    }\n\n    #[test]\n    fn test_metric_with_tag() {\n        let metric = Metric::new(Measurement::Phase, 1).with_tags(tags![(\"key\", 42)]);\n        assert!(WriteQuery::from(metric)\n            .build()\n            .unwrap()\n            .get()\n            .starts_with(\"phase,key=42 value=1i \"))\n    }\n\n    #[test]\n    fn test_metric_with_tags() {\n        let metric = Metric::new(Measurement::Phase, 1).with_tags(tags![\n            (\"key_1\", 42),\n            (\"key_2\", \"42\"),\n            (\"key_3\", 1.0f32),\n        ]);\n        assert!(WriteQuery::from(metric)\n            .build()\n            .unwrap()\n            .get()\n            .starts_with(\"phase,key_1=42,key_2=42,key_3=1 value=1i \"))\n    }\n\n    #[test]\n    fn test_basic_event() {\n        let event = Event::new(\"error\");\n        assert!(WriteQuery::from(event)\n            .build()\n            .unwrap()\n            .get()\n            .starts_with(\"event title=\\\"error\\\" \"))\n    }\n\n    #[test]\n    fn test_event_with_description() {\n        let event = Event::new(\"error\").with_description(\"description\");\n        assert!(WriteQuery::from(event)\n            .build()\n            .unwrap()\n            .get()\n            .starts_with(\"event title=\\\"error\\\",description=\\\"description\\\" \"))\n    }\n\n    #[test]\n    fn test_event_with_description_and_tag() {\n        let event = Event::new(\"error\")\n            .with_description(\"description\")\n            .with_tags([\"tag\"]);\n        assert!(WriteQuery::from(event)\n            .build()\n            .unwrap()\n            .get()\n            .starts_with(\"event title=\\\"error\\\",description=\\\"description\\\",tags=\\\"tag\\\" \"))\n    }\n\n    #[test]\n    fn test_event_with_description_and_tags() {\n        let event = Event::new(\"error\")\n            .with_description(\"description\")\n            .with_tags([\"tag_1\", \"tag_2\"]);\n        assert!(WriteQuery::from(event)\n            .build()\n            .unwrap()\n            .get()\n            .starts_with(\"event title=\\\"error\\\",description=\\\"description\\\",tags=\\\"tag_1,tag_2\\\" \"))\n    }\n\n    #[test]\n    fn test_event_with_tag() {\n        let event = Event::new(\"error\").with_tags([\"tag\"]);\n        assert!(WriteQuery::from(event)\n            .build()\n            .unwrap()\n            .get()\n            .starts_with(\"event title=\\\"error\\\",tags=\\\"tag\\\" \"))\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/metrics/recorders/influxdb/recorder.rs",
    "content": "use std::borrow::Borrow;\n\nuse futures::future::poll_fn;\nuse influxdb::Type;\nuse tower::Service;\nuse tracing::{error, warn};\n\nuse super::{Dispatcher, Event, InfluxDbService, Measurement, Metric, Request, Tags};\nuse crate::settings::InfluxSettings;\n\n/// An InfluxDB metrics / events recorder.\npub struct Recorder {\n    /// A services that dispatches the recorded metrics / events to an InfluxDB instance.\n    service: InfluxDbService,\n}\n\nimpl Recorder {\n    /// Creates a new InfluxDB recorder.\n    pub fn new(settings: InfluxSettings) -> Self {\n        let dispatcher = Dispatcher::new(settings.url, settings.db);\n        Self {\n            service: InfluxDbService::new(dispatcher),\n        }\n    }\n\n    /// Records a new metric and dispatches it to an InfluxDB instance.\n    pub fn metric<V, T, I>(&self, measurement: Measurement, value: V, tags: T)\n    where\n        V: Into<Type>,\n        T: Into<Option<I>>,\n        I: Into<Tags>,\n    {\n        let metric = Metric::new(measurement, value).with_tags(tags);\n        self.call(metric.into());\n    }\n\n    /// Records a new event and dispatches it to an InfluxDB instance.\n    pub fn event<H, D, S, T, A, B>(&self, title: H, description: D, tags: T)\n    where\n        H: Into<String>,\n        D: Into<Option<S>>,\n        S: Into<String>,\n        T: Into<Option<A>>,\n        A: AsRef<[B]>,\n        B: Borrow<str>,\n    {\n        let event = Event::new(title)\n            .with_description(description)\n            .with_tags(tags);\n        self.call(event.into());\n    }\n\n    fn call(&self, req: Request) {\n        let mut handle = self.service.0.clone();\n        tokio::spawn(async move {\n            if let Err(err) = poll_fn(|cx| handle.poll_ready(cx)).await {\n                error!(\"influx service temporarily unavailable: {}\", err)\n            }\n\n            if let Err(err) = handle.call(req).await {\n                warn!(\"influx service error: {}\", err)\n            }\n        });\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/metrics/recorders/influxdb/service.rs",
    "content": "use super::{Dispatcher, Request};\nuse tower::{buffer::Buffer, limit::ConcurrencyLimit, load_shed::LoadShed, ServiceBuilder};\n\npub(in crate::metrics) struct InfluxDbService(\n    pub LoadShed<Buffer<ConcurrencyLimit<Dispatcher>, Request>>,\n);\n\nimpl InfluxDbService {\n    pub fn new(dispatcher: Dispatcher) -> Self {\n        let service = ServiceBuilder::new()\n            .load_shed()\n            .buffer(4048)\n            .concurrency_limit(50)\n            .service(dispatcher);\n        Self(service)\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/metrics/recorders/mod.rs",
    "content": "pub mod influxdb;\n"
  },
  {
    "path": "rust/xaynet-server/src/rest.rs",
    "content": "//! A HTTP API for the PET protocol interactions.\n\nuse std::convert::Infallible;\n#[cfg(feature = \"tls\")]\nuse std::path::PathBuf;\n\nuse bytes::Bytes;\nuse serde::{Deserialize, Serialize};\nuse thiserror::Error;\nuse tracing::{error, warn};\nuse warp::{\n    http::{Response, StatusCode},\n    reply::Reply,\n    Filter,\n};\n#[cfg(feature = \"tls\")]\nuse warp::{Server, TlsServer};\n\nuse crate::{\n    services::{fetchers::Fetcher, messages::PetMessageHandler},\n    settings::ApiSettings,\n};\nuse xaynet_core::{crypto::ByteObject, ParticipantPublicKey};\n\n#[derive(Deserialize, Serialize)]\nstruct SeedDictQuery {\n    pk: String,\n}\n\n/// Starts a HTTP server at the given address, listening to GET requests for\n/// data and POST requests containing PET messages.\n///\n/// * `api_settings`: address of the server and optional certificate and key for TLS server\n///   authentication as well as trusted anchors for TLS client authentication.\n/// * `fetcher`: fetcher for responding to data requests.\n/// * `pet_message_handler`: handler for responding to PET messages.\n///\n/// # Errors\n/// Fails if the TLS settings are invalid.\npub async fn serve<F>(\n    api_settings: ApiSettings,\n    fetcher: F,\n    pet_message_handler: PetMessageHandler,\n) -> Result<(), RestError>\nwhere\n    F: Fetcher + Sync + Send + 'static + Clone,\n{\n    let message = warp::path!(\"message\")\n        .and(warp::post())\n        .and(warp::body::bytes())\n        .and(with_message_handler(pet_message_handler.clone()))\n        .and_then(handle_message);\n\n    let sum_dict = warp::path!(\"sums\")\n        .and(warp::get())\n        .and(with_fetcher(fetcher.clone()))\n        .and_then(handle_sums);\n\n    let seed_dict = warp::path!(\"seeds\")\n        .and(warp::get())\n        .and(warp::query::<SeedDictQuery>())\n        .and_then(part_pk)\n        .and(with_fetcher(fetcher.clone()))\n        .and_then(handle_seeds);\n\n    let round_params = warp::path!(\"params\")\n        .and(warp::get())\n        .and(with_fetcher(fetcher.clone()))\n        .and_then(handle_params);\n\n    let model = warp::path!(\"model\")\n        .and(warp::get())\n        .and(with_fetcher(fetcher.clone()))\n        .and_then(handle_model);\n\n    let routes = message\n        .or(round_params)\n        .or(sum_dict)\n        .or(seed_dict)\n        .or(model)\n        .recover(handle_reject)\n        .with(warp::log(\"http\"));\n\n    #[cfg(not(feature = \"tls\"))]\n    return run_http(routes, api_settings)\n        .await\n        .map_err(RestError::from);\n    #[cfg(feature = \"tls\")]\n    return run_https(routes, api_settings).await;\n}\n\n/// Handles and responds to a PET message.\nasync fn handle_message(\n    body: Bytes,\n    mut handler: PetMessageHandler,\n) -> Result<impl warp::Reply, Infallible> {\n    let _ = handler.handle_message(body.to_vec()).await.map_err(|e| {\n        warn!(\"failed to handle message: {:?}\", e);\n    });\n    Ok(warp::reply())\n}\n\n/// Handles and responds to a request for the sum dictionary.\nasync fn handle_sums<F: Fetcher>(mut fetcher: F) -> Result<impl warp::Reply, Infallible> {\n    Ok(match fetcher.sum_dict().await {\n        Err(e) => {\n            warn!(\"failed to handle sum dict request: {:?}\", e);\n            Response::builder()\n                .status(StatusCode::INTERNAL_SERVER_ERROR)\n                .body(Vec::new())\n                .unwrap()\n        }\n        Ok(None) => Response::builder()\n            .status(StatusCode::NO_CONTENT)\n            .body(Vec::new())\n            .unwrap(),\n        Ok(Some(dict)) => {\n            let bytes = bincode::serialize(dict.as_ref()).unwrap();\n            Response::builder()\n                .header(\"Content-Type\", \"application/octet-stream\")\n                .status(StatusCode::OK)\n                .body(bytes)\n                .unwrap()\n        }\n    })\n}\n\n/// Handles and responds to a request for the seed dictionary.\nasync fn handle_seeds<F: Fetcher>(\n    pk: ParticipantPublicKey,\n    mut fetcher: F,\n) -> Result<impl warp::Reply, Infallible> {\n    Ok(match fetcher.seed_dict().await {\n        Err(e) => {\n            warn!(\"failed to handle seed dict request: {:?}\", e);\n            Response::builder()\n                .status(StatusCode::INTERNAL_SERVER_ERROR)\n                .body(Vec::new())\n                .unwrap()\n        }\n        Ok(Some(dict)) if dict.get(&pk).is_some() => {\n            let bytes = bincode::serialize(dict.as_ref().get(&pk).unwrap()).unwrap();\n            Response::builder()\n                .header(\"Content-Type\", \"application/octet-stream\")\n                .status(StatusCode::OK)\n                .body(bytes)\n                .unwrap()\n        }\n        _ => Response::builder()\n            .status(StatusCode::NO_CONTENT)\n            .body(Vec::new())\n            .unwrap(),\n    })\n}\n\n/// Handles and responds to a request for the global model.\nasync fn handle_model<F: Fetcher>(mut fetcher: F) -> Result<impl warp::Reply, Infallible> {\n    Ok(match fetcher.model().await {\n        Ok(Some(model)) => Response::builder()\n            .status(StatusCode::OK)\n            .body(bincode::serialize(model.as_ref()).unwrap())\n            .unwrap(),\n        Ok(None) => Response::builder()\n            .status(StatusCode::NO_CONTENT)\n            .body(Vec::new())\n            .unwrap(),\n        Err(e) => {\n            warn!(\"failed to handle model request: {:?}\", e);\n            Response::builder()\n                .status(StatusCode::INTERNAL_SERVER_ERROR)\n                .body(Vec::new())\n                .unwrap()\n        }\n    })\n}\n\n/// Handles and responds to a request for the round parameters.\nasync fn handle_params<F: Fetcher>(mut fetcher: F) -> Result<impl warp::Reply, Infallible> {\n    Ok(match fetcher.round_params().await {\n        Ok(params) => Response::builder()\n            .status(StatusCode::OK)\n            .body(bincode::serialize(&params).unwrap())\n            .unwrap(),\n        Err(e) => {\n            warn!(\"failed to handle round parameters request: {:?}\", e);\n            Response::builder()\n                .status(StatusCode::INTERNAL_SERVER_ERROR)\n                .body(Vec::new())\n                .unwrap()\n        }\n    })\n}\n\n/// Converts a PET message handler into a `warp` filter.\nfn with_message_handler(\n    handler: PetMessageHandler,\n) -> impl Filter<Extract = (PetMessageHandler,), Error = Infallible> + Clone {\n    warp::any().map(move || handler.clone())\n}\n\n/// Converts a data fetcher into a `warp` filter.\nfn with_fetcher<F: Fetcher + Sync + Send + 'static + Clone>(\n    fetcher: F,\n) -> impl Filter<Extract = (F,), Error = Infallible> + Clone {\n    warp::any().map(move || fetcher.clone())\n}\n\n/// Extracts a participant public key from the url query string\nasync fn part_pk(query: SeedDictQuery) -> Result<ParticipantPublicKey, warp::Rejection> {\n    match base64::decode(query.pk.as_bytes()) {\n        Ok(bytes) => {\n            if let Some(pk) = ParticipantPublicKey::from_slice(&bytes[..]) {\n                Ok(pk)\n            } else {\n                Err(warp::reject::custom(InvalidPublicKey))\n            }\n        }\n        Err(_) => Err(warp::reject::custom(InvalidPublicKey)),\n    }\n}\n\n#[derive(Debug)]\nstruct InvalidPublicKey;\n\nimpl warp::reject::Reject for InvalidPublicKey {}\n\n/// Handles `warp` rejections of bad requests.\nasync fn handle_reject(err: warp::Rejection) -> Result<impl warp::Reply, Infallible> {\n    let code = if err.is_not_found() {\n        StatusCode::NOT_FOUND\n    } else if let Some(InvalidPublicKey) = err.find() {\n        StatusCode::BAD_REQUEST\n    } else {\n        error!(\"unhandled rejection: {:?}\", err);\n        StatusCode::INTERNAL_SERVER_ERROR\n    };\n    // reply with empty body; the status code is the interesting part\n    Ok(warp::reply::with_status(Vec::new(), code))\n}\n\n#[derive(Debug, Error)]\n/// Errors of the rest server.\npub enum RestError {\n    #[error(\"invalid TLS configuration was provided\")]\n    InvalidTlsConfig,\n}\n\nimpl From<Infallible> for RestError {\n    fn from(infallible: Infallible) -> RestError {\n        match infallible {}\n    }\n}\n\n#[cfg(feature = \"tls\")]\n/// Configures a server for TLS server and client authentication.\n///\n/// # Errors\n/// Fails if the TLS settings are invalid.\nfn configure_tls<F>(\n    server: Server<F>,\n    tls_certificate: Option<PathBuf>,\n    tls_key: Option<PathBuf>,\n    tls_client_auth: Option<PathBuf>,\n) -> Result<TlsServer<F>, RestError>\nwhere\n    F: Filter + Clone + Send + Sync + 'static,\n    F::Extract: Reply,\n{\n    if tls_certificate.is_none() && tls_key.is_none() && tls_client_auth.is_none() {\n        return Err(RestError::InvalidTlsConfig);\n    }\n\n    let mut server = server.tls();\n    match (tls_certificate, tls_key) {\n        (Some(cert), Some(key)) => server = server.cert_path(cert).key_path(key),\n        (None, None) => {}\n        _ => return Err(RestError::InvalidTlsConfig),\n    }\n    if let Some(trust_anchor) = tls_client_auth {\n        server = server.client_auth_required_path(trust_anchor);\n    }\n    Ok(server)\n}\n\n#[cfg(not(feature = \"tls\"))]\n/// Runs a server with the provided filter routes.\nasync fn run_http<F>(filter: F, api_settings: ApiSettings) -> Result<(), Infallible>\nwhere\n    F: Filter + Clone + Send + Sync + 'static,\n    F::Extract: Reply,\n{\n    warp::serve(filter).run(api_settings.bind_address).await;\n    Ok(())\n}\n\n#[cfg(feature = \"tls\")]\n/// Runs a TLS server with the provided filter routes.\n///\n/// # Errors\n/// Fails if the TLS settings are invalid.\nasync fn run_https<F>(filter: F, api_settings: ApiSettings) -> Result<(), RestError>\nwhere\n    F: Filter + Clone + Send + Sync + 'static,\n    F::Extract: Reply,\n{\n    configure_tls(\n        warp::serve(filter),\n        api_settings.tls_certificate,\n        api_settings.tls_key,\n        api_settings.tls_client_auth,\n    )?\n    .run(api_settings.bind_address)\n    .await;\n    Ok(())\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/services/fetchers/mod.rs",
    "content": "//! This module provides the services for serving data.\n//!\n//! There are multiple such services and the [`Fetcher`] trait\n//! provides a single unifying interface for all of these.\n\nmod model;\nmod round_parameters;\nmod seed_dict;\nmod sum_dict;\n\nuse std::task::{Context, Poll};\n\nuse async_trait::async_trait;\nuse futures::future::poll_fn;\nuse tower::{layer::Layer, Service, ServiceBuilder};\n\npub use self::{\n    model::{ModelRequest, ModelResponse, ModelService},\n    round_parameters::{RoundParamsRequest, RoundParamsResponse, RoundParamsService},\n    seed_dict::{SeedDictRequest, SeedDictResponse, SeedDictService},\n    sum_dict::{SumDictRequest, SumDictResponse, SumDictService},\n};\nuse crate::state_machine::events::EventSubscriber;\n\n/// A single interface for retrieving data from the coordinator.\n#[async_trait]\npub trait Fetcher {\n    /// Fetch the parameters for the current round\n    async fn round_params(&mut self) -> Result<RoundParamsResponse, FetchError>;\n\n    /// Fetch the latest global model.\n    async fn model(&mut self) -> Result<ModelResponse, FetchError>;\n\n    /// Fetch the global seed dictionary. Each sum2 participant needs a\n    /// different portion of that dictionary.\n    async fn seed_dict(&mut self) -> Result<SeedDictResponse, FetchError>;\n\n    /// Fetch the sum dictionary. The update participants need this\n    /// dictionary to encrypt their masking seed for each sum\n    /// participant.\n    async fn sum_dict(&mut self) -> Result<SumDictResponse, FetchError>;\n}\n\n/// An error returned by the [`Fetcher`]'s method.\npub type FetchError = anyhow::Error;\n\nfn into_fetch_error<E: Into<Box<dyn std::error::Error + 'static + Sync + Send>>>(\n    e: E,\n) -> FetchError {\n    anyhow::anyhow!(\"Fetcher failed: {:?}\", e.into())\n}\n\n#[async_trait]\nimpl<RoundParams, SumDict, SeedDict, Model> Fetcher\n    for Fetchers<RoundParams, SumDict, SeedDict, Model>\nwhere\n    Self: Send + Sync + 'static,\n\n    RoundParams: Service<RoundParamsRequest, Response = RoundParamsResponse> + Send + 'static,\n    <RoundParams as Service<RoundParamsRequest>>::Future: Send + Sync + 'static,\n    <RoundParams as Service<RoundParamsRequest>>::Error:\n        Into<Box<dyn std::error::Error + 'static + Sync + Send>>,\n\n    Model: Service<ModelRequest, Response = ModelResponse> + Send + 'static,\n    <Model as Service<ModelRequest>>::Future: Send + Sync + 'static,\n    <Model as Service<ModelRequest>>::Error:\n        Into<Box<dyn std::error::Error + 'static + Sync + Send>>,\n\n    SeedDict: Service<SeedDictRequest, Response = SeedDictResponse> + Send + 'static,\n    <SeedDict as Service<SeedDictRequest>>::Future: Send + Sync + 'static,\n    <SeedDict as Service<SeedDictRequest>>::Error:\n        Into<Box<dyn std::error::Error + 'static + Sync + Send>>,\n\n    SumDict: Service<SumDictRequest, Response = SumDictResponse> + Send + 'static,\n    <SumDict as Service<SumDictRequest>>::Future: Send + Sync + 'static,\n    <SumDict as Service<SumDictRequest>>::Error:\n        Into<Box<dyn std::error::Error + 'static + Sync + Send>>,\n{\n    async fn round_params(&mut self) -> Result<RoundParamsResponse, FetchError> {\n        poll_fn(|cx| {\n            <RoundParams as Service<RoundParamsRequest>>::poll_ready(&mut self.round_params, cx)\n        })\n        .await\n        .map_err(into_fetch_error)?;\n        Ok(<RoundParams as Service<RoundParamsRequest>>::call(\n            &mut self.round_params,\n            RoundParamsRequest,\n        )\n        .await\n        .map_err(into_fetch_error)?)\n    }\n\n    async fn model(&mut self) -> Result<ModelResponse, FetchError> {\n        poll_fn(|cx| <Model as Service<ModelRequest>>::poll_ready(&mut self.model, cx))\n            .await\n            .map_err(into_fetch_error)?;\n        Ok(\n            <Model as Service<ModelRequest>>::call(&mut self.model, ModelRequest)\n                .await\n                .map_err(into_fetch_error)?,\n        )\n    }\n\n    async fn seed_dict(&mut self) -> Result<SeedDictResponse, FetchError> {\n        poll_fn(|cx| <SeedDict as Service<SeedDictRequest>>::poll_ready(&mut self.seed_dict, cx))\n            .await\n            .map_err(into_fetch_error)?;\n        Ok(\n            <SeedDict as Service<SeedDictRequest>>::call(&mut self.seed_dict, SeedDictRequest)\n                .await\n                .map_err(into_fetch_error)?,\n        )\n    }\n\n    async fn sum_dict(&mut self) -> Result<SumDictResponse, FetchError> {\n        poll_fn(|cx| <SumDict as Service<SumDictRequest>>::poll_ready(&mut self.sum_dict, cx))\n            .await\n            .map_err(into_fetch_error)?;\n        Ok(\n            <SumDict as Service<SumDictRequest>>::call(&mut self.sum_dict, SumDictRequest)\n                .await\n                .map_err(into_fetch_error)?,\n        )\n    }\n}\n\npub(in crate::services) struct FetcherService<S>(S);\n\nimpl<S, R> Service<R> for FetcherService<S>\nwhere\n    S: Service<R>,\n{\n    type Response = S::Response;\n    type Error = S::Error;\n    type Future = S::Future;\n\n    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {\n        self.0.poll_ready(cx)\n    }\n\n    fn call(&mut self, req: R) -> Self::Future {\n        self.0.call(req)\n    }\n}\n\npub(in crate::services) struct FetcherLayer;\n\nimpl<S> Layer<S> for FetcherLayer {\n    type Service = FetcherService<S>;\n\n    fn layer(&self, service: S) -> Self::Service {\n        FetcherService(service)\n    }\n}\n\n#[derive(Debug, Clone)]\npub struct Fetchers<RoundParams, SumDict, SeedDict, Model> {\n    round_params: RoundParams,\n    sum_dict: SumDict,\n    seed_dict: SeedDict,\n    model: Model,\n}\n\nimpl<RoundParams, SumDict, SeedDict, Model> Fetchers<RoundParams, SumDict, SeedDict, Model> {\n    pub fn new(\n        round_params: RoundParams,\n        sum_dict: SumDict,\n        seed_dict: SeedDict,\n        model: Model,\n    ) -> Self {\n        Self {\n            round_params,\n            sum_dict,\n            seed_dict,\n            model,\n        }\n    }\n}\n\n/// Construct a [`Fetcher`] service\npub fn fetcher(event_subscriber: &EventSubscriber) -> impl Fetcher + Sync + Send + Clone + 'static {\n    let round_params = ServiceBuilder::new()\n        .buffer(100)\n        .concurrency_limit(100)\n        .layer(FetcherLayer)\n        .service(RoundParamsService::new(event_subscriber));\n\n    let model = ServiceBuilder::new()\n        .buffer(100)\n        .concurrency_limit(100)\n        .layer(FetcherLayer)\n        .service(ModelService::new(event_subscriber));\n\n    let sum_dict = ServiceBuilder::new()\n        .buffer(100)\n        .concurrency_limit(100)\n        .layer(FetcherLayer)\n        .service(SumDictService::new(event_subscriber));\n\n    let seed_dict = ServiceBuilder::new()\n        .buffer(100)\n        .concurrency_limit(100)\n        .layer(FetcherLayer)\n        .service(SeedDictService::new(event_subscriber));\n\n    Fetchers::new(round_params, sum_dict, seed_dict, model)\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/services/fetchers/model.rs",
    "content": "use std::{\n    sync::Arc,\n    task::{Context, Poll},\n};\n\nuse futures::future::{self, Ready};\nuse tower::Service;\nuse tracing::error_span;\nuse tracing_futures::{Instrument, Instrumented};\n\nuse crate::state_machine::events::{EventListener, EventSubscriber, ModelUpdate};\nuse xaynet_core::mask::Model;\n\n/// [`ModelService`]'s request type\n#[derive(Default, Clone, Eq, PartialEq, Debug)]\npub struct ModelRequest;\n\n/// [`ModelService`]'s response type.\n///\n/// The response is `None` when no model is currently available.\npub type ModelResponse = Option<Arc<Model>>;\n\n/// A service that serves the latest available global model\npub struct ModelService(EventListener<ModelUpdate>);\n\nimpl ModelService {\n    pub fn new(events: &EventSubscriber) -> Self {\n        Self(events.model_listener())\n    }\n}\n\nimpl Service<ModelRequest> for ModelService {\n    type Response = ModelResponse;\n    type Error = std::convert::Infallible;\n    type Future = Instrumented<Ready<Result<Self::Response, Self::Error>>>;\n\n    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {\n        Poll::Ready(Ok(()))\n    }\n\n    fn call(&mut self, _req: ModelRequest) -> Self::Future {\n        future::ready(match self.0.get_latest().event {\n            ModelUpdate::Invalidate => Ok(None),\n            ModelUpdate::New(model) => Ok(Some(model)),\n        })\n        .instrument(error_span!(\"model_fetch_request\"))\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/services/fetchers/round_parameters.rs",
    "content": "use std::task::{Context, Poll};\n\nuse futures::future::{self, Ready};\nuse tower::Service;\nuse tracing::error_span;\nuse tracing_futures::{Instrument, Instrumented};\n\nuse crate::state_machine::events::{EventListener, EventSubscriber};\nuse xaynet_core::common::RoundParameters;\n\n/// [`RoundParamsService`]'s request type\n#[derive(Default, Clone, Eq, PartialEq, Debug)]\npub struct RoundParamsRequest;\n\n/// [`RoundParamsService`]'s response type\npub type RoundParamsResponse = RoundParameters;\n\n/// A service that serves the round parameters for the current round.\npub struct RoundParamsService(EventListener<RoundParameters>);\n\nimpl RoundParamsService {\n    pub fn new(events: &EventSubscriber) -> Self {\n        Self(events.params_listener())\n    }\n}\n\nimpl Service<RoundParamsRequest> for RoundParamsService {\n    type Response = RoundParameters;\n    type Error = std::convert::Infallible;\n    type Future = Instrumented<Ready<Result<Self::Response, Self::Error>>>;\n\n    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {\n        Poll::Ready(Ok(()))\n    }\n\n    fn call(&mut self, _req: RoundParamsRequest) -> Self::Future {\n        future::ready(Ok(self.0.get_latest().event))\n            .instrument(error_span!(\"round_params_fetch_request\"))\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/services/fetchers/seed_dict.rs",
    "content": "use std::{\n    sync::Arc,\n    task::{Context, Poll},\n};\n\nuse futures::future::{self, Ready};\nuse tower::Service;\nuse tracing::error_span;\nuse tracing_futures::{Instrument, Instrumented};\n\nuse crate::state_machine::events::{DictionaryUpdate, EventListener, EventSubscriber};\nuse xaynet_core::SeedDict;\n\n/// A service that serves the seed dictionary for the current round.\npub struct SeedDictService(EventListener<DictionaryUpdate<SeedDict>>);\n\nimpl SeedDictService {\n    pub fn new(events: &EventSubscriber) -> Self {\n        Self(events.seed_dict_listener())\n    }\n}\n\n/// [`SeedDictService`]'s request type\n#[derive(Default, Clone, Eq, PartialEq, Debug)]\npub struct SeedDictRequest;\n\n/// [`SeedDictService`]'s response type.\n///\n/// The response is `None` when no seed dictionary is currently\n/// available\npub type SeedDictResponse = Option<Arc<SeedDict>>;\n\nimpl Service<SeedDictRequest> for SeedDictService {\n    type Response = SeedDictResponse;\n    type Error = std::convert::Infallible;\n    type Future = Instrumented<Ready<Result<Self::Response, Self::Error>>>;\n\n    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {\n        Poll::Ready(Ok(()))\n    }\n\n    fn call(&mut self, _req: SeedDictRequest) -> Self::Future {\n        future::ready(match self.0.get_latest().event {\n            DictionaryUpdate::Invalidate => Ok(None),\n            DictionaryUpdate::New(dict) => Ok(Some(dict)),\n        })\n        .instrument(error_span!(\"seed_dict_fetch_request\"))\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/services/fetchers/sum_dict.rs",
    "content": "use std::{\n    sync::Arc,\n    task::{Context, Poll},\n};\n\nuse futures::future::{self, Ready};\nuse tower::Service;\nuse tracing::error_span;\nuse tracing_futures::{Instrument, Instrumented};\n\nuse crate::state_machine::events::{DictionaryUpdate, EventListener, EventSubscriber};\nuse xaynet_core::SumDict;\n\n/// A service that returns the sum dictionary for the current round.\npub struct SumDictService(EventListener<DictionaryUpdate<SumDict>>);\n\n/// [`SumDictService`]'s request type\n#[derive(Default, Clone, Eq, PartialEq, Debug)]\npub struct SumDictRequest;\n\n/// [`SumDictService`]'s response type.\n///\n/// The response is `None` when no sum dictionary is currently\n/// available\npub type SumDictResponse = Option<Arc<SumDict>>;\n\nimpl SumDictService {\n    pub fn new(events: &EventSubscriber) -> Self {\n        Self(events.sum_dict_listener())\n    }\n}\n\nimpl Service<SumDictRequest> for SumDictService {\n    type Response = SumDictResponse;\n    type Error = std::convert::Infallible;\n    type Future = Instrumented<Ready<Result<Self::Response, Self::Error>>>;\n\n    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {\n        Poll::Ready(Ok(()))\n    }\n\n    fn call(&mut self, _req: SumDictRequest) -> Self::Future {\n        future::ready(match self.0.get_latest().event {\n            DictionaryUpdate::Invalidate => Ok(None),\n            DictionaryUpdate::New(dict) => Ok(Some(dict)),\n        })\n        .instrument(error_span!(\"sum_dict_fetch_request\"))\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/services/messages/decryptor.rs",
    "content": "use std::{pin::Pin, sync::Arc, task::Poll};\n\nuse futures::{future::Future, task::Context};\nuse rayon::ThreadPool;\nuse tokio::sync::oneshot;\nuse tower::{\n    limit::concurrency::{future::ResponseFuture, ConcurrencyLimit},\n    Service,\n};\nuse tracing::{debug, info, trace};\n\nuse crate::{\n    services::messages::{BoxedServiceFuture, ServiceError},\n    state_machine::events::{EventListener, EventSubscriber},\n};\nuse xaynet_core::crypto::EncryptKeyPair;\n\n/// A service for decrypting PET messages.\n///\n/// Since this is a CPU-intensive task for large messages, this\n/// service offloads the processing to a `rayon` thread-pool to avoid\n/// overloading the tokio thread-pool with blocking tasks.\n#[derive(Clone)]\nstruct RawDecryptor {\n    /// A listener to retrieve the latest coordinator keys. These are\n    /// necessary for decrypting messages and verifying their\n    /// signature.\n    keys_events: EventListener<EncryptKeyPair>,\n\n    /// Thread-pool the CPU-intensive tasks are offloaded to.\n    thread_pool: Arc<ThreadPool>,\n}\n\nimpl<T> Service<T> for RawDecryptor\nwhere\n    T: AsRef<[u8]> + Sync + Send + 'static,\n{\n    type Response = Vec<u8>;\n    type Error = ServiceError;\n    #[allow(clippy::type_complexity)]\n    type Future =\n        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + 'static + Send + Sync>>;\n\n    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {\n        Poll::Ready(Ok(()))\n    }\n\n    fn call(&mut self, data: T) -> Self::Future {\n        debug!(\"retrieving the current keys\");\n        let keys = self.keys_events.get_latest().event;\n        let (tx, rx) = oneshot::channel::<Result<Self::Response, Self::Error>>();\n\n        trace!(\"spawning decryption task on threadpool\");\n        self.thread_pool.spawn(move || {\n            info!(\"decrypting message\");\n            let res = keys\n                .secret\n                .decrypt(data.as_ref(), &keys.public)\n                .map_err(|_| ServiceError::Decrypt);\n            let _ = tx.send(res);\n        });\n        Box::pin(async move {\n            rx.await.unwrap_or_else(|_| {\n                Err(ServiceError::InternalError(\n                    \"failed to receive response from thread-pool\".to_string(),\n                ))\n            })\n        })\n    }\n}\n\n#[derive(Clone)]\npub struct Decryptor(ConcurrencyLimit<RawDecryptor>);\n\nimpl Decryptor {\n    pub fn new(state_machine_events: &EventSubscriber, thread_pool: Arc<ThreadPool>) -> Self {\n        let limit = thread_pool.current_num_threads();\n        let keys_events = state_machine_events.keys_listener();\n        let service = RawDecryptor {\n            keys_events,\n            thread_pool,\n        };\n        Self(ConcurrencyLimit::new(service, limit))\n    }\n}\n\nimpl<T> Service<T> for Decryptor\nwhere\n    T: AsRef<[u8]> + Sync + Send + 'static,\n{\n    type Response = Vec<u8>;\n    type Error = ServiceError;\n    type Future = ResponseFuture<BoxedServiceFuture<Self::Response, Self::Error>>;\n\n    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {\n        <ConcurrencyLimit<RawDecryptor> as Service<T>>::poll_ready(&mut self.0, cx)\n    }\n\n    fn call(&mut self, data: T) -> Self::Future {\n        self.0.call(data)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use rayon::ThreadPoolBuilder;\n    use tokio_test::assert_ready;\n    use tower_test::mock::Spawn;\n\n    use crate::{\n        services::tests::utils,\n        state_machine::events::{EventPublisher, EventSubscriber},\n    };\n\n    use super::*;\n\n    fn spawn_svc() -> (EventPublisher, EventSubscriber, Spawn<Decryptor>) {\n        let (publisher, subscriber) = utils::new_event_channels();\n        let thread_pool = Arc::new(ThreadPoolBuilder::new().build().unwrap());\n        let task = Spawn::new(Decryptor::new(&subscriber, thread_pool));\n        (publisher, subscriber, task)\n    }\n\n    #[tokio::test]\n    async fn test_decrypt_fail() {\n        let (_publisher, _subscriber, mut task) = spawn_svc();\n        assert_ready!(task.poll_ready::<Vec<u8>>()).unwrap();\n\n        let req = vec![0, 1, 2, 3, 4, 5, 6];\n        match task.call(req).await {\n            Err(ServiceError::Decrypt) => {}\n            _ => panic!(\"expected decrypt error\"),\n        }\n        assert_ready!(task.poll_ready::<Vec<u8>>()).unwrap();\n    }\n\n    #[tokio::test]\n    async fn test_decrypt_ok() {\n        let (_publisher, subscriber, mut task) = spawn_svc();\n        assert_ready!(task.poll_ready::<Vec<u8>>()).unwrap();\n\n        let round_params = subscriber.params_listener().get_latest().event;\n        let (message, participant_signing_keys) = utils::new_sum_message(&round_params);\n        let serialized_message = utils::serialize_message(&message, &participant_signing_keys);\n        let encrypted_message =\n            utils::encrypt_message(&message, &round_params, &participant_signing_keys);\n\n        // Call the service\n        let decrypted_message = task.call(encrypted_message).await.unwrap();\n        assert_eq!(decrypted_message, serialized_message);\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/services/messages/error.rs",
    "content": "use displaydoc::Display;\nuse thiserror::Error;\n\nuse crate::state_machine::requests::RequestError;\nuse xaynet_core::message::DecodeError;\n\n/// Errors for the message parsing service.\n#[derive(Debug, Display, Error)]\npub enum ServiceError {\n    /// Failed to decrypt the message with the coordinator secret key.\n    Decrypt,\n    /// Failed to parse the message: {0}.\n    Parsing(DecodeError),\n    /// Invalid message signature.\n    InvalidMessageSignature,\n    /// Invalid coordinator public key.\n    InvalidCoordinatorPublicKey,\n    /// The message was not expected in the current phase.\n    UnexpectedMessage,\n    // FIXME: we need to refine the state machine errors and the\n    // conversion into a service error\n    /// The state machine failed to process the request: {0}.\n    StateMachine(RequestError),\n    /// Participant is not eligible for sum task.\n    NotSumEligible,\n    /// Participant is not eligible for update task.\n    NotUpdateEligible,\n    /// Internal error: {0}.\n    InternalError(String),\n}\n\nimpl From<Box<dyn std::error::Error>> for ServiceError {\n    fn from(e: Box<dyn std::error::Error>) -> Self {\n        match e.downcast::<ServiceError>() {\n            Ok(e) => *e,\n            Err(e) => ServiceError::InternalError(format!(\"{}\", e)),\n        }\n    }\n}\n\nimpl From<Box<dyn std::error::Error + Sync + Send>> for ServiceError {\n    fn from(e: Box<dyn std::error::Error + Sync + Send>) -> Self {\n        ServiceError::from(e as Box<dyn std::error::Error>)\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/services/messages/message_parser.rs",
    "content": "use std::{convert::TryInto, sync::Arc, task::Poll};\n\nuse futures::{future, task::Context};\nuse rayon::ThreadPool;\nuse tokio::sync::oneshot;\nuse tower::{layer::Layer, limit::concurrency::ConcurrencyLimit, Service, ServiceBuilder};\nuse tracing::{debug, info, trace, warn};\n\nuse crate::{\n    services::messages::{BoxedServiceFuture, ServiceError},\n    state_machine::{\n        events::{EventListener, EventSubscriber},\n        phases::PhaseName,\n    },\n};\nuse xaynet_core::{\n    crypto::{EncryptKeyPair, PublicEncryptKey},\n    message::{FromBytes, Message, MessageBuffer, Tag},\n};\n\n/// A type that hold a un-parsed message\nstruct RawMessage<T> {\n    /// The buffer that contains the message to parse\n    buffer: Arc<MessageBuffer<T>>,\n}\n\nimpl<T> Clone for RawMessage<T> {\n    fn clone(&self) -> Self {\n        Self {\n            buffer: self.buffer.clone(),\n        }\n    }\n}\n\nimpl<T> From<MessageBuffer<T>> for RawMessage<T> {\n    fn from(buffer: MessageBuffer<T>) -> Self {\n        RawMessage {\n            buffer: Arc::new(buffer),\n        }\n    }\n}\n\n/// A service that wraps a buffer `T` representing a message into a\n/// [`RawMessage<T>`]\n#[derive(Debug, Clone)]\nstruct BufferWrapper<S>(S);\n\nimpl<S, T> Service<T> for BufferWrapper<S>\nwhere\n    T: AsRef<[u8]> + Send + 'static,\n    S: Service<RawMessage<T>, Response = Message, Error = ServiceError>,\n    S::Future: Sync + Send + 'static,\n{\n    type Response = Message;\n    type Error = ServiceError;\n    type Future = BoxedServiceFuture<Self::Response, Self::Error>;\n\n    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {\n        self.0.poll_ready(cx)\n    }\n\n    fn call(&mut self, req: T) -> Self::Future {\n        debug!(\"creating a RawMessage request\");\n        match MessageBuffer::new(req) {\n            Ok(buffer) => {\n                let fut = self.0.call(RawMessage::from(buffer));\n                Box::pin(async move {\n                    trace!(\"calling inner service\");\n                    fut.await\n                })\n            }\n            Err(e) => Box::pin(future::ready(Err(ServiceError::Parsing(e)))),\n        }\n    }\n}\n\nstruct BufferWrapperLayer;\n\nimpl<S> Layer<S> for BufferWrapperLayer {\n    type Service = BufferWrapper<S>;\n\n    fn layer(&self, service: S) -> BufferWrapper<S> {\n        BufferWrapper(service)\n    }\n}\n\n/// A service that discards messages that are not expected in the current phase\n#[derive(Debug, Clone)]\nstruct PhaseFilter<S> {\n    /// A listener to retrieve the current phase\n    phase: EventListener<PhaseName>,\n    /// Next service to be called\n    next_svc: S,\n}\n\nimpl<T, S> Service<RawMessage<T>> for PhaseFilter<S>\nwhere\n    T: AsRef<[u8]> + Send + 'static,\n    S: Service<RawMessage<T>, Response = Message, Error = ServiceError>,\n    S::Future: Sync + Send + 'static,\n{\n    type Response = Message;\n    type Error = ServiceError;\n    type Future = BoxedServiceFuture<Self::Response, Self::Error>;\n\n    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {\n        self.next_svc.poll_ready(cx)\n    }\n\n    fn call(&mut self, req: RawMessage<T>) -> Self::Future {\n        debug!(\"retrieving the current phase\");\n        let phase = self.phase.get_latest().event;\n        match req.buffer.tag().try_into() {\n            Ok(tag) => match (phase, tag) {\n                (PhaseName::Sum, Tag::Sum)\n                | (PhaseName::Update, Tag::Update)\n                | (PhaseName::Sum2, Tag::Sum2) => {\n                    let fut = self.next_svc.call(req);\n                    Box::pin(async move { fut.await })\n                }\n                _ => Box::pin(future::ready(Err(ServiceError::UnexpectedMessage))),\n            },\n            Err(e) => Box::pin(future::ready(Err(ServiceError::Parsing(e)))),\n        }\n    }\n}\n\nstruct PhaseFilterLayer {\n    phase: EventListener<PhaseName>,\n}\n\nimpl<S> Layer<S> for PhaseFilterLayer {\n    type Service = PhaseFilter<S>;\n\n    fn layer(&self, service: S) -> PhaseFilter<S> {\n        PhaseFilter {\n            phase: self.phase.clone(),\n            next_svc: service,\n        }\n    }\n}\n\n/// A service for verifying the signature of PET messages\n///\n/// Since this is a CPU-intensive task for large messages, this\n/// service offloads the processing to a `rayon` thread-pool to avoid\n/// overloading the tokio thread-pool with blocking tasks.\n#[derive(Debug, Clone)]\nstruct SignatureVerifier<S> {\n    /// Thread-pool the CPU-intensive tasks are offloaded to.\n    thread_pool: Arc<ThreadPool>,\n    /// The service to be called after the [`SignatureVerifier`]\n    next_svc: S,\n}\n\nimpl<T, S> Service<RawMessage<T>> for SignatureVerifier<S>\nwhere\n    T: AsRef<[u8]> + Sync + Send + 'static,\n    S: Service<RawMessage<T>, Response = Message, Error = ServiceError>\n        + Clone\n        + Sync\n        + Send\n        + 'static,\n    S::Future: Sync + Send + 'static,\n{\n    type Response = Message;\n    type Error = ServiceError;\n    type Future = BoxedServiceFuture<Self::Response, Self::Error>;\n\n    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {\n        self.next_svc.poll_ready(cx)\n    }\n\n    fn call(&mut self, req: RawMessage<T>) -> Self::Future {\n        let (tx, rx) = oneshot::channel::<Result<(), ServiceError>>();\n\n        let req_clone = req.clone();\n        trace!(\"spawning signature verification task on thread-pool\");\n        self.thread_pool.spawn(move || {\n            let res = match req.buffer.as_ref().as_ref().check_signature() {\n                Ok(()) => {\n                    info!(\"found a valid message signature\");\n                    Ok(())\n                }\n                Err(e) => {\n                    warn!(\"invalid message signature: {:?}\", e);\n                    Err(ServiceError::InvalidMessageSignature)\n                }\n            };\n            let _ = tx.send(res);\n        });\n\n        let mut next_svc = self.next_svc.clone();\n        let fut = async move {\n            rx.await.map_err(|_| {\n                ServiceError::InternalError(\n                    \"failed to receive response from thread-pool\".to_string(),\n                )\n            })??;\n            next_svc.call(req_clone).await\n        };\n        Box::pin(fut)\n    }\n}\n\nstruct SignatureVerifierLayer {\n    thread_pool: Arc<ThreadPool>,\n}\n\nimpl<S> Layer<S> for SignatureVerifierLayer {\n    type Service = ConcurrencyLimit<SignatureVerifier<S>>;\n\n    fn layer(&self, service: S) -> Self::Service {\n        let limit = self.thread_pool.current_num_threads();\n        // FIXME: we actually want to limit the concurrency of just\n        // the SignatureVerifier middleware. Right now we're limiting\n        // the whole stack of services.\n        ConcurrencyLimit::new(\n            SignatureVerifier {\n                thread_pool: self.thread_pool.clone(),\n                next_svc: service,\n            },\n            limit,\n        )\n    }\n}\n\n/// A service that verifies the coordinator public key embedded in PET\n/// messsages\n#[derive(Debug, Clone)]\nstruct CoordinatorPublicKeyValidator<S> {\n    /// A listener to retrieve the latest coordinator keys\n    keys: EventListener<EncryptKeyPair>,\n    /// Next service to be called\n    next_svc: S,\n}\n\nimpl<T, S> Service<RawMessage<T>> for CoordinatorPublicKeyValidator<S>\nwhere\n    T: AsRef<[u8]> + Send + 'static,\n    S: Service<RawMessage<T>, Response = Message, Error = ServiceError>,\n    S::Future: Sync + Send + 'static,\n{\n    type Response = Message;\n    type Error = ServiceError;\n    type Future = BoxedServiceFuture<Self::Response, Self::Error>;\n\n    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {\n        self.next_svc.poll_ready(cx)\n    }\n\n    fn call(&mut self, req: RawMessage<T>) -> Self::Future {\n        debug!(\"retrieving the current keys\");\n        let coord_pk = self.keys.get_latest().event.public;\n        match PublicEncryptKey::from_byte_slice(&req.buffer.as_ref().as_ref().coordinator_pk()) {\n            Ok(pk) => {\n                if pk != coord_pk {\n                    warn!(\"found an invalid coordinator public key\");\n                    Box::pin(future::ready(Err(\n                        ServiceError::InvalidCoordinatorPublicKey,\n                    )))\n                } else {\n                    info!(\"found a valid coordinator public key\");\n                    let fut = self.next_svc.call(req);\n                    Box::pin(async move { fut.await })\n                }\n            }\n            Err(_) => Box::pin(future::ready(Err(\n                ServiceError::InvalidCoordinatorPublicKey,\n            ))),\n        }\n    }\n}\n\nstruct CoordinatorPublicKeyValidatorLayer {\n    keys: EventListener<EncryptKeyPair>,\n}\n\nimpl<S> Layer<S> for CoordinatorPublicKeyValidatorLayer {\n    type Service = CoordinatorPublicKeyValidator<S>;\n\n    fn layer(&self, service: S) -> CoordinatorPublicKeyValidator<S> {\n        CoordinatorPublicKeyValidator {\n            keys: self.keys.clone(),\n            next_svc: service,\n        }\n    }\n}\n\n#[derive(Debug, Clone)]\nstruct Parser;\n\nimpl<T> Service<RawMessage<T>> for Parser\nwhere\n    T: AsRef<[u8]> + Send + 'static,\n{\n    type Response = Message;\n    type Error = ServiceError;\n    type Future = future::Ready<Result<Self::Response, Self::Error>>;\n\n    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {\n        Poll::Ready(Ok(()))\n    }\n\n    fn call(&mut self, req: RawMessage<T>) -> Self::Future {\n        let bytes = req.buffer.inner();\n        future::ready(Message::from_byte_slice(&bytes).map_err(ServiceError::Parsing))\n    }\n}\n\ntype InnerService = BufferWrapper<\n    PhaseFilter<ConcurrencyLimit<SignatureVerifier<CoordinatorPublicKeyValidator<Parser>>>>,\n>;\n\n#[derive(Debug, Clone)]\npub struct MessageParser(InnerService);\n\nimpl<T> Service<T> for MessageParser\nwhere\n    T: AsRef<[u8]> + Sync + Send + 'static,\n{\n    type Response = Message;\n    type Error = ServiceError;\n    type Future = BoxedServiceFuture<Self::Response, Self::Error>;\n\n    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {\n        <InnerService as Service<T>>::poll_ready(&mut self.0, cx)\n    }\n\n    fn call(&mut self, req: T) -> Self::Future {\n        let fut = self.0.call(req);\n        Box::pin(async move { fut.await })\n    }\n}\n\nimpl MessageParser {\n    pub fn new(events: &EventSubscriber, thread_pool: Arc<ThreadPool>) -> Self {\n        let inner = ServiceBuilder::new()\n            .layer(BufferWrapperLayer)\n            .layer(PhaseFilterLayer {\n                phase: events.phase_listener(),\n            })\n            .layer(SignatureVerifierLayer { thread_pool })\n            .layer(CoordinatorPublicKeyValidatorLayer {\n                keys: events.keys_listener(),\n            })\n            .service(Parser);\n        Self(inner)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use rayon::ThreadPoolBuilder;\n    use tokio_test::assert_ready;\n    use tower_test::mock::Spawn;\n\n    use super::*;\n    use crate::{\n        services::tests::utils,\n        state_machine::events::{EventPublisher, EventSubscriber},\n    };\n\n    fn spawn_svc() -> (EventPublisher, EventSubscriber, Spawn<MessageParser>) {\n        let (publisher, subscriber) = utils::new_event_channels();\n        let thread_pool = Arc::new(ThreadPoolBuilder::new().build().unwrap());\n        let task = Spawn::new(MessageParser::new(&subscriber, thread_pool));\n        (publisher, subscriber, task)\n    }\n\n    #[tokio::test]\n    async fn test_valid_request() {\n        let (mut publisher, subscriber, mut task) = spawn_svc();\n        assert_ready!(task.poll_ready::<Vec<u8>>()).unwrap();\n\n        let round_params = subscriber.params_listener().get_latest().event;\n        let (message, signing_keys) = utils::new_sum_message(&round_params);\n        let serialized_message = utils::serialize_message(&message, &signing_keys);\n\n        // Simulate the state machine broadcasting the sum phase\n        // (otherwise the request will be rejected by the phase\n        // filter)\n        publisher.broadcast_phase(PhaseName::Sum);\n\n        // Call the service\n        let mut resp = task.call(serialized_message).await.unwrap();\n        // The signature should be set. However in `message` it's not been\n        // computed, so we just check that it's there, then set it to\n        // `None` in `resp`\n        assert!(resp.signature.is_some());\n        resp.signature = None;\n        // Now the comparison should work\n        assert_eq!(resp, message);\n    }\n\n    #[tokio::test]\n    async fn test_unexpected_message() {\n        let (_publisher, subscriber, mut task) = spawn_svc();\n        assert_ready!(task.poll_ready::<Vec<u8>>()).unwrap();\n\n        let round_params = subscriber.params_listener().get_latest().event;\n        let (message, signing_keys) = utils::new_sum_message(&round_params);\n        let serialized_message = utils::serialize_message(&message, &signing_keys);\n        let err = task.call(serialized_message).await.unwrap_err();\n        match err {\n            ServiceError::UnexpectedMessage => {}\n            _ => panic!(\"expected ServiceError::UnexpectedMessage got {:?}\", err),\n        }\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/services/messages/mod.rs",
    "content": "//! This module provides the services for processing PET messages.\n//!\n//! There are multiple such services and [`PetMessageHandler`]\n//! provides a single unifying interface for all of these.\n\nmod decryptor;\nmod error;\nmod message_parser;\nmod multipart;\nmod state_machine;\nmod task_validator;\n\nuse std::sync::Arc;\n\nuse futures::future::poll_fn;\nuse rayon::ThreadPoolBuilder;\nuse tower::Service;\nuse xaynet_core::message::Message;\n\npub use self::error::ServiceError;\nuse self::{\n    decryptor::Decryptor,\n    message_parser::MessageParser,\n    multipart::MultipartHandler,\n    state_machine::StateMachine,\n    task_validator::TaskValidator,\n};\nuse crate::state_machine::{events::EventSubscriber, requests::RequestSender};\n\nimpl PetMessageHandler {\n    pub fn new(event_subscriber: &EventSubscriber, requests_tx: RequestSender) -> Self {\n        // TODO: make this configurable. Users should be able to\n        // choose how many threads they want etc.\n        //\n        // TODO: don't unwrap\n        let thread_pool = Arc::new(ThreadPoolBuilder::new().build().unwrap());\n        let decryptor = Decryptor::new(event_subscriber, thread_pool.clone());\n        let multipart_handler = MultipartHandler::new();\n        let message_parser = MessageParser::new(event_subscriber, thread_pool);\n        let task_validator = TaskValidator::new(event_subscriber);\n        let state_machine = StateMachine::new(requests_tx);\n\n        Self {\n            decryptor,\n            multipart_handler,\n            message_parser,\n            task_validator,\n            state_machine,\n        }\n    }\n    async fn decrypt(&mut self, enc_data: Vec<u8>) -> Result<Vec<u8>, ServiceError> {\n        poll_fn(|cx| <Decryptor as Service<Vec<u8>>>::poll_ready(&mut self.decryptor, cx)).await?;\n        self.decryptor.call(enc_data).await\n    }\n\n    async fn parse(&mut self, data: Vec<u8>) -> Result<Message, ServiceError> {\n        poll_fn(|cx| <MessageParser as Service<Vec<u8>>>::poll_ready(&mut self.message_parser, cx))\n            .await?;\n        self.message_parser.call(data).await\n    }\n\n    async fn handle_multipart(\n        &mut self,\n        message: Message,\n    ) -> Result<Option<Message>, ServiceError> {\n        poll_fn(|cx| self.multipart_handler.poll_ready(cx)).await?;\n        self.multipart_handler.call(message).await\n    }\n\n    async fn validate_task(&mut self, message: Message) -> Result<Message, ServiceError> {\n        poll_fn(|cx| self.task_validator.poll_ready(cx)).await?;\n        self.task_validator.call(message).await\n    }\n\n    async fn process(&mut self, message: Message) -> Result<(), ServiceError> {\n        poll_fn(|cx| self.state_machine.poll_ready(cx)).await?;\n        self.state_machine.call(message).await\n    }\n\n    pub async fn handle_message(&mut self, enc_data: Vec<u8>) -> Result<(), ServiceError> {\n        let raw_message = self.decrypt(enc_data).await?;\n        let message = self.parse(raw_message).await?;\n        match self.handle_multipart(message).await? {\n            Some(message) => {\n                let message = self.validate_task(message).await?;\n                self.process(message).await\n            }\n            None => Ok(()),\n        }\n    }\n}\n\n/// A service that processes requests from the beginning to the\n/// end.\n///\n/// The processing is divided in three phases:\n///\n/// 1. The raw request (which is just a vector of bytes represented an\n///    encrypted message) goes through the `MessageParser` service,\n///    which decrypt the message, validates it, and parses it\n///\n/// 2. The message is passed to the `TaskValidator`, which depending on\n///    the message type performs some additional checks. The\n///    `TaskValidator` may also discard the message\n///\n/// 3. Finally, the message is handled by the `StateMachine` service.\n#[derive(Clone)]\npub struct PetMessageHandler {\n    decryptor: Decryptor,\n    multipart_handler: MultipartHandler,\n    message_parser: MessageParser,\n    task_validator: TaskValidator,\n    state_machine: StateMachine,\n}\n\npub type BoxedServiceFuture<Response, Error> = std::pin::Pin<\n    Box<dyn futures::Future<Output = Result<Response, Error>> + 'static + Send + Sync>,\n>;\n"
  },
  {
    "path": "rust/xaynet-server/src/services/messages/multipart/buffer.rs",
    "content": "use std::{\n    collections::btree_map::{BTreeMap, IntoIter as BTreeMapIter},\n    iter::{ExactSizeIterator, Iterator},\n    vec::IntoIter as VecIter,\n};\n\n/// A data structure for reading a multipart message\npub struct MultipartMessageBuffer {\n    /// message chunks that haven't been read yet\n    remaining_chunks: BTreeMapIter<u16, Vec<u8>>,\n    /// chunk being read\n    current_chunk: Option<VecIter<u8>>,\n    /// total length of the buffer\n    initial_length: usize,\n    /// number of bytes that have been read\n    consumed: usize,\n}\n\nimpl From<BTreeMap<u16, Vec<u8>>> for MultipartMessageBuffer {\n    fn from(map: BTreeMap<u16, Vec<u8>>) -> Self {\n        let initial_length = map.values().fold(0, |acc, chunk| acc + chunk.len());\n        Self {\n            remaining_chunks: map.into_iter(),\n            current_chunk: None,\n            initial_length,\n            consumed: 0,\n        }\n    }\n}\n\n// Note that this Iterator implementation could be optimized. We\n// currently increment a counter for every byte consumed, but we could\n// exploits the fact that IterVec implements ExactSizeIterator avoid\n// that.\nimpl Iterator for MultipartMessageBuffer {\n    type Item = u8;\n\n    fn next(&mut self) -> Option<Self::Item> {\n        if self.current_chunk.is_none() {\n            let (_, chunk) = self.remaining_chunks.next()?;\n            self.current_chunk = Some(chunk.into_iter());\n            return self.next();\n        }\n\n        // Per `if` above, `self.current_chunk` is not None\n        match self.current_chunk.as_mut().unwrap().next() {\n            Some(b) => {\n                self.consumed += 1;\n                Some(b)\n            }\n            None => {\n                self.current_chunk = None;\n                self.next()\n            }\n        }\n    }\n\n    fn size_hint(&self) -> (usize, Option<usize>) {\n        let lower_bound = self.initial_length - self.consumed;\n        let upper_bound = self.initial_length - self.consumed;\n        (lower_bound, Some(upper_bound))\n    }\n}\n\nimpl ExactSizeIterator for MultipartMessageBuffer {}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn test() {\n        let mut map: BTreeMap<u16, Vec<u8>> = BTreeMap::new();\n        map.insert(1, vec![0, 1, 2]);\n        map.insert(2, vec![3]);\n        map.insert(3, vec![4, 5]);\n\n        let mut iter = MultipartMessageBuffer::from(map);\n        assert_eq!(iter.consumed, 0);\n        assert_eq!(iter.initial_length, 6);\n        assert_eq!(iter.len(), 6);\n        assert!(iter.current_chunk.is_none());\n\n        assert_eq!(iter.next(), Some(0));\n        assert_eq!(iter.consumed, 1);\n        assert_eq!(iter.initial_length, 6);\n        assert_eq!(iter.len(), 5);\n        assert!(iter.current_chunk.is_some());\n\n        assert_eq!(iter.next(), Some(1));\n        assert_eq!(iter.consumed, 2);\n        assert_eq!(iter.initial_length, 6);\n        assert_eq!(iter.len(), 4);\n        assert!(iter.current_chunk.is_some());\n\n        assert_eq!(iter.next(), Some(2));\n        assert_eq!(iter.consumed, 3);\n        assert_eq!(iter.initial_length, 6);\n        assert_eq!(iter.len(), 3);\n        assert!(iter.current_chunk.is_some());\n\n        assert_eq!(iter.next(), Some(3));\n        assert_eq!(iter.consumed, 4);\n        assert_eq!(iter.initial_length, 6);\n        assert_eq!(iter.len(), 2);\n        assert!(iter.current_chunk.is_some());\n\n        assert_eq!(iter.next(), Some(4));\n        assert_eq!(iter.consumed, 5);\n        assert_eq!(iter.initial_length, 6);\n        assert_eq!(iter.len(), 1);\n        assert!(iter.current_chunk.is_some());\n\n        assert_eq!(iter.next(), Some(5));\n        assert_eq!(iter.consumed, 6);\n        assert_eq!(iter.initial_length, 6);\n        assert_eq!(iter.len(), 0);\n        assert!(iter.current_chunk.is_some());\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/services/messages/multipart/mod.rs",
    "content": "mod buffer;\nmod service;\n\nuse std::task::{Context, Poll};\n\nuse futures::future::TryFutureExt;\nuse tower::{buffer::Buffer, Service, ServiceBuilder};\n\nuse crate::services::messages::ServiceError;\nuse xaynet_core::message::Message;\n\ntype Inner = Buffer<service::MultipartHandler, Message>;\n\n#[derive(Clone)]\npub struct MultipartHandler(Inner);\n\nimpl Service<Message> for MultipartHandler {\n    type Response = Option<Message>;\n    type Error = ServiceError;\n    #[allow(clippy::type_complexity)]\n    type Future = futures::future::MapErr<\n        <Inner as Service<Message>>::Future,\n        fn(<Inner as Service<Message>>::Error) -> ServiceError,\n    >;\n\n    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {\n        <Inner as Service<Message>>::poll_ready(&mut self.0, cx).map_err(ServiceError::from)\n    }\n\n    fn call(&mut self, req: Message) -> Self::Future {\n        <<Inner as Service<Message>>::Future>::map_err(self.0.call(req), ServiceError::from)\n    }\n}\n\nimpl MultipartHandler {\n    pub fn new() -> Self {\n        Self(\n            ServiceBuilder::new()\n                .buffer(100)\n                .service(service::MultipartHandler::new()),\n        )\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/services/messages/multipart/service.rs",
    "content": "use std::{\n    collections::{BTreeMap, HashMap},\n    task::Poll,\n};\n\nuse futures::{\n    future::{self, Ready},\n    task::Context,\n};\nuse tower::Service;\nuse tracing::{debug, trace, warn};\n\nuse crate::services::messages::{multipart::buffer::MultipartMessageBuffer, ServiceError};\nuse xaynet_core::{\n    crypto::{PublicEncryptKey, PublicSigningKey},\n    message::{Chunk, DecodeError, FromBytes, Message, Payload, Sum, Sum2, Tag, Update},\n};\n\n/// A `MessageBuilder` stores chunks of a multipart message. Once it\n/// has all the chunks, it can be consumed and turned into a\n/// full-blown [`Message`] (see [`into_message()`]).\n///\n/// [`into_message()`]: MessageBuilder::into_message\n#[derive(Debug)]\n#[cfg_attr(test, derive(Clone))]\npub struct MessageBuilder {\n    /// Public key of the participant sending the message\n    participant_pk: PublicSigningKey,\n    /// Public key of the coordinator\n    coordinator_pk: PublicEncryptKey,\n    /// Message type\n    tag: Tag,\n    /// The ID of the last chunk is actually the total number of\n    /// chunks this message is made of.\n    last_chunk_id: Option<u16>,\n    /// Chunks, ordered by ID\n    data: BTreeMap<u16, Vec<u8>>,\n}\n\nimpl MessageBuilder {\n    /// Create a new [`MessageBuilder`] that contains no chunk.\n    fn new(tag: Tag, participant_pk: PublicSigningKey, coordinator_pk: PublicEncryptKey) -> Self {\n        MessageBuilder {\n            tag,\n            participant_pk,\n            coordinator_pk,\n            data: BTreeMap::new(),\n            last_chunk_id: None,\n        }\n    }\n\n    /// Return `true` if the message is complete, _i.e._ if the\n    /// builder holds all the chunks.\n    fn has_all_chunks(&self) -> bool {\n        self.last_chunk_id\n            .map(|last_chunk_id| {\n                // The IDs start at 0, hence the + 1\n                self.data.len() >= (last_chunk_id as usize + 1)\n            })\n            .unwrap_or(false)\n    }\n\n    /// Add a chunk.\n    fn add_chunk(&mut self, chunk: Chunk) {\n        let Chunk { id, last, data, .. } = chunk;\n        if last {\n            self.last_chunk_id = Some(id);\n        }\n        self.data.insert(id, data);\n    }\n\n    /// Aggregate all the chunks. This method should only be called\n    /// when all the chunks are here, otherwise the aggregated message\n    /// will be invalid.\n    fn into_message(self) -> Result<Message, DecodeError> {\n        let mut bytes = MultipartMessageBuffer::from(self.data);\n        let payload = match self.tag {\n            Tag::Sum => Sum::from_byte_stream(&mut bytes).map(Into::into)?,\n            Tag::Update => Update::from_byte_stream(&mut bytes).map(Into::into)?,\n            Tag::Sum2 => Sum2::from_byte_stream(&mut bytes).map(Into::into)?,\n        };\n        let message = Message {\n            signature: None,\n            participant_pk: self.participant_pk,\n            coordinator_pk: self.coordinator_pk,\n            tag: self.tag,\n            is_multipart: false,\n            payload,\n        };\n        Ok(message)\n    }\n}\n\n/// [`MessageId`] uniquely identifies a multipart message by its ID\n/// (which uniquely identify a message _for a given participant_), and\n/// the participant public key.\n#[derive(Debug, Hash, Eq, PartialEq, Clone)]\npub struct MessageId {\n    message_id: u16,\n    participant_pk: PublicSigningKey,\n}\n\n/// A service that handles multipart messages.\npub struct MultipartHandler {\n    message_builders: HashMap<MessageId, MessageBuilder>,\n}\n\nimpl MultipartHandler {\n    #[allow(dead_code)]\n    pub fn new() -> Self {\n        Self {\n            message_builders: HashMap::new(),\n        }\n    }\n}\n\nimpl Service<Message> for MultipartHandler {\n    type Response = Option<Message>;\n    type Error = ServiceError;\n    type Future = Ready<Result<Self::Response, Self::Error>>;\n\n    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {\n        Poll::Ready(Ok(()))\n    }\n\n    fn call(&mut self, message: Message) -> Self::Future {\n        // If the message doesn't have the multipart flag, this\n        // service has nothing to do with it.\n        if !message.is_multipart {\n            trace!(\"message is not multipart, nothing to do\");\n            return ready_ok(Some(message));\n        }\n\n        debug!(\"handling multipart message\");\n        if let Message {\n            tag,\n            participant_pk,\n            coordinator_pk,\n            payload: Payload::Chunk(chunk),\n            ..\n        } = message\n        {\n            let id = MessageId {\n                message_id: chunk.message_id,\n                participant_pk,\n            };\n            // If we don't have a partial message for this ID, create\n            // an empty one.\n            let mp_message = self.message_builders.entry(id.clone()).or_insert_with(|| {\n                debug!(\"new multipart message (id = {})\", id.message_id);\n                MessageBuilder::new(tag, participant_pk, coordinator_pk)\n            });\n            // Add the chunk to the partial message\n            mp_message.add_chunk(chunk);\n\n            // Check if the message is complete, and if so parse it\n            // and return it\n            if mp_message.has_all_chunks() {\n                debug!(\"received the final message chunk, now parsing the full message\");\n                // This entry exists, because `mp_message` above\n                // refers to it, so it's ok to unwrap.\n                match self.message_builders.remove(&id).unwrap().into_message() {\n                    Ok(message) => {\n                        debug!(\"multipart message succesfully parsed\");\n                        ready_ok(Some(message))\n                    }\n                    Err(e) => {\n                        warn!(\"invalid multipart message: {}\", e);\n                        ready_err(ServiceError::Parsing(e))\n                    }\n                }\n            } else {\n                ready_ok(None)\n            }\n        } else {\n            // This cannot happen, because parsing have fail\n            panic!(\"multipart flag is set but payload is not a multipart message\");\n        }\n    }\n}\n\nfn ready_ok<T, E>(t: T) -> Ready<Result<T, E>> {\n    future::ready(Ok(t))\n}\n\nfn ready_err<T, E>(e: E) -> Ready<Result<T, E>> {\n    future::ready(Err(e))\n}\n\n#[cfg(test)]\nmod tests {\n    use std::iter;\n\n    use tokio_test::assert_ready;\n    use tower_test::mock::Spawn;\n    use xaynet_core::crypto::{ByteObject, PublicEncryptKey, Signature};\n\n    use super::*;\n\n    fn spawn_svc() -> Spawn<MultipartHandler> {\n        Spawn::new(MultipartHandler::new())\n    }\n\n    fn sum() -> (Vec<u8>, Sum) {\n        let mut start_byte: u8 = 0xff;\n        let f = move || {\n            start_byte = start_byte.wrapping_add(1) & 0b_0001_1111;\n            Some(start_byte)\n        };\n        let bytes: Vec<u8> = iter::from_fn(f)\n            .take(PublicEncryptKey::LENGTH + Signature::LENGTH)\n            .collect();\n\n        let sum = Sum {\n            sum_signature: Signature::from_slice(&bytes[..Signature::LENGTH]).unwrap(),\n            ephm_pk: PublicEncryptKey::from_slice(&bytes[Signature::LENGTH..]).unwrap(),\n        };\n        (bytes, sum)\n    }\n\n    fn message_builder() -> MessageBuilder {\n        let participant_pk = PublicSigningKey::zeroed();\n        let coordinator_pk = PublicEncryptKey::zeroed();\n        let tag = Tag::Sum;\n        MessageBuilder::new(tag, participant_pk, coordinator_pk)\n    }\n\n    fn chunks(mut data: Vec<u8>) -> (Chunk, Chunk, Chunk, Chunk, Chunk) {\n        // Chunk 1: 1 byte\n        // Chunk 2: 2 bytes\n        // Chunk 3: 3 bytes\n        // Chunk 4: 4 bytes\n        // Chunk 5: 96 - (1 + 2 + 3 + 4) = 86 bytes\n\n        assert_eq!(data.len(), 96);\n\n        // 96 - 10 = 86, remains 10\n        let data5 = data.split_off(10);\n        assert_eq!(data5.len(), 86);\n        assert_eq!(data.len(), 10);\n\n        // 10 - 6 = 4, remains 6\n        let data4 = data.split_off(6);\n        assert_eq!(data4.len(), 4);\n        assert_eq!(data.len(), 6);\n\n        // 6 - 3 = 3, remains 3\n        let data3 = data.split_off(3);\n        assert_eq!(data3.len(), 3);\n        assert_eq!(data.len(), 3);\n\n        // 3 - 1 = 2, remains 1\n        let data2 = data.split_off(1);\n        assert_eq!(data2.len(), 2);\n        assert_eq!(data.len(), 1);\n\n        let chunk1 = Chunk {\n            id: 0,\n            message_id: 1234,\n            last: false,\n            data,\n        };\n        let chunk2 = Chunk {\n            id: 1,\n            message_id: 1234,\n            last: false,\n            data: data2,\n        };\n        let chunk3 = Chunk {\n            id: 2,\n            message_id: 1234,\n            last: false,\n            data: data3,\n        };\n        let chunk4 = Chunk {\n            id: 3,\n            message_id: 1234,\n            last: false,\n            data: data4,\n        };\n        let chunk5 = Chunk {\n            id: 4,\n            message_id: 1234,\n            last: true,\n            data: data5,\n        };\n        (chunk1, chunk2, chunk3, chunk4, chunk5)\n    }\n\n    #[test]\n    fn test_message_builder_in_order() {\n        let mut msg = message_builder();\n        let (data, sum) = sum();\n        let (c1, c2, c3, c4, c5) = chunks(data);\n\n        assert!(msg.data.is_empty());\n        assert!(msg.last_chunk_id.is_none());\n\n        msg.add_chunk(c1);\n        assert_eq!(msg.data.len(), 1);\n        assert!(msg.last_chunk_id.is_none());\n        assert!(!msg.has_all_chunks());\n\n        msg.add_chunk(c2);\n        assert_eq!(msg.data.len(), 2);\n        assert!(msg.last_chunk_id.is_none());\n        assert!(!msg.has_all_chunks());\n\n        msg.add_chunk(c3);\n        assert_eq!(msg.data.len(), 3);\n        assert!(msg.last_chunk_id.is_none());\n        assert!(!msg.has_all_chunks());\n\n        msg.add_chunk(c4);\n        assert_eq!(msg.data.len(), 4);\n        assert!(msg.last_chunk_id.is_none());\n        assert!(!msg.has_all_chunks());\n\n        msg.add_chunk(c5);\n        assert_eq!(msg.data.len(), 5);\n        assert_eq!(msg.last_chunk_id, Some(4));\n        assert!(msg.has_all_chunks());\n\n        let actual = msg.into_message().unwrap();\n        let expected =\n            Message::new_sum(PublicSigningKey::zeroed(), PublicEncryptKey::zeroed(), sum);\n\n        assert_eq!(actual, expected);\n    }\n\n    #[test]\n    fn test_message_builder_out_of_order() {\n        let mut msg = message_builder();\n        let (data, sum) = sum();\n        let (c1, c2, c3, c4, c5) = chunks(data);\n\n        assert!(msg.data.is_empty());\n        assert!(msg.last_chunk_id.is_none());\n\n        msg.add_chunk(c3);\n        assert_eq!(msg.data.len(), 1);\n        assert!(msg.last_chunk_id.is_none());\n        assert!(!msg.has_all_chunks());\n\n        msg.add_chunk(c1);\n        assert_eq!(msg.data.len(), 2);\n        assert!(msg.last_chunk_id.is_none());\n        assert!(!msg.has_all_chunks());\n\n        msg.add_chunk(c5);\n        assert_eq!(msg.data.len(), 3);\n        assert_eq!(msg.last_chunk_id, Some(4));\n        assert!(!msg.has_all_chunks());\n\n        msg.add_chunk(c2);\n        assert_eq!(msg.data.len(), 4);\n        assert_eq!(msg.last_chunk_id, Some(4));\n        assert!(!msg.has_all_chunks());\n\n        msg.add_chunk(c4);\n        assert_eq!(msg.data.len(), 5);\n        assert_eq!(msg.last_chunk_id, Some(4));\n        assert!(msg.has_all_chunks());\n\n        let actual = msg.into_message().unwrap();\n        let expected =\n            Message::new_sum(PublicSigningKey::zeroed(), PublicEncryptKey::zeroed(), sum);\n\n        assert_eq!(actual, expected);\n    }\n\n    #[tokio::test]\n    async fn message_handler() {\n        let mut task = spawn_svc();\n        assert_ready!(task.poll_ready()).unwrap();\n\n        let coordinator_pk =\n            PublicEncryptKey::from_slice(&[0x00; PublicSigningKey::LENGTH]).unwrap();\n\n        // The payload of the message (and therefore the chunks) will\n        // be the same for the two participants. What must differ is\n        // the participant public key in the header.\n        let (data, sum) = sum();\n        let (c1, c2, c3, c4, c5) = chunks(data.clone());\n\n        // A signing key that identifies a first faked participant.\n        let pk1 = PublicSigningKey::from_slice(&[0x11; PublicSigningKey::LENGTH]).unwrap();\n        // message ID for the message from our fake participant identified by `pk1`\n        let message_id1 = MessageId {\n            message_id: 1234,\n            participant_pk: pk1,\n        };\n        // function that take a data chunk and create Chunk message\n        // with `pk1` as participant public key in the header\n        let make_message1 =\n            |chunk: &Chunk| Message::new_multipart(pk1, coordinator_pk, chunk.clone(), Tag::Sum);\n\n        // Do the same thing to fake a second participant: generate a\n        // public key, a message ID, and a function to create messages\n        // originating from that participant\n        let pk2 = PublicSigningKey::from_slice(&[0x22; PublicSigningKey::LENGTH]).unwrap();\n        let message_id2 = MessageId {\n            message_id: 1234,\n            participant_pk: pk2,\n        };\n        let make_message2 =\n            |chunk: &Chunk| Message::new_multipart(pk2, coordinator_pk, chunk.clone(), Tag::Sum);\n\n        // Start of the actual test. Notice that we send the chunks\n        // out of order.\n\n        assert!(task.call(make_message1(&c3)).await.unwrap().is_none());\n        assert_eq!(task.get_ref().message_builders.len(), 1);\n        let builder = task.get_ref().message_builders.get(&message_id1).unwrap();\n        assert_eq!(builder.data.len(), 1);\n\n        assert!(task.call(make_message2(&c3)).await.unwrap().is_none());\n        assert_eq!(task.get_ref().message_builders.len(), 2);\n        let builder = task.get_ref().message_builders.get(&message_id2).unwrap();\n        assert_eq!(builder.data.len(), 1);\n\n        assert!(task.call(make_message1(&c5)).await.unwrap().is_none());\n        assert!(task.call(make_message2(&c5)).await.unwrap().is_none());\n\n        assert!(task.call(make_message1(&c1)).await.unwrap().is_none());\n        assert!(task.call(make_message2(&c1)).await.unwrap().is_none());\n\n        assert!(task.call(make_message1(&c4)).await.unwrap().is_none());\n        assert!(task.call(make_message2(&c4)).await.unwrap().is_none());\n\n        let builder = task.get_ref().message_builders.get(&message_id1).unwrap();\n        assert_eq!(builder.data.len(), 4);\n\n        let builder = task.get_ref().message_builders.get(&message_id2).unwrap();\n        assert_eq!(builder.data.len(), 4);\n\n        let res1 = task.call(make_message1(&c2)).await.unwrap().unwrap();\n        let res2 = task.call(make_message2(&c2)).await.unwrap().unwrap();\n\n        assert!(task.get_ref().message_builders.get(&message_id1).is_none());\n        assert!(task.get_ref().message_builders.get(&message_id2).is_none());\n\n        assert_eq!(res1, Message::new_sum(pk1, coordinator_pk, sum.clone()));\n        assert_eq!(res2, Message::new_sum(pk2, coordinator_pk, sum.clone()));\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/services/messages/state_machine.rs",
    "content": "use std::task::Poll;\n\nuse futures::task::Context;\nuse tower::Service;\nuse xaynet_core::message::Message;\n\nuse crate::{\n    services::messages::{BoxedServiceFuture, ServiceError},\n    state_machine::requests::RequestSender,\n};\n\n/// A service that hands the requests to the [`StateMachine`] that runs in the background.\n///\n/// [`StateMachine`]: crate::state_machine::StateMachine\n#[derive(Debug, Clone)]\npub struct StateMachine {\n    handle: RequestSender,\n}\n\nimpl StateMachine {\n    /// Create a new service with the given handle for forwarding\n    /// requests to the state machine. The handle should be obtained\n    /// via [`init()`].\n    ///\n    /// [`init()`]: crate::state_machine::initializer::StateMachineInitializer::init\n    pub fn new(handle: RequestSender) -> Self {\n        Self { handle }\n    }\n}\n\nimpl Service<Message> for StateMachine {\n    type Response = ();\n    type Error = ServiceError;\n    type Future = BoxedServiceFuture<Self::Response, Self::Error>;\n\n    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {\n        Poll::Ready(Ok(()))\n    }\n\n    fn call(&mut self, req: Message) -> Self::Future {\n        let handle = self.handle.clone();\n        Box::pin(async move {\n            handle\n                .request(req.into(), tracing::Span::none())\n                .await\n                .map_err(ServiceError::StateMachine)\n        })\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/services/messages/task_validator.rs",
    "content": "use std::task::Poll;\n\nuse futures::{future, task::Context};\nuse tower::Service;\n\nuse crate::{\n    services::messages::ServiceError,\n    state_machine::events::{EventListener, EventSubscriber},\n};\nuse xaynet_core::{\n    common::RoundParameters,\n    crypto::ByteObject,\n    message::{Message, Payload},\n};\n\n/// A service for performing sanity checks and preparing incoming\n/// requests to be handled by the state machine.\n#[derive(Clone, Debug)]\npub struct TaskValidator {\n    params_listener: EventListener<RoundParameters>,\n}\n\nimpl TaskValidator {\n    pub fn new(subscriber: &EventSubscriber) -> Self {\n        Self {\n            params_listener: subscriber.params_listener(),\n        }\n    }\n}\n\nimpl Service<Message> for TaskValidator {\n    type Response = Message;\n    type Error = ServiceError;\n    type Future = future::Ready<Result<Self::Response, Self::Error>>;\n\n    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {\n        Poll::Ready(Ok(()))\n    }\n\n    fn call(&mut self, message: Message) -> Self::Future {\n        let (sum_signature, update_signature) = match message.payload {\n            Payload::Sum(ref sum) => (sum.sum_signature, None),\n            Payload::Update(ref update) => (update.sum_signature, Some(update.update_signature)),\n            Payload::Sum2(ref sum2) => (sum2.sum_signature, None),\n            _ => return future::ready(Err(ServiceError::UnexpectedMessage)),\n        };\n        let params = self.params_listener.get_latest().event;\n        let seed = params.seed.as_slice();\n\n        // Check whether the participant is eligible for the sum task\n        let has_valid_sum_signature = message\n            .participant_pk\n            .verify_detached(&sum_signature, &[seed, b\"sum\"].concat());\n        let is_summer = has_valid_sum_signature && sum_signature.is_eligible(params.sum);\n\n        // Check whether the participant is eligible for the update task\n        let has_valid_update_signature = update_signature\n            .map(|sig| {\n                message\n                    .participant_pk\n                    .verify_detached(&sig, &[seed, b\"update\"].concat())\n            })\n            .unwrap_or(false);\n        let is_updater = !is_summer\n            && has_valid_update_signature\n            && update_signature\n                .map(|sig| sig.is_eligible(params.update))\n                .unwrap_or(false);\n\n        match message.payload {\n            Payload::Sum(_) | Payload::Sum2(_) => {\n                if is_summer {\n                    future::ready(Ok(message))\n                } else {\n                    future::ready(Err(ServiceError::NotSumEligible))\n                }\n            }\n            Payload::Update(_) => {\n                if is_updater {\n                    future::ready(Ok(message))\n                } else {\n                    future::ready(Err(ServiceError::NotUpdateEligible))\n                }\n            }\n            _ => future::ready(Err(ServiceError::UnexpectedMessage)),\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use tokio_test::assert_ready;\n    use tower_test::mock::Spawn;\n\n    use crate::{\n        services::tests::utils,\n        state_machine::{\n            events::{EventPublisher, EventSubscriber},\n            phases::PhaseName,\n        },\n    };\n\n    use super::*;\n\n    fn spawn_svc() -> (EventPublisher, EventSubscriber, Spawn<TaskValidator>) {\n        let (publisher, subscriber) = utils::new_event_channels();\n        let task = Spawn::new(TaskValidator::new(&subscriber));\n        (publisher, subscriber, task)\n    }\n\n    #[tokio::test]\n    async fn test_sum_ok() {\n        let (mut publisher, subscriber, mut task) = spawn_svc();\n\n        let mut round_params = subscriber.params_listener().get_latest().event;\n\n        // make sure everyone is eligible\n        round_params.sum = 1.0;\n\n        publisher.broadcast_params(round_params.clone());\n        publisher.broadcast_phase(PhaseName::Sum);\n\n        let (message, _) = utils::new_sum_message(&round_params);\n\n        assert_ready!(task.poll_ready()).unwrap();\n        let resp = task.call(message.clone()).await.unwrap();\n        assert_eq!(resp, message);\n    }\n\n    #[tokio::test]\n    async fn test_sum_not_eligible() {\n        let (mut publisher, subscriber, mut task) = spawn_svc();\n\n        let mut round_params = subscriber.params_listener().get_latest().event;\n\n        // make sure no-one is eligible\n        round_params.sum = 0.0;\n\n        publisher.broadcast_params(round_params.clone());\n        publisher.broadcast_phase(PhaseName::Sum);\n\n        let (message, _) = utils::new_sum_message(&round_params);\n\n        assert_ready!(task.poll_ready()).unwrap();\n        let err = task.call(message).await.unwrap_err();\n        match err {\n            ServiceError::NotSumEligible => {}\n            _ => panic!(\"expected ServiceError::NotSumEligible got {:?}\", err),\n        }\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/services/mod.rs",
    "content": "//! This module implements the services the PET protocol provides.\n//!\n//! There are two main types of services:\n//!\n//! - the services for fetching data broadcasted by the state\n//!   machine. These services are implemented in the [`fetchers`]\n//!   module\n//! - the services for processing PET message are provided by the\n//!   [`messages`] module.\n\npub mod fetchers;\npub mod messages;\n\n#[cfg(test)]\nmod tests;\n"
  },
  {
    "path": "rust/xaynet-server/src/services/tests/fetchers.rs",
    "content": "use std::{collections::HashMap, sync::Arc};\n\nuse tokio_test::assert_ready;\nuse tower_test::mock::Spawn;\n\nuse crate::{\n    services::{\n        fetchers::{\n            ModelRequest,\n            ModelService,\n            RoundParamsRequest,\n            RoundParamsService,\n            SeedDictRequest,\n            SeedDictService,\n            SumDictRequest,\n            SumDictService,\n        },\n        tests::utils::{mask_config, new_event_channels},\n    },\n    state_machine::events::{DictionaryUpdate, ModelUpdate},\n};\nuse xaynet_core::{\n    common::{RoundParameters, RoundSeed},\n    crypto::{ByteObject, PublicEncryptKey, PublicSigningKey},\n    mask::{EncryptedMaskSeed, Model},\n    SeedDict,\n    SumDict,\n    UpdateSeedDict,\n};\n\n#[tokio::test]\nasync fn test_model_svc() {\n    let (mut publisher, subscriber) = new_event_channels();\n\n    let mut task = Spawn::new(ModelService::new(&subscriber));\n    assert_ready!(task.poll_ready()).unwrap();\n\n    let resp = task.call(ModelRequest).await;\n    assert_eq!(resp, Ok(None));\n\n    let model = Arc::new(Model::from(vec![]));\n    publisher.broadcast_model(ModelUpdate::New(model.clone()));\n    assert_ready!(task.poll_ready()).unwrap();\n    let resp = task.call(ModelRequest).await;\n    assert_eq!(resp, Ok(Some(model)));\n\n    publisher.broadcast_model(ModelUpdate::Invalidate);\n    assert_ready!(task.poll_ready()).unwrap();\n    let resp = task.call(ModelRequest).await;\n    assert_eq!(resp, Ok(None));\n}\n\n#[tokio::test]\nasync fn test_round_params_svc() {\n    let (mut publisher, subscriber) = new_event_channels();\n    let initial_params = subscriber.params_listener().get_latest().event;\n\n    let mut task = Spawn::new(RoundParamsService::new(&subscriber));\n    assert_ready!(task.poll_ready()).unwrap();\n\n    let resp = task.call(RoundParamsRequest).await;\n    assert_eq!(resp, Ok(initial_params));\n\n    let params = RoundParameters {\n        pk: PublicEncryptKey::fill_with(0x11),\n        sum: 0.42,\n        update: 0.42,\n        seed: RoundSeed::fill_with(0x11),\n        mask_config: mask_config().into(),\n        model_length: 42,\n    };\n    publisher.broadcast_params(params.clone());\n    assert_ready!(task.poll_ready()).unwrap();\n    let resp = task.call(RoundParamsRequest).await;\n    assert_eq!(resp, Ok(params));\n}\n\nfn dummy_seed_dict() -> SeedDict {\n    let mut dict = HashMap::new();\n    dict.insert(PublicSigningKey::fill_with(0xaa), dummy_update_dict());\n    dict.insert(PublicSigningKey::fill_with(0xbb), dummy_update_dict());\n    dict\n}\n\nfn dummy_update_dict() -> UpdateSeedDict {\n    let mut dict = HashMap::new();\n    dict.insert(\n        PublicSigningKey::fill_with(0x11),\n        EncryptedMaskSeed::fill_with(0x11),\n    );\n    dict.insert(\n        PublicSigningKey::fill_with(0x22),\n        EncryptedMaskSeed::fill_with(0x22),\n    );\n    dict\n}\n\n#[tokio::test]\nasync fn test_seed_dict_svc() {\n    let (mut publisher, subscriber) = new_event_channels();\n\n    let mut task = Spawn::new(SeedDictService::new(&subscriber));\n    assert_ready!(task.poll_ready()).unwrap();\n\n    let resp = task.call(SeedDictRequest).await;\n    assert_eq!(resp, Ok(None));\n\n    let seed_dict = Arc::new(dummy_seed_dict());\n    publisher.broadcast_seed_dict(DictionaryUpdate::New(seed_dict.clone()));\n    assert_ready!(task.poll_ready()).unwrap();\n    let resp = task.call(SeedDictRequest).await;\n    assert_eq!(resp, Ok(Some(seed_dict)));\n\n    publisher.broadcast_seed_dict(DictionaryUpdate::Invalidate);\n    assert_ready!(task.poll_ready()).unwrap();\n    let resp = task.call(SeedDictRequest).await;\n    assert_eq!(resp, Ok(None));\n}\n\nfn dummy_sum_dict() -> SumDict {\n    let mut dict = HashMap::new();\n    dict.insert(\n        PublicSigningKey::fill_with(0xaa),\n        PublicEncryptKey::fill_with(0xcc),\n    );\n    dict.insert(\n        PublicSigningKey::fill_with(0xbb),\n        PublicEncryptKey::fill_with(0xdd),\n    );\n    dict\n}\n\n#[tokio::test]\nasync fn test_sum_dict_svc() {\n    let (mut publisher, subscriber) = new_event_channels();\n\n    let mut task = Spawn::new(SumDictService::new(&subscriber));\n    assert_ready!(task.poll_ready()).unwrap();\n\n    let resp = task.call(SumDictRequest).await;\n    assert_eq!(resp, Ok(None));\n\n    let sum_dict = Arc::new(dummy_sum_dict());\n    publisher.broadcast_sum_dict(DictionaryUpdate::New(sum_dict.clone()));\n    assert_ready!(task.poll_ready()).unwrap();\n    let resp = task.call(SumDictRequest).await;\n    assert_eq!(resp, Ok(Some(sum_dict)));\n\n    publisher.broadcast_sum_dict(DictionaryUpdate::Invalidate);\n    assert_ready!(task.poll_ready()).unwrap();\n    let resp = task.call(SumDictRequest).await;\n    assert_eq!(resp, Ok(None));\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/services/tests/mod.rs",
    "content": "mod fetchers;\npub mod utils;\n"
  },
  {
    "path": "rust/xaynet-server/src/services/tests/utils.rs",
    "content": "use crate::state_machine::{\n    events::{EventPublisher, EventSubscriber, ModelUpdate},\n    phases::PhaseName,\n};\nuse xaynet_core::{\n    common::{RoundParameters, RoundSeed},\n    crypto::{ByteObject, EncryptKeyPair, PublicEncryptKey, SigningKeyPair},\n    mask::{self, MaskConfig},\n    message::{Message, Sum},\n};\n\npub fn mask_config() -> MaskConfig {\n    MaskConfig {\n        group_type: mask::GroupType::Integer,\n        data_type: mask::DataType::F32,\n        bound_type: mask::BoundType::B0,\n        model_type: mask::ModelType::M3,\n    }\n}\n\n/// Create an [`EventPublisher`]/[`EventSubscriber`] pair with default\n/// values similar to those produced in practice when instantiating a\n/// new coordinator.\npub fn new_event_channels() -> (EventPublisher, EventSubscriber) {\n    let keys = EncryptKeyPair::generate();\n    let params = RoundParameters {\n        pk: keys.public,\n        sum: 0.0,\n        update: 0.0,\n        seed: RoundSeed::generate(),\n        mask_config: mask_config().into(),\n        model_length: 0,\n    };\n    let phase = PhaseName::Idle;\n    let round_id = 0;\n    let model = ModelUpdate::Invalidate;\n    EventPublisher::init(round_id, keys, params, phase, model)\n}\n\n/// Simulate a participant generating keys and crafting a valid sum\n/// message for the given round parameters. The keys generated by the\n/// participants are returned along with the message.\npub fn new_sum_message(round_params: &RoundParameters) -> (Message, SigningKeyPair) {\n    let signing_keys = SigningKeyPair::generate();\n    let sum = Sum {\n        sum_signature: signing_keys\n            .secret\n            .sign_detached(&[round_params.seed.as_slice(), b\"sum\"].concat()),\n        ephm_pk: PublicEncryptKey::generate(),\n    };\n    let message = Message::new_sum(signing_keys.public, round_params.pk, sum);\n    (message, signing_keys)\n}\n\n/// Sign and encrypt the given message using the given round\n/// parameters and particpant keys.\npub fn encrypt_message(\n    message: &Message,\n    round_params: &RoundParameters,\n    participant_signing_keys: &SigningKeyPair,\n) -> Vec<u8> {\n    let serialized = serialize_message(message, participant_signing_keys);\n    round_params.pk.encrypt(&serialized[..])\n}\n\npub fn serialize_message(message: &Message, participant_signing_keys: &SigningKeyPair) -> Vec<u8> {\n    let mut buf = vec![0; message.buffer_length()];\n    message.to_bytes(&mut buf, &participant_signing_keys.secret);\n    buf\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/settings/mod.rs",
    "content": "//! Loading and validation of settings.\n//!\n//! Values defined in the configuration file can be overridden by environment variables. Examples of\n//! configuration files can be found in the `configs/` directory located in the repository root.\n\n#[cfg(feature = \"tls\")]\nuse std::path::PathBuf;\nuse std::{fmt, path::Path};\n\nuse config::{Config, ConfigError, Environment, File};\nuse displaydoc::Display;\nuse redis::{ConnectionInfo, IntoConnectionInfo};\nuse serde::{\n    de::{self, Deserializer, Visitor},\n    Deserialize,\n};\nuse thiserror::Error;\nuse tracing_subscriber::filter::EnvFilter;\nuse validator::{Validate, ValidationError, ValidationErrors};\n\nuse xaynet_core::{\n    mask::{BoundType, DataType, GroupType, MaskConfig, ModelType},\n    message::{SUM_COUNT_MIN, UPDATE_COUNT_MIN},\n};\n\n#[cfg(feature = \"model-persistence\")]\n#[cfg_attr(docsrs, doc(cfg(feature = \"model-persistence\")))]\npub mod s3;\n#[cfg(feature = \"model-persistence\")]\npub use self::{s3::RestoreSettings, s3::S3BucketsSettings, s3::S3Settings};\n\n#[derive(Debug, Display, Error)]\n/// An error related to loading and validation of settings.\npub enum SettingsError {\n    /// Configuration loading failed: {0}.\n    Loading(#[from] ConfigError),\n    /// Validation failed: {0}.\n    Validation(#[from] ValidationErrors),\n}\n\n#[derive(Debug, Validate, Deserialize)]\n/// The combined settings.\n///\n/// Each section in the configuration file corresponds to the identically named settings field.\npub struct Settings {\n    pub api: ApiSettings,\n    #[validate]\n    pub pet: PetSettings,\n    pub mask: MaskSettings,\n    pub log: LoggingSettings,\n    pub model: ModelSettings,\n    #[validate]\n    pub metrics: MetricsSettings,\n    pub redis: RedisSettings,\n    #[cfg(feature = \"model-persistence\")]\n    #[validate]\n    pub s3: S3Settings,\n    #[cfg(feature = \"model-persistence\")]\n    #[validate]\n    pub restore: RestoreSettings,\n    #[serde(default)]\n    pub trust_anchor: TrustAnchorSettings,\n}\n\nimpl Settings {\n    /// Loads and validates the settings via a configuration file.\n    ///\n    /// # Errors\n    /// Fails when the loading of the configuration file or its validation failed.\n    pub fn new(path: impl AsRef<Path>) -> Result<Self, SettingsError> {\n        let settings: Settings = Self::load(path)?;\n        settings.validate()?;\n        Ok(settings)\n    }\n\n    fn load(path: impl AsRef<Path>) -> Result<Self, ConfigError> {\n        Config::builder()\n            .add_source(File::from(path.as_ref()))\n            .add_source(Environment::with_prefix(\"xaynet\").separator(\"__\"))\n            .build()?\n            .try_deserialize()\n    }\n}\n\n/// The PET protocol count settings.\n#[derive(Debug, Deserialize, Clone, Copy)]\n#[cfg_attr(test, derive(PartialEq))]\npub struct PetSettingsCount {\n    /// The minimal number of participants selected in a phase.\n    pub min: u64,\n    /// The maximal number of participants selected in a phase.\n    pub max: u64,\n}\n\n/// The PET protocol time settings.\n#[derive(Debug, Deserialize, Clone, Copy)]\n#[cfg_attr(test, derive(PartialEq))]\npub struct PetSettingsTime {\n    /// The minimal amount of time reserved for a phase.\n    pub min: u64,\n    /// The maximal amount of time reserved for a phase.\n    pub max: u64,\n}\n\n/// The PET protocol `sum` phase settings.\n#[derive(Debug, Deserialize, Clone, Copy)]\n#[cfg_attr(test, derive(PartialEq))]\npub struct PetSettingsSum {\n    /// The probability of participants selected for preparing and computing the aggregated mask.\n    /// The value must be between `0` and `1` (i.e. `0 < sum.prob < 1`).\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [pet.sum]\n    /// prob = 0.01\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__PET__SUM__PROB=0.01\n    /// ```\n    pub prob: f64,\n\n    /// The minimal and maximal number of participants selected for preparing the unmasking.\n    ///\n    /// The minimal value must be greater or equal to `1` (i.e. `sum.count.min >= 1`) for the PET\n    /// protocol to function correctly. The maximal value must be greater or equal to the minimal\n    /// value (i.e. `sum.count.min <= sum.count.max`). No more than `sum.count.max` messages will be\n    /// processed in the `sum` phase if the `sum.time.min` has not yet elapsed.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [pet.sum.count]\n    /// min = 10\n    /// max = 100\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__PET__SUM__COUNT__MIN=10\n    /// XAYNET__PET__SUM__COUNT__MAX=100\n    /// ```\n    pub count: PetSettingsCount,\n\n    /// The minimal and maximal amount of time reserved for processing messages in the `sum` phase,\n    /// in seconds.\n    ///\n    /// Once the minimal time has passed, the `sum` phase ends *as soon as* `sum.count.min` messages\n    /// have been processed. Set this higher to allow for the possibility of more than\n    /// `sum.count.min` messages to be processed in the `sum` phase. Set the maximal time lower to\n    /// allow for the processing of `sum.count.min` messages to time-out sooner in the `sum` phase.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [pet.sum.time]\n    /// min = 5\n    /// max = 3600\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__PET__SUM__TIME__MIN=5\n    /// XAYNET__PET__SUM__TIME__MAX=3600\n    /// ```\n    pub time: PetSettingsTime,\n}\n\n/// The PET protocol `update` phase settings.\n#[derive(Debug, Deserialize, Clone, Copy)]\n#[cfg_attr(test, derive(PartialEq))]\npub struct PetSettingsUpdate {\n    /// The probability of participants selected for submitting an updated local model for\n    /// aggregation. The value must be between `0` and `1` (i.e. `0 < update.prob <= 1`). Here, `1`\n    /// is included to be able to express that every participant who is not a sum participant must be\n    /// an update participant.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [pet.update]\n    /// prob = 0.1\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__PET__UPDATE__PROB=0.1\n    /// ```\n    pub prob: f64,\n\n    /// The minimal and maximal number of participants selected for submitting an updated local\n    /// model for aggregation.\n    ///\n    /// The minimal value must be greater or equal to `3` (i.e. `update.count.min >= 3`) for the PET\n    /// protocol to function correctly. The maximal value must be greater or equal to the minimal\n    /// value (i.e. `update.count.min <= update.count.max`). No more than `update.count.max`\n    /// messages will be processed in the `update` phase if the `update.time.min` has not yet\n    /// elapsed.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [pet.update.count]\n    /// min = 100\n    /// max = 10000\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__PET__UPDATE__COUNT__MIN=100\n    /// XAYNET__PET__UPDATE__COUNT__MAX=10000\n    /// ```\n    pub count: PetSettingsCount,\n\n    /// The minimal and maximal amount of time reserved for processing messages in the `update`\n    /// phase, in seconds.\n    ///\n    /// Once the minimal time has passed, the `update` phase ends *as soon as* `update.count.min`\n    /// messages have been processed. Set this higher to allow for the possibility of more than\n    /// `update.count.min` messages to be processed in the `update` phase. Set the maximal time\n    /// lower to allow for the processing of `update.count.min` messages to time-out sooner in the\n    /// `update` phase.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [pet.update.time]\n    /// min = 10\n    /// max = 3600\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__PET__UPDATE__TIME__MIN=10\n    /// XAYNET__PET__UPDATE__TIME__MAX=10\n    /// ```\n    pub time: PetSettingsTime,\n}\n\n/// The PET protocol `sum2` phase settings.\n#[derive(Debug, Deserialize, Clone, Copy)]\n#[cfg_attr(test, derive(PartialEq))]\npub struct PetSettingsSum2 {\n    /// The minimal and maximal number of participants selected for submitting the aggregated masks.\n    ///\n    /// The minimal value must be greater or equal to `1` (i.e. `sum2.count.min >= 1`) for the PET\n    /// protocol to function correctly and less or equal to the maximal value of the `sum` phase\n    /// (i.e. `sum2.count.sum <= sum.count.max`). The maximal value must be greater or equal to the\n    /// minimal value (i.e. `sum2.count.min <= sum2.count.max`) and less or equal to the maximal\n    /// value of the `sum` phase (i.e. `sum2.count.max <= sum.count.max`). No more than\n    /// `sum2.count.max` messages will be processed in the `sum2` phase if the `sum2.time.min` has\n    /// not yet elapsed.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [pet.sum2.count]\n    /// min = 10\n    /// max = 100\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__PET__SUM2__COUNT__MIN=10\n    /// XAYNET__PET__SUM2__COUNT__MAX=100\n    /// ```\n    pub count: PetSettingsCount,\n\n    /// The minimal and maximal amount of time reserved for processing messages in the `sum2` phase,\n    /// in seconds.\n    ///\n    /// Once the minimal time has passed, the `sum2` phase ends *as soon as* `sum2.count.min`\n    /// messages have been processed. Set this higher to allow for the possibility of more than\n    /// `sum2.count.min` messages to be processed in the `sum2` phase. Set the maximal time lower to\n    /// allow for the processing of `sum2.count.min` messages to time-out sooner in the `sum2`\n    /// phase.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [pet.sum2.time]\n    /// min = 5\n    /// max = 3600\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__PET__SUM2__TIME__MIN=5\n    /// XAYNET__PET__SUM2__TIME__MAX=3600\n    /// ```\n    pub time: PetSettingsTime,\n}\n\n/// The PET protocol settings.\n#[derive(Debug, Validate, Deserialize, Clone, Copy)]\n#[cfg_attr(test, derive(PartialEq))]\n#[validate(schema(function = \"validate_pet\"))]\npub struct PetSettings {\n    /// The PET settings for the `sum` phase.\n    pub sum: PetSettingsSum,\n    /// The PET settings for the `update` phase.\n    pub update: PetSettingsUpdate,\n    /// The PET settings for the `sum2` phase.\n    pub sum2: PetSettingsSum2,\n}\n\nimpl PetSettings {\n    /// Checks the PET settings.\n    fn validate_pet(&self) -> Result<(), ValidationError> {\n        self.validate_counts()?;\n        self.validate_times()?;\n        self.validate_probabilities()\n    }\n\n    /// Checks the validity of phase count ranges.\n    fn validate_counts(&self) -> Result<(), ValidationError> {\n        // the validate attribute only accepts literals, therefore we check the invariants here\n        if SUM_COUNT_MIN <= self.sum.count.min\n            && self.sum.count.min <= self.sum.count.max\n            && UPDATE_COUNT_MIN <= self.update.count.min\n            && self.update.count.min <= self.update.count.max\n            && SUM_COUNT_MIN <= self.sum2.count.min\n            && self.sum2.count.min <= self.sum2.count.max\n            && self.sum2.count.min <= self.sum.count.max\n            && self.sum2.count.max <= self.sum.count.max\n        {\n            Ok(())\n        } else {\n            Err(ValidationError::new(\"invalid phase count range(s)\"))\n        }\n    }\n\n    /// Checks the validity of phase time ranges.\n    fn validate_times(&self) -> Result<(), ValidationError> {\n        if self.sum.time.min <= self.sum.time.max\n            && self.update.time.min <= self.update.time.max\n            && self.sum2.time.min <= self.sum2.time.max\n        {\n            Ok(())\n        } else {\n            Err(ValidationError::new(\"invalid phase time range(s)\"))\n        }\n    }\n\n    /// Checks the validity of fraction ranges including pathological cases of deadlocks.\n    fn validate_probabilities(&self) -> Result<(), ValidationError> {\n        if 0. < self.sum.prob\n            && self.sum.prob < 1.\n            && 0. < self.update.prob\n            && self.update.prob <= 1.\n            && 0. < self.sum.prob + self.update.prob - self.sum.prob * self.update.prob\n            && self.sum.prob + self.update.prob - self.sum.prob * self.update.prob <= 1.\n        {\n            Ok(())\n        } else {\n            Err(ValidationError::new(\"starvation\"))\n        }\n    }\n}\n\n/// A wrapper for validate derive.\nfn validate_pet(s: &PetSettings) -> Result<(), ValidationError> {\n    s.validate_pet()\n}\n\n#[derive(Debug, Deserialize, Clone)]\n#[cfg_attr(\n    feature = \"tls\",\n    derive(Validate),\n    validate(schema(function = \"validate_api\"))\n)]\n/// REST API settings.\n///\n/// Requires at least one of the following arguments if the `tls` feature is enabled:\n/// - `tls_certificate` together with `tls_key` for TLS server authentication\n// - `tls_client_auth` for TLS client authentication\npub struct ApiSettings {\n    /// The address to which the REST API should be bound.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [api]\n    /// bind_address = \"0.0.0.0:8081\"\n    /// # or\n    /// bind_address = \"127.0.0.1:8081\"\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__API__BIND_ADDRESS=127.0.0.1:8081\n    /// ```\n    pub bind_address: std::net::SocketAddr,\n\n    #[cfg(feature = \"tls\")]\n    #[cfg_attr(docsrs, doc(cfg(feature = \"tls\")))]\n    /// The path to the server certificate to enable TLS server authentication. Leave this out to\n    /// disable server authentication. If this is present, then `tls_key` must also be present.\n    ///\n    /// Requires the `tls` feature to be enabled.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [api]\n    /// tls_certificate = path/to/tls/files/cert.pem\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__API__TLS_CERTIFICATE=path/to/tls/files/certificate.pem\n    /// ```\n    pub tls_certificate: Option<PathBuf>,\n\n    #[cfg(feature = \"tls\")]\n    #[cfg_attr(docsrs, doc(cfg(feature = \"tls\")))]\n    /// The path to the server private key to enable TLS server authentication. Leave this out to\n    /// disable server authentication. If this is present, then `tls_certificate` must also be\n    /// present.\n    ///\n    /// Requires the `tls` feature to be enabled.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [api]\n    /// tls_key = path/to/tls/files/key.rsa\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__API__TLS_KEY=path/to/tls/files/key.rsa\n    /// ```\n    pub tls_key: Option<PathBuf>,\n\n    #[cfg(feature = \"tls\")]\n    #[cfg_attr(docsrs, doc(cfg(feature = \"tls\")))]\n    /// The path to the trust anchor to enable TLS client authentication. Leave this out to disable\n    /// client authentication.\n    ///\n    /// Requires the `tls` feature to be enabled.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [api]\n    /// tls_client_auth = path/to/tls/files/trust_anchor.pem\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__API__TLS_CLIENT_AUTH=path/to/tls/files/trust_anchor.pem\n    /// ```\n    pub tls_client_auth: Option<PathBuf>,\n}\n\n#[cfg(feature = \"tls\")]\nimpl ApiSettings {\n    /// Checks API settings.\n    fn validate_api(&self) -> Result<(), ValidationError> {\n        match (&self.tls_certificate, &self.tls_key, &self.tls_client_auth) {\n            (Some(_), Some(_), _) | (None, None, Some(_)) => Ok(()),\n            _ => Err(ValidationError::new(\"invalid tls settings\")),\n        }\n    }\n}\n\n/// A wrapper for validate derive.\n#[cfg(feature = \"tls\")]\nfn validate_api(s: &ApiSettings) -> Result<(), ValidationError> {\n    s.validate_api()\n}\n\n#[derive(Debug, Validate, Deserialize, Clone, Copy)]\n#[cfg_attr(test, derive(PartialEq, Eq))]\n/// Masking settings.\npub struct MaskSettings {\n    /// The order of the finite group.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [mask]\n    /// group_type = \"Integer\"\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__MASK__GROUP_TYPE=Integer\n    /// ```\n    pub group_type: GroupType,\n\n    /// The data type of the numbers to be masked.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [mask]\n    /// data_type = \"F32\"\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__MASK__DATA_TYPE=F32\n    /// ```\n    pub data_type: DataType,\n\n    /// The bounds of the numbers to be masked.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [mask]\n    /// bound_type = \"B0\"\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__MASK__BOUND_TYPE=B0\n    /// ```\n    pub bound_type: BoundType,\n\n    /// The maximum number of models to be aggregated.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [mask]\n    /// model_type = \"M3\"\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__MASK__MODEL_TYPE=M3\n    /// ```\n    pub model_type: ModelType,\n}\n\nimpl From<MaskSettings> for MaskConfig {\n    fn from(\n        MaskSettings {\n            group_type,\n            data_type,\n            bound_type,\n            model_type,\n        }: MaskSettings,\n    ) -> MaskConfig {\n        MaskConfig {\n            group_type,\n            data_type,\n            bound_type,\n            model_type,\n        }\n    }\n}\n\n#[derive(Debug, Deserialize, Clone)]\n#[cfg_attr(test, derive(PartialEq))]\n/// Model settings.\npub struct ModelSettings {\n    /// The expected length of the model. The model length corresponds to the number of elements.\n    /// This value is used to validate the uniform length of the submitted models/masks.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [model]\n    /// length = 100\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__MODEL__LENGTH=100\n    /// ```\n    pub length: usize,\n}\n\n#[derive(Debug, Deserialize, Validate)]\n/// Metrics settings.\npub struct MetricsSettings {\n    #[validate]\n    /// Settings for the InfluxDB backend.\n    pub influxdb: InfluxSettings,\n}\n\n#[derive(Debug, Deserialize, Validate)]\n/// InfluxDB settings.\npub struct InfluxSettings {\n    #[validate(url)]\n    /// The URL where InfluxDB is running.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [metrics.influxdb]\n    /// url = \"http://localhost:8086\"\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__METRICS__INFLUXDB__URL=http://localhost:8086\n    /// ```\n    pub url: String,\n\n    /// The InfluxDB database name.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [metrics.influxdb]\n    /// db = \"test\"\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__METRICS__INFLUXDB__DB=test\n    /// ```\n    pub db: String,\n}\n\n#[derive(Debug, Deserialize)]\n/// Redis settings.\npub struct RedisSettings {\n    /// The URL where Redis is running.\n    ///\n    /// The format of the URL is `redis://[<username>][:<passwd>@]<hostname>[:port][/<db>]`.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [redis]\n    /// url = \"redis://127.0.0.1/\"\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__REDIS__URL=redis://127.0.0.1/\n    /// ```\n    #[serde(deserialize_with = \"deserialize_redis_url\")]\n    pub url: ConnectionInfo,\n}\n\nfn deserialize_redis_url<'de, D>(deserializer: D) -> Result<ConnectionInfo, D::Error>\nwhere\n    D: Deserializer<'de>,\n{\n    struct ConnectionInfoVisitor;\n\n    impl<'de> Visitor<'de> for ConnectionInfoVisitor {\n        type Value = ConnectionInfo;\n\n        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {\n            write!(\n                formatter,\n                \"redis://[<username>][:<passwd>@]<hostname>[:port][/<db>]\"\n            )\n        }\n\n        fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>\n        where\n            E: de::Error,\n        {\n            value\n                .into_connection_info()\n                .map_err(|_| de::Error::invalid_value(serde::de::Unexpected::Str(value), &self))\n        }\n    }\n\n    deserializer.deserialize_str(ConnectionInfoVisitor)\n}\n\n#[derive(Debug, Deserialize, Validate)]\n/// Trust anchor settings.\npub struct TrustAnchorSettings {}\n\n// Default value for the global models bucket\nimpl Default for TrustAnchorSettings {\n    fn default() -> Self {\n        Self {}\n    }\n}\n\n#[derive(Debug, Deserialize)]\n/// Logging settings.\npub struct LoggingSettings {\n    /// A comma-separated list of logging directives. More information about logging directives\n    /// can be found [here].\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [log]\n    /// filter = \"info\"\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__LOG__FILTER=info\n    /// ```\n    ///\n    /// [here]: https://docs.rs/tracing-subscriber/0.2.15/tracing_subscriber/filter/struct.EnvFilter.html#directives\n    #[serde(deserialize_with = \"deserialize_env_filter\")]\n    pub filter: EnvFilter,\n}\n\nfn deserialize_env_filter<'de, D>(deserializer: D) -> Result<EnvFilter, D::Error>\nwhere\n    D: Deserializer<'de>,\n{\n    struct EnvFilterVisitor;\n\n    impl<'de> Visitor<'de> for EnvFilterVisitor {\n        type Value = EnvFilter;\n\n        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {\n            write!(formatter, \"a valid tracing filter directive: https://docs.rs/tracing-subscriber/0.2.6/tracing_subscriber/filter/struct.EnvFilter.html#directives\")\n        }\n\n        fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>\n        where\n            E: de::Error,\n        {\n            EnvFilter::try_new(value)\n                .map_err(|_| de::Error::invalid_value(serde::de::Unexpected::Str(value), &self))\n        }\n    }\n\n    deserializer.deserialize_str(EnvFilterVisitor)\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    impl Default for PetSettings {\n        fn default() -> Self {\n            Self {\n                sum: PetSettingsSum {\n                    prob: 0.01,\n                    count: PetSettingsCount { min: 10, max: 100 },\n                    time: PetSettingsTime {\n                        min: 0,\n                        max: 604800,\n                    },\n                },\n                update: PetSettingsUpdate {\n                    prob: 0.1,\n                    count: PetSettingsCount {\n                        min: 100,\n                        max: 10000,\n                    },\n                    time: PetSettingsTime {\n                        min: 0,\n                        max: 604800,\n                    },\n                },\n                sum2: PetSettingsSum2 {\n                    count: PetSettingsCount { min: 10, max: 100 },\n                    time: PetSettingsTime {\n                        min: 0,\n                        max: 604800,\n                    },\n                },\n            }\n        }\n    }\n\n    impl Default for MaskSettings {\n        fn default() -> Self {\n            Self {\n                group_type: GroupType::Prime,\n                data_type: DataType::F32,\n                bound_type: BoundType::B0,\n                model_type: ModelType::M3,\n            }\n        }\n    }\n\n    #[test]\n    fn test_settings_new() {\n        assert!(Settings::new(\"../../configs/config.toml\").is_ok());\n        assert!(Settings::new(\"\").is_err());\n    }\n\n    #[test]\n    fn test_validate_pet() {\n        assert!(PetSettings::default().validate_pet().is_ok());\n    }\n\n    #[test]\n    fn test_validate_pet_counts() {\n        assert_eq!(SUM_COUNT_MIN, 1);\n        assert_eq!(UPDATE_COUNT_MIN, 3);\n\n        let mut pet = PetSettings::default();\n        pet.sum.count.min = 0;\n        assert!(pet.validate().is_err());\n\n        let mut pet = PetSettings::default();\n        pet.sum.count.min = 11;\n        pet.sum.count.max = 10;\n        assert!(pet.validate().is_err());\n\n        let mut pet = PetSettings::default();\n        pet.update.count.min = 2;\n        assert!(pet.validate().is_err());\n\n        let mut pet = PetSettings::default();\n        pet.update.count.min = 11;\n        pet.update.count.max = 10;\n        assert!(pet.validate().is_err());\n\n        let mut pet = PetSettings::default();\n        pet.sum2.count.min = 0;\n        assert!(pet.validate().is_err());\n\n        let mut pet = PetSettings::default();\n        pet.sum2.count.min = 11;\n        pet.sum2.count.max = 10;\n        assert!(pet.validate().is_err());\n\n        let mut pet = PetSettings::default();\n        pet.sum2.count.min = 11;\n        pet.sum.count.max = 10;\n        assert!(pet.validate().is_err());\n\n        let mut pet = PetSettings::default();\n        pet.sum2.count.max = 11;\n        pet.sum.count.max = 10;\n        assert!(pet.validate().is_err());\n    }\n\n    #[test]\n    fn test_validate_pet_times() {\n        let mut pet = PetSettings::default();\n        pet.sum.time.min = 2;\n        pet.sum.time.max = 1;\n        assert!(pet.validate().is_err());\n\n        let mut pet = PetSettings::default();\n        pet.update.time.min = 2;\n        pet.update.time.max = 1;\n        assert!(pet.validate().is_err());\n\n        let mut pet = PetSettings::default();\n        pet.sum2.time.min = 2;\n        pet.sum2.time.max = 1;\n        assert!(pet.validate().is_err());\n    }\n\n    #[test]\n    fn test_validate_pet_probabilities() {\n        let mut pet = PetSettings::default();\n        pet.sum.prob = 0.;\n        assert!(pet.validate().is_err());\n\n        let mut pet = PetSettings::default();\n        pet.sum.prob = 1.;\n        assert!(pet.validate().is_err());\n\n        let mut pet = PetSettings::default();\n        pet.update.prob = 0.;\n        assert!(pet.validate().is_err());\n\n        let mut pet = PetSettings::default();\n        pet.update.prob = 1. + f64::EPSILON;\n        assert!(pet.validate().is_err());\n    }\n\n    #[cfg(feature = \"tls\")]\n    #[test]\n    fn test_validate_api() {\n        let bind_address = ([0, 0, 0, 0], 0).into();\n        let some_path = Some(std::path::PathBuf::new());\n\n        assert!(ApiSettings {\n            bind_address,\n            tls_certificate: some_path.clone(),\n            tls_key: some_path.clone(),\n            tls_client_auth: some_path.clone(),\n        }\n        .validate()\n        .is_ok());\n        assert!(ApiSettings {\n            bind_address,\n            tls_certificate: some_path.clone(),\n            tls_key: some_path.clone(),\n            tls_client_auth: None,\n        }\n        .validate()\n        .is_ok());\n        assert!(ApiSettings {\n            bind_address,\n            tls_certificate: None,\n            tls_key: None,\n            tls_client_auth: some_path.clone(),\n        }\n        .validate()\n        .is_ok());\n\n        assert!(ApiSettings {\n            bind_address,\n            tls_certificate: some_path.clone(),\n            tls_key: None,\n            tls_client_auth: some_path.clone(),\n        }\n        .validate()\n        .is_err());\n        assert!(ApiSettings {\n            bind_address,\n            tls_certificate: None,\n            tls_key: some_path.clone(),\n            tls_client_auth: some_path.clone(),\n        }\n        .validate()\n        .is_err());\n        assert!(ApiSettings {\n            bind_address,\n            tls_certificate: some_path.clone(),\n            tls_key: None,\n            tls_client_auth: None,\n        }\n        .validate()\n        .is_err());\n        assert!(ApiSettings {\n            bind_address,\n            tls_certificate: None,\n            tls_key: some_path,\n            tls_client_auth: None,\n        }\n        .validate()\n        .is_err());\n        assert!(ApiSettings {\n            bind_address,\n            tls_certificate: None,\n            tls_key: None,\n            tls_client_auth: None,\n        }\n        .validate()\n        .is_err());\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/settings/s3.rs",
    "content": "//! S3 settings.\n\nuse std::fmt;\n\nuse fancy_regex::Regex;\nuse rusoto_core::Region;\nuse serde::{\n    de::{self, value, Deserializer, Visitor},\n    Deserialize,\n};\nuse validator::{Validate, ValidationError};\n\n#[derive(Debug, Validate, Deserialize)]\n/// S3 settings.\npub struct S3Settings {\n    /// The [access key ID](https://docs.aws.amazon.com/general/latest/gr/aws-sec-cred-types.html).\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [s3]\n    /// access_key = \"AKIAIOSFODNN7EXAMPLE\"\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__S3__ACCESS_KEY=AKIAIOSFODNN7EXAMPLE\n    /// ```\n    pub access_key: String,\n\n    /// The [secret access key](https://docs.aws.amazon.com/general/latest/gr/aws-sec-cred-types.html).\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [s3]\n    /// secret_access_key = \"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY\"\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__S3__SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY\n    /// ```\n    pub secret_access_key: String,\n\n    /// The Regional AWS endpoint.\n    ///\n    /// The region is specified using the [Region code](https://docs.aws.amazon.com/general/latest/gr/rande.html#regional-endpoints)\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [s3]\n    /// region = [\"eu-west-1\"]\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__S3__REGION=\"eu-west-1\"\n    /// ```\n    ///\n    /// To connect to AWS-compatible services such as Minio, you need to specify a custom region.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [s3]\n    /// region = [\"minio\", \"http://localhost:8000\"]\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__S3__REGION=\"minio http://localhost:8000\"\n    /// ```\n    #[serde(deserialize_with = \"deserialize_s3_region\")]\n    pub region: Region,\n    #[validate]\n    #[serde(default)]\n    pub buckets: S3BucketsSettings,\n}\n\n#[derive(Debug, Validate, Deserialize)]\n/// S3 buckets settings.\npub struct S3BucketsSettings {\n    /// The bucket name in which the global models are stored.\n    /// Defaults to `global-models`.\n    ///\n    /// Please follow the [rules for bucket naming](https://docs.aws.amazon.com/AmazonS3/latest/dev/BucketRestrictions.html)\n    /// when creating the name.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [s3.buckets]\n    /// global_models = \"global-models\"\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__S3__BUCKETS__GLOBAL_MODELS=\"global-models\"\n    /// ```\n    #[validate(custom = \"validate_s3_bucket_name\")]\n    pub global_models: String,\n}\n\n// Default value for the global models bucket\nimpl Default for S3BucketsSettings {\n    fn default() -> Self {\n        Self {\n            global_models: String::from(\"global-models\"),\n        }\n    }\n}\n\n// Validates the bucket name\n// [Rules for AWS bucket naming](https://docs.aws.amazon.com/AmazonS3/latest/dev/BucketRestrictions.html)\nfn validate_s3_bucket_name(bucket_name: &str) -> Result<(), ValidationError> {\n    // https://stackoverflow.com/questions/50480924/regex-for-s3-bucket-name#comment104807676_58248645\n    // I had to use fancy_regex here because the std regex does not support `look-around`\n    let re =\n        Regex::new(r\"(?!^(\\d{1,3}\\.){3}\\d{1,3}$)(^[a-z0-9]([a-z0-9-]*(\\.[a-z0-9])?)*$(?<!\\-))\")\n            .unwrap();\n    match re.is_match(bucket_name) {\n        Ok(true) => Ok(()),\n        Ok(false) => Err(ValidationError::new(\"invalid bucket name\\n See here: https://docs.aws.amazon.com/AmazonS3/latest/dev/BucketRestrictions.html\")),\n        // something went wrong with the regex engine\n        Err(_) => Err(ValidationError::new(\"can not validate bucket name\")),\n    }\n}\n\n// A small wrapper to support the list type for environment variable values.\n// config-rs always converts a environment variable value to a string\n// https://github.com/mehcode/config-rs/blob/master/src/env.rs#L114 .\n// Strings however, are not supported by the deserializer of rusoto_core::Region (only sequences).\n// Therefore we use S3RegionVisitor to implement `visit_str` and thus support\n// the deserialization of rusoto_core::Region from strings.\nfn deserialize_s3_region<'de, D>(deserializer: D) -> Result<Region, D::Error>\nwhere\n    D: Deserializer<'de>,\n{\n    struct S3RegionVisitor;\n\n    impl<'de> Visitor<'de> for S3RegionVisitor {\n        type Value = Region;\n\n        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {\n            formatter.write_str(\"sequence of \\\"name Optional<endpoint>\\\"\")\n        }\n\n        // FIXME: a copy of https://rusoto.github.io/rusoto/src/rusoto_core/region.rs.html#185\n        // I haven't managed to create a sequence and call `self.visit_seq(seq)`.\n        fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>\n        where\n            E: de::Error,\n        {\n            let mut seq = value.split_whitespace();\n\n            let name: &str = seq\n                .next()\n                .ok_or_else(|| de::Error::custom(\"region is missing name\"))?;\n            let endpoint: Option<&str> = seq.next();\n\n            match (name, endpoint) {\n                (name, Some(endpoint)) => Ok(Region::Custom {\n                    name: name.to_string(),\n                    endpoint: endpoint.to_string(),\n                }),\n                (name, None) => name.parse().map_err(de::Error::custom),\n            }\n        }\n\n        // delegate the call for sequences to the deserializer of rusoto_core::Region\n        fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>\n        where\n            A: de::SeqAccess<'de>,\n        {\n            Deserialize::deserialize(value::SeqAccessDeserializer::new(seq))\n        }\n    }\n\n    deserializer.deserialize_any(S3RegionVisitor)\n}\n\n#[derive(Debug, Deserialize, Validate)]\n/// Restore settings.\npub struct RestoreSettings {\n    /// If set to `false`, the restoring of coordinator state is prevented.\n    /// Instead, the state is reset and the coordinator is started with the\n    /// settings of the configuration file.\n    ///\n    /// # Examples\n    ///\n    /// **TOML**\n    /// ```text\n    /// [restore]\n    /// enable = true\n    /// ```\n    ///\n    /// **Environment variable**\n    /// ```text\n    /// XAYNET__RESTORE__ENABLE=false\n    /// ```\n    pub enable: bool,\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::settings::Settings;\n    use config::{Config, ConfigError, Environment, File, FileFormat};\n    use serial_test::serial;\n\n    impl Settings {\n        fn load_from_str(string: &str) -> Result<Self, ConfigError> {\n            Config::builder()\n                .add_source(File::from_str(string, FileFormat::Toml))\n                .add_source(Environment::with_prefix(\"xaynet\").separator(\"__\"))\n                .build()?\n                .try_deserialize()\n        }\n    }\n\n    struct ConfigBuilder {\n        config: String,\n    }\n\n    impl ConfigBuilder {\n        fn new() -> Self {\n            Self {\n                config: String::new(),\n            }\n        }\n\n        fn build(self) -> String {\n            self.config\n        }\n\n        fn with_log(mut self) -> Self {\n            let log = r#\"\n            [log]\n            filter = \"xaynet=debug,http=warn,info\"\n            \"#;\n\n            self.config.push_str(log);\n            self\n        }\n\n        fn with_api(mut self) -> Self {\n            let api = r#\"\n            [api]\n            bind_address = \"127.0.0.1:8081\"\n            tls_certificate = \"/app/ssl/tls.pem\"\n            tls_key = \"/app/ssl/tls.key\"\n            \"#;\n\n            self.config.push_str(api);\n            self\n        }\n\n        fn with_pet(mut self) -> Self {\n            let pet = r#\"\n            [pet.sum]\n            prob = 0.5\n            count = { min = 1, max = 100 }\n            time = { min = 5, max = 3600 }\n\n            [pet.update]\n            prob = 0.9\n            count = { min = 3, max = 10000 }\n            time = { min = 10, max = 3600 }\n\n            [pet.sum2]\n            count = { min = 1, max = 100 }\n            time = { min = 5, max = 3600 }\n            \"#;\n\n            self.config.push_str(pet);\n            self\n        }\n\n        fn with_mask(mut self) -> Self {\n            let mask = r#\"\n            [mask]\n            group_type = \"Prime\"\n            data_type = \"F32\"\n            bound_type = \"B0\"\n            model_type = \"M3\"\n            \"#;\n\n            self.config.push_str(mask);\n            self\n        }\n\n        fn with_model(mut self) -> Self {\n            let model = r#\"\n            [model]\n            length = 4\n            \"#;\n\n            self.config.push_str(model);\n            self\n        }\n\n        fn with_metrics(mut self) -> Self {\n            let metrics = r#\"\n            [metrics.influxdb]\n            url = \"http://influxdb:8086\"\n            db = \"metrics\"\n            \"#;\n\n            self.config.push_str(metrics);\n            self\n        }\n\n        fn with_redis(mut self) -> Self {\n            let redis = r#\"\n            [redis]\n            url = \"redis://127.0.0.1/\"\n            \"#;\n\n            self.config.push_str(redis);\n            self\n        }\n\n        fn with_s3(mut self) -> Self {\n            let s3 = r#\"\n            [s3]\n            access_key = \"minio\"\n            secret_access_key = \"minio123\"\n            region = [\"minio\", \"http://localhost:9000\"]\n            \"#;\n\n            self.config.push_str(s3);\n            self\n        }\n\n        fn with_s3_buckets(mut self) -> Self {\n            let s3_buckets = r#\"\n            [s3.buckets]\n            global_models = \"global-models-toml\"\n            \"#;\n\n            self.config.push_str(s3_buckets);\n            self\n        }\n\n        fn with_restore(mut self) -> Self {\n            let restore = r#\"\n            [restore]\n            enable = true\n            \"#;\n\n            self.config.push_str(restore);\n            self\n        }\n\n        fn with_custom(mut self, custom_config: &str) -> Self {\n            self.config.push_str(custom_config);\n            self\n        }\n    }\n\n    #[test]\n    fn test_validate_s3_bucket_name() {\n        // I took the examples from https://docs.aws.amazon.com/AmazonS3/latest/dev/BucketRestrictions.html\n\n        // valid names\n        assert!(validate_s3_bucket_name(\"docexamplebucket\").is_ok());\n        assert!(validate_s3_bucket_name(\"log-delivery-march-2020\").is_ok());\n        assert!(validate_s3_bucket_name(\"my-hosted-content\").is_ok());\n\n        // valid but not recommended names\n        assert!(validate_s3_bucket_name(\"docexamplewebsite.com\").is_ok());\n        assert!(validate_s3_bucket_name(\"www.docexamplewebsite.com\").is_ok());\n        assert!(validate_s3_bucket_name(\"my.example.s3.bucket\").is_ok());\n\n        // invalid names\n        assert!(validate_s3_bucket_name(\"doc_example_bucket\").is_err());\n        assert!(validate_s3_bucket_name(\"DocExampleBucket\").is_err());\n        assert!(validate_s3_bucket_name(\"doc-example-bucket-\").is_err());\n    }\n\n    #[test]\n    #[serial]\n    fn test_s3_bucket_name_default() {\n        let config = ConfigBuilder::new()\n            .with_log()\n            .with_api()\n            .with_pet()\n            .with_mask()\n            .with_model()\n            .with_metrics()\n            .with_redis()\n            .with_restore()\n            .with_s3()\n            .build();\n\n        let settings = Settings::load_from_str(&config).unwrap();\n        assert_eq!(\n            settings.s3.buckets.global_models,\n            S3BucketsSettings::default().global_models\n        )\n    }\n\n    #[test]\n    #[serial]\n    fn test_s3_bucket_name_toml_overrides_default() {\n        let config = ConfigBuilder::new()\n            .with_log()\n            .with_api()\n            .with_pet()\n            .with_mask()\n            .with_model()\n            .with_metrics()\n            .with_redis()\n            .with_restore()\n            .with_s3()\n            .with_s3_buckets()\n            .build();\n\n        let settings = Settings::load_from_str(&config).unwrap();\n        assert_eq!(settings.s3.buckets.global_models, \"global-models-toml\")\n    }\n\n    #[test]\n    #[serial]\n    fn test_s3_bucket_name_env_overrides_toml_and_default() {\n        let config = ConfigBuilder::new()\n            .with_log()\n            .with_api()\n            .with_pet()\n            .with_mask()\n            .with_model()\n            .with_metrics()\n            .with_redis()\n            .with_restore()\n            .with_s3()\n            .with_s3_buckets()\n            .build();\n\n        std::env::set_var(\"XAYNET__S3__BUCKETS__GLOBAL_MODELS\", \"global-models-env\");\n        let settings = Settings::load_from_str(&config).unwrap();\n        assert_eq!(settings.s3.buckets.global_models, \"global-models-env\");\n        std::env::remove_var(\"XAYNET__S3__BUCKETS__GLOBAL_MODELS\");\n    }\n\n    #[test]\n    #[serial]\n    fn test_s3_bucket_name_env_overrides_default() {\n        let config = ConfigBuilder::new()\n            .with_log()\n            .with_api()\n            .with_pet()\n            .with_mask()\n            .with_model()\n            .with_metrics()\n            .with_redis()\n            .with_restore()\n            .with_s3()\n            .build();\n\n        std::env::set_var(\"XAYNET__S3__BUCKETS__GLOBAL_MODELS\", \"global-models-env\");\n        let settings = Settings::load_from_str(&config).unwrap();\n        assert_eq!(settings.s3.buckets.global_models, \"global-models-env\");\n        std::env::remove_var(\"XAYNET__S3__BUCKETS__GLOBAL_MODELS\");\n    }\n\n    #[test]\n    #[serial]\n    fn test_s3_region_toml() {\n        let region = r#\"\n        [s3]\n        access_key = \"minio\"\n        secret_access_key = \"minio123\"\n        region = [\"eu-west-1\"]\n        \"#;\n\n        let config = ConfigBuilder::new()\n            .with_log()\n            .with_api()\n            .with_pet()\n            .with_mask()\n            .with_model()\n            .with_metrics()\n            .with_redis()\n            .with_restore()\n            .with_custom(region)\n            .build();\n\n        let settings = Settings::load_from_str(&config).unwrap();\n        assert!(matches!(settings.s3.region, Region::EuWest1));\n    }\n\n    #[test]\n    #[serial]\n    fn test_s3_custom_region_toml() {\n        let config = ConfigBuilder::new()\n            .with_log()\n            .with_api()\n            .with_pet()\n            .with_mask()\n            .with_model()\n            .with_metrics()\n            .with_redis()\n            .with_restore()\n            .with_s3()\n            .build();\n\n        let settings = Settings::load_from_str(&config).unwrap();\n        assert!(matches!(\n            settings.s3.region,\n            Region::Custom {\n                name,\n                endpoint\n            } if name == \"minio\" && endpoint == \"http://localhost:9000\"\n        ));\n    }\n\n    #[test]\n    #[serial]\n    fn test_s3_region_env() {\n        let config = ConfigBuilder::new()\n            .with_log()\n            .with_api()\n            .with_pet()\n            .with_mask()\n            .with_model()\n            .with_metrics()\n            .with_redis()\n            .with_restore()\n            .with_s3()\n            .build();\n\n        std::env::set_var(\"XAYNET__S3__REGION\", \"eu-west-1\");\n        let settings = Settings::load_from_str(&config).unwrap();\n        assert!(matches!(settings.s3.region, Region::EuWest1));\n        std::env::remove_var(\"XAYNET__S3__REGION\");\n    }\n\n    #[test]\n    #[serial]\n    fn test_restore() {\n        let no_restore = r#\"\n        [restore]\n        enable = false\n        \"#;\n\n        let config = ConfigBuilder::new()\n            .with_log()\n            .with_api()\n            .with_pet()\n            .with_mask()\n            .with_model()\n            .with_metrics()\n            .with_redis()\n            .with_s3()\n            .with_custom(no_restore)\n            .build();\n\n        let settings = Settings::load_from_str(&config).unwrap();\n        assert!(!settings.restore.enable);\n    }\n\n    #[test]\n    #[serial]\n    fn test_s3_custom_region_env() {\n        let config = ConfigBuilder::new()\n            .with_log()\n            .with_api()\n            .with_pet()\n            .with_mask()\n            .with_model()\n            .with_metrics()\n            .with_redis()\n            .with_restore()\n            .with_s3()\n            .build();\n\n        std::env::set_var(\"XAYNET__S3__REGION\", \"minio-env http://localhost:8000\");\n        let settings = Settings::load_from_str(&config).unwrap();\n        assert!(matches!(\n            settings.s3.region,\n            Region::Custom {\n                name,\n                endpoint\n            } if name == \"minio-env\" && endpoint == \"http://localhost:8000\"\n        ));\n        std::env::remove_var(\"XAYNET__S3__REGION\");\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/coordinator.rs",
    "content": "//! Coordinator state and round parameter types.\n\nuse serde::{Deserialize, Serialize};\n\nuse crate::settings::{\n    MaskSettings,\n    ModelSettings,\n    PetSettings,\n    PetSettingsCount,\n    PetSettingsSum,\n    PetSettingsSum2,\n    PetSettingsTime,\n    PetSettingsUpdate,\n};\nuse xaynet_core::{\n    common::{RoundParameters, RoundSeed},\n    crypto::{ByteObject, EncryptKeyPair},\n    mask::MaskConfig,\n};\n\n/// The phase count parameters.\n#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]\npub struct CountParameters {\n    /// The minimal number of required messages.\n    pub min: u64,\n    /// The maximal number of accepted messages.\n    pub max: u64,\n}\n\nimpl From<PetSettingsCount> for CountParameters {\n    fn from(count: PetSettingsCount) -> Self {\n        let PetSettingsCount { min, max } = count;\n        Self { min, max }\n    }\n}\n\n/// The phase time parameters.\n#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]\npub struct TimeParameters {\n    /// The minimal amount of time (in seconds) reserved for processing messages.\n    pub min: u64,\n    /// The maximal amount of time (in seconds) permitted for processing messages.\n    pub max: u64,\n}\n\nimpl From<PetSettingsTime> for TimeParameters {\n    fn from(time: PetSettingsTime) -> Self {\n        let PetSettingsTime { min, max } = time;\n        Self { min, max }\n    }\n}\n\n/// The phase parameters.\n#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]\npub struct PhaseParameters {\n    /// The number of messages.\n    pub count: CountParameters,\n    /// The amount of time for processing messages.\n    pub time: TimeParameters,\n}\n\nimpl From<PetSettingsSum> for PhaseParameters {\n    fn from(sum: PetSettingsSum) -> Self {\n        let PetSettingsSum { count, time, .. } = sum;\n        Self {\n            count: count.into(),\n            time: time.into(),\n        }\n    }\n}\n\nimpl From<PetSettingsUpdate> for PhaseParameters {\n    fn from(update: PetSettingsUpdate) -> Self {\n        let PetSettingsUpdate { count, time, .. } = update;\n        Self {\n            count: count.into(),\n            time: time.into(),\n        }\n    }\n}\n\nimpl From<PetSettingsSum2> for PhaseParameters {\n    fn from(sum2: PetSettingsSum2) -> Self {\n        let PetSettingsSum2 { count, time } = sum2;\n        Self {\n            count: count.into(),\n            time: time.into(),\n        }\n    }\n}\n\n/// The coordinator state.\n#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]\npub struct CoordinatorState {\n    /// The credentials of the coordinator.\n    pub keys: EncryptKeyPair,\n    /// Internal ID used to identify a round\n    pub round_id: u64,\n    /// The round parameters.\n    pub round_params: RoundParameters,\n    /// The sum phase parameters.\n    pub sum: PhaseParameters,\n    /// The update phase parameters.\n    pub update: PhaseParameters,\n    /// The sum2 phase parameters.\n    pub sum2: PhaseParameters,\n}\n\nimpl CoordinatorState {\n    pub fn new(\n        pet_settings: PetSettings,\n        mask_settings: MaskSettings,\n        model_settings: ModelSettings,\n    ) -> Self {\n        let keys = EncryptKeyPair::generate();\n        let round_params = RoundParameters {\n            pk: keys.public,\n            sum: pet_settings.sum.prob,\n            update: pet_settings.update.prob,\n            seed: RoundSeed::zeroed(),\n            mask_config: MaskConfig::from(mask_settings).into(),\n            model_length: model_settings.length,\n        };\n        let round_id = 0;\n        Self {\n            keys,\n            round_params,\n            round_id,\n            sum: pet_settings.sum.into(),\n            update: pet_settings.update.into(),\n            sum2: pet_settings.sum2.into(),\n        }\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/events.rs",
    "content": "//! This module provides the `StateMachine`, `Events`, `EventSubscriber` and `EventPublisher` types.\n\nuse std::sync::Arc;\n\nuse tokio::sync::watch;\n\nuse crate::state_machine::phases::PhaseName;\nuse xaynet_core::{\n    common::RoundParameters,\n    crypto::EncryptKeyPair,\n    mask::Model,\n    SeedDict,\n    SumDict,\n};\n\n/// An event emitted by the coordinator.\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct Event<E> {\n    /// Metadata that associates this event to the round in which it is\n    /// emitted.\n    pub round_id: u64,\n    /// The event itself\n    pub event: E,\n}\n\n// FIXME: should we simply use `Option`s here?\n/// Global model update event.\n#[derive(Debug, Clone, PartialEq)]\npub enum ModelUpdate {\n    Invalidate,\n    New(Arc<Model>),\n}\n\n/// Dictionary update event.\n#[derive(Debug, Clone, Eq, PartialEq)]\npub enum DictionaryUpdate<D> {\n    Invalidate,\n    New(Arc<D>),\n}\n\n/// A convenience type to emit any coordinator event.\n#[derive(Debug)]\npub struct EventPublisher {\n    /// Round ID that is attached to all the requests.\n    round_id: u64,\n    keys_tx: EventBroadcaster<EncryptKeyPair>,\n    params_tx: EventBroadcaster<RoundParameters>,\n    phase_tx: EventBroadcaster<PhaseName>,\n    model_tx: EventBroadcaster<ModelUpdate>,\n    sum_dict_tx: EventBroadcaster<DictionaryUpdate<SumDict>>,\n    seed_dict_tx: EventBroadcaster<DictionaryUpdate<SeedDict>>,\n}\n\n/// The `EventSubscriber` hands out `EventListener`s for any\n/// coordinator event.\n#[derive(Debug)]\npub struct EventSubscriber {\n    keys_rx: EventListener<EncryptKeyPair>,\n    params_rx: EventListener<RoundParameters>,\n    phase_rx: EventListener<PhaseName>,\n    model_rx: EventListener<ModelUpdate>,\n    sum_dict_rx: EventListener<DictionaryUpdate<SumDict>>,\n    seed_dict_rx: EventListener<DictionaryUpdate<SeedDict>>,\n}\n\nimpl EventPublisher {\n    /// Initialize a new event publisher with the given initial events.\n    pub fn init(\n        round_id: u64,\n        keys: EncryptKeyPair,\n        params: RoundParameters,\n        phase: PhaseName,\n        model: ModelUpdate,\n    ) -> (Self, EventSubscriber) {\n        let (keys_tx, keys_rx) = watch::channel::<Event<EncryptKeyPair>>(Event {\n            round_id,\n            event: keys,\n        });\n\n        let (params_tx, params_rx) = watch::channel::<Event<RoundParameters>>(Event {\n            round_id,\n            event: params,\n        });\n\n        let (phase_tx, phase_rx) = watch::channel::<Event<PhaseName>>(Event {\n            round_id,\n            event: phase,\n        });\n\n        let (model_tx, model_rx) = watch::channel::<Event<ModelUpdate>>(Event {\n            round_id,\n            event: model,\n        });\n\n        let (sum_dict_tx, sum_dict_rx) =\n            watch::channel::<Event<DictionaryUpdate<SumDict>>>(Event {\n                round_id,\n                event: DictionaryUpdate::Invalidate,\n            });\n\n        let (seed_dict_tx, seed_dict_rx) =\n            watch::channel::<Event<DictionaryUpdate<SeedDict>>>(Event {\n                round_id,\n                event: DictionaryUpdate::Invalidate,\n            });\n\n        let publisher = EventPublisher {\n            round_id,\n            keys_tx: keys_tx.into(),\n            params_tx: params_tx.into(),\n            phase_tx: phase_tx.into(),\n            model_tx: model_tx.into(),\n            sum_dict_tx: sum_dict_tx.into(),\n            seed_dict_tx: seed_dict_tx.into(),\n        };\n\n        let subscriber = EventSubscriber {\n            keys_rx: keys_rx.into(),\n            params_rx: params_rx.into(),\n            phase_rx: phase_rx.into(),\n            model_rx: model_rx.into(),\n            sum_dict_rx: sum_dict_rx.into(),\n            seed_dict_rx: seed_dict_rx.into(),\n        };\n\n        (publisher, subscriber)\n    }\n\n    /// Set the round ID that is attached to the events the publisher broadcasts.\n    pub fn set_round_id(&mut self, id: u64) {\n        self.round_id = id;\n    }\n\n    fn event<T>(&self, event: T) -> Event<T> {\n        Event {\n            round_id: self.round_id,\n            event,\n        }\n    }\n\n    /// Emit a keys event\n    pub fn broadcast_keys(&mut self, keys: EncryptKeyPair) {\n        let _ = self.keys_tx.broadcast(self.event(keys));\n    }\n\n    /// Emit a round parameters event\n    pub fn broadcast_params(&mut self, params: RoundParameters) {\n        let _ = self.params_tx.broadcast(self.event(params));\n    }\n\n    /// Emit a phase event\n    pub fn broadcast_phase(&mut self, phase: PhaseName) {\n        let _ = self.phase_tx.broadcast(self.event(phase));\n    }\n\n    /// Emit a model event\n    pub fn broadcast_model(&mut self, update: ModelUpdate) {\n        let _ = self.model_tx.broadcast(self.event(update));\n    }\n\n    /// Emit a sum dictionary update\n    pub fn broadcast_sum_dict(&mut self, update: DictionaryUpdate<SumDict>) {\n        let _ = self.sum_dict_tx.broadcast(self.event(update));\n    }\n\n    /// Emit a seed dictionary update\n    pub fn broadcast_seed_dict(&mut self, update: DictionaryUpdate<SeedDict>) {\n        let _ = self.seed_dict_tx.broadcast(self.event(update));\n    }\n}\n\nimpl EventSubscriber {\n    /// Get a listener for keys events. Callers must be careful not to\n    /// leak the secret key they receive, since that would compromise\n    /// the security of the coordinator.\n    pub fn keys_listener(&self) -> EventListener<EncryptKeyPair> {\n        self.keys_rx.clone()\n    }\n    /// Get a listener for round parameters events\n    pub fn params_listener(&self) -> EventListener<RoundParameters> {\n        self.params_rx.clone()\n    }\n\n    /// Get a listener for new phase events\n    pub fn phase_listener(&self) -> EventListener<PhaseName> {\n        self.phase_rx.clone()\n    }\n\n    /// Get a listener for new model events\n    pub fn model_listener(&self) -> EventListener<ModelUpdate> {\n        self.model_rx.clone()\n    }\n\n    /// Get a listener for sum dictionary updates\n    pub fn sum_dict_listener(&self) -> EventListener<DictionaryUpdate<SumDict>> {\n        self.sum_dict_rx.clone()\n    }\n\n    /// Get a listener for seed dictionary updates\n    pub fn seed_dict_listener(&self) -> EventListener<DictionaryUpdate<SeedDict>> {\n        self.seed_dict_rx.clone()\n    }\n}\n\n/// A listener for coordinator events. It can be used to either\n/// retrieve the latest `Event<E>` emitted by the coordinator (with\n/// `EventListener::get_latest`).\n#[derive(Debug, Clone)]\npub struct EventListener<E>(watch::Receiver<Event<E>>);\n\nimpl<E> From<watch::Receiver<Event<E>>> for EventListener<E> {\n    fn from(receiver: watch::Receiver<Event<E>>) -> Self {\n        EventListener(receiver)\n    }\n}\n\nimpl<E> EventListener<E>\nwhere\n    E: Clone,\n{\n    pub fn get_latest(&self) -> Event<E> {\n        self.0.borrow().clone()\n    }\n\n    #[cfg(test)]\n    pub async fn changed(&mut self) -> Result<(), watch::error::RecvError> {\n        self.0.changed().await\n    }\n}\n\n/// A channel to send `Event<E>` to all the `EventListener<E>`.\n#[derive(Debug)]\npub struct EventBroadcaster<E>(watch::Sender<Event<E>>);\n\nimpl<E> EventBroadcaster<E> {\n    /// Send `event` to all the `EventListener<E>`\n    fn broadcast(&self, event: Event<E>) {\n        // We don't care whether there's a listener or not\n        let _ = self.0.send(event);\n    }\n}\n\nimpl<E> From<watch::Sender<Event<E>>> for EventBroadcaster<E> {\n    fn from(sender: watch::Sender<Event<E>>) -> Self {\n        Self(sender)\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/initializer.rs",
    "content": "//! A state machine initializer.\n\nuse displaydoc::Display;\nuse thiserror::Error;\n#[cfg(feature = \"model-persistence\")]\nuse tracing::{debug, info};\n\n#[cfg(feature = \"model-persistence\")]\nuse crate::settings::RestoreSettings;\nuse crate::{\n    settings::{MaskSettings, ModelSettings, PetSettings},\n    state_machine::{\n        coordinator::CoordinatorState,\n        events::{EventPublisher, EventSubscriber, ModelUpdate},\n        phases::{Idle, PhaseName, PhaseState, Shared},\n        requests::{RequestReceiver, RequestSender},\n        StateMachine,\n    },\n    storage::{Storage, StorageError},\n};\n#[cfg(feature = \"model-persistence\")]\nuse xaynet_core::mask::Model;\n\ntype StateMachineInitializationResult<T> = Result<T, StateMachineInitializationError>;\n\n/// Errors which can occur during the initialization of the [`StateMachine`].\n#[derive(Debug, Display, Error)]\npub enum StateMachineInitializationError {\n    /// Initializing crypto library failed.\n    CryptoInit,\n    /// Fetching coordinator state failed: {0}.\n    FetchCoordinatorState(StorageError),\n    /// Deleting coordinator data failed: {0}.\n    DeleteCoordinatorData(StorageError),\n    /// Fetching latest global model id failed: {0}.\n    FetchLatestGlobalModelId(StorageError),\n    /// Fetching global model failed: {0}.\n    FetchGlobalModel(StorageError),\n    /// Global model is unavailable: {0}.\n    GlobalModelUnavailable(String),\n    /// Global model is invalid: {0}.\n    GlobalModelInvalid(String),\n}\n\n/// The state machine initializer that initializes a new state machine.\npub struct StateMachineInitializer<T> {\n    pet_settings: PetSettings,\n    mask_settings: MaskSettings,\n    model_settings: ModelSettings,\n    #[cfg(feature = \"model-persistence\")]\n    restore_settings: RestoreSettings,\n    store: T,\n}\n\nimpl<T> StateMachineInitializer<T> {\n    /// Creates a new [`StateMachineInitializer`].\n    pub fn new(\n        pet_settings: PetSettings,\n        mask_settings: MaskSettings,\n        model_settings: ModelSettings,\n        #[cfg(feature = \"model-persistence\")] restore_settings: RestoreSettings,\n        store: T,\n    ) -> Self {\n        Self {\n            pet_settings,\n            mask_settings,\n            model_settings,\n            #[cfg(feature = \"model-persistence\")]\n            restore_settings,\n            store,\n        }\n    }\n\n    // Initializes a new [`StateMachine`] with its components.\n    fn init_state_machine(\n        self,\n        coordinator_state: CoordinatorState,\n        global_model: ModelUpdate,\n    ) -> (StateMachine<T>, RequestSender, EventSubscriber) {\n        let (event_publisher, event_subscriber) = EventPublisher::init(\n            coordinator_state.round_id,\n            coordinator_state.keys.clone(),\n            coordinator_state.round_params.clone(),\n            PhaseName::Idle,\n            global_model,\n        );\n\n        let (request_rx, request_tx) = RequestReceiver::new();\n\n        let shared = Shared::new(coordinator_state, event_publisher, request_rx, self.store);\n\n        let state_machine = StateMachine::from(PhaseState::<Idle, _>::new(shared));\n        (state_machine, request_tx, event_subscriber)\n    }\n}\n\nimpl<T> StateMachineInitializer<T>\nwhere\n    T: Storage,\n{\n    #[cfg(not(feature = \"model-persistence\"))]\n    /// Initializes a new [`StateMachine`] with the given settings.\n    pub async fn init(\n        mut self,\n    ) -> StateMachineInitializationResult<(StateMachine<T>, RequestSender, EventSubscriber)> {\n        // crucial: init must be called before anything else in this module\n        sodiumoxide::init().or(Err(StateMachineInitializationError::CryptoInit))?;\n\n        let (coordinator_state, global_model) = { self.from_settings().await? };\n        Ok(self.init_state_machine(coordinator_state, global_model))\n    }\n\n    // Creates a new [`CoordinatorState`] from the given settings and deletes\n    // all coordinator data. Should only be called for the first start\n    // or if we need to perform reset.\n    pub(in crate::state_machine) async fn from_settings(\n        &mut self,\n    ) -> StateMachineInitializationResult<(CoordinatorState, ModelUpdate)> {\n        self.store\n            .delete_coordinator_data()\n            .await\n            .map_err(StateMachineInitializationError::DeleteCoordinatorData)?;\n        Ok((\n            CoordinatorState::new(\n                self.pet_settings,\n                self.mask_settings,\n                self.model_settings.clone(),\n            ),\n            ModelUpdate::Invalidate,\n        ))\n    }\n}\n\n#[cfg(feature = \"model-persistence\")]\n#[cfg_attr(docsrs, doc(cfg(feature = \"model-persistence\")))]\nimpl<T> StateMachineInitializer<T>\nwhere\n    T: Storage,\n{\n    /// Initializes a new [`StateMachine`] by trying to restore the previous coordinator state\n    /// along with the latest global model. After a successful initialization, the state machine\n    /// always starts from a new round. This means that the round id is increased by one.\n    /// If the state machine is reset during the initialization, the state machine starts\n    /// with the round id `1`.\n    ///\n    /// # Behavior\n    /// ![](https://mermaid.ink/svg/eyJjb2RlIjoic2VxdWVuY2VEaWFncmFtXG4gICAgYWx0IHJlc3RvcmUuZW5hYmxlID0gZmFsc2VcbiAgICAgICAgQ29vcmRpbmF0b3ItPj4rUmVkaXM6IGZsdXNoIGRiXG4gICAgICAgIE5vdGUgb3ZlciBDb29yZGluYXRvcixSZWRpczogc3RhcnQgZnJvbSBzZXR0aW5nc1xuICAgIGVsc2VcbiAgICAgICAgQ29vcmRpbmF0b3ItPj4rUmVkaXM6IGdldCBzdGF0ZVxuICAgICAgICBSZWRpcy0tPj4tQ29vcmRpbmF0b3I6IHN0YXRlXG4gICAgICAgIGFsdCBzdGF0ZSBub24tZXhpc3RlbnRcbiAgICAgICAgICAgIENvb3JkaW5hdG9yLT4-K1JlZGlzOiBmbHVzaCBkYlxuICAgICAgICAgICAgTm90ZSBvdmVyIENvb3JkaW5hdG9yLFJlZGlzOiBzdGFydCBmcm9tIHNldHRpbmdzXG4gICAgICAgIGVsc2Ugc3RhdGUgZXhpc3RcbiAgICAgICAgICAgIENvb3JkaW5hdG9yLT4-K1JlZGlzOiBnZXQgbGF0ZXN0IGdsb2JhbCBtb2RlbCBpZFxuICAgICAgICAgICAgUmVkaXMtLT4-LUNvb3JkaW5hdG9yOiBnbG9iYWwgbW9kZWwgaWRcbiAgICAgICAgICAgIGFsdCBnbG9iYWwgbW9kZWwgaWQgbm9uLWV4aXN0ZW50XG4gICAgICAgICAgICAgICAgTm90ZSBvdmVyIENvb3JkaW5hdG9yLFMzOiByZXN0b3JlIGNvb3JkaW5hdG9yIHdpdGggbGF0ZXN0IHN0YXRlIGJ1dCB3aXRob3V0IGEgZ2xvYmFsIG1vZGVsXG4gICAgICAgICAgICBlbHNlIGdsb2JhbCBtb2RlbCBpZCBleGlzdFxuICAgICAgICAgICAgICBDb29yZGluYXRvci0-PitTMzogZ2V0IGdsb2JhbCBtb2RlbFxuICAgICAgICAgICAgICBTMy0tPj4tQ29vcmRpbmF0b3I6IGdsb2JhbCBtb2RlbFxuICAgICAgICAgICAgICBhbHQgZ2xvYmFsIG1vZGVsIG5vbi1leGlzdGVudFxuICAgICAgICAgICAgICAgIE5vdGUgb3ZlciBDb29yZGluYXRvcixTMzogZXhpdCB3aXRoIGVycm9yXG4gICAgICAgICAgICAgIGVsc2UgZ2xvYmFsIG1vZGVsIGV4aXN0XG4gICAgICAgICAgICAgICAgTm90ZSBvdmVyIENvb3JkaW5hdG9yLFMzOiByZXN0b3JlIGNvb3JkaW5hdG9yIHdpdGggbGF0ZXN0IHN0YXRlIGFuZCBsYXRlc3QgZ2xvYmFsIG1vZGVsXG4gICAgICAgICAgICAgIGVuZFxuICAgICAgICAgICAgZW5kXG4gICAgICAgICAgZW5kXG4gICAgICAgIGVuZCIsIm1lcm1haWQiOnsidGhlbWUiOiJkZWZhdWx0IiwidGhlbWVWYXJpYWJsZXMiOnsiYmFja2dyb3VuZCI6IndoaXRlIiwicHJpbWFyeUNvbG9yIjoiI0VDRUNGRiIsInNlY29uZGFyeUNvbG9yIjoiI2ZmZmZkZSIsInRlcnRpYXJ5Q29sb3IiOiJoc2woODAsIDEwMCUsIDk2LjI3NDUwOTgwMzklKSIsInByaW1hcnlCb3JkZXJDb2xvciI6ImhzbCgyNDAsIDYwJSwgODYuMjc0NTA5ODAzOSUpIiwic2Vjb25kYXJ5Qm9yZGVyQ29sb3IiOiJoc2woNjAsIDYwJSwgODMuNTI5NDExNzY0NyUpIiwidGVydGlhcnlCb3JkZXJDb2xvciI6ImhzbCg4MCwgNjAlLCA4Ni4yNzQ1MDk4MDM5JSkiLCJwcmltYXJ5VGV4dENvbG9yIjoiIzEzMTMwMCIsInNlY29uZGFyeVRleHRDb2xvciI6IiMwMDAwMjEiLCJ0ZXJ0aWFyeVRleHRDb2xvciI6InJnYig5LjUwMDAwMDAwMDEsIDkuNTAwMDAwMDAwMSwgOS41MDAwMDAwMDAxKSIsImxpbmVDb2xvciI6IiMzMzMzMzMiLCJ0ZXh0Q29sb3IiOiIjMzMzIiwibWFpbkJrZyI6IiNFQ0VDRkYiLCJzZWNvbmRCa2ciOiIjZmZmZmRlIiwiYm9yZGVyMSI6IiM5MzcwREIiLCJib3JkZXIyIjoiI2FhYWEzMyIsImFycm93aGVhZENvbG9yIjoiIzMzMzMzMyIsImZvbnRGYW1pbHkiOiJcInRyZWJ1Y2hldCBtc1wiLCB2ZXJkYW5hLCBhcmlhbCIsImZvbnRTaXplIjoiMTZweCIsImxhYmVsQmFja2dyb3VuZCI6IiNlOGU4ZTgiLCJub2RlQmtnIjoiI0VDRUNGRiIsIm5vZGVCb3JkZXIiOiIjOTM3MERCIiwiY2x1c3RlckJrZyI6IiNmZmZmZGUiLCJjbHVzdGVyQm9yZGVyIjoiI2FhYWEzMyIsImRlZmF1bHRMaW5rQ29sb3IiOiIjMzMzMzMzIiwidGl0bGVDb2xvciI6IiMzMzMiLCJlZGdlTGFiZWxCYWNrZ3JvdW5kIjoiI2U4ZThlOCIsImFjdG9yQm9yZGVyIjoiaHNsKDI1OS42MjYxNjgyMjQzLCA1OS43NzY1MzYzMTI4JSwgODcuOTAxOTYwNzg0MyUpIiwiYWN0b3JCa2ciOiIjRUNFQ0ZGIiwiYWN0b3JUZXh0Q29sb3IiOiJibGFjayIsImFjdG9yTGluZUNvbG9yIjoiZ3JleSIsInNpZ25hbENvbG9yIjoiIzMzMyIsInNpZ25hbFRleHRDb2xvciI6IiMzMzMiLCJsYWJlbEJveEJrZ0NvbG9yIjoiI0VDRUNGRiIsImxhYmVsQm94Qm9yZGVyQ29sb3IiOiJoc2woMjU5LjYyNjE2ODIyNDMsIDU5Ljc3NjUzNjMxMjglLCA4Ny45MDE5NjA3ODQzJSkiLCJsYWJlbFRleHRDb2xvciI6ImJsYWNrIiwibG9vcFRleHRDb2xvciI6ImJsYWNrIiwibm90ZUJvcmRlckNvbG9yIjoiI2FhYWEzMyIsIm5vdGVCa2dDb2xvciI6IiNmZmY1YWQiLCJub3RlVGV4dENvbG9yIjoiYmxhY2siLCJhY3RpdmF0aW9uQm9yZGVyQ29sb3IiOiIjNjY2IiwiYWN0aXZhdGlvbkJrZ0NvbG9yIjoiI2Y0ZjRmNCIsInNlcXVlbmNlTnVtYmVyQ29sb3IiOiJ3aGl0ZSIsInNlY3Rpb25Ca2dDb2xvciI6InJnYmEoMTAyLCAxMDIsIDI1NSwgMC40OSkiLCJhbHRTZWN0aW9uQmtnQ29sb3IiOiJ3aGl0ZSIsInNlY3Rpb25Ca2dDb2xvcjIiOiIjZmZmNDAwIiwidGFza0JvcmRlckNvbG9yIjoiIzUzNGZiYyIsInRhc2tCa2dDb2xvciI6IiM4YTkwZGQiLCJ0YXNrVGV4dExpZ2h0Q29sb3IiOiJ3aGl0ZSIsInRhc2tUZXh0Q29sb3IiOiJ3aGl0ZSIsInRhc2tUZXh0RGFya0NvbG9yIjoiYmxhY2siLCJ0YXNrVGV4dE91dHNpZGVDb2xvciI6ImJsYWNrIiwidGFza1RleHRDbGlja2FibGVDb2xvciI6IiMwMDMxNjMiLCJhY3RpdmVUYXNrQm9yZGVyQ29sb3IiOiIjNTM0ZmJjIiwiYWN0aXZlVGFza0JrZ0NvbG9yIjoiI2JmYzdmZiIsImdyaWRDb2xvciI6ImxpZ2h0Z3JleSIsImRvbmVUYXNrQmtnQ29sb3IiOiJsaWdodGdyZXkiLCJkb25lVGFza0JvcmRlckNvbG9yIjoiZ3JleSIsImNyaXRCb3JkZXJDb2xvciI6IiNmZjg4ODgiLCJjcml0QmtnQ29sb3IiOiJyZWQiLCJ0b2RheUxpbmVDb2xvciI6InJlZCIsImxhYmVsQ29sb3IiOiJibGFjayIsImVycm9yQmtnQ29sb3IiOiIjNTUyMjIyIiwiZXJyb3JUZXh0Q29sb3IiOiIjNTUyMjIyIiwiY2xhc3NUZXh0IjoiIzEzMTMwMCIsImZpbGxUeXBlMCI6IiNFQ0VDRkYiLCJmaWxsVHlwZTEiOiIjZmZmZmRlIiwiZmlsbFR5cGUyIjoiaHNsKDMwNCwgMTAwJSwgOTYuMjc0NTA5ODAzOSUpIiwiZmlsbFR5cGUzIjoiaHNsKDEyNCwgMTAwJSwgOTMuNTI5NDExNzY0NyUpIiwiZmlsbFR5cGU0IjoiaHNsKDE3NiwgMTAwJSwgOTYuMjc0NTA5ODAzOSUpIiwiZmlsbFR5cGU1IjoiaHNsKC00LCAxMDAlLCA5My41Mjk0MTE3NjQ3JSkiLCJmaWxsVHlwZTYiOiJoc2woOCwgMTAwJSwgOTYuMjc0NTA5ODAzOSUpIiwiZmlsbFR5cGU3IjoiaHNsKDE4OCwgMTAwJSwgOTMuNTI5NDExNzY0NyUpIn19LCJ1cGRhdGVFZGl0b3IiOmZhbHNlfQ)\n    ///\n    /// - If the [`RestoreSettings.enable`] flag is set to `false`, the current coordinator\n    ///   state will be reset and a new [`StateMachine`] is created with the given settings.\n    /// - If no coordinator state exists, the current coordinator state will be reset and a new\n    ///   [`StateMachine`] is created with the given settings.\n    /// - If a coordinator state exists but no global model has been created so far, the\n    ///   [`StateMachine`] will be restored with the coordinator state but without a global model.\n    /// - If a coordinator state and a global model exists, the [`StateMachine`] will be restored\n    ///   with the coordinator state and the global model.\n    /// - If a global model has been created but does not exists, the initialization will fail with\n    ///   [`StateMachineInitializationError::GlobalModelUnavailable`].\n    /// - If a global model exists but its properties do not match the coordinator model settings,\n    ///   the initialization will fail with [`StateMachineInitializationError::GlobalModelInvalid`].\n    /// - Any network error will cause the initialization to fail.\n    pub async fn init(\n        mut self,\n    ) -> StateMachineInitializationResult<(StateMachine<T>, RequestSender, EventSubscriber)> {\n        // crucial: init must be called before anything else in this module\n        sodiumoxide::init().or(Err(StateMachineInitializationError::CryptoInit))?;\n\n        let (coordinator_state, global_model) = if self.restore_settings.enable {\n            self.from_previous_state().await?\n        } else {\n            info!(\"restoring coordinator state is disabled\");\n            info!(\"initialize state machine from settings\");\n            self.from_settings().await?\n        };\n\n        Ok(self.init_state_machine(coordinator_state, global_model))\n    }\n\n    // see [`StateMachineInitializer::init`]\n    async fn from_previous_state(\n        &mut self,\n    ) -> StateMachineInitializationResult<(CoordinatorState, ModelUpdate)> {\n        let (coordinator_state, global_model) = if let Some(coordinator_state) = self\n            .store\n            .coordinator_state()\n            .await\n            .map_err(StateMachineInitializationError::FetchCoordinatorState)?\n        {\n            self.try_restore_state(coordinator_state).await?\n        } else {\n            // no coordinator state available seems to be a fresh start\n            self.from_settings().await?\n        };\n\n        Ok((coordinator_state, global_model))\n    }\n\n    // see [`StateMachineInitializer::init`]\n    async fn try_restore_state(\n        &mut self,\n        coordinator_state: CoordinatorState,\n    ) -> StateMachineInitializationResult<(CoordinatorState, ModelUpdate)> {\n        let global_model_id = match self\n            .store\n            .latest_global_model_id()\n            .await\n            .map_err(StateMachineInitializationError::FetchLatestGlobalModelId)?\n        {\n            // the state machine was shut down before completing a round\n            // we cannot use the round_id here because we increment the round_id after each restart\n            // that means even if the round id is larger than one, it doesn't mean that a\n            // round has ever been completed\n            None => {\n                debug!(\"apparently no round has been completed yet\");\n                debug!(\"restore coordinator without a global model\");\n                return Ok((coordinator_state, ModelUpdate::Invalidate));\n            }\n            Some(global_model_id) => global_model_id,\n        };\n\n        let global_model = self\n            .load_global_model(&coordinator_state, &global_model_id)\n            .await?;\n\n        debug!(\n            \"restore coordinator with global model id: {}\",\n            global_model_id\n        );\n        Ok((\n            coordinator_state,\n            ModelUpdate::New(std::sync::Arc::new(global_model)),\n        ))\n    }\n\n    // Loads a global model and checks its properties for suitability.\n    async fn load_global_model(\n        &mut self,\n        coordinator_state: &CoordinatorState,\n        global_model_id: &str,\n    ) -> StateMachineInitializationResult<Model> {\n        match self\n            .store\n            .global_model(global_model_id)\n            .await\n            .map_err(StateMachineInitializationError::FetchGlobalModel)?\n        {\n            Some(global_model) => {\n                if Self::model_properties_matches_settings(coordinator_state, &global_model) {\n                    Ok(global_model)\n                } else {\n                    let error_msg = format!(\n                        \"the length of global model with the id {} does not match with the value of the model length setting {} != {}\",\n                        &global_model_id,\n                        global_model.len(),\n                        coordinator_state.round_params.model_length);\n\n                    Err(StateMachineInitializationError::GlobalModelInvalid(\n                        error_msg,\n                    ))\n                }\n            }\n            None => {\n                // the model id exists but we cannot find it in the model store\n                // here we better fail because if we restart a coordinator with an empty model\n                // the clients will throw away their current global model and start from scratch\n                Err(StateMachineInitializationError::GlobalModelUnavailable(\n                    format!(\"cannot find global model {}\", &global_model_id),\n                ))\n            }\n        }\n    }\n\n    // Checks whether the properties of the loaded global model match the current\n    // model settings of the coordinator.\n    fn model_properties_matches_settings(\n        coordinator_state: &CoordinatorState,\n        global_model: &Model,\n    ) -> bool {\n        coordinator_state.round_params.model_length == global_model.len()\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/mod.rs",
    "content": "//! The state machine that controls the execution of the PET protocol.\n//!\n//! # Overview\n//!\n//! ![State Machine](https://mermaid.ink/svg/eyJjb2RlIjoic3RhdGVEaWFncmFtXG5cdFsqXSAtLT4gSWRsZVxuXG4gICAgSWRsZSAtLT4gU3VtXG4gICAgU3VtIC0tPiBVcGRhdGVcbiAgICBVcGRhdGUgLS0-IFN1bTJcbiAgICBTdW0yIC0tPiBVbm1hc2tcbiAgICBVbm1hc2sgLS0-IElkbGVcblxuICAgIFN1bSAtLT4gRmFpbHVyZVxuICAgIFVwZGF0ZSAtLT4gRmFpbHVyZVxuICAgIFN1bTIgLS0-IEZhaWx1cmVcbiAgICBVbm1hc2sgLS0-IEZhaWx1cmVcbiAgICBGYWlsdXJlIC0tPiBJZGxlXG4gICAgRmFpbHVyZSAtLT4gU2h1dGRvd25cblxuICAgIFNodXRkb3duIC0tPiBbKl1cbiIsIm1lcm1haWQiOnsidGhlbWUiOiJuZXV0cmFsIn0sInVwZGF0ZUVkaXRvciI6ZmFsc2V9)\n//!\n//! The [`StateMachine`] is responsible for executing the individual tasks of the PET protocol.\n//! The main tasks include: building the sum and seed dictionaries, aggregating the masked\n//! models, determining the applicable mask and unmasking the global masked model.\n//!\n//! Furthermore, the [`StateMachine`] publishes protocol events and handles protocol errors.\n//!\n//! The [`StateMachine`] as well as the PET settings can be configured in the config file.\n//! See [here][settings] for more details.\n//!\n//! # Phase states\n//!\n//! **Idle**\n//!\n//! Publishes [`PhaseName::Idle`] and increments the `round_id` by `1`. Invalidates the [`SumDict`],\n//! [`SeedDict`], `scalar` and `mask length`. Updates the [`EncryptKeyPair`], `probabilities` for\n//! the tasks and the `seed`. Publishes the [`EncryptKeyPair`] and the [`RoundParameters`].\n//!\n//! **Sum**\n//!\n//! Publishes [`PhaseName::Sum`], builds and publishes the [`SumDict`], ensures that enough sum\n//! messages have been submitted and initializes the [`SeedDict`].\n//!\n//! **Update**\n//!\n//! Publishes [`PhaseName::Update`], builds and publishes the [`SeedDict`], ensures that enough\n//! update messages have been submitted and aggregates the masked model.\n//!\n//! **Sum2**\n//!\n//! Publishes [`PhaseName::Sum2`], builds the mask dictionary, ensures that enough sum2\n//! messages have been submitted and determines the applicable mask for unmasking the global\n//! masked model.\n//!\n//! **Unmask**\n//!\n//! Publishes [`PhaseName::Unmask`], unmasks the global masked model and publishes the global\n//! model.\n//!\n//! **Failure**\n//!\n//! Publishes [`PhaseName::Failure`] and handles [`PhaseError`]s that can occur during the\n//! execution of the [`StateMachine`]. In most cases, the error is handled by restarting the round.\n//! However, if a [`PhaseError::RequestChannel`] occurs, the [`StateMachine`] will shut down.\n//!\n//! **Shutdown**\n//!\n//! Publishes [`PhaseName::Shutdown`] and shuts down the [`StateMachine`]. During the shutdown,\n//! the [`StateMachine`] performs a clean shutdown of the [Request][requests] channel by\n//! closing it and consuming all remaining messages.\n//!\n//! # Requests\n//!\n//! By initiating a new [`StateMachine`] via [`StateMachineInitializer::init()`], a new\n//! [StateMachineRequest][requests] channel is created, the function of which is to send\n//! [`StateMachineRequest`]s to the [`StateMachine`]. The sender half of that channel\n//! ([`RequestSender`]) is returned back to the caller of\n//! [`StateMachineInitializer::init()`], whereas the receiver half ([`RequestReceiver`])\n//! is used by the [`StateMachine`].\n//!\n//! See [here][requests] for more details.\n//!\n//! # Events\n//!\n//! During the execution of the PET protocol, the [`StateMachine`] will publish various events\n//! (see Phase states). Everyone who is interested in the events can subscribe to the respective\n//! events via the [`EventSubscriber`]. An [`EventSubscriber`] is automatically created when a new\n//! [`StateMachine`] is created through [`StateMachineInitializer::init()`].\n//!\n//! See [here][events] for more details.\n//!\n//! [settings]: crate::settings\n//! [`PhaseName::Idle`]: crate::state_machine::phases::PhaseName::Idle\n//! [`PhaseName::Sum`]: crate::state_machine::phases::PhaseName::Sum\n//! [`PhaseName::Update`]: crate::state_machine::phases::PhaseName::Update\n//! [`PhaseName::Sum2`]: crate::state_machine::phases::PhaseName::Sum2\n//! [`PhaseName::Unmask`]: crate::state_machine::phases::PhaseName::Unmask\n//! [`PhaseName::Failure`]: crate::state_machine::phases::PhaseName::Failure\n//! [`PhaseName::Shutdown`]: crate::state_machine::phases::PhaseName::Shutdown\n//! [`PhaseError`]: crate::state_machine::phases::PhaseError\n//! [`PhaseError::RequestChannel`]: crate::state_machine::phases::PhaseError::RequestChannel\n//! [`SumDict`]: xaynet_core::SumDict\n//! [`SeedDict`]: xaynet_core::SeedDict\n//! [`EncryptKeyPair`]: xaynet_core::crypto::EncryptKeyPair\n//! [`RoundParameters`]: xaynet_core::common::RoundParameters\n//! [`StateMachineInitializer::init()`]: crate::state_machine::initializer::StateMachineInitializer::init\n//! [`StateMachineRequest`]: crate::state_machine::requests::StateMachineRequest\n//! [requests]: crate::state_machine::requests\n//! [`RequestSender`]: crate::state_machine::requests::RequestSender\n//! [`RequestReceiver`]: crate::state_machine::requests::RequestReceiver\n//! [events]: crate::state_machine::events\n//! [`EventSubscriber`]: crate::state_machine::events::EventSubscriber\n\npub mod coordinator;\npub mod events;\npub mod initializer;\npub mod phases;\npub mod requests;\n\nuse derive_more::From;\n\nuse crate::{\n    state_machine::phases::{\n        Failure,\n        Idle,\n        Phase,\n        PhaseState,\n        Shutdown,\n        Sum,\n        Sum2,\n        Unmask,\n        Update,\n    },\n    storage::Storage,\n};\n\n/// The state machine with all its states.\n#[derive(From)]\npub enum StateMachine<T> {\n    /// The [`Idle`] phase.\n    Idle(PhaseState<Idle, T>),\n    /// The [`Sum`] phase.\n    Sum(PhaseState<Sum, T>),\n    /// The [`Update`] phase.\n    Update(PhaseState<Update, T>),\n    /// The [`Sum2`] phase.\n    Sum2(PhaseState<Sum2, T>),\n    /// The [`Unmask`] phase.\n    Unmask(PhaseState<Unmask, T>),\n    /// The [`Failure`] phase.\n    Failure(PhaseState<Failure, T>),\n    /// The [`Shutdown`] phase.\n    Shutdown(PhaseState<Shutdown, T>),\n}\n\nimpl<T> StateMachine<T>\nwhere\n    T: Storage,\n    PhaseState<Idle, T>: Phase<T>,\n    PhaseState<Sum, T>: Phase<T>,\n    PhaseState<Update, T>: Phase<T>,\n    PhaseState<Sum2, T>: Phase<T>,\n    PhaseState<Unmask, T>: Phase<T>,\n    PhaseState<Failure, T>: Phase<T>,\n    PhaseState<Shutdown, T>: Phase<T>,\n{\n    /// Moves the [`StateMachine`] to the next state and consumes the current one.\n    ///\n    /// Returns the next state or `None` if the [`StateMachine`] reached the state [`Shutdown`].\n    pub async fn next(self) -> Option<Self> {\n        match self {\n            StateMachine::Idle(state) => state.run_phase().await,\n            StateMachine::Sum(state) => state.run_phase().await,\n            StateMachine::Update(state) => state.run_phase().await,\n            StateMachine::Sum2(state) => state.run_phase().await,\n            StateMachine::Unmask(state) => state.run_phase().await,\n            StateMachine::Failure(state) => state.run_phase().await,\n            StateMachine::Shutdown(state) => state.run_phase().await,\n        }\n    }\n\n    /// Runs the state machine until it shuts down.\n    ///\n    /// The [`StateMachine`] shuts down once all [`RequestSender`] have been dropped.\n    ///\n    /// [`RequestSender`]: crate::state_machine::requests::RequestSender\n    pub async fn run(mut self) -> Option<()> {\n        loop {\n            self = self.next().await?;\n        }\n    }\n}\n\n/// Records a message accepted metric.\n#[doc(hidden)]\n#[macro_export]\nmacro_rules! accepted {\n    ($round_id: expr, $phase: expr $(,)?) => {\n        crate::metric!(\n            crate::metrics::Measurement::MessageAccepted,\n            1,\n            (\"round_id\", $round_id),\n            (\"phase\", $phase as u8),\n        );\n    };\n}\n\n/// Records a message rejected metric.\n#[doc(hidden)]\n#[macro_export]\nmacro_rules! rejected {\n    ($round_id: expr, $phase: expr $(,)?) => {\n        crate::metric!(\n            crate::metrics::Measurement::MessageRejected,\n            1,\n            (\"round_id\", $round_id),\n            (\"phase\", $phase as u8),\n        );\n    };\n}\n\n/// Records a message discarded metric.\n#[doc(hidden)]\n#[macro_export]\nmacro_rules! discarded {\n    ($round_id: expr, $phase: expr $(,)?) => {\n        crate::metric!(\n            crate::metrics::Measurement::MessageDiscarded,\n            1,\n            (\"round_id\", $round_id),\n            (\"phase\", $phase as u8),\n        );\n    };\n}\n\n#[cfg(test)]\npub(crate) mod tests;\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/phases/failure.rs",
    "content": "use std::time::Duration;\n\nuse async_trait::async_trait;\nuse displaydoc::Display;\nuse thiserror::Error;\nuse tokio::time::sleep;\nuse tracing::{error, info};\n\nuse crate::{\n    event,\n    state_machine::{\n        events::DictionaryUpdate,\n        phases::{\n            Idle,\n            IdleError,\n            Phase,\n            PhaseName,\n            PhaseState,\n            Shared,\n            Shutdown,\n            SumError,\n            UnmaskError,\n            UpdateError,\n        },\n        StateMachine,\n    },\n    storage::Storage,\n};\n\n/// Errors which can occur during the execution of the [`StateMachine`].\n#[derive(Debug, Display, Error)]\npub enum PhaseError {\n    /// Request channel error: {0}.\n    RequestChannel(&'static str),\n    /// Phase timeout.\n    PhaseTimeout(#[from] tokio::time::error::Elapsed),\n    /// Idle phase failed: {0}.\n    Idle(#[from] IdleError),\n    /// Sum phase failed: {0}.\n    Sum(#[from] SumError),\n    /// Update phase failed: {0}.\n    Update(#[from] UpdateError),\n    /// Unmask phase failed: {0}.\n    Unmask(#[from] UnmaskError),\n}\n\n/// The failure state.\n#[derive(Debug)]\npub struct Failure {\n    pub(in crate::state_machine) error: PhaseError,\n}\n\n#[async_trait]\nimpl<T> Phase<T> for PhaseState<Failure, T>\nwhere\n    T: Storage,\n{\n    const NAME: PhaseName = PhaseName::Failure;\n\n    async fn process(&mut self) -> Result<(), PhaseError> {\n        error!(\"phase state error: {}\", self.private.error);\n        event!(\"Phase error\", self.private.error.to_string());\n\n        Ok(())\n    }\n\n    fn broadcast(&mut self) {\n        info!(\"broadcasting invalidation of sum dictionary\");\n        self.shared\n            .events\n            .broadcast_sum_dict(DictionaryUpdate::Invalidate);\n\n        info!(\"broadcasting invalidation of seed dictionary\");\n        self.shared\n            .events\n            .broadcast_seed_dict(DictionaryUpdate::Invalidate);\n    }\n\n    async fn next(mut self) -> Option<StateMachine<T>> {\n        if let PhaseError::RequestChannel(_) = self.private.error {\n            Some(PhaseState::<Shutdown, _>::new(self.shared).into())\n        } else {\n            self.wait_for_store_readiness().await;\n            Some(PhaseState::<Idle, _>::new(self.shared).into())\n        }\n    }\n}\n\nimpl<T> PhaseState<Failure, T> {\n    /// Creates a new error phase.\n    pub fn new(shared: Shared<T>, error: PhaseError) -> Self {\n        Self {\n            private: Failure { error },\n            shared,\n        }\n    }\n}\n\nimpl<T> PhaseState<Failure, T>\nwhere\n    T: Storage,\n{\n    /// Waits until the [`Store`] is ready.\n    ///\n    /// [`Store`]: crate::storage::Store\n    async fn wait_for_store_readiness(&mut self) {\n        while let Err(err) = <T as Storage>::is_ready(&mut self.shared.store).await {\n            error!(\"store not ready: {}\", err);\n            info!(\"try again in 5 sec\");\n            sleep(Duration::from_secs(5)).await;\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use std::sync::Arc;\n\n    use super::*;\n\n    use anyhow::anyhow;\n    use tokio::time::{timeout, Duration, Instant};\n    use xaynet_core::{SeedDict, SumDict};\n\n    use crate::{\n        state_machine::{\n            coordinator::CoordinatorState,\n            events::{EventPublisher, EventSubscriber, ModelUpdate},\n            tests::{\n                utils::{enable_logging, init_shared, EventSnapshot},\n                CoordinatorStateBuilder,\n                EventBusBuilder,\n            },\n        },\n        storage::{\n            tests::{utils::create_global_model, MockCoordinatorStore, MockModelStore},\n            Store,\n        },\n    };\n\n    fn state_and_events_from_sum2_phase() -> (CoordinatorState, EventPublisher, EventSubscriber) {\n        let state = CoordinatorStateBuilder::new().build();\n\n        let (event_publisher, event_subscriber) = EventBusBuilder::new(&state)\n            .broadcast_phase(PhaseName::Sum2)\n            .broadcast_sum_dict(DictionaryUpdate::New(Arc::new(SumDict::new())))\n            .broadcast_seed_dict(DictionaryUpdate::New(Arc::new(SeedDict::new())))\n            .broadcast_model(ModelUpdate::New(Arc::new(create_global_model(1))))\n            .build();\n\n        (state, event_publisher, event_subscriber)\n    }\n\n    #[tokio::test]\n    async fn error_to_idle_phase() {\n        // No Storage errors\n        //\n        // What should happen:\n        // 1. broadcast Error phase\n        // 2. broadcast invalidation of sum and seed dict\n        // 3. check if store is ready to process requests\n        // 4. move into idle phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        //   (except for`round_id` when moving into idle phase)\n        // - events have been broadcasted (except phase event and invalidation\n        //   event of sum and seed dict)\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_is_ready().return_once(move || Ok(()));\n\n        let mut ms = MockModelStore::new();\n        ms.expect_is_ready().return_once(move || Ok(()));\n\n        let store = Store::new(cs, ms);\n\n        let (state, event_publisher, event_subscriber) = state_and_events_from_sum2_phase();\n        let events_before_error = EventSnapshot::from(&event_subscriber);\n        let state_before_error = state.clone();\n\n        let (shared, _request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Failure, _>::new(\n            shared,\n            PhaseError::Idle(IdleError::DeleteDictionaries(anyhow!(\"\"))),\n        ));\n        assert!(state_machine.is_failure());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_error = state_machine.as_ref().clone();\n\n        // round id is updated in idle phase\n        assert_ne!(state_after_error.round_id, state_before_error.round_id);\n        assert_eq!(\n            state_after_error.round_params,\n            state_before_error.round_params\n        );\n        assert_eq!(state_after_error.keys, state_before_error.keys);\n        assert_eq!(state_after_error.sum, state_before_error.sum);\n        assert_eq!(state_after_error.update, state_before_error.update);\n        assert_eq!(state_after_error.sum2, state_before_error.sum2);\n\n        let events_after_error = EventSnapshot::from(&event_subscriber);\n        assert_ne!(events_after_error.phase, events_before_error.phase);\n        assert_eq!(events_after_error.keys, events_before_error.keys);\n        assert_eq!(events_after_error.params, events_before_error.params);\n        assert_eq!(\n            events_after_error.sum_dict.event,\n            DictionaryUpdate::Invalidate\n        );\n        assert_eq!(\n            events_after_error.seed_dict.event,\n            DictionaryUpdate::Invalidate\n        );\n        assert_eq!(events_after_error.model, events_before_error.model);\n        assert_eq!(events_after_error.phase.event, PhaseName::Failure);\n\n        assert!(state_machine.is_idle());\n    }\n\n    #[tokio::test]\n    async fn test_error_to_shutdown_phase() {\n        // No Storage errors\n        //\n        // What should happen:\n        // 1. broadcast Error phase\n        // 2. broadcast invalidation of sum and seed dict\n        // 3. previous phase failed with Failure::RequestChannel\n        //    which means that the state machine should be shut down\n        // 4. move into shutdown phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - events have been broadcasted (except phase event and invalidation\n        //   event of sum and seed dict)\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_is_ready().return_once(move || Ok(()));\n\n        let mut ms = MockModelStore::new();\n        ms.expect_is_ready().return_once(move || Ok(()));\n\n        let store = Store::new(cs, ms);\n\n        let (state, event_publisher, event_subscriber) = state_and_events_from_sum2_phase();\n        let events_before_error = EventSnapshot::from(&event_subscriber);\n        let state_before_error = state.clone();\n\n        let (shared, _request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Failure, _>::new(\n            shared,\n            PhaseError::RequestChannel(\"\"),\n        ));\n        assert!(state_machine.is_failure());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_error = state_machine.as_ref().clone();\n\n        assert_eq!(state_after_error, state_before_error);\n\n        let events_after_error = EventSnapshot::from(&event_subscriber);\n        assert_ne!(events_after_error.phase, events_before_error.phase);\n        assert_eq!(events_after_error.keys, events_before_error.keys);\n        assert_eq!(events_after_error.params, events_before_error.params);\n        assert_eq!(\n            events_after_error.sum_dict.event,\n            DictionaryUpdate::Invalidate\n        );\n        assert_eq!(\n            events_after_error.seed_dict.event,\n            DictionaryUpdate::Invalidate\n        );\n        assert_eq!(events_after_error.model, events_before_error.model);\n        assert_eq!(events_after_error.phase.event, PhaseName::Failure);\n\n        assert!(state_machine.is_shutdown());\n    }\n\n    #[tokio::test]\n    async fn test_error_to_idle_store_failed() {\n        // Storage error:\n        // - first call on `is_ready` the coordinator store and model store fails\n        // - second call on `is_ready` the coordinator store fails and model store passes\n        // - third call on `is_ready` the coordinator store passes and model store fails\n        // - forth call on `is_ready` the coordinator store and model store passes\n        //\n        // What should happen:\n        // 1. broadcast Error phase\n        // 2. broadcast invalidation of sum and seed dict\n        // 3. check if store is ready to process requests\n        // 4. wait until store is ready again (15 sec)\n        // 5. move into idle phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        //   (except for`round_id` when moving into idle phase)\n        // - events have been broadcasted (except phase event and invalidation\n        //   event of sum and seed dict)\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        let mut cs_counter = 0;\n        cs.expect_is_ready().returning(move || {\n            let res = match cs_counter {\n                0 => Err(anyhow!(\"\")),\n                1 => Err(anyhow!(\"\")),\n                2 => Ok(()),\n                3 => Ok(()),\n                _ => panic!(\"\"),\n            };\n            cs_counter += 1;\n            res\n        });\n\n        let mut ms = MockModelStore::new();\n        let mut ms_counter = 0;\n        ms.expect_is_ready().returning(move || {\n            let res = match ms_counter {\n                // we skip step 1 and 2 because Storage::is_ready does not call\n                // MockModelStore::is_ready if MockCoordinatorStore::is_ready\n                // has already failed\n                0 => Err(anyhow!(\"\")),\n                1 => Ok(()),\n                _ => panic!(\"\"),\n            };\n            ms_counter += 1;\n            res\n        });\n\n        let store = Store::new(cs, ms);\n\n        let state = CoordinatorStateBuilder::new().build();\n        let (event_publisher, _event_subscriber) = EventBusBuilder::new(&state).build();\n        let (shared, _request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Failure, _>::new(\n            shared,\n            PhaseError::Idle(IdleError::DeleteDictionaries(anyhow!(\"\"))),\n        ));\n\n        assert!(state_machine.is_failure());\n\n        let now = Instant::now();\n\n        let state_machine = timeout(Duration::from_secs(20), state_machine.next())\n            .await\n            .unwrap()\n            .unwrap();\n\n        assert!(now.elapsed().as_secs() > 14);\n\n        assert!(state_machine.is_idle());\n    }\n\n    #[tokio::test]\n    async fn test_error_to_shutdown_skip_store_readiness_check() {\n        // Storage error:\n        //\n        // What should happen:\n        // 1. broadcast Error phase\n        // 2. broadcast invalidation of sum and seed dict\n        // 3. previous phase failed with Failure::RequestChannel\n        //    which means that the state machine should be shut down\n        // 4. skip store readiness check\n        // 5. move into shutdown phase\n        //\n        // What should not happen:\n        // - wait for the store to be ready again\n        // - the shared state has been changed\n        // - events have been broadcasted (except phase event and invalidation\n        //   event of sum and seed dict)\n        enable_logging();\n\n        let store = Store::new(MockCoordinatorStore::new(), MockModelStore::new());\n\n        let state = CoordinatorStateBuilder::new().build();\n        let (event_publisher, _event_subscriber) = EventBusBuilder::new(&state).build();\n        let (shared, _request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Failure, _>::new(\n            shared,\n            PhaseError::RequestChannel(\"\"),\n        ));\n\n        assert!(state_machine.is_failure());\n\n        let state_machine = timeout(Duration::from_secs(5), state_machine.next())\n            .await\n            .unwrap()\n            .unwrap();\n\n        assert!(state_machine.is_shutdown());\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/phases/handler.rs",
    "content": "use async_trait::async_trait;\nuse tokio::time::{timeout, Duration};\nuse tracing::{debug, info, Span};\n\nuse crate::{\n    accepted,\n    discarded,\n    rejected,\n    state_machine::{\n        coordinator::{CountParameters, PhaseParameters},\n        phases::{Phase, PhaseError, PhaseState},\n        requests::{RequestError, ResponseSender, StateMachineRequest},\n    },\n    storage::Storage,\n};\n\n/// A trait that must be implemented by a state to handle a request.\n#[async_trait]\npub trait Handler {\n    /// Handles a request.\n    ///\n    /// # Errors\n    /// Fails on PET and storage errors.\n    async fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), RequestError>;\n}\n\n/// A counter to keep track of handled messages.\nstruct Counter {\n    /// The minimal number of successfully processed messages.\n    min: u64,\n    /// The maximal number of successfully processed messages.\n    max: u64,\n    /// The number of messages successfully processed.\n    accepted: u64,\n    /// The number of messages failed to processed.\n    rejected: u64,\n    /// The number of messages discarded without being processed.\n    discarded: u64,\n}\n\nimpl AsMut<Counter> for Counter {\n    fn as_mut(&mut self) -> &mut Self {\n        self\n    }\n}\n\nimpl Counter {\n    /// Creates a new message counter.\n    fn new(CountParameters { min, max }: CountParameters) -> Self {\n        Self {\n            min,\n            max,\n            accepted: 0,\n            rejected: 0,\n            discarded: 0,\n        }\n    }\n\n    /// Checks whether enough requests have been processed successfully wrt the PET settings.\n    fn has_enough_messages(&self) -> bool {\n        self.accepted >= self.min\n    }\n\n    /// Checks whether too many requests are processed wrt the PET settings.\n    fn has_overmuch_messages(&self) -> bool {\n        self.accepted >= self.max\n    }\n\n    /// Increments the counter for accepted requests.\n    fn increment_accepted(&mut self) {\n        self.accepted += 1;\n        debug!(\n            \"{} messages accepted (min {} and max {} required)\",\n            self.accepted, self.min, self.max,\n        );\n    }\n\n    /// Increments the counter for rejected requests.\n    fn increment_rejected(&mut self) {\n        self.rejected += 1;\n        debug!(\"{} messages rejected\", self.rejected);\n    }\n\n    /// Increments the counter for discarded requests.\n    fn increment_discarded(&mut self) {\n        self.discarded += 1;\n        debug!(\"{} messages discarded\", self.discarded);\n    }\n}\n\nimpl<S, T> PhaseState<S, T>\nwhere\n    T: Storage,\n    Self: Phase<T> + Handler,\n{\n    /// Processes requests wrt the phase parameters.\n    ///\n    /// - Processes at most `count.max` requests during the time interval `[now, now + time.min]`.\n    /// - Processes requests until there are enough (ie `count.min`) for the time interval\n    /// `[now + time.min, now + time.max]`.\n    /// - Aborts if either all connections were dropped or not enough requests were processed until\n    /// timeout.\n    pub(super) async fn process(\n        &mut self,\n        PhaseParameters { count, time }: PhaseParameters,\n    ) -> Result<(), PhaseError> {\n        let mut counter = Counter::new(count);\n\n        info!(\"processing requests\");\n        debug!(\n            \"processing for min {} and max {} seconds\",\n            time.min, time.max\n        );\n        self.process_during(Duration::from_secs(time.min), counter.as_mut())\n            .await?;\n\n        let time_left = time.max - time.min;\n        timeout(\n            Duration::from_secs(time_left),\n            self.process_until_enough(counter.as_mut()),\n        )\n        .await??;\n\n        info!(\n            \"in total {} messages accepted (min {} and max {} required)\",\n            counter.accepted, counter.min, counter.max,\n        );\n        info!(\"in total {} messages rejected\", counter.rejected);\n        info!(\n            \"in total {} messages discarded (purged not included)\",\n            counter.discarded,\n        );\n\n        Ok(())\n    }\n\n    /// Processes requests for as long as the given duration.\n    async fn process_during(\n        &mut self,\n        dur: tokio::time::Duration,\n        counter: &mut Counter,\n    ) -> Result<(), PhaseError> {\n        let deadline = tokio::time::sleep(dur);\n        tokio::pin!(deadline);\n\n        loop {\n            tokio::select! {\n                biased;\n\n                _ = &mut deadline => {\n                    debug!(\"duration elapsed\");\n                    break Ok(());\n                }\n                next = self.next_request() => {\n                    let (req, span, resp_tx) = next?;\n                    self.process_single(req, span, resp_tx, counter).await;\n                }\n            }\n        }\n    }\n\n    /// Processes requests until there are enough.\n    async fn process_until_enough(&mut self, counter: &mut Counter) -> Result<(), PhaseError> {\n        while !counter.has_enough_messages() {\n            let (req, span, resp_tx) = self.next_request().await?;\n            self.process_single(req, span, resp_tx, counter).await;\n        }\n        Ok(())\n    }\n\n    /// Processes a single request.\n    ///\n    /// The request is discarded if the maximum message count is reached, accepted if processed\n    /// successfully and rejected otherwise.\n    async fn process_single(\n        &mut self,\n        req: StateMachineRequest,\n        span: Span,\n        resp_tx: ResponseSender,\n        counter: &mut Counter,\n    ) {\n        let _span_guard = span.enter();\n\n        let response = if counter.has_overmuch_messages() {\n            counter.increment_discarded();\n            discarded!(self.shared.state.round_id, Self::NAME);\n            Err(RequestError::MessageDiscarded)\n        } else {\n            let response = self.handle_request(req).await;\n            if response.is_ok() {\n                counter.increment_accepted();\n                accepted!(self.shared.state.round_id, Self::NAME);\n            } else {\n                counter.increment_rejected();\n                rejected!(self.shared.state.round_id, Self::NAME);\n            }\n            response\n        };\n\n        // This may error out if the receiver has already been dropped but it doesn't matter for us.\n        let _ = resp_tx.send(response);\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn test_counter() {\n        // 0 accepted\n        let mut counter = Counter::new(CountParameters { min: 1, max: 3 });\n        assert!(!counter.has_enough_messages());\n        assert!(!counter.has_overmuch_messages());\n\n        // 1 accepted\n        counter.increment_accepted();\n        assert!(counter.has_enough_messages());\n        assert!(!counter.has_overmuch_messages());\n\n        // 2 accepted\n        counter.increment_accepted();\n        assert!(counter.has_enough_messages());\n        assert!(!counter.has_overmuch_messages());\n\n        // 3 accepted\n        counter.increment_accepted();\n        assert!(counter.has_enough_messages());\n        assert!(counter.has_overmuch_messages());\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/phases/idle.rs",
    "content": "use async_trait::async_trait;\nuse displaydoc::Display;\nuse sodiumoxide::crypto::hash::sha256;\nuse thiserror::Error;\nuse tracing::{debug, info, warn};\n\nuse crate::{\n    metric,\n    metrics::Measurement,\n    state_machine::{\n        phases::{Phase, PhaseError, PhaseName, PhaseState, Shared, Sum},\n        StateMachine,\n    },\n    storage::{Storage, StorageError},\n};\nuse xaynet_core::{\n    common::RoundSeed,\n    crypto::{ByteObject, EncryptKeyPair, SigningKeySeed},\n};\n\n/// Errors which can occur during the idle phase.\n#[derive(Debug, Display, Error)]\npub enum IdleError {\n    /// Setting the coordinator state failed: {0}.\n    SetCoordinatorState(StorageError),\n    /// Deleting the dictionaries failed: {0}.\n    DeleteDictionaries(StorageError),\n}\n\n/// The idle state.\n#[derive(Debug)]\npub struct Idle;\n\n#[async_trait]\nimpl<T> Phase<T> for PhaseState<Idle, T>\nwhere\n    T: Storage,\n{\n    const NAME: PhaseName = PhaseName::Idle;\n\n    async fn process(&mut self) -> Result<(), PhaseError> {\n        self.delete_dicts().await?;\n\n        self.gen_round_keypair();\n        self.update_round_probabilities();\n        self.update_round_seed();\n\n        self.set_coordinator_state().await?;\n\n        Ok(())\n    }\n\n    fn broadcast(&mut self) {\n        self.broadcast_keys();\n        self.broadcast_params();\n        self.broadcast_metrics();\n    }\n\n    async fn next(self) -> Option<StateMachine<T>> {\n        Some(PhaseState::<Sum, _>::new(self.shared).into())\n    }\n}\n\nimpl<T> PhaseState<Idle, T> {\n    /// Creates a new idle state.\n    pub fn new(mut shared: Shared<T>) -> Self {\n        // Since some events are emitted very early, the round id must\n        // be correct when the idle phase starts. Therefore, we update\n        // it here, when instantiating the idle PhaseState.\n        shared.set_round_id(shared.round_id() + 1);\n        debug!(\"new round ID = {}\", shared.round_id());\n        Self {\n            private: Idle,\n            shared,\n        }\n    }\n\n    /// Updates the participant probabilities round parameters.\n    fn update_round_probabilities(&mut self) {\n        info!(\"updating round probabilities\");\n        warn!(\"round probabilities stay constant, no update strategy implemented yet\");\n    }\n\n    /// Updates the seed round parameter.\n    fn update_round_seed(&mut self) {\n        info!(\"updating round seed\");\n        // Safe unwrap: `sk` and `seed` have same number of bytes\n        let (_, sk) =\n            SigningKeySeed::from_slice_unchecked(self.shared.state.keys.secret.as_slice())\n                .derive_signing_key_pair();\n        let signature = sk.sign_detached(\n            &[\n                self.shared.state.round_params.seed.as_slice(),\n                &self.shared.state.round_params.sum.to_le_bytes(),\n                &self.shared.state.round_params.update.to_le_bytes(),\n            ]\n            .concat(),\n        );\n        // Safe unwrap: the length of the hash is 32 bytes\n        self.shared.state.round_params.seed =\n            RoundSeed::from_slice_unchecked(sha256::hash(signature.as_slice()).as_ref());\n    }\n\n    /// Generates fresh round credentials.\n    fn gen_round_keypair(&mut self) {\n        info!(\"updating the keys\");\n        self.shared.state.keys = EncryptKeyPair::generate();\n        self.shared.state.round_params.pk = self.shared.state.keys.public;\n    }\n\n    /// Broadcasts the keys.\n    fn broadcast_keys(&mut self) {\n        info!(\"broadcasting new keys\");\n        self.shared\n            .events\n            .broadcast_keys(self.shared.state.keys.clone());\n    }\n\n    /// Broadcasts the round parameters.\n    fn broadcast_params(&mut self) {\n        info!(\"broadcasting new round parameters\");\n        self.shared\n            .events\n            .broadcast_params(self.shared.state.round_params.clone());\n    }\n}\n\nimpl<T> PhaseState<Idle, T>\nwhere\n    T: Storage,\n{\n    /// Deletes the dicts from the store.\n    async fn delete_dicts(&mut self) -> Result<(), IdleError> {\n        info!(\"removing phase dictionaries from previous round\");\n        self.shared\n            .store\n            .delete_dicts()\n            .await\n            .map_err(IdleError::DeleteDictionaries)\n    }\n\n    /// Persists the coordinator state to the store.\n    async fn set_coordinator_state(&mut self) -> Result<(), IdleError> {\n        info!(\"storing new coordinator state\");\n        self.shared\n            .store\n            .set_coordinator_state(&self.shared.state)\n            .await\n            .map_err(IdleError::SetCoordinatorState)\n    }\n}\n\nimpl<T> PhaseState<Idle, T>\nwhere\n    T: Storage,\n    Self: Phase<T>,\n{\n    /// Broadcasts idle phase metrics.\n    fn broadcast_metrics(&self) {\n        metric!(Measurement::RoundTotalNumber, self.shared.state.round_id);\n        metric!(\n            Measurement::RoundParamSum,\n            self.shared.state.round_params.sum,\n            (\"round_id\", self.shared.state.round_id),\n            (\"phase\", Self::NAME as u8),\n        );\n        metric!(\n            Measurement::RoundParamUpdate,\n            self.shared.state.round_params.update,\n            (\"round_id\", self.shared.state.round_id),\n            (\"phase\", Self::NAME as u8),\n        );\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    use std::sync::Arc;\n\n    use anyhow::anyhow;\n    use xaynet_core::common::RoundParameters;\n\n    use crate::{\n        state_machine::{\n            coordinator::CoordinatorState,\n            events::{DictionaryUpdate, EventPublisher, EventSubscriber, ModelUpdate},\n            tests::{\n                utils::{assert_event_updated_with_id, enable_logging, init_shared, EventSnapshot},\n                CoordinatorStateBuilder,\n                EventBusBuilder,\n            },\n        },\n        storage::{\n            tests::{utils::create_global_model, MockCoordinatorStore, MockModelStore},\n            Store,\n        },\n    };\n\n    fn state_and_events_from_unmask_phase() -> (CoordinatorState, EventPublisher, EventSubscriber) {\n        let state = CoordinatorStateBuilder::new().build();\n\n        let (event_publisher, event_subscriber) = EventBusBuilder::new(&state)\n            .broadcast_phase(PhaseName::Unmask)\n            .broadcast_sum_dict(DictionaryUpdate::Invalidate)\n            .broadcast_seed_dict(DictionaryUpdate::Invalidate)\n            .broadcast_model(ModelUpdate::New(Arc::new(create_global_model(1))))\n            .build();\n\n        (state, event_publisher, event_subscriber)\n    }\n\n    fn assert_params(params1: &RoundParameters, params2: &RoundParameters) {\n        assert_ne!(params1.pk, params2.pk);\n        assert_ne!(params1.seed, params2.seed);\n        assert!((params1.sum - params2.sum).abs() <= f64::EPSILON);\n        assert!((params1.update - params2.update).abs() <= f64::EPSILON);\n        assert_eq!(params1.mask_config, params2.mask_config);\n        assert_eq!(params1.model_length, params2.model_length);\n    }\n\n    fn assert_after_delete_dict_failure(\n        state_before: &CoordinatorState,\n        events_before: &EventSnapshot,\n        state_after: &CoordinatorState,\n        events_after: &EventSnapshot,\n    ) {\n        assert_eq!(state_after.round_params.pk, state_before.round_params.pk);\n        assert_eq!(\n            state_after.round_params.seed,\n            state_before.round_params.seed\n        );\n        assert!(\n            (state_after.round_params.sum - state_before.round_params.sum).abs() <= f64::EPSILON\n        );\n        assert!(\n            (state_after.round_params.update - state_before.round_params.update).abs()\n                <= f64::EPSILON\n        );\n        assert_eq!(\n            state_after.round_params.mask_config,\n            state_before.round_params.mask_config\n        );\n        assert_eq!(\n            state_after.round_params.model_length,\n            state_before.round_params.model_length\n        );\n\n        assert_ne!(state_after.round_id, state_before.round_id);\n        assert_eq!(state_after.keys, state_before.keys);\n        assert_eq!(state_after.sum, state_before.sum);\n        assert_eq!(state_after.update, state_before.update);\n        assert_eq!(state_after.sum2, state_before.sum2);\n        assert_eq!(state_after.keys.public, state_after.round_params.pk);\n        assert_eq!(state_after.round_id, 1);\n\n        assert_event_updated_with_id(&events_after.phase, &events_before.phase);\n        assert_eq!(events_after.phase.event, PhaseName::Idle);\n        assert_eq!(&events_after.keys, &events_before.keys);\n        assert_eq!(&events_after.sum_dict, &events_before.sum_dict);\n        assert_eq!(&events_after.seed_dict, &events_before.seed_dict);\n        assert_eq!(events_after.params, events_before.params);\n        assert_eq!(events_after.model, events_before.model);\n    }\n\n    #[tokio::test]\n    async fn test_idle_to_sum_phase() {\n        // No Storage errors\n        // lets pretend we come from the unmask phase\n        //\n        // What should happen:\n        // 1. increase round id by 1\n        // 2. broadcast Idle phase\n        // 3. delete the sum/seed/mask dict\n        // 4. update coordinator keys\n        // 5. update round thresholds (not implemented yet)\n        // 6. update round seeds\n        // 7. save the new coordinator state\n        // 8. broadcast updated keys\n        // 9. broadcast new round parameters\n        // 10. move into sum phase\n        //\n        // What should not happen:\n        // - the global model has been invalidated\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_delete_dicts().return_once(move || Ok(()));\n        cs.expect_set_coordinator_state()\n            .return_once(move |_| Ok(()));\n        let store = Store::new(cs, MockModelStore::new());\n\n        let (state, event_publisher, event_subscriber) = state_and_events_from_unmask_phase();\n        let events_before_idle = EventSnapshot::from(&event_subscriber);\n        let state_before_idle = state.clone();\n\n        let (shared, _request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Idle, _>::new(shared));\n        assert!(state_machine.is_idle());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_idle = state_machine.as_ref().clone();\n        assert_params(\n            &state_after_idle.round_params,\n            &state_before_idle.round_params,\n        );\n        assert_ne!(state_after_idle.keys, state_before_idle.keys);\n        assert_ne!(state_after_idle.round_id, state_before_idle.round_id);\n        assert_eq!(state_after_idle.sum, state_before_idle.sum);\n        assert_eq!(state_after_idle.update, state_before_idle.update);\n        assert_eq!(state_after_idle.sum2, state_before_idle.sum2);\n        assert_eq!(\n            state_after_idle.keys.public,\n            state_after_idle.round_params.pk\n        );\n        assert_eq!(state_after_idle.round_id, 1);\n\n        let events_after_idle = EventSnapshot::from(&event_subscriber);\n        assert_event_updated_with_id(&events_after_idle.keys, &events_before_idle.keys);\n        assert_event_updated_with_id(&events_after_idle.params, &events_before_idle.params);\n        assert_event_updated_with_id(&events_after_idle.phase, &events_before_idle.phase);\n        assert_eq!(events_after_idle.phase.event, PhaseName::Idle);\n        assert_eq!(events_after_idle.sum_dict, events_before_idle.sum_dict);\n        assert_eq!(events_after_idle.seed_dict, events_before_idle.seed_dict);\n        assert_eq!(events_after_idle.model, events_before_idle.model);\n\n        assert!(state_machine.is_sum());\n    }\n\n    #[tokio::test]\n    async fn test_idle_to_sum_delete_dicts_failed() {\n        // Storage:\n        // - delete_dicts fails\n        //\n        // What should happen:\n        // 1. increase round id by 1\n        // 2. broadcast Idle phase\n        // 3. delete the sum/seed/mask dict (fails)\n        // 4. move into error phase\n        //\n        // What should not happen:\n        // - new keys have been broadcasted\n        // - new round parameters have been broadcasted\n        // - the global model has been invalidated\n        // - the state machine has moved into sum phase\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_delete_dicts()\n            .return_once(move || Err(anyhow!(\"\")));\n        let store = Store::new(cs, MockModelStore::new());\n\n        let (state, event_publisher, event_subscriber) = state_and_events_from_unmask_phase();\n        let events_before_idle = EventSnapshot::from(&event_subscriber);\n        let state_before_idle = state.clone();\n\n        let (shared, _request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Idle, _>::new(shared));\n        assert!(state_machine.is_idle());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_idle = state_machine.as_ref().clone();\n        let events_after_idle = EventSnapshot::from(&event_subscriber);\n        assert_after_delete_dict_failure(\n            &state_before_idle,\n            &events_before_idle,\n            &state_after_idle,\n            &events_after_idle,\n        );\n\n        assert!(state_machine.is_failure());\n        assert!(matches!(\n            state_machine.into_failure_phase_state().private.error,\n            PhaseError::Idle(IdleError::DeleteDictionaries(_))\n        ))\n    }\n\n    #[tokio::test]\n    async fn test_idle_to_sum_save_state_failed() {\n        // Storage:\n        // - set_coordinator_state fails\n        //\n        // What should happen:\n        // 1. increase round id by 1\n        // 2. broadcast Idle phase\n        // 3. delete the sum/seed/mask dict\n        // 4. update coordinator keys\n        // 5. update round thresholds (not implemented yet)\n        // 6. update round seeds\n        // 7. save the new coordinator state (fails)\n\n        // 6. broadcast updated keys\n\n        // 10. move into error phase\n        //\n        // What should not happen:\n        // - new round parameters have been broadcast\n        // - the global model has been invalidated\n        // - the state machine has moved into sum phase\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_delete_dicts().return_once(move || Ok(()));\n        cs.expect_set_coordinator_state()\n            .return_once(move |_| Err(anyhow!(\"\")));\n        let store = Store::new(cs, MockModelStore::new());\n\n        let (state, event_publisher, event_subscriber) = state_and_events_from_unmask_phase();\n        let events_before_idle = EventSnapshot::from(&event_subscriber);\n        let state_before_idle = state.clone();\n\n        let (shared, _request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Idle, _>::new(shared));\n        assert!(state_machine.is_idle());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_idle = state_machine.as_ref().clone();\n        let events_after_idle = EventSnapshot::from(&event_subscriber);\n\n        assert_params(\n            &state_after_idle.round_params,\n            &state_before_idle.round_params,\n        );\n        assert_ne!(state_after_idle.keys, state_before_idle.keys);\n        assert_ne!(state_after_idle.round_id, state_before_idle.round_id);\n        assert_eq!(state_after_idle.sum, state_before_idle.sum);\n        assert_eq!(state_after_idle.update, state_before_idle.update);\n        assert_eq!(state_after_idle.sum2, state_before_idle.sum2);\n        assert_eq!(\n            state_after_idle.keys.public,\n            state_after_idle.round_params.pk\n        );\n        assert_eq!(state_after_idle.round_id, 1);\n\n        assert_event_updated_with_id(&events_after_idle.phase, &events_before_idle.phase);\n        assert_eq!(events_after_idle.phase.event, PhaseName::Idle);\n        assert_eq!(&events_after_idle.keys, &events_before_idle.keys);\n        assert_eq!(&events_after_idle.sum_dict, &events_before_idle.sum_dict);\n        assert_eq!(&events_after_idle.seed_dict, &events_before_idle.seed_dict);\n        assert_eq!(events_after_idle.params, events_before_idle.params);\n        assert_eq!(events_after_idle.model, events_before_idle.model);\n\n        assert!(state_machine.is_failure());\n        assert!(matches!(\n            state_machine.into_failure_phase_state().private.error,\n            PhaseError::Idle(IdleError::SetCoordinatorState(_))\n        ))\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/phases/mod.rs",
    "content": "//! This module provides the states (aka phases) of the [`StateMachine`].\n//!\n//! [`StateMachine`]: crate::state_machine::StateMachine\n\nmod failure;\nmod handler;\nmod idle;\nmod phase;\nmod shutdown;\nmod sum;\nmod sum2;\nmod unmask;\nmod update;\n\npub use self::{\n    failure::{Failure, PhaseError},\n    handler::Handler,\n    idle::{Idle, IdleError},\n    phase::{Phase, PhaseName, PhaseState, Shared},\n    shutdown::Shutdown,\n    sum::{Sum, SumError},\n    sum2::Sum2,\n    unmask::{Unmask, UnmaskError},\n    update::{Update, UpdateError},\n};\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/phases/phase.rs",
    "content": "use std::fmt;\n\nuse async_trait::async_trait;\nuse derive_more::Display;\nuse futures::StreamExt;\nuse tracing::{debug, error, error_span, info, warn, Span};\nuse tracing_futures::Instrument;\n\nuse crate::{\n    discarded,\n    metric,\n    metrics::Measurement,\n    state_machine::{\n        coordinator::CoordinatorState,\n        events::EventPublisher,\n        phases::{Failure, PhaseError},\n        requests::{RequestError, RequestReceiver, ResponseSender, StateMachineRequest},\n        StateMachine,\n    },\n    storage::Storage,\n};\n\n/// The name of the current phase.\n#[derive(Clone, Copy, Debug, Display, Eq, PartialEq)]\npub enum PhaseName {\n    #[display(fmt = \"Idle\")]\n    Idle,\n    #[display(fmt = \"Sum\")]\n    Sum,\n    #[display(fmt = \"Update\")]\n    Update,\n    #[display(fmt = \"Sum2\")]\n    Sum2,\n    #[display(fmt = \"Unmask\")]\n    Unmask,\n    #[display(fmt = \"Failure\")]\n    Failure,\n    #[display(fmt = \"Shutdown\")]\n    Shutdown,\n}\n\n/// A trait that must be implemented by a state in order to perform its tasks and to move to a next\n/// state.\n///\n/// See the [module level documentation] for more details.\n///\n/// [module level documentation]: crate::state_machine\n#[async_trait]\npub trait Phase<T>\nwhere\n    T: Storage,\n{\n    /// The name of the current phase.\n    const NAME: PhaseName;\n\n    /// Performs the tasks of this phase.\n    async fn process(&mut self) -> Result<(), PhaseError>;\n    // TODO: add a filter service in PetMessageHandler that only passes through messages if\n    // the state machine is in one of the Sum, Update or Sum2 phases. then we can add a Purge\n    // phase here which gets broadcasted when the purge starts to prevent further incomming\n    // messages, which means we can split `purge()` from `process()` and use a no-op default impl\n    // for all phases except Sum, Update and Sum. until then we have to have a purge impl in every\n    // phase, which also means that the metrics can be a bit off.\n\n    /// Broadcasts data of this phase (nothing by default).\n    fn broadcast(&mut self) {}\n\n    /// Moves from this phase to the next phase.\n    async fn next(self) -> Option<StateMachine<T>>;\n}\n\n/// The coordinator state and the I/O interfaces that are shared and accessible by all\n/// [`PhaseState`]s.\npub struct Shared<T> {\n    /// The coordinator state.\n    pub(in crate::state_machine) state: CoordinatorState,\n    /// The request receiver half.\n    pub(in crate::state_machine) request_rx: RequestReceiver,\n    /// The event publisher.\n    pub(in crate::state_machine) events: EventPublisher,\n    /// The store for storing coordinator and model data.\n    pub(in crate::state_machine) store: T,\n}\n\nimpl<T> fmt::Debug for Shared<T> {\n    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {\n        f.debug_struct(\"Shared\")\n            .field(\"state\", &self.state)\n            .field(\"request_rx\", &self.request_rx)\n            .field(\"events\", &self.events)\n            .finish()\n    }\n}\n\nimpl<T> Shared<T> {\n    /// Creates a new shared state.\n    pub fn new(\n        coordinator_state: CoordinatorState,\n        publisher: EventPublisher,\n        request_rx: RequestReceiver,\n        store: T,\n    ) -> Self {\n        Self {\n            state: coordinator_state,\n            request_rx,\n            events: publisher,\n            store,\n        }\n    }\n\n    /// Sets the round ID to the given value.\n    pub fn set_round_id(&mut self, id: u64) {\n        self.state.round_id = id;\n        self.events.set_round_id(id);\n    }\n\n    /// Returns the current round ID.\n    pub fn round_id(&self) -> u64 {\n        self.state.round_id\n    }\n}\n\n/// The state corresponding to a phase of the PET protocol.\n///\n/// This contains the state-dependent `private` state and the state-independent `shared` state\n/// which is shared across state transitions.\npub struct PhaseState<S, T> {\n    /// The private state.\n    pub(in crate::state_machine) private: S,\n    /// The shared coordinator state and I/O interfaces.\n    pub(in crate::state_machine) shared: Shared<T>,\n}\n\nimpl<S, T> PhaseState<S, T>\nwhere\n    S: Send,\n    T: Storage,\n    Self: Phase<T>,\n{\n    /// Runs the current phase to completion.\n    ///\n    /// 1. Performs the phase tasks.\n    /// 2. Purges outdated phase messages.\n    /// 3. Broadcasts the phase data.\n    /// 4. Transitions to the next phase.\n    pub async fn run_phase(mut self) -> Option<StateMachine<T>> {\n        let phase = Self::NAME;\n        let span = error_span!(\"run_phase\", phase = %phase);\n\n        async move {\n            info!(\"starting phase\");\n            self.shared.events.broadcast_phase(phase);\n            metric!(Measurement::Phase, phase as u8);\n\n            if let Err(err) = self.process().await {\n                warn!(\"failed to perform the phase tasks\");\n                return Some(self.into_failure_state(err));\n            }\n            info!(\"phase ran successfully\");\n\n            if let Err(err) = self.purge_outdated_requests() {\n                warn!(\"failed to purge outdated requests\");\n                if let PhaseName::Failure | PhaseName::Shutdown = phase {\n                    debug!(\n                        \"already in {} phase: ignoring error while purging outdated requests\",\n                        phase,\n                    );\n                } else {\n                    return Some(self.into_failure_state(err));\n                }\n            }\n\n            self.broadcast();\n\n            info!(\"transitioning to the next phase\");\n            self.next().await\n        }\n        .instrument(span)\n        .await\n    }\n\n    /// Purges all pending requests that are considered outdated at the end of a successful phase.\n    fn purge_outdated_requests(&mut self) -> Result<(), PhaseError> {\n        info!(\"discarding outdated requests\");\n        while let Some((_, span, resp_tx)) = self.try_next_request()? {\n            debug!(\"discarding outdated request\");\n            let _span_guard = span.enter();\n            discarded!(self.shared.state.round_id, Self::NAME);\n            let _ = resp_tx.send(Err(RequestError::MessageDiscarded));\n        }\n        Ok(())\n    }\n}\n\nimpl<S, T> PhaseState<S, T> {\n    /// Receives the next [`StateMachineRequest`].\n    ///\n    /// # Errors\n    /// Returns [`PhaseError::RequestChannel`] when all sender halves have been dropped.\n    pub async fn next_request(\n        &mut self,\n    ) -> Result<(StateMachineRequest, Span, ResponseSender), PhaseError> {\n        debug!(\"waiting for the next incoming request\");\n        self.shared.request_rx.next().await.ok_or_else(|| {\n            error!(\"request receiver broken: senders have been dropped\");\n            PhaseError::RequestChannel(\"all message senders have been dropped!\")\n        })\n    }\n\n    pub fn try_next_request(\n        &mut self,\n    ) -> Result<Option<(StateMachineRequest, Span, ResponseSender)>, PhaseError> {\n        match self.shared.request_rx.try_recv() {\n            Some(Some(item)) => Ok(Some(item)),\n            None => {\n                debug!(\"no pending request\");\n                Ok(None)\n            }\n            Some(None) => {\n                warn!(\"failed to get next pending request: channel shut down\");\n                Err(PhaseError::RequestChannel(\n                    \"all message senders have been dropped!\",\n                ))\n            }\n        }\n    }\n\n    fn into_failure_state(self, err: PhaseError) -> StateMachine<T> {\n        PhaseState::<Failure, _>::new(self.shared, err).into()\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/phases/shutdown.rs",
    "content": "use async_trait::async_trait;\nuse tracing::debug;\n\nuse crate::{\n    state_machine::{\n        phases::{Phase, PhaseError, PhaseName, PhaseState, Shared},\n        StateMachine,\n    },\n    storage::Storage,\n};\n\n/// The shutdown state.\n#[derive(Debug)]\npub struct Shutdown;\n\n#[async_trait]\nimpl<T> Phase<T> for PhaseState<Shutdown, T>\nwhere\n    T: Storage,\n{\n    const NAME: PhaseName = PhaseName::Shutdown;\n\n    async fn process(&mut self) -> Result<(), PhaseError> {\n        debug!(\"clearing the request channel\");\n        self.shared.request_rx.close();\n        while self.shared.request_rx.recv().await.is_some() {}\n\n        Ok(())\n    }\n\n    async fn next(self) -> Option<StateMachine<T>> {\n        None\n    }\n}\n\nimpl<T> PhaseState<Shutdown, T> {\n    /// Creates a new shutdown state.\n    pub fn new(shared: Shared<T>) -> Self {\n        Self {\n            private: Shutdown,\n            shared,\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    use crate::{\n        state_machine::tests::{\n            utils::{enable_logging, init_shared},\n            CoordinatorStateBuilder,\n            EventBusBuilder,\n        },\n        storage::{\n            tests::{MockCoordinatorStore, MockModelStore},\n            Store,\n        },\n    };\n\n    #[tokio::test]\n    async fn test_shutdown_to_none() {\n        // No Storage errors\n        //\n        // What should happen:\n        // 1. broadcast Shutdown phase\n        // 2. request channel is closed\n        // 3. state machine is stopped\n        //\n        // What should not happen:\n        // - events have been broadcasted (except phase event)\n        enable_logging();\n        let store = Store::new(MockCoordinatorStore::new(), MockModelStore::new());\n\n        let state = CoordinatorStateBuilder::new().build();\n        let (event_publisher, _event_subscriber) = EventBusBuilder::new(&state).build();\n        let (shared, request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Shutdown, _>::new(shared));\n\n        assert!(state_machine.is_shutdown());\n\n        assert!(!request_tx.is_closed());\n\n        let state_machine = state_machine.next().await;\n\n        assert!(request_tx.is_closed());\n        assert!(state_machine.is_none());\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/phases/sum.rs",
    "content": "use std::sync::Arc;\n\nuse async_trait::async_trait;\nuse displaydoc::Display;\nuse thiserror::Error;\nuse tracing::info;\n\nuse crate::{\n    state_machine::{\n        events::DictionaryUpdate,\n        phases::{Handler, Phase, PhaseError, PhaseName, PhaseState, Shared, Update},\n        requests::{RequestError, StateMachineRequest, SumRequest},\n        StateMachine,\n    },\n    storage::{Storage, StorageError},\n};\nuse xaynet_core::{SumDict, SumParticipantEphemeralPublicKey, SumParticipantPublicKey};\n\n/// Errors which can occur during the sum phase.\n#[derive(Debug, Display, Error)]\npub enum SumError {\n    /// Sum dictionary does not exists.\n    NoSumDict,\n    /// Fetching sum dictionary failed: {0}.\n    FetchSumDict(StorageError),\n}\n\n/// The sum state.\n#[derive(Debug)]\npub struct Sum {\n    /// The sum dictionary which gets assembled during the sum phase.\n    sum_dict: Option<SumDict>,\n}\n\n#[async_trait]\nimpl<T> Phase<T> for PhaseState<Sum, T>\nwhere\n    T: Storage,\n    Self: Handler,\n{\n    const NAME: PhaseName = PhaseName::Sum;\n\n    async fn process(&mut self) -> Result<(), PhaseError> {\n        self.process(self.shared.state.sum).await?;\n        self.sum_dict().await?;\n\n        Ok(())\n    }\n\n    fn broadcast(&mut self) {\n        info!(\"broadcasting sum dictionary\");\n        let sum_dict = self\n            .private\n            .sum_dict\n            .take()\n            .expect(\"unreachable: never fails when `broadcast()` is called after `process()`\");\n        self.shared\n            .events\n            .broadcast_sum_dict(DictionaryUpdate::New(Arc::new(sum_dict)));\n    }\n\n    async fn next(self) -> Option<StateMachine<T>> {\n        Some(PhaseState::<Update, _>::new(self.shared).into())\n    }\n}\n\n#[async_trait]\nimpl<T> Handler for PhaseState<Sum, T>\nwhere\n    T: Storage,\n{\n    async fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), RequestError> {\n        if let StateMachineRequest::Sum(SumRequest {\n            participant_pk,\n            ephm_pk,\n        }) = req\n        {\n            self.update_sum_dict(participant_pk, ephm_pk).await\n        } else {\n            Err(RequestError::MessageRejected)\n        }\n    }\n}\n\nimpl<T> PhaseState<Sum, T> {\n    /// Creates a new sum state.\n    pub fn new(shared: Shared<T>) -> Self {\n        Self {\n            private: Sum { sum_dict: None },\n            shared,\n        }\n    }\n}\n\nimpl<T> PhaseState<Sum, T>\nwhere\n    T: Storage,\n{\n    /// Updates the sum dict with a sum participant request.\n    async fn update_sum_dict(\n        &mut self,\n        participant_pk: SumParticipantPublicKey,\n        ephm_pk: SumParticipantEphemeralPublicKey,\n    ) -> Result<(), RequestError> {\n        self.shared\n            .store\n            .add_sum_participant(&participant_pk, &ephm_pk)\n            .await?\n            .into_inner()\n            .map_err(RequestError::from)\n    }\n\n    /// Gets the sum dict from the store.\n    async fn sum_dict(&mut self) -> Result<(), SumError> {\n        self.private.sum_dict = self\n            .shared\n            .store\n            .sum_dict()\n            .await\n            .map_err(SumError::FetchSumDict)?\n            .ok_or(SumError::NoSumDict)?\n            .into();\n\n        Ok(())\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    use anyhow::anyhow;\n    use tokio::time::{timeout, Duration};\n    use xaynet_core::SumDict;\n\n    use crate::{\n        state_machine::{\n            coordinator::CoordinatorState,\n            events::{EventPublisher, EventSubscriber, ModelUpdate},\n            tests::{\n                utils::{\n                    assert_event_updated,\n                    enable_logging,\n                    init_shared,\n                    send_sum2_messages,\n                    send_sum_messages,\n                    send_update_messages,\n                    EventSnapshot,\n                },\n                CoordinatorStateBuilder,\n                EventBusBuilder,\n            },\n        },\n        storage::{\n            tests::{utils::create_global_model, MockCoordinatorStore, MockModelStore},\n            Store,\n            SumPartAdd,\n            SumPartAddError,\n        },\n    };\n\n    fn events_from_idle_phase(state: &CoordinatorState) -> (EventPublisher, EventSubscriber) {\n        EventBusBuilder::new(state)\n            .broadcast_phase(PhaseName::Idle)\n            .broadcast_sum_dict(DictionaryUpdate::Invalidate)\n            .broadcast_seed_dict(DictionaryUpdate::Invalidate)\n            .broadcast_model(ModelUpdate::New(Arc::new(create_global_model(1))))\n            .build()\n    }\n\n    fn assert_after_phase_success(\n        state_before: &CoordinatorState,\n        events_before: &EventSnapshot,\n        state_after: &CoordinatorState,\n        events_after: &EventSnapshot,\n    ) {\n        assert_eq!(state_after, state_before);\n\n        assert_event_updated(&events_after.phase, &events_before.phase);\n        assert_event_updated(&events_after.sum_dict, &events_before.sum_dict);\n        assert_eq!(events_after.keys, events_before.keys);\n        assert_eq!(events_after.params, events_before.params);\n        assert_eq!(events_after.phase.event, PhaseName::Sum);\n        assert_eq!(events_after.seed_dict, events_before.seed_dict);\n        assert_eq!(events_after.model, events_before.model);\n    }\n\n    fn assert_after_phase_failure(\n        state_before: &CoordinatorState,\n        events_before: &EventSnapshot,\n        state_after: &CoordinatorState,\n        events_after: &EventSnapshot,\n    ) {\n        assert_eq!(state_after, state_before);\n\n        assert_event_updated(&events_after.phase, &events_before.phase);\n        assert_eq!(events_after.keys, events_before.keys);\n        assert_eq!(events_after.params, events_before.params);\n        assert_eq!(events_after.phase.event, PhaseName::Sum);\n        assert_eq!(events_after.sum_dict, events_before.sum_dict);\n        assert_eq!(events_after.seed_dict, events_before.seed_dict);\n        assert_eq!(events_after.model, events_before.model);\n    }\n\n    #[tokio::test]\n    async fn test_sum_to_update_phase() {\n        // No Storage errors\n        // lets pretend we come from the sum phase\n        //\n        // What should happen:\n        // 1. broadcast Sum phase\n        // 2. accept 10 sum messages\n        // 3. fetch sum dict\n        // 4. broadcast sum dict\n        // 5. move into update phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_add_sum_participant()\n            .times(10)\n            .returning(move |_, _| Ok(SumPartAdd(Ok(()))));\n        cs.expect_sum_dict()\n            .return_once(move || Ok(Some(SumDict::new())));\n        let store = Store::new(cs, MockModelStore::new());\n        let state = CoordinatorStateBuilder::new()\n            .with_round_id(1)\n            .with_sum_count_min(10)\n            .with_sum_count_max(10)\n            .with_sum_time_min(1)\n            .build();\n\n        let (event_publisher, event_subscriber) = events_from_idle_phase(&state);\n        let events_before_sum = EventSnapshot::from(&event_subscriber);\n        let state_before_sum = state.clone();\n\n        let (shared, request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Sum, _>::new(shared));\n        assert!(state_machine.is_sum());\n\n        send_sum_messages(10, request_tx.clone());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_sum = state_machine.as_ref().clone();\n        let events_after_sum = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_success(\n            &state_before_sum,\n            &events_before_sum,\n            &state_after_sum,\n            &events_after_sum,\n        );\n\n        assert!(state_machine.is_update());\n    }\n\n    #[tokio::test]\n    async fn test_sum_phase_timeout() {\n        // No Storage errors\n        //\n        // What should happen:\n        // 1. broadcast Sum phase\n        // 2. phase should timeout\n        // 3. move into error phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated\n        // - the sum dict has been fetched\n        // - the sum dict has been broadcasted\n        enable_logging();\n\n        let store = Store::new(MockCoordinatorStore::new(), MockModelStore::new());\n        let state = CoordinatorStateBuilder::new()\n            .with_round_id(1)\n            .with_sum_time_min(1)\n            .with_sum_time_max(2)\n            .build();\n\n        let (event_publisher, event_subscriber) = events_from_idle_phase(&state);\n        let events_before_sum = EventSnapshot::from(&event_subscriber);\n        let state_before_sum = state.clone();\n\n        let (shared, _request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Sum, _>::new(shared));\n        assert!(state_machine.is_sum());\n\n        let state_machine = timeout(Duration::from_secs(4), state_machine.next())\n            .await\n            .unwrap()\n            .unwrap();\n\n        let state_after_sum = state_machine.as_ref().clone();\n        let events_after_sum = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_failure(\n            &state_before_sum,\n            &events_before_sum,\n            &state_after_sum,\n            &events_after_sum,\n        );\n\n        assert!(state_machine.is_failure());\n        assert!(matches!(\n            state_machine.into_failure_phase_state().private.error,\n            PhaseError::PhaseTimeout(_)\n        ))\n    }\n\n    #[tokio::test]\n    async fn test_rejected_messages() {\n        // No Storage errors\n        //\n        // What should happen:\n        // 1. broadcast Sum phase\n        // 2. accept 7 sum messages\n        // 3. reject 3 update and 5 sum2 messages\n        // 4. fetch sum dict\n        // 5. broadcast sum dict\n        // 6. move into update phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_add_sum_participant()\n            .times(7)\n            .returning(move |_, _| Ok(SumPartAdd(Ok(()))));\n        cs.expect_sum_dict()\n            .return_once(move || Ok(Some(SumDict::new())));\n        let store = Store::new(cs, MockModelStore::new());\n        let state = CoordinatorStateBuilder::new()\n            .with_round_id(1)\n            .with_sum_count_min(7)\n            .with_sum_count_max(7)\n            .build();\n\n        let (event_publisher, event_subscriber) = events_from_idle_phase(&state);\n        let events_before_sum = EventSnapshot::from(&event_subscriber);\n        let state_before_sum = state.clone();\n\n        let (shared, request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Sum, _>::new(shared));\n        assert!(state_machine.is_sum());\n\n        send_update_messages(3, request_tx.clone());\n        send_sum2_messages(5, request_tx.clone());\n        send_sum_messages(7, request_tx.clone());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_sum = state_machine.as_ref().clone();\n        let events_after_sum = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_success(\n            &state_before_sum,\n            &events_before_sum,\n            &state_after_sum,\n            &events_after_sum,\n        );\n\n        assert!(state_machine.is_update());\n    }\n\n    #[tokio::test]\n    async fn test_discarded_messages() {\n        // No Storage errors\n        //\n        // What should happen:\n        // 1. broadcast Sum phase\n        // 2. accept 5 sum messages\n        // 3. discard 5 sum messages\n        // 4. fetch sum dict\n        // 5. broadcast sum dict\n        // 6. move into update phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_add_sum_participant()\n            .times(5)\n            .returning(move |_, _| Ok(SumPartAdd(Ok(()))));\n        cs.expect_sum_dict()\n            .return_once(move || Ok(Some(SumDict::new())));\n        let store = Store::new(cs, MockModelStore::new());\n        let state = CoordinatorStateBuilder::new()\n            .with_round_id(1)\n            .with_sum_count_min(5)\n            .with_sum_count_max(5)\n            .with_sum_time_min(5)\n            .with_sum_time_max(10)\n            .build();\n\n        let (event_publisher, event_subscriber) = events_from_idle_phase(&state);\n        let events_before_sum = EventSnapshot::from(&event_subscriber);\n        let state_before_sum = state.clone();\n\n        let (shared, request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Sum, _>::new(shared));\n        assert!(state_machine.is_sum());\n\n        send_sum_messages(10, request_tx.clone());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_sum = state_machine.as_ref().clone();\n        let events_after_sum = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_success(\n            &state_before_sum,\n            &events_before_sum,\n            &state_after_sum,\n            &events_after_sum,\n        );\n\n        assert!(state_machine.is_update());\n    }\n\n    #[tokio::test]\n    async fn test_request_channel_is_dropped() {\n        // No Storage errors\n        //\n        // What should happen:\n        // 1. broadcast Sum phase\n        // 2. request channel is dropped\n        // 3. move into error phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated\n        // - the sum dict has been fetched\n        // - the sum dict has been broadcasted\n        enable_logging();\n\n        let store = Store::new(MockCoordinatorStore::new(), MockModelStore::new());\n        let state = CoordinatorStateBuilder::new()\n            .with_round_id(1)\n            .with_sum_count_min(1)\n            .with_sum_count_max(1)\n            .with_sum_time_min(1)\n            .with_sum_time_max(5)\n            .build();\n\n        let (event_publisher, event_subscriber) = events_from_idle_phase(&state);\n        let events_before_sum = EventSnapshot::from(&event_subscriber);\n        let state_before_sum = state.clone();\n\n        let (shared, request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Sum, _>::new(shared));\n        assert!(state_machine.is_sum());\n\n        drop(request_tx);\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_sum = state_machine.as_ref().clone();\n        let events_after_sum = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_failure(\n            &state_before_sum,\n            &events_before_sum,\n            &state_after_sum,\n            &events_after_sum,\n        );\n\n        assert!(state_machine.is_failure());\n        assert!(matches!(\n            state_machine.into_failure_phase_state().private.error,\n            PhaseError::RequestChannel(_)\n        ))\n    }\n\n    #[tokio::test]\n    async fn test_sum_to_update_fetch_sum_dict_failed() {\n        // Storage errors\n        // - sum_dict fails\n        //\n        // What should happen:\n        // 1. broadcast Sum phase\n        // 2. accept 1 sum message\n        // 3. fetch sum dict (fails)\n        // 4. move into error phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated\n        // - the sum dict has been broadcasted\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_add_sum_participant()\n            .times(1)\n            .returning(move |_, _| Ok(SumPartAdd(Ok(()))));\n        cs.expect_sum_dict().return_once(move || Err(anyhow!(\"\")));\n        let store = Store::new(cs, MockModelStore::new());\n        let state = CoordinatorStateBuilder::new()\n            .with_round_id(1)\n            .with_sum_count_min(1)\n            .with_sum_count_max(1)\n            .with_sum_time_min(1)\n            .with_sum_time_max(5)\n            .build();\n\n        let (event_publisher, event_subscriber) = events_from_idle_phase(&state);\n        let events_before_sum = EventSnapshot::from(&event_subscriber);\n        let state_before_sum = state.clone();\n\n        let (shared, request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Sum, _>::new(shared));\n        assert!(state_machine.is_sum());\n\n        send_sum_messages(1, request_tx.clone());\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_sum = state_machine.as_ref().clone();\n        let events_after_sum = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_failure(\n            &state_before_sum,\n            &events_before_sum,\n            &state_after_sum,\n            &events_after_sum,\n        );\n\n        assert!(state_machine.is_failure());\n        assert!(matches!(\n            state_machine.into_failure_phase_state().private.error,\n            PhaseError::Sum(SumError::FetchSumDict(_))\n        ))\n    }\n\n    #[tokio::test]\n    async fn test_sum_to_update_sum_dict_none() {\n        // No Storage errors\n        //\n        // What should happen:\n        // 1. broadcast Sum phase\n        // 2. accept 1 sum message\n        // 3. fetch sum dict (no storage error but the sum dict is None)\n        // 4. move into error phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated\n        // - the sum dict has been broadcasted\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_add_sum_participant()\n            .times(1)\n            .returning(move |_, _| Ok(SumPartAdd(Ok(()))));\n        cs.expect_sum_dict().return_once(move || Ok(None));\n        let store = Store::new(cs, MockModelStore::new());\n        let state = CoordinatorStateBuilder::new()\n            .with_round_id(1)\n            .with_sum_count_min(1)\n            .with_sum_count_max(1)\n            .with_sum_time_min(1)\n            .with_sum_time_max(5)\n            .build();\n\n        let (event_publisher, event_subscriber) = events_from_idle_phase(&state);\n        let events_before_sum = EventSnapshot::from(&event_subscriber);\n        let state_before_sum = state.clone();\n\n        let (shared, request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Sum, _>::new(shared));\n        assert!(state_machine.is_sum());\n\n        send_sum_messages(1, request_tx.clone());\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_sum = state_machine.as_ref().clone();\n        let events_after_sum = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_failure(\n            &state_before_sum,\n            &events_before_sum,\n            &state_after_sum,\n            &events_after_sum,\n        );\n\n        assert!(state_machine.is_failure());\n        assert!(matches!(\n            state_machine.into_failure_phase_state().private.error,\n            PhaseError::Sum(SumError::NoSumDict)\n        ))\n    }\n\n    #[tokio::test]\n    async fn test_rejected_messages_pet_error() {\n        // No Storage errors\n        //\n        // What should happen:\n        // 1. broadcast Sum phase\n        // 2. reject 3 sum messages (pet error SumPartAddError::AlreadyExists)\n        // 3. phase should timeout\n        // 4. move into error phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated\n        // - the sum dict has been fetched\n        // - the sum dict has been broadcasted\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_add_sum_participant()\n            .times(3)\n            .returning(move |_, _| Ok(SumPartAdd(Err(SumPartAddError::AlreadyExists))));\n        let store = Store::new(cs, MockModelStore::new());\n        let state = CoordinatorStateBuilder::new()\n            .with_round_id(1)\n            .with_sum_count_min(3)\n            .with_sum_count_max(3)\n            .with_sum_time_min(0)\n            .with_sum_time_max(2)\n            .build();\n\n        let (event_publisher, event_subscriber) = events_from_idle_phase(&state);\n        let events_before_sum = EventSnapshot::from(&event_subscriber);\n        let state_before_sum = state.clone();\n\n        let (shared, request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Sum, _>::new(shared));\n        assert!(state_machine.is_sum());\n\n        send_sum_messages(3, request_tx.clone());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_sum = state_machine.as_ref().clone();\n        let events_after_sum = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_failure(\n            &state_before_sum,\n            &events_before_sum,\n            &state_after_sum,\n            &events_after_sum,\n        );\n\n        assert!(state_machine.is_failure());\n        assert!(matches!(\n            state_machine.into_failure_phase_state().private.error,\n            PhaseError::PhaseTimeout(_)\n        ))\n    }\n\n    // #[tokio::test]\n    // async fn test_sum_phase_publish_after_purge() {\n    //     // Publish sum dict after purging all remaining messages.\n    //     enable_logging();\n\n    //     let mut cs = MockCoordinatorStore::new();\n    //     cs.expect_add_sum_participant()\n    //         .returning(move |_, _| Ok(SumPartAdd(Ok(()))));\n    //     cs.expect_sum_dict()\n    //         .return_once(move || Ok(Some(SumDict::new())));\n\n    //     let store = Store::new(cs, MockModelStore::new());\n    //     let state = CoordinatorStateBuilder::new()\n    //         .with_round_id(1)\n    //         .with_sum_count_min(2)\n    //         .with_sum_count_max(500)\n    //         .with_sum_time_min(0)\n    //         .build();\n\n    //     let (event_publisher, event_subscriber) = events_from_idle_phase(&state);\n\n    //     let (shared, request_tx) = init_shared(state, store, event_publisher);\n    //     let state_machine = StateMachine::from(PhaseState::<Sum, _>::new(shared));\n    //     assert!(state_machine.is_sum());\n\n    //     let (mut ready, latch) = Readiness::new();\n\n    //     send_sum_messages_with_latch(1000, request_tx.clone(), latch);\n\n    //     let mut sum_dict_listener = event_subscriber.sum_dict_listener();\n    //     sum_dict_listener.changed().await.unwrap();\n    //     tokio::time::sleep(Duration::from_secs(10)).await;\n    //     tokio::select! {\n    //         // TODO: purge_outdated_requests blocks the current thread (we should fix that)\n    //         // and sum_dict_listener.changed() would always be executed after\n    //         // state_machine.next(). The test always passes although it shouldn't\n    //         // therefore we need to spawn it here to run the state machine on a separate\n    //         // thread\n    //         //\n    //         // Further more we suffer from the https://github.com/tokio-rs/tokio/issues/3350\n    //         // issue in request_tx::try_recv(). We fill the request channel with 1000\n    //         // before we start the machine. Nevertheless, the message purging stops after\n    //         // around 134 messages.\n    //         _ = state_machine.next() => {\n    //             panic!(\"state did no run successfully\")\n    //         }\n    //         _ = sum_dict_listener.changed() => {\n    //             panic!(\"sum dict was broadcasted before all requests has been purged\")\n    //         }\n    //         _ = ready.is_ready() => {\n\n    //         }\n    //     }\n    // }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/phases/sum2.rs",
    "content": "use async_trait::async_trait;\nuse tracing::info;\n\nuse crate::{\n    state_machine::{\n        events::DictionaryUpdate,\n        phases::{Handler, Phase, PhaseError, PhaseName, PhaseState, Shared, Unmask},\n        requests::{RequestError, StateMachineRequest, Sum2Request},\n        StateMachine,\n    },\n    storage::Storage,\n};\nuse xaynet_core::{\n    mask::{Aggregation, MaskObject},\n    SumParticipantPublicKey,\n};\n\n/// The sum2 state.\n#[derive(Debug)]\npub struct Sum2 {\n    /// The aggregator for masked models.\n    model_agg: Aggregation,\n}\n\n#[async_trait]\nimpl<T> Phase<T> for PhaseState<Sum2, T>\nwhere\n    T: Storage,\n    Self: Handler,\n{\n    const NAME: PhaseName = PhaseName::Sum2;\n\n    async fn process(&mut self) -> Result<(), PhaseError> {\n        self.process(self.shared.state.sum2).await\n    }\n\n    fn broadcast(&mut self) {\n        info!(\"broadcasting invalidation of sum dictionary\");\n        self.shared\n            .events\n            .broadcast_sum_dict(DictionaryUpdate::Invalidate);\n\n        info!(\"broadcasting invalidation of seed dictionary\");\n        self.shared\n            .events\n            .broadcast_seed_dict(DictionaryUpdate::Invalidate);\n    }\n\n    async fn next(self) -> Option<StateMachine<T>> {\n        Some(PhaseState::<Unmask, _>::new(self.shared, self.private.model_agg).into())\n    }\n}\n\n#[async_trait]\nimpl<T> Handler for PhaseState<Sum2, T>\nwhere\n    T: Storage,\n{\n    async fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), RequestError> {\n        if let StateMachineRequest::Sum2(Sum2Request {\n            participant_pk,\n            model_mask,\n        }) = req\n        {\n            self.update_mask_dict(participant_pk, model_mask).await\n        } else {\n            Err(RequestError::MessageRejected)\n        }\n    }\n}\n\nimpl<T> PhaseState<Sum2, T> {\n    /// Creates a new sum2 state.\n    pub fn new(shared: Shared<T>, model_agg: Aggregation) -> Self {\n        Self {\n            private: Sum2 { model_agg },\n            shared,\n        }\n    }\n}\n\nimpl<T> PhaseState<Sum2, T>\nwhere\n    T: Storage,\n{\n    /// Updates the mask dict with a sum2 participant request.\n    async fn update_mask_dict(\n        &mut self,\n        participant_pk: SumParticipantPublicKey,\n        model_mask: MaskObject,\n    ) -> Result<(), RequestError> {\n        self.shared\n            .store\n            .incr_mask_score(&participant_pk, &model_mask)\n            .await?\n            .into_inner()\n            .map_err(RequestError::from)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    use std::sync::Arc;\n\n    use xaynet_core::{SeedDict, SumDict};\n\n    use crate::{\n        state_machine::{\n            coordinator::CoordinatorState,\n            events::{DictionaryUpdate, EventPublisher, EventSubscriber, ModelUpdate},\n            tests::{\n                utils::{\n                    assert_event_updated,\n                    enable_logging,\n                    init_shared,\n                    send_sum2_messages,\n                    EventSnapshot,\n                },\n                CoordinatorStateBuilder,\n                EventBusBuilder,\n            },\n        },\n        storage::{\n            tests::{utils::create_global_model, MockCoordinatorStore, MockModelStore},\n            MaskScoreIncr,\n            MaskScoreIncrError,\n            Store,\n        },\n    };\n\n    fn events_from_update_phase(state: &CoordinatorState) -> (EventPublisher, EventSubscriber) {\n        EventBusBuilder::new(state)\n            .broadcast_phase(PhaseName::Update)\n            .broadcast_sum_dict(DictionaryUpdate::New(Arc::new(SumDict::new())))\n            .broadcast_seed_dict(DictionaryUpdate::New(Arc::new(SeedDict::new())))\n            .broadcast_model(ModelUpdate::New(Arc::new(create_global_model(1))))\n            .build()\n    }\n\n    fn assert_after_phase_success(\n        state_before: &CoordinatorState,\n        events_before: &EventSnapshot,\n        state_after: &CoordinatorState,\n        events_after: &EventSnapshot,\n    ) {\n        assert_eq!(state_after, state_before);\n\n        assert_event_updated(&events_after.phase, &events_before.phase);\n        assert_event_updated(&events_after.sum_dict, &events_before.sum_dict);\n        assert_event_updated(&events_after.seed_dict, &events_before.seed_dict);\n        assert_eq!(events_after.sum_dict.event, DictionaryUpdate::Invalidate);\n        assert_eq!(events_after.seed_dict.event, DictionaryUpdate::Invalidate);\n        assert_eq!(events_after.keys, events_before.keys);\n        assert_eq!(events_after.params, events_before.params);\n        assert_eq!(events_after.phase.event, PhaseName::Sum2);\n        assert_eq!(events_after.model, events_before.model);\n    }\n\n    fn assert_after_phase_failure(\n        state_before: &CoordinatorState,\n        events_before: &EventSnapshot,\n        state_after: &CoordinatorState,\n        events_after: &EventSnapshot,\n    ) {\n        assert_eq!(state_after, state_before);\n\n        assert_event_updated(&events_after.phase, &events_before.phase);\n        assert_eq!(events_after.keys, events_before.keys);\n        assert_eq!(events_after.params, events_before.params);\n        assert_eq!(events_after.phase.event, PhaseName::Sum2);\n        assert_eq!(events_after.sum_dict, events_before.sum_dict);\n        assert_eq!(events_after.seed_dict, events_before.seed_dict);\n        assert_eq!(events_after.model, events_before.model);\n    }\n\n    #[tokio::test]\n    async fn test_sum2_to_unmask_phase() {\n        // No Storage errors\n        // lets pretend we come from the update phase\n        //\n        // What should happen:\n        // 1. broadcast Sum2 phase\n        // 2. accept 10 sum2 messages\n        // 3. broadcast invalidation of sum and seed dict\n        // 4. move into unmask phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - events have been broadcasted (except phase event and invalidation\n        //   event of sum and seed dict)\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_incr_mask_score()\n            .times(10)\n            .returning(move |_, _| Ok(MaskScoreIncr(Ok(()))));\n\n        let store = Store::new(cs, MockModelStore::new());\n        let state = CoordinatorStateBuilder::new()\n            .with_round_id(1)\n            .with_sum2_count_min(10)\n            .with_sum2_count_max(10)\n            .with_sum2_time_min(1)\n            .build();\n\n        let (event_publisher, event_subscriber) = events_from_update_phase(&state);\n        let events_before_sum2 = EventSnapshot::from(&event_subscriber);\n        let state_before_sum2 = state.clone();\n\n        let (shared, request_tx) = init_shared(state, store, event_publisher);\n        let agg = Aggregation::new(\n            state_before_sum2.round_params.mask_config,\n            state_before_sum2.round_params.model_length,\n        );\n        let state_machine = StateMachine::from(PhaseState::<Sum2, _>::new(shared, agg));\n        assert!(state_machine.is_sum2());\n\n        send_sum2_messages(10, request_tx.clone());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_sum2 = state_machine.as_ref().clone();\n        let events_after_sum2 = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_success(\n            &state_before_sum2,\n            &events_before_sum2,\n            &state_after_sum2,\n            &events_after_sum2,\n        );\n\n        assert!(state_machine.is_unmask());\n    }\n\n    #[tokio::test]\n    async fn test_rejected_messages_pet_error() {\n        // No Storage errors\n        //\n        // What should happen:\n        // 1. broadcast Sum2 phase\n        // 2. reject 3 sum2 messages (pet error MaskScoreIncrError::UnknownSumPk)\n        // 3. phase should timeout\n        // 4. move into error phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated\n        // - the sum dict has been invalidated\n        // - the seed dict has been invalidated\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_incr_mask_score()\n            .times(3)\n            .returning(move |_, _| Ok(MaskScoreIncr(Err(MaskScoreIncrError::UnknownSumPk))));\n        let store = Store::new(cs, MockModelStore::new());\n        let state = CoordinatorStateBuilder::new()\n            .with_round_id(1)\n            .with_sum2_count_min(3)\n            .with_sum2_count_max(3)\n            .with_sum2_time_min(0)\n            .with_sum2_time_max(2)\n            .build();\n\n        let (event_publisher, event_subscriber) = events_from_update_phase(&state);\n        let events_before_sum2 = EventSnapshot::from(&event_subscriber);\n        let state_before_sum2 = state.clone();\n\n        let (shared, request_tx) = init_shared(state, store, event_publisher);\n        let agg = Aggregation::new(\n            state_before_sum2.round_params.mask_config,\n            state_before_sum2.round_params.model_length,\n        );\n        let state_machine = StateMachine::from(PhaseState::<Sum2, _>::new(shared, agg));\n        assert!(state_machine.is_sum2());\n\n        send_sum2_messages(3, request_tx.clone());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_sum2 = state_machine.as_ref().clone();\n        let events_after_sum2 = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_failure(\n            &state_before_sum2,\n            &events_before_sum2,\n            &state_after_sum2,\n            &events_after_sum2,\n        );\n\n        assert!(state_machine.is_failure());\n        assert!(matches!(\n            state_machine.into_failure_phase_state().private.error,\n            PhaseError::PhaseTimeout(_)\n        ))\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/phases/unmask.rs",
    "content": "use std::{cmp::Ordering, sync::Arc};\n\nuse async_trait::async_trait;\nuse displaydoc::Display;\nuse thiserror::Error;\n#[cfg(feature = \"model-persistence\")]\nuse tracing::warn;\nuse tracing::{error, info};\n\nuse crate::{\n    metric,\n    metrics::{GlobalRecorder, Measurement},\n    state_machine::{\n        events::ModelUpdate,\n        phases::{Idle, Phase, PhaseError, PhaseName, PhaseState, Shared},\n        StateMachine,\n    },\n    storage::{Storage, StorageError},\n};\nuse xaynet_core::mask::{Aggregation, MaskObject, Model, UnmaskingError};\n\n/// Errors which can occur during the unmask phase.\n#[derive(Debug, Display, Error)]\npub enum UnmaskError {\n    /// Ambiguous masks were computed by the sum participants.\n    AmbiguousMasks,\n    /// No mask found.\n    NoMask,\n    /// Unmasking global model failed: {0}.\n    Unmasking(#[from] UnmaskingError),\n    /// Fetching best masks failed: {0}.\n    FetchBestMasks(#[from] StorageError),\n    #[cfg(feature = \"model-persistence\")]\n    /// Saving the global model failed: {0}.\n    SaveGlobalModel(crate::storage::StorageError),\n    /// Publishing the proof of the global model failed: {0}.\n    PublishProof(crate::storage::StorageError),\n}\n\n/// The unmask state.\n#[derive(Debug)]\npub struct Unmask {\n    /// The aggregator for masked models.\n    model_agg: Option<Aggregation>,\n    /// The global model of the current round.\n    global_model: Option<Arc<Model>>,\n}\n\n#[async_trait]\nimpl<T> Phase<T> for PhaseState<Unmask, T>\nwhere\n    T: Storage,\n{\n    const NAME: PhaseName = PhaseName::Unmask;\n\n    async fn process(&mut self) -> Result<(), PhaseError> {\n        self.emit_number_of_unique_masks_metrics();\n        let best_masks = self.best_masks().await?;\n        self.end_round(best_masks).await?;\n\n        #[cfg(feature = \"model-persistence\")]\n        self.save_global_model().await?;\n        self.publish_proof().await?;\n\n        Ok(())\n    }\n\n    fn broadcast(&mut self) {\n        info!(\"broadcasting the new global model\");\n        let global_model =\n            self.private.global_model.take().expect(\n                \"unreachable: never fails when `broadcast()` is called after `end_round()`\",\n            );\n        self.shared\n            .events\n            .broadcast_model(ModelUpdate::New(global_model));\n    }\n\n    async fn next(self) -> Option<StateMachine<T>> {\n        Some(PhaseState::<Idle, _>::new(self.shared).into())\n    }\n}\n\nimpl<T> PhaseState<Unmask, T> {\n    /// Creates a new unmask state.\n    pub fn new(shared: Shared<T>, model_agg: Aggregation) -> Self {\n        Self {\n            private: Unmask {\n                model_agg: Some(model_agg),\n                global_model: None,\n            },\n            shared,\n        }\n    }\n\n    /// Freezes the mask dictionary.\n    async fn freeze_mask_dict(\n        &mut self,\n        mut best_masks: Vec<(MaskObject, u64)>,\n    ) -> Result<MaskObject, UnmaskError> {\n        let mask = best_masks\n            .drain(0..)\n            .fold(\n                (None, 0),\n                |(unique_mask, unique_count), (mask, count)| match unique_count.cmp(&count) {\n                    Ordering::Less => (Some(mask), count),\n                    Ordering::Greater => (unique_mask, unique_count),\n                    Ordering::Equal => (None, unique_count),\n                },\n            )\n            .0\n            .ok_or(UnmaskError::AmbiguousMasks)?;\n\n        Ok(mask)\n    }\n\n    /// Ends the round by unmasking the global model.\n    async fn end_round(&mut self, best_masks: Vec<(MaskObject, u64)>) -> Result<(), UnmaskError> {\n        let mask = self.freeze_mask_dict(best_masks).await?;\n\n        // Safe unwrap: State::<Unmask>::new always creates Some(aggregation)\n        let model_agg = self.private.model_agg.take().unwrap();\n\n        model_agg\n            .validate_unmasking(&mask)\n            .map_err(UnmaskError::from)?;\n        self.private.global_model = Some(Arc::new(model_agg.unmask(mask)));\n\n        Ok(())\n    }\n}\n\nimpl<T> PhaseState<Unmask, T>\nwhere\n    T: Storage,\n{\n    /// Broadcasts mask metrics.\n    fn emit_number_of_unique_masks_metrics(&mut self) {\n        if GlobalRecorder::global().is_none() {\n            return;\n        }\n\n        let mut store = self.shared.store.clone();\n        let (round_id, phase_name) = (self.shared.state.round_id, Self::NAME);\n\n        tokio::spawn(async move {\n            match store.number_of_unique_masks().await {\n                Ok(number_of_masks) => metric!(\n                    Measurement::MasksTotalNumber,\n                    number_of_masks,\n                    (\"round_id\", round_id),\n                    (\"phase\", phase_name as u8),\n                ),\n                Err(err) => error!(\"failed to fetch total number of masks: {}\", err),\n            };\n        });\n    }\n\n    /// Gets the two masks with the highest score.\n    async fn best_masks(&mut self) -> Result<Vec<(MaskObject, u64)>, UnmaskError> {\n        self.shared\n            .store\n            .best_masks()\n            .await\n            .map_err(UnmaskError::FetchBestMasks)?\n            .ok_or(UnmaskError::NoMask)\n    }\n\n    /// Persists the global model to the store.\n    #[cfg(feature = \"model-persistence\")]\n    async fn save_global_model(&mut self) -> Result<(), UnmaskError> {\n        info!(\"saving global model\");\n        let global_model = self\n            .private\n            .global_model\n            .as_ref()\n            .expect(\n                \"unreachable: never fails when `save_global_model()` is called after `end_round()`\",\n            )\n            .as_ref();\n        let global_model_id = self\n            .shared\n            .store\n            .set_global_model(\n                self.shared.state.round_id,\n                &self.shared.state.round_params.seed,\n                global_model,\n            )\n            .await\n            .map_err(UnmaskError::SaveGlobalModel)?;\n        if let Err(err) = self\n            .shared\n            .store\n            .set_latest_global_model_id(&global_model_id)\n            .await\n        {\n            warn!(\"failed to update latest global model id: {}\", err);\n        }\n\n        Ok(())\n    }\n\n    /// Publishes proof of the global model.\n    async fn publish_proof(&mut self) -> Result<(), UnmaskError> {\n        info!(\"publishing proof of the new global model\");\n        let global_model = self\n            .private\n            .global_model\n            .as_ref()\n            .expect(\n                \"unreachable: never fails when `save_global_model()` is called after `end_round()`\",\n            )\n            .as_ref();\n        self.shared\n            .store\n            .publish_proof(global_model)\n            .await\n            .map_err(UnmaskError::PublishProof)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    use std::sync::Arc;\n\n    use anyhow::anyhow;\n\n    use crate::{\n        state_machine::{\n            coordinator::CoordinatorState,\n            events::{DictionaryUpdate, EventPublisher, EventSubscriber, ModelUpdate},\n            tests::{\n                utils::{assert_event_updated, enable_logging, init_shared, EventSnapshot},\n                CoordinatorStateBuilder,\n                EventBusBuilder,\n            },\n        },\n        storage::{\n            tests::{\n                utils::{create_global_model, create_mask},\n                MockCoordinatorStore,\n                MockModelStore,\n                MockTrustAnchor,\n            },\n            Store,\n        },\n    };\n\n    fn events_from_sum2_phase(state: &CoordinatorState) -> (EventPublisher, EventSubscriber) {\n        EventBusBuilder::new(state)\n            .broadcast_phase(PhaseName::Sum2)\n            .broadcast_sum_dict(DictionaryUpdate::Invalidate)\n            .broadcast_seed_dict(DictionaryUpdate::Invalidate)\n            .broadcast_model(ModelUpdate::New(Arc::new(create_global_model(1))))\n            .build()\n    }\n\n    fn assert_after_phase_success(\n        state_before: &CoordinatorState,\n        events_before: &EventSnapshot,\n        state_after: &CoordinatorState,\n        events_after: &EventSnapshot,\n    ) {\n        assert_ne!(state_after.round_id, state_before.round_id);\n        assert_eq!(state_after.round_params, state_before.round_params);\n        assert_eq!(state_after.keys, state_before.keys);\n        assert_eq!(state_after.sum, state_before.sum);\n        assert_eq!(state_after.update, state_before.update);\n        assert_eq!(state_after.sum2, state_before.sum2);\n\n        assert_event_updated(&events_after.phase, &events_before.phase);\n        assert_event_updated(&events_after.model, &events_before.model);\n        assert_eq!(events_after.keys, events_before.keys);\n        assert_eq!(events_after.params, events_before.params);\n        assert_eq!(events_after.phase.event, PhaseName::Unmask);\n        assert_eq!(events_after.sum_dict, events_before.sum_dict);\n        assert_eq!(events_after.seed_dict, events_before.seed_dict);\n    }\n\n    fn assert_after_phase_failure(\n        state_before: &CoordinatorState,\n        events_before: &EventSnapshot,\n        state_after: &CoordinatorState,\n        events_after: &EventSnapshot,\n    ) {\n        assert_eq!(state_after, state_before);\n\n        assert_event_updated(&events_after.phase, &events_before.phase);\n        assert_eq!(events_after.keys, events_before.keys);\n        assert_eq!(events_after.params, events_before.params);\n        assert_eq!(events_after.phase.event, PhaseName::Unmask);\n        assert_eq!(events_after.sum_dict, events_before.sum_dict);\n        assert_eq!(events_after.seed_dict, events_before.seed_dict);\n        assert_eq!(events_after.model, events_before.model);\n    }\n\n    fn init_aggregator(state: &CoordinatorState) -> Aggregation {\n        let mut aggregator = Aggregation::new(\n            state.round_params.mask_config,\n            state.round_params.model_length,\n        );\n        aggregator.aggregate(create_mask(state.round_params.model_length, 1));\n        aggregator\n    }\n\n    #[tokio::test]\n    async fn test_unmask_to_idle_phase() {\n        // No Storage errors\n        // lets pretend we come from the sum2 phase\n        //\n        // What should happen:\n        // 1. broadcast Unmask phase\n        // 2  fetch best masks (return only one)\n        // 3. unmask the masked global model\n        // 4. publish proof\n        // 5. broadcast unmasked global model\n        // 6. move into idle phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - events have been broadcasted (except phase event and global model)\n        enable_logging();\n\n        let state = CoordinatorStateBuilder::new().with_round_id(1).build();\n        let model_length = state.round_params.model_length;\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_best_masks()\n            .returning(move || Ok(Some(vec![(create_mask(model_length, 1), 1)])));\n        #[cfg(feature = \"model-persistence\")]\n        {\n            cs.expect_set_latest_global_model_id()\n                .returning(move |_| Ok(()));\n        }\n        let ms = {\n            #[cfg(not(feature = \"model-persistence\"))]\n            {\n                MockModelStore::new()\n            }\n            #[cfg(feature = \"model-persistence\")]\n            {\n                let mut ms = MockModelStore::new();\n                ms.expect_set_global_model()\n                    .returning(move |_, _, _| Ok(\"id\".to_string()));\n                ms\n            }\n        };\n\n        let store = Store::new(cs, ms);\n\n        let (event_publisher, event_subscriber) = events_from_sum2_phase(&state);\n        let events_before_sum2 = EventSnapshot::from(&event_subscriber);\n        let state_before_sum2 = state.clone();\n\n        let (shared, _request_tx) = init_shared(state, store, event_publisher);\n        let aggregator = init_aggregator(&state_before_sum2);\n        let state_machine = StateMachine::from(PhaseState::<Unmask, _>::new(shared, aggregator));\n        assert!(state_machine.is_unmask());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_sum2 = state_machine.as_ref().clone();\n        let events_after_sum2 = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_success(\n            &state_before_sum2,\n            &events_before_sum2,\n            &state_after_sum2,\n            &events_after_sum2,\n        );\n\n        assert!(state_machine.is_idle());\n    }\n\n    #[tokio::test]\n    async fn test_unmask_to_idle_phase_best_masks_fails() {\n        // Storage:\n        // - best_masks fails\n        //\n        // What should happen:\n        // 1. broadcast Unmask phase\n        // 2. fetch best masks (fails)\n        // 3. move into error phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated/changed\n        // - the sum dict has been invalidated\n        // - the seed dict has been invalidated\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_best_masks().returning(move || Err(anyhow!(\"\")));\n\n        let store = Store::new(cs, MockModelStore::new());\n\n        let state = CoordinatorStateBuilder::new().with_round_id(1).build();\n        let (event_publisher, event_subscriber) = events_from_sum2_phase(&state);\n        let events_before_sum2 = EventSnapshot::from(&event_subscriber);\n        let state_before_sum2 = state.clone();\n\n        let (shared, _request_tx) = init_shared(state, store, event_publisher);\n        let aggregator = init_aggregator(&state_before_sum2);\n        let state_machine = StateMachine::from(PhaseState::<Unmask, _>::new(shared, aggregator));\n        assert!(state_machine.is_unmask());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_sum2 = state_machine.as_ref().clone();\n        let events_after_sum2 = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_failure(\n            &state_before_sum2,\n            &events_before_sum2,\n            &state_after_sum2,\n            &events_after_sum2,\n        );\n\n        assert!(state_machine.is_failure());\n        assert!(matches!(\n            state_machine.into_failure_phase_state().private.error,\n            PhaseError::Unmask(UnmaskError::FetchBestMasks(_))\n        ))\n    }\n\n    #[tokio::test]\n    async fn test_unmask_to_idle_phase_no_mask() {\n        // No Storage errors\n        //\n        // What should happen:\n        // 1. broadcast Unmask phase\n        // 2. fetch best masks (no storage error but the mask vec is None)\n        // 3. move into error phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated/changed\n        // - the sum dict has been invalidated\n        // - the seed dict has been invalidated\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_best_masks().returning(move || Ok(None));\n\n        let store = Store::new(cs, MockModelStore::new());\n\n        let state = CoordinatorStateBuilder::new().with_round_id(1).build();\n        let (event_publisher, event_subscriber) = events_from_sum2_phase(&state);\n        let events_before_sum2 = EventSnapshot::from(&event_subscriber);\n        let state_before_sum2 = state.clone();\n\n        let (shared, _request_tx) = init_shared(state, store, event_publisher);\n        let aggregator = init_aggregator(&state_before_sum2);\n        let state_machine = StateMachine::from(PhaseState::<Unmask, _>::new(shared, aggregator));\n        assert!(state_machine.is_unmask());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_sum2 = state_machine.as_ref().clone();\n        let events_after_sum2 = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_failure(\n            &state_before_sum2,\n            &events_before_sum2,\n            &state_after_sum2,\n            &events_after_sum2,\n        );\n\n        assert!(state_machine.is_failure());\n        assert!(matches!(\n            state_machine.into_failure_phase_state().private.error,\n            PhaseError::Unmask(UnmaskError::NoMask)\n        ))\n    }\n\n    #[tokio::test]\n    async fn test_unmask_to_idle_phase_ambiguous_masks() {\n        // No Storage errors\n        //\n        // What should happen:\n        // 1. broadcast Unmask phase\n        // 2. fetch best masks\n        // 3. unmask the masked global model (fails because of ambiguous masks)\n        // 4. move into error phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated/changed\n        // - the sum dict has been invalidated\n        // - the seed dict has been invalidated\n        enable_logging();\n\n        let state = CoordinatorStateBuilder::new().with_round_id(1).build();\n        let model_length = state.round_params.model_length;\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_best_masks().returning(move || {\n            Ok(Some(vec![\n                (create_mask(model_length, 1), 1),\n                (create_mask(model_length, 2), 1),\n            ]))\n        });\n\n        let store = Store::new(cs, MockModelStore::new());\n\n        let (event_publisher, event_subscriber) = events_from_sum2_phase(&state);\n        let events_before_sum2 = EventSnapshot::from(&event_subscriber);\n        let state_before_sum2 = state.clone();\n\n        let (shared, _request_tx) = init_shared(state, store, event_publisher);\n        let aggregator = init_aggregator(&state_before_sum2);\n        let state_machine = StateMachine::from(PhaseState::<Unmask, _>::new(shared, aggregator));\n        assert!(state_machine.is_unmask());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_sum2 = state_machine.as_ref().clone();\n        let events_after_sum2 = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_failure(\n            &state_before_sum2,\n            &events_before_sum2,\n            &state_after_sum2,\n            &events_after_sum2,\n        );\n\n        assert!(state_machine.is_failure());\n        assert!(matches!(\n            state_machine.into_failure_phase_state().private.error,\n            PhaseError::Unmask(UnmaskError::AmbiguousMasks)\n        ))\n    }\n\n    #[tokio::test]\n    async fn test_unmask_to_idle_phase_validate_unmasking_fails() {\n        // No Storage errors\n        //\n        // What should happen:\n        // 1. broadcast Unmask phase\n        // 2. fetch best masks\n        // 3. unmask the masked global model (fails because of validate unmasking error)\n        // 4. move into error phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated/changed\n        // - the sum dict has been invalidated\n        // - the seed dict has been invalidated\n        enable_logging();\n\n        let state = CoordinatorStateBuilder::new().with_round_id(1).build();\n        let model_length = state.round_params.model_length;\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_best_masks()\n            .returning(move || Ok(Some(vec![(create_mask(model_length, 1), 1)])));\n\n        let store = Store::new(cs, MockModelStore::new());\n\n        let (event_publisher, event_subscriber) = events_from_sum2_phase(&state);\n        let events_before_sum2 = EventSnapshot::from(&event_subscriber);\n        let state_before_sum2 = state.clone();\n\n        let (shared, _request_tx) = init_shared(state, store, event_publisher);\n        let aggregator = Aggregation::new(\n            state_before_sum2.round_params.mask_config,\n            state_before_sum2.round_params.model_length,\n        );\n        let state_machine = StateMachine::from(PhaseState::<Unmask, _>::new(shared, aggregator));\n        assert!(state_machine.is_unmask());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_sum2 = state_machine.as_ref().clone();\n        let events_after_sum2 = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_failure(\n            &state_before_sum2,\n            &events_before_sum2,\n            &state_after_sum2,\n            &events_after_sum2,\n        );\n\n        assert!(state_machine.is_failure());\n        assert!(matches!(\n            state_machine.into_failure_phase_state().private.error,\n            PhaseError::Unmask(UnmaskError::Unmasking(UnmaskingError::NoModel))\n        ))\n    }\n\n    #[tokio::test]\n    async fn test_unmask_to_idle_phase_publish_proof_fails() {\n        // TODO: we should set the latest_global_model_id only if the\n        // the proof was successfully published\n        //\n        // Why? If the coordinator were to restart after this phase, they would\n        // be using a model that has no evidence and therefore cannot be validated\n        // by the user.\n\n        // Storage:\n        // - publish_proof fails\n        //\n        // What should happen:\n        // 1. broadcast Unmask phase\n        // 2. fetch best masks\n        // 3. unmask the masked global model\n        // 4. save global model and model id (model-persistence feature)\n        // 5. publish proof (fails)\n        // 6. move into error phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated/changed\n        // - the sum dict has been invalidated\n        // - the seed dict has been invalidated\n        enable_logging();\n\n        let state = CoordinatorStateBuilder::new().with_round_id(1).build();\n        let model_length = state.round_params.model_length;\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_best_masks()\n            .returning(move || Ok(Some(vec![(create_mask(model_length, 1), 1)])));\n        #[cfg(feature = \"model-persistence\")]\n        {\n            cs.expect_set_latest_global_model_id()\n                .returning(move |_| Ok(()));\n        }\n        let ms = {\n            #[cfg(not(feature = \"model-persistence\"))]\n            {\n                MockModelStore::new()\n            }\n            #[cfg(feature = \"model-persistence\")]\n            {\n                let mut ms = MockModelStore::new();\n                ms.expect_set_global_model()\n                    .returning(move |_, _, _| Ok(\"id\".to_string()));\n                ms\n            }\n        };\n        let mut ta = MockTrustAnchor::new();\n        ta.expect_publish_proof()\n            .returning(move |_| Err(anyhow!(\"\")));\n\n        let store = Store::new_with_trust_anchor(cs, ms, ta);\n\n        let (event_publisher, event_subscriber) = events_from_sum2_phase(&state);\n        let events_before_sum2 = EventSnapshot::from(&event_subscriber);\n        let state_before_sum2 = state.clone();\n\n        let (shared, _request_tx) = init_shared(state, store, event_publisher);\n        let aggregator = init_aggregator(&state_before_sum2);\n        let state_machine = StateMachine::from(PhaseState::<Unmask, _>::new(shared, aggregator));\n        assert!(state_machine.is_unmask());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_sum2 = state_machine.as_ref().clone();\n        let events_after_sum2 = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_failure(\n            &state_before_sum2,\n            &events_before_sum2,\n            &state_after_sum2,\n            &events_after_sum2,\n        );\n\n        assert!(state_machine.is_failure());\n        assert!(matches!(\n            state_machine.into_failure_phase_state().private.error,\n            PhaseError::Unmask(UnmaskError::PublishProof(_))\n        ))\n    }\n\n    #[cfg(feature = \"model-persistence\")]\n    #[tokio::test]\n    async fn test_unmask_to_idle_phase_set_global_model_fails() {\n        // Storage:\n        // - set_global_model fails\n        //\n        // What should happen:\n        // 1. broadcast Unmask phase\n        // 2. fetch best masks\n        // 3. unmask the masked global model\n        // 4. save global model (fails)\n        // 5. move into error phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated/changed\n        // - the sum dict has been invalidated\n        // - the seed dict has been invalidated\n        enable_logging();\n\n        let state = CoordinatorStateBuilder::new().with_round_id(1).build();\n        let model_length = state.round_params.model_length;\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_best_masks()\n            .returning(move || Ok(Some(vec![(create_mask(model_length, 1), 1)])));\n        cs.expect_set_latest_global_model_id()\n            .returning(move |_| Ok(()));\n\n        let mut ms = MockModelStore::new();\n        ms.expect_set_global_model()\n            .returning(move |_, _, _| Err(anyhow!(\"\")));\n\n        let store = Store::new(cs, ms);\n\n        let (event_publisher, event_subscriber) = events_from_sum2_phase(&state);\n        let events_before_sum2 = EventSnapshot::from(&event_subscriber);\n        let state_before_sum2 = state.clone();\n\n        let (shared, _request_tx) = init_shared(state, store, event_publisher);\n        let aggregator = init_aggregator(&state_before_sum2);\n        let state_machine = StateMachine::from(PhaseState::<Unmask, _>::new(shared, aggregator));\n        assert!(state_machine.is_unmask());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_sum2 = state_machine.as_ref().clone();\n        let events_after_sum2 = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_failure(\n            &state_before_sum2,\n            &events_before_sum2,\n            &state_after_sum2,\n            &events_after_sum2,\n        );\n\n        assert!(state_machine.is_failure());\n        assert!(matches!(\n            state_machine.into_failure_phase_state().private.error,\n            PhaseError::Unmask(UnmaskError::SaveGlobalModel(_))\n        ))\n    }\n\n    #[cfg(feature = \"model-persistence\")]\n    #[tokio::test]\n    async fn test_unmask_to_idle_phase_set_global_model_id_fails() {\n        // Storage:\n        // - set_latest_global_model_id fails\n        //\n        // What should happen:\n        // 1. broadcast Unmask phase\n        // 2. fetch best masks\n        // 3. unmask the masked global model\n        // 4. save global model and model id (fails)\n        // 5. publish proof\n        // 6. broadcast unmasked global model\n        // 7. move into idle phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - events have been broadcasted (except phase event and global model)\n        enable_logging();\n\n        let state = CoordinatorStateBuilder::new().with_round_id(1).build();\n        let model_length = state.round_params.model_length;\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_best_masks()\n            .returning(move || Ok(Some(vec![(create_mask(model_length, 1), 1)])));\n        cs.expect_set_latest_global_model_id()\n            .returning(move |_| Err(anyhow!(\"\")));\n\n        let mut ms = MockModelStore::new();\n        ms.expect_set_global_model()\n            .returning(move |_, _, _| Ok(\"id\".to_string()));\n\n        let store = Store::new(cs, ms);\n\n        let (event_publisher, event_subscriber) = events_from_sum2_phase(&state);\n        let events_before_sum2 = EventSnapshot::from(&event_subscriber);\n        let state_before_sum2 = state.clone();\n\n        let (shared, _request_tx) = init_shared(state, store, event_publisher);\n        let aggregator = init_aggregator(&state_before_sum2);\n        let state_machine = StateMachine::from(PhaseState::<Unmask, _>::new(shared, aggregator));\n        assert!(state_machine.is_unmask());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_sum2 = state_machine.as_ref().clone();\n        let events_after_sum2 = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_success(\n            &state_before_sum2,\n            &events_before_sum2,\n            &state_after_sum2,\n            &events_after_sum2,\n        );\n\n        assert!(state_machine.is_idle());\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/phases/update.rs",
    "content": "use std::sync::Arc;\n\nuse async_trait::async_trait;\nuse displaydoc::Display;\nuse thiserror::Error;\nuse tracing::{debug, info, warn};\n\nuse crate::{\n    state_machine::{\n        events::DictionaryUpdate,\n        phases::{Handler, Phase, PhaseError, PhaseName, PhaseState, Shared, Sum2},\n        requests::{RequestError, StateMachineRequest, UpdateRequest},\n        StateMachine,\n    },\n    storage::{Storage, StorageError},\n};\nuse xaynet_core::{\n    mask::{Aggregation, MaskObject},\n    LocalSeedDict,\n    SeedDict,\n    UpdateParticipantPublicKey,\n};\n\n/// Errors which can occur during the update phase.\n#[derive(Debug, Display, Error)]\npub enum UpdateError {\n    /// Seed dictionary does not exists.\n    NoSeedDict,\n    /// Fetching seed dictionary failed: {0}.\n    FetchSeedDict(StorageError),\n}\n\n/// The update state.\n#[derive(Debug)]\npub struct Update {\n    /// The aggregator for masked models.\n    model_agg: Aggregation,\n    /// The seed dictionary which gets assembled during the update phase.\n    seed_dict: Option<SeedDict>,\n}\n\n#[async_trait]\nimpl<T> Phase<T> for PhaseState<Update, T>\nwhere\n    T: Storage,\n    Self: Handler,\n{\n    const NAME: PhaseName = PhaseName::Update;\n\n    async fn process(&mut self) -> Result<(), PhaseError> {\n        self.process(self.shared.state.update).await?;\n        self.seed_dict().await?;\n\n        Ok(())\n    }\n\n    fn broadcast(&mut self) {\n        info!(\"broadcasting the global seed dictionary\");\n        let seed_dict = self\n            .private\n            .seed_dict\n            .take()\n            .expect(\"unreachable: never fails when `broadcast()` is called after `process()`\");\n        self.shared\n            .events\n            .broadcast_seed_dict(DictionaryUpdate::New(Arc::new(seed_dict)));\n    }\n\n    async fn next(self) -> Option<StateMachine<T>> {\n        Some(PhaseState::<Sum2, _>::new(self.shared, self.private.model_agg).into())\n    }\n}\n\n#[async_trait]\nimpl<T> Handler for PhaseState<Update, T>\nwhere\n    T: Storage,\n{\n    async fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), RequestError> {\n        if let StateMachineRequest::Update(UpdateRequest {\n            participant_pk,\n            local_seed_dict,\n            masked_model,\n        }) = req\n        {\n            self.update_seed_dict_and_aggregate_mask(\n                &participant_pk,\n                &local_seed_dict,\n                masked_model,\n            )\n            .await\n        } else {\n            Err(RequestError::MessageRejected)\n        }\n    }\n}\n\nimpl<T> PhaseState<Update, T> {\n    /// Creates a new update state.\n    pub fn new(shared: Shared<T>) -> Self {\n        let model_agg = Aggregation::new(\n            shared.state.round_params.mask_config,\n            shared.state.round_params.model_length,\n        );\n        Self {\n            private: Update {\n                model_agg,\n                seed_dict: None,\n            },\n            shared,\n        }\n    }\n}\n\nimpl<T> PhaseState<Update, T>\nwhere\n    T: Storage,\n{\n    /// Updates the local seed dict and aggregates the masked model.\n    async fn update_seed_dict_and_aggregate_mask(\n        &mut self,\n        pk: &UpdateParticipantPublicKey,\n        local_seed_dict: &LocalSeedDict,\n        mask_object: MaskObject,\n    ) -> Result<(), RequestError> {\n        // Check if aggregation can be performed. It is important to\n        // do that _before_ updating the seed dictionary, because we\n        // don't want to add the local seed dict if the corresponding\n        // masked model is invalid\n        debug!(\"checking whether the masked model can be aggregated\");\n        self.private\n            .model_agg\n            .validate_aggregation(&mask_object)\n            .map_err(|e| {\n                warn!(\"model aggregation error: {}\", e);\n                RequestError::AggregationFailed\n            })?;\n\n        // Try to update local seed dict first. If this fail, we do\n        // not want to aggregate the model.\n        info!(\"updating the global seed dictionary\");\n        self.add_local_seed_dict(pk, local_seed_dict)\n            .await\n            .map_err(|err| {\n                warn!(\"invalid local seed dictionary, ignoring update message\");\n                err\n            })?;\n\n        info!(\"aggregating the masked model and scalar\");\n        self.private.model_agg.aggregate(mask_object);\n        Ok(())\n    }\n\n    /// Adds a local seed dictionary to the global seed dictionary.\n    ///\n    /// # Error\n    ///\n    /// Fails if the local seed dict cannot be added due to a PET or [`StorageError`].\n    async fn add_local_seed_dict(\n        &mut self,\n        pk: &UpdateParticipantPublicKey,\n        local_seed_dict: &LocalSeedDict,\n    ) -> Result<(), RequestError> {\n        self.shared\n            .store\n            .add_local_seed_dict(pk, local_seed_dict)\n            .await?\n            .into_inner()\n            .map_err(RequestError::from)\n    }\n\n    /// Gets the global seed dict from the store.\n    async fn seed_dict(&mut self) -> Result<(), UpdateError> {\n        self.private.seed_dict = self\n            .shared\n            .store\n            .seed_dict()\n            .await\n            .map_err(UpdateError::FetchSeedDict)?\n            .ok_or(UpdateError::NoSeedDict)?\n            .into();\n\n        Ok(())\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    use anyhow::anyhow;\n    use xaynet_core::{SeedDict, SumDict};\n\n    use crate::{\n        state_machine::{\n            coordinator::CoordinatorState,\n            events::{EventPublisher, EventSubscriber, ModelUpdate},\n            tests::{\n                utils::{\n                    assert_event_updated,\n                    enable_logging,\n                    init_shared,\n                    send_update_messages,\n                    send_update_messages_with_model,\n                    EventSnapshot,\n                },\n                CoordinatorStateBuilder,\n                EventBusBuilder,\n            },\n        },\n        storage::{\n            tests::{\n                utils::{create_global_model, create_mask},\n                MockCoordinatorStore,\n                MockModelStore,\n            },\n            LocalSeedDictAdd,\n            LocalSeedDictAddError,\n            Store,\n        },\n    };\n\n    fn events_from_sum_phase(state: &CoordinatorState) -> (EventPublisher, EventSubscriber) {\n        EventBusBuilder::new(state)\n            .broadcast_phase(PhaseName::Sum)\n            .broadcast_sum_dict(DictionaryUpdate::New(Arc::new(SumDict::new())))\n            .broadcast_seed_dict(DictionaryUpdate::Invalidate)\n            .broadcast_model(ModelUpdate::New(Arc::new(create_global_model(1))))\n            .build()\n    }\n\n    fn assert_after_phase_success(\n        state_before: &CoordinatorState,\n        events_before: &EventSnapshot,\n        state_after: &CoordinatorState,\n        events_after: &EventSnapshot,\n    ) {\n        assert_eq!(state_after, state_before);\n\n        assert_event_updated(&events_after.phase, &events_before.phase);\n        assert_event_updated(&events_after.seed_dict, &events_before.seed_dict);\n        assert_eq!(events_after.keys, events_before.keys);\n        assert_eq!(events_after.params, events_before.params);\n        assert_eq!(events_after.phase.event, PhaseName::Update);\n        assert_eq!(events_after.sum_dict, events_before.sum_dict);\n        assert_eq!(events_after.model, events_before.model);\n    }\n\n    fn assert_after_phase_failure(\n        state_before: &CoordinatorState,\n        events_before: &EventSnapshot,\n        state_after: &CoordinatorState,\n        events_after: &EventSnapshot,\n    ) {\n        assert_eq!(state_after, state_before);\n\n        assert_event_updated(&events_after.phase, &events_before.phase);\n        assert_eq!(events_after.keys, events_before.keys);\n        assert_eq!(events_after.params, events_before.params);\n        assert_eq!(events_after.phase.event, PhaseName::Update);\n        assert_eq!(events_after.sum_dict, events_before.sum_dict);\n        assert_eq!(events_after.seed_dict, events_before.seed_dict);\n        assert_eq!(events_after.model, events_before.model);\n    }\n\n    #[tokio::test]\n    async fn test_update_to_sum2_phase() {\n        // No Storage errors\n        // lets pretend we come from the sum phase\n        //\n        // What should happen:\n        // 1. broadcast Update phase\n        // 2. accept 10 update messages\n        // 3. fetch seed dict\n        // 4. broadcast seed dict\n        // 5. move into sum2 phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated\n        // - the sum dict has been invalidated\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_add_local_seed_dict()\n            .times(10)\n            .returning(move |_, _| Ok(LocalSeedDictAdd(Ok(()))));\n        cs.expect_seed_dict()\n            .return_once(move || Ok(Some(SeedDict::new())));\n        let store = Store::new(cs, MockModelStore::new());\n        let state = CoordinatorStateBuilder::new()\n            .with_round_id(1)\n            .with_update_count_min(10)\n            .with_update_count_max(10)\n            .with_update_time_min(1)\n            .build();\n\n        let (event_publisher, event_subscriber) = events_from_sum_phase(&state);\n        let events_before_update = EventSnapshot::from(&event_subscriber);\n        let state_before_update = state.clone();\n\n        let (shared, request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Update, _>::new(shared));\n        assert!(state_machine.is_update());\n\n        send_update_messages(10, request_tx.clone());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_update = state_machine.as_ref().clone();\n        let events_after_update = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_success(\n            &state_before_update,\n            &events_before_update,\n            &state_after_update,\n            &events_after_update,\n        );\n\n        assert!(state_machine.is_sum2());\n    }\n\n    #[tokio::test]\n    async fn test_update_to_sum2_fetch_seed_dict_failed() {\n        // Storage errors\n        // - seed_dict fails\n        //\n        // What should happen:\n        // 1. broadcast Update phase\n        // 2. accept 1 update message\n        // 3. fetch seed dict (fails)\n        // 4. move into error phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated\n        // - the sum dict has been invalidated\n        // - the seed dict has been broadcasted\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_add_local_seed_dict()\n            .times(1)\n            .returning(move |_, _| Ok(LocalSeedDictAdd(Ok(()))));\n        cs.expect_seed_dict().return_once(move || Err(anyhow!(\"\")));\n        let store = Store::new(cs, MockModelStore::new());\n        let state = CoordinatorStateBuilder::new()\n            .with_round_id(1)\n            .with_update_count_min(1)\n            .with_update_count_max(1)\n            .with_update_time_min(1)\n            .with_update_time_max(5)\n            .build();\n\n        let (event_publisher, event_subscriber) = events_from_sum_phase(&state);\n        let events_before_update = EventSnapshot::from(&event_subscriber);\n        let state_before_update = state.clone();\n\n        let (shared, request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Update, _>::new(shared));\n        assert!(state_machine.is_update());\n\n        send_update_messages(1, request_tx.clone());\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_update = state_machine.as_ref().clone();\n        let events_after_update = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_failure(\n            &state_before_update,\n            &events_before_update,\n            &state_after_update,\n            &events_after_update,\n        );\n\n        assert!(state_machine.is_failure());\n        assert!(matches!(\n            state_machine.into_failure_phase_state().private.error,\n            PhaseError::Update(UpdateError::FetchSeedDict(_))\n        ))\n    }\n\n    #[tokio::test]\n    async fn test_update_to_sum2_seed_dict_none() {\n        // No Storage errors\n        //\n        // What should happen:\n        // 1. broadcast Update phase\n        // 2. accept 1 update message\n        // 3. fetch seed dict (no storage error but the seed dict is None)\n        // 4. move into error phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated\n        // - the sum dict has been invalidated\n        // - the seed dict has been broadcasted\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_add_local_seed_dict()\n            .times(1)\n            .returning(move |_, _| Ok(LocalSeedDictAdd(Ok(()))));\n        cs.expect_seed_dict().return_once(move || Ok(None));\n        let store = Store::new(cs, MockModelStore::new());\n        let state = CoordinatorStateBuilder::new()\n            .with_round_id(1)\n            .with_update_count_min(1)\n            .with_update_count_max(1)\n            .with_update_time_min(1)\n            .with_update_time_max(5)\n            .build();\n\n        let (event_publisher, event_subscriber) = events_from_sum_phase(&state);\n        let events_before_update = EventSnapshot::from(&event_subscriber);\n        let state_before_update = state.clone();\n\n        let (shared, request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Update, _>::new(shared));\n        assert!(state_machine.is_update());\n\n        send_update_messages(1, request_tx.clone());\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_update = state_machine.as_ref().clone();\n        let events_after_update = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_failure(\n            &state_before_update,\n            &events_before_update,\n            &state_after_update,\n            &events_after_update,\n        );\n\n        assert!(state_machine.is_failure());\n        assert!(matches!(\n            state_machine.into_failure_phase_state().private.error,\n            PhaseError::Update(UpdateError::NoSeedDict)\n        ))\n    }\n\n    #[tokio::test]\n    async fn test_aggregation_error() {\n        // No Storage errors\n        //\n        // What should happen:\n        // 1. broadcast Update phase\n        // 2. reject 3 update messages (validation of the models fail due to an invalid length)\n        // 3. accept 3 update messages\n        // 4. fetch seed dict\n        // 5. broadcast seed dict\n        // 6. move into sum2 phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated\n        // - the sum dict has been invalidated\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_add_local_seed_dict()\n            .times(3)\n            .returning(move |_, _| Ok(LocalSeedDictAdd(Ok(()))));\n        cs.expect_seed_dict()\n            .return_once(move || Ok(Some(SeedDict::new())));\n        let store = Store::new(cs, MockModelStore::new());\n        let state = CoordinatorStateBuilder::new()\n            .with_round_id(1)\n            .with_update_count_min(3)\n            .with_update_count_max(3)\n            .with_update_time_min(1)\n            .build();\n\n        let (event_publisher, event_subscriber) = events_from_sum_phase(&state);\n        let events_before_update = EventSnapshot::from(&event_subscriber);\n        let state_before_update = state.clone();\n\n        let (shared, request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Update, _>::new(shared));\n        assert!(state_machine.is_update());\n\n        send_update_messages_with_model(3, request_tx.clone(), create_mask(2, 1));\n        send_update_messages(3, request_tx.clone());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_update = state_machine.as_ref().clone();\n        let events_after_update = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_success(\n            &state_before_update,\n            &events_before_update,\n            &state_after_update,\n            &events_after_update,\n        );\n\n        assert!(state_machine.is_sum2());\n    }\n\n    #[tokio::test]\n    async fn test_rejected_messages_pet_error() {\n        // No Storage errors\n        //\n        // What should happen:\n        // 1. broadcast Update phase\n        // 2. reject 3 update messages (pet error LocalSeedDictAddError::LengthMisMatch)\n        // 3. phase should timeout\n        // 4. move into error phase\n        //\n        // What should not happen:\n        // - the shared state has been changed\n        // - the global model has been invalidated\n        // - the sum dict has been invalidated\n        // - the seed dict has been broadcasted\n        enable_logging();\n\n        let mut cs = MockCoordinatorStore::new();\n        cs.expect_add_local_seed_dict()\n            .times(3)\n            .returning(move |_, _| {\n                Ok(LocalSeedDictAdd(Err(LocalSeedDictAddError::LengthMisMatch)))\n            });\n        let store = Store::new(cs, MockModelStore::new());\n        let state = CoordinatorStateBuilder::new()\n            .with_round_id(1)\n            .with_update_count_min(3)\n            .with_update_count_max(3)\n            .with_update_time_min(0)\n            .with_update_time_max(2)\n            .build();\n\n        let (event_publisher, event_subscriber) = events_from_sum_phase(&state);\n        let events_before_update = EventSnapshot::from(&event_subscriber);\n        let state_before_update = state.clone();\n\n        let (shared, request_tx) = init_shared(state, store, event_publisher);\n        let state_machine = StateMachine::from(PhaseState::<Update, _>::new(shared));\n        assert!(state_machine.is_update());\n\n        send_update_messages(3, request_tx.clone());\n\n        let state_machine = state_machine.next().await.unwrap();\n\n        let state_after_update = state_machine.as_ref().clone();\n        let events_after_update = EventSnapshot::from(&event_subscriber);\n        assert_after_phase_failure(\n            &state_before_update,\n            &events_before_update,\n            &state_after_update,\n            &events_after_update,\n        );\n\n        assert!(state_machine.is_failure());\n        assert!(matches!(\n            state_machine.into_failure_phase_state().private.error,\n            PhaseError::PhaseTimeout(_)\n        ))\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/requests.rs",
    "content": "//! This module provides the the `StateMachine`, `Request`, `RequestSender` and `RequestReceiver`\n//! types.\n\nuse std::{\n    pin::Pin,\n    task::{Context, Poll},\n};\n\nuse derive_more::From;\nuse displaydoc::Display;\nuse futures::{future::FutureExt, Stream};\nuse thiserror::Error;\nuse tokio::sync::{mpsc, oneshot};\nuse tracing::{trace, Span};\n\nuse crate::storage::{LocalSeedDictAddError, MaskScoreIncrError, StorageError, SumPartAddError};\nuse xaynet_core::{\n    mask::MaskObject,\n    message::{Message, Payload, Update},\n    LocalSeedDict,\n    ParticipantPublicKey,\n    SumParticipantEphemeralPublicKey,\n    SumParticipantPublicKey,\n    UpdateParticipantPublicKey,\n};\n\n/// Errors which can occur while the state machine handles a request.\n#[derive(Debug, Display, Error)]\npub enum RequestError {\n    /// The message was rejected.\n    MessageRejected,\n    /// The message was discarded.\n    MessageDiscarded,\n    /// Invalid update: the model or scalar sent by the participant could not be aggregated.\n    AggregationFailed,\n    /// The request could not be processed due to an internal error: {0}.\n    InternalError(&'static str),\n    /// Storage request failed: {0}.\n    CoordinatorStorage(#[from] StorageError),\n    /// Adding a local seed dict to the seed dictionary failed: {0}.\n    LocalSeedDictAdd(#[from] LocalSeedDictAddError),\n    /// Adding a sum participant to the sum dictionary failed: {0}.\n    SumPartAdd(#[from] SumPartAddError),\n    /// Incrementing a mask score failed: {0}.\n    MaskScoreIncr(#[from] MaskScoreIncrError),\n}\n\n/// A sum request.\n#[derive(Debug)]\npub struct SumRequest {\n    /// The public key of the participant.\n    pub participant_pk: SumParticipantPublicKey,\n    /// The ephemeral public key of the participant.\n    pub ephm_pk: SumParticipantEphemeralPublicKey,\n}\n\n/// An update request.\n#[derive(Debug)]\npub struct UpdateRequest {\n    /// The public key of the participant.\n    pub participant_pk: UpdateParticipantPublicKey,\n    /// The local seed dict that contains the seed used to mask `masked_model`.\n    pub local_seed_dict: LocalSeedDict,\n    /// The masked model trained by the participant.\n    pub masked_model: MaskObject,\n}\n\n/// A sum2 request.\n#[derive(Debug)]\npub struct Sum2Request {\n    /// The public key of the participant.\n    pub participant_pk: ParticipantPublicKey,\n    /// The model mask computed by the participant.\n    pub model_mask: MaskObject,\n}\n\n/// A [`StateMachine`] request.\n///\n/// [`StateMachine`]: crate::state_machine\n#[derive(Debug, From)]\npub enum StateMachineRequest {\n    Sum(SumRequest),\n    Update(UpdateRequest),\n    Sum2(Sum2Request),\n}\n\nimpl From<Message> for StateMachineRequest {\n    fn from(message: Message) -> Self {\n        let participant_pk = message.participant_pk;\n        match message.payload {\n            Payload::Sum(sum) => StateMachineRequest::Sum(SumRequest {\n                participant_pk,\n                ephm_pk: sum.ephm_pk,\n            }),\n            Payload::Update(update) => {\n                let Update {\n                    local_seed_dict,\n                    masked_model,\n                    ..\n                } = update;\n                StateMachineRequest::Update(UpdateRequest {\n                    participant_pk,\n                    local_seed_dict,\n                    masked_model,\n                })\n            }\n            Payload::Sum2(sum2) => StateMachineRequest::Sum2(Sum2Request {\n                participant_pk,\n                model_mask: sum2.model_mask,\n            }),\n            Payload::Chunk(_) => unimplemented!(),\n        }\n    }\n}\n\n/// A handle to send requests to the [`StateMachine`].\n///\n/// [`StateMachine`]: crate::state_machine\n#[derive(Clone, From, Debug)]\npub struct RequestSender(mpsc::UnboundedSender<(StateMachineRequest, Span, ResponseSender)>);\n\nimpl RequestSender {\n    /// Sends a request to the [`StateMachine`].\n    ///\n    /// # Errors\n    /// Fails if the [`StateMachine`] has already shut down and the `Request` channel has been\n    /// closed as a result.\n    ///\n    /// [`StateMachine`]: crate::state_machine\n    pub async fn request(&self, req: StateMachineRequest, span: Span) -> Result<(), RequestError> {\n        let (resp_tx, resp_rx) = oneshot::channel::<Result<(), RequestError>>();\n        self.0.send((req, span, resp_tx)).map_err(|_| {\n            RequestError::InternalError(\n                \"failed to send request to the state machine: state machine is shutting down\",\n            )\n        })?;\n        resp_rx.await.map_err(|_| {\n            RequestError::InternalError(\"failed to receive response from the state machine\")\n        })?\n    }\n\n    #[cfg(test)]\n    pub fn is_closed(&self) -> bool {\n        self.0.is_closed()\n    }\n}\n\n/// A channel for sending the state machine to send the response to a\n/// [`StateMachineRequest`].\npub(in crate::state_machine) type ResponseSender = oneshot::Sender<Result<(), RequestError>>;\n\n/// The receiver half of the `Request` channel that is used by the [`StateMachine`] to receive\n/// requests.\n///\n/// [`StateMachine`]: crate::state_machine\n#[derive(From, Debug)]\npub struct RequestReceiver(mpsc::UnboundedReceiver<(StateMachineRequest, Span, ResponseSender)>);\n\nimpl Stream for RequestReceiver {\n    type Item = (StateMachineRequest, Span, ResponseSender);\n\n    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {\n        trace!(\"RequestReceiver: polling\");\n        Pin::new(&mut self.get_mut().0).poll_recv(cx)\n    }\n}\n\nimpl RequestReceiver {\n    /// Creates a new `Request` channel and returns the [`RequestReceiver`] as well as the\n    /// [`RequestSender`] half.\n    pub fn new() -> (Self, RequestSender) {\n        let (tx, rx) = mpsc::unbounded_channel::<(StateMachineRequest, Span, ResponseSender)>();\n        let receiver = RequestReceiver::from(rx);\n        let handle = RequestSender::from(tx);\n        (receiver, handle)\n    }\n\n    /// Closes the `Request` channel.\n    /// See [the `tokio` documentation][close] for more information.\n    ///\n    /// [close]: https://docs.rs/tokio/1.1.0/tokio/sync/mpsc/struct.UnboundedReceiver.html#method.close\n    pub fn close(&mut self) {\n        self.0.close()\n    }\n\n    /// Receives the next request.\n    /// See [the `tokio` documentation][receive] for more information.\n    ///\n    /// [receive]: https://docs.rs/tokio/1.1.0/tokio/sync/mpsc/struct.UnboundedReceiver.html#method.recv\n    pub async fn recv(&mut self) -> Option<(StateMachineRequest, Span, ResponseSender)> {\n        self.0.recv().await\n    }\n\n    /// Try to retrieve the next request without blocking\n    pub fn try_recv(&mut self) -> Option<Option<(StateMachineRequest, Span, ResponseSender)>> {\n        // Note `try_recv` (tokio 0.2.x) or `recv().now_or_never()` (tokio 1.x)\n        // has an implementation bug where previously sent messages may not be\n        // available immediately.\n        // Related issue: https://github.com/tokio-rs/tokio/issues/3350\n        // At the moment it behaves like `try_recv`, but we should check if this\n        // bug is a problem for us. But first we should replace the unbounded channel canal with\n        // a bounded channel (XN-1162)\n        self.0.recv().now_or_never()\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/tests/coordinator_state.rs",
    "content": "use xaynet_core::{common::RoundSeed, crypto::EncryptKeyPair, mask::MaskConfig};\n\nuse crate::state_machine::coordinator::CoordinatorState;\n\nuse super::utils::{mask_settings, model_settings, pet_settings};\n\npub struct CoordinatorStateBuilder {\n    state: CoordinatorState,\n}\n\n#[allow(dead_code)]\nimpl CoordinatorStateBuilder {\n    pub fn new() -> Self {\n        Self {\n            state: CoordinatorState::new(pet_settings(), mask_settings(), model_settings()),\n        }\n    }\n\n    pub fn build(self) -> CoordinatorState {\n        self.state\n    }\n\n    pub fn with_keys(mut self, keys: EncryptKeyPair) -> Self {\n        self.state.round_params.pk = keys.public;\n        self.state.keys = keys;\n        self\n    }\n\n    pub fn with_round_id(mut self, id: u64) -> Self {\n        self.state.round_id = id;\n        self\n    }\n\n    pub fn with_sum_probability(mut self, prob: f64) -> Self {\n        self.state.round_params.sum = prob;\n        self\n    }\n\n    pub fn with_update_probability(mut self, prob: f64) -> Self {\n        self.state.round_params.update = prob;\n        self\n    }\n\n    pub fn with_seed(mut self, seed: RoundSeed) -> Self {\n        self.state.round_params.seed = seed;\n        self\n    }\n\n    pub fn with_sum_count_min(mut self, min: u64) -> Self {\n        self.state.sum.count.min = min;\n        self\n    }\n\n    pub fn with_sum_count_max(mut self, max: u64) -> Self {\n        self.state.sum.count.max = max;\n        self\n    }\n\n    pub fn with_mask_config(mut self, mask_config: MaskConfig) -> Self {\n        self.state.round_params.mask_config = mask_config.into();\n        self\n    }\n\n    pub fn with_update_count_min(mut self, min: u64) -> Self {\n        self.state.update.count.min = min;\n        self\n    }\n\n    pub fn with_update_count_max(mut self, max: u64) -> Self {\n        self.state.update.count.max = max;\n        self\n    }\n\n    pub fn with_sum2_count_min(mut self, min: u64) -> Self {\n        self.state.sum2.count.min = min;\n        self\n    }\n\n    pub fn with_sum2_count_max(mut self, max: u64) -> Self {\n        self.state.sum2.count.max = max;\n        self\n    }\n\n    pub fn with_model_length(mut self, model_length: usize) -> Self {\n        self.state.round_params.model_length = model_length;\n        self\n    }\n\n    pub fn with_sum_time_min(mut self, min: u64) -> Self {\n        self.state.sum.time.min = min;\n        self\n    }\n\n    pub fn with_sum_time_max(mut self, max: u64) -> Self {\n        self.state.sum.time.max = max;\n        self\n    }\n\n    pub fn with_update_time_min(mut self, min: u64) -> Self {\n        self.state.update.time.min = min;\n        self\n    }\n\n    pub fn with_update_time_max(mut self, max: u64) -> Self {\n        self.state.update.time.max = max;\n        self\n    }\n\n    pub fn with_sum2_time_min(mut self, min: u64) -> Self {\n        self.state.sum2.time.min = min;\n        self\n    }\n\n    pub fn with_sum2_time_max(mut self, max: u64) -> Self {\n        self.state.sum2.time.max = max;\n        self\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/tests/event_bus.rs",
    "content": "use xaynet_core::{SeedDict, SumDict};\n\nuse crate::state_machine::{\n    coordinator::CoordinatorState,\n    events::{DictionaryUpdate, EventPublisher, EventSubscriber, ModelUpdate},\n    phases::PhaseName,\n};\n\nuse super::{utils::EventSnapshot, CoordinatorStateBuilder, WARNING};\n\npub struct EventBusBuilder {\n    event_publisher: EventPublisher,\n    event_subscriber: EventSubscriber,\n}\n\nimpl EventBusBuilder {\n    pub fn new(state: &CoordinatorState) -> Self {\n        let (event_publisher, event_subscriber) = EventPublisher::init(\n            state.round_id,\n            state.keys.clone(),\n            state.round_params.clone(),\n            PhaseName::Idle,\n            ModelUpdate::Invalidate,\n        );\n\n        Self {\n            event_publisher,\n            event_subscriber,\n        }\n    }\n\n    pub fn broadcast_phase(mut self, phase: PhaseName) -> Self {\n        self.event_publisher.broadcast_phase(phase);\n        self\n    }\n\n    pub fn broadcast_model(mut self, update: ModelUpdate) -> Self {\n        self.event_publisher.broadcast_model(update);\n        self\n    }\n\n    pub fn broadcast_sum_dict(mut self, update: DictionaryUpdate<SumDict>) -> Self {\n        self.event_publisher.broadcast_sum_dict(update);\n        self\n    }\n\n    pub fn broadcast_seed_dict(mut self, update: DictionaryUpdate<SeedDict>) -> Self {\n        self.event_publisher.broadcast_seed_dict(update);\n        self\n    }\n\n    pub fn build(self) -> (EventPublisher, EventSubscriber) {\n        (self.event_publisher, self.event_subscriber)\n    }\n}\n\n#[test]\nfn test_initial_events() {\n    const PANIC_MESSAGE: &str = \"the initial events have been changed.\";\n\n    let state = CoordinatorStateBuilder::new().build();\n    let (_, subscriber) = EventBusBuilder::new(&state).build();\n    let events = EventSnapshot::from(&subscriber);\n\n    assert_eq!(\n        events.phase.event,\n        PhaseName::Idle,\n        \"{} {}\",\n        PANIC_MESSAGE,\n        WARNING\n    );\n    assert_eq!(\n        events.model.event,\n        ModelUpdate::Invalidate,\n        \"{} {}\",\n        PANIC_MESSAGE,\n        WARNING\n    );\n    assert_eq!(\n        events.sum_dict.event,\n        DictionaryUpdate::Invalidate,\n        \"{} {}\",\n        PANIC_MESSAGE,\n        WARNING\n    );\n    assert_eq!(\n        events.seed_dict.event,\n        DictionaryUpdate::Invalidate,\n        \"{} {}\",\n        PANIC_MESSAGE,\n        WARNING\n    );\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/tests/impls.rs",
    "content": "use tracing::Span;\nuse xaynet_core::message::Message;\n\nuse crate::state_machine::{\n    coordinator::CoordinatorState,\n    events::DictionaryUpdate,\n    phases::{Failure, Idle, PhaseState, Shutdown, Sum, Sum2, Unmask, Update},\n    requests::{RequestError, RequestSender},\n    StateMachine,\n};\n\nimpl RequestSender {\n    pub async fn msg(&self, msg: &Message) -> Result<(), RequestError> {\n        self.request(msg.clone().into(), Span::none()).await\n    }\n}\n\nimpl<T> StateMachine<T> {\n    pub fn is_idle(&self) -> bool {\n        matches!(self, StateMachine::Idle(_))\n    }\n\n    pub fn into_idle_phase_state(self) -> PhaseState<Idle, T> {\n        match self {\n            StateMachine::Idle(state) => state,\n            _ => panic!(\"not in idle state\"),\n        }\n    }\n\n    pub fn is_sum(&self) -> bool {\n        matches!(self, StateMachine::Sum(_))\n    }\n\n    pub fn into_sum_phase_state(self) -> PhaseState<Sum, T> {\n        match self {\n            StateMachine::Sum(state) => state,\n            _ => panic!(\"not in sum state\"),\n        }\n    }\n\n    pub fn is_update(&self) -> bool {\n        matches!(self, StateMachine::Update(_))\n    }\n\n    pub fn into_update_phase_state(self) -> PhaseState<Update, T> {\n        match self {\n            StateMachine::Update(state) => state,\n            _ => panic!(\"not in update state\"),\n        }\n    }\n\n    pub fn is_sum2(&self) -> bool {\n        matches!(self, StateMachine::Sum2(_))\n    }\n\n    pub fn into_sum2_phase_state(self) -> PhaseState<Sum2, T> {\n        match self {\n            StateMachine::Sum2(state) => state,\n            _ => panic!(\"not in sum2 state\"),\n        }\n    }\n\n    pub fn is_unmask(&self) -> bool {\n        matches!(self, StateMachine::Unmask(_))\n    }\n\n    pub fn into_unmask_phase_state(self) -> PhaseState<Unmask, T> {\n        match self {\n            StateMachine::Unmask(state) => state,\n            _ => panic!(\"not in unmask state\"),\n        }\n    }\n\n    pub fn is_failure(&self) -> bool {\n        matches!(self, StateMachine::Failure(_))\n    }\n\n    pub fn into_failure_phase_state(self) -> PhaseState<Failure, T> {\n        match self {\n            StateMachine::Failure(state) => state,\n            _ => panic!(\"not in error state\"),\n        }\n    }\n\n    pub fn is_shutdown(&self) -> bool {\n        matches!(self, StateMachine::Shutdown(_))\n    }\n\n    pub fn into_shutdown_phase_state(self) -> PhaseState<Shutdown, T> {\n        match self {\n            StateMachine::Shutdown(state) => state,\n            _ => panic!(\"not in shutdown state\"),\n        }\n    }\n}\n\nimpl<T> AsRef<CoordinatorState> for StateMachine<T> {\n    fn as_ref(&self) -> &CoordinatorState {\n        match self {\n            StateMachine::Idle(state) => &state.shared.state,\n            StateMachine::Sum(state) => &state.shared.state,\n            StateMachine::Update(state) => &state.shared.state,\n            StateMachine::Sum2(state) => &state.shared.state,\n            StateMachine::Unmask(state) => &state.shared.state,\n            StateMachine::Failure(state) => &state.shared.state,\n            StateMachine::Shutdown(state) => &state.shared.state,\n        }\n    }\n}\n\nimpl<D> DictionaryUpdate<D> {\n    pub fn unwrap(self) -> std::sync::Arc<D> {\n        if let DictionaryUpdate::New(inner) = self {\n            inner\n        } else {\n            panic!(\"DictionaryUpdate::Invalidate\");\n        }\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/tests/initializer.rs",
    "content": "//! State machine initialization test utilities.\n\nuse serial_test::serial;\n\n#[cfg(feature = \"model-persistence\")]\nuse crate::{\n    settings::RestoreSettings,\n    state_machine::{\n        events::{DictionaryUpdate, ModelUpdate},\n        initializer::StateMachineInitializationError,\n        phases::PhaseName,\n    },\n    storage::tests::utils::create_global_model,\n    storage::ModelStorage,\n};\nuse crate::{\n    state_machine::{\n        coordinator::CoordinatorState,\n        initializer::StateMachineInitializer,\n        tests::utils::{mask_settings, model_settings, pet_settings},\n    },\n    storage::{tests::init_store, CoordinatorStorage},\n};\n\n#[cfg(feature = \"model-persistence\")]\n#[tokio::test]\n#[serial]\n#[ignore]\nasync fn integration_state_machine_initializer_no_restore() {\n    let store = init_store().await;\n    let smi = StateMachineInitializer::new(\n        pet_settings(),\n        mask_settings(),\n        model_settings(),\n        RestoreSettings { enable: false },\n        store,\n    );\n\n    let (state_machine, _request_sender, event_subscriber) = smi.init().await.unwrap();\n\n    assert!(state_machine.is_idle());\n\n    let phase = event_subscriber.phase_listener().get_latest().event;\n    assert!(matches!(phase, PhaseName::Idle));\n\n    let sum_dict = event_subscriber.sum_dict_listener().get_latest().event;\n    assert!(matches!(sum_dict, DictionaryUpdate::Invalidate));\n\n    let seed_dict = event_subscriber.seed_dict_listener().get_latest().event;\n    assert!(matches!(seed_dict, DictionaryUpdate::Invalidate));\n\n    let global_model = event_subscriber.model_listener().get_latest().event;\n    assert!(matches!(global_model, ModelUpdate::Invalidate));\n\n    let round_id = event_subscriber.params_listener().get_latest().round_id;\n    assert_eq!(round_id, 0);\n}\n\n#[cfg(feature = \"model-persistence\")]\n#[tokio::test]\n#[serial]\n#[ignore]\nasync fn integration_state_machine_initializer_no_state() {\n    let store = init_store().await;\n    let smi = StateMachineInitializer::new(\n        pet_settings(),\n        mask_settings(),\n        model_settings(),\n        RestoreSettings { enable: true },\n        store,\n    );\n\n    let (state_machine, _request_sender, event_subscriber) = smi.init().await.unwrap();\n\n    assert!(state_machine.is_idle());\n\n    let phase = event_subscriber.phase_listener().get_latest().event;\n    assert!(matches!(phase, PhaseName::Idle));\n\n    let sum_dict = event_subscriber.sum_dict_listener().get_latest().event;\n    assert!(matches!(sum_dict, DictionaryUpdate::Invalidate));\n\n    let seed_dict = event_subscriber.seed_dict_listener().get_latest().event;\n    assert!(matches!(seed_dict, DictionaryUpdate::Invalidate));\n\n    let global_model = event_subscriber.model_listener().get_latest().event;\n    assert!(matches!(global_model, ModelUpdate::Invalidate));\n\n    let round_id = event_subscriber.params_listener().get_latest().round_id;\n    assert_eq!(round_id, 0);\n}\n\n#[cfg(feature = \"model-persistence\")]\n#[tokio::test]\n#[serial]\n#[ignore]\nasync fn integration_state_machine_initializer_without_global_model() {\n    let pet_settings = pet_settings();\n    let mask_settings = mask_settings();\n    let model_settings = model_settings();\n\n    // we change the round id to ensure that the state machine is\n    // initialized with the coordinator state in the store\n    // if we don't update the round_id we can't check if the state in the store was used or if the state was reset\n    // because in both cases the round id will be 0\n    let mut store = init_store().await;\n    let mut state = CoordinatorState::new(pet_settings, mask_settings, model_settings.clone());\n    let new_round_id = 5;\n    state.round_id = new_round_id;\n    store.set_coordinator_state(&state).await.unwrap();\n\n    let smi = StateMachineInitializer::new(\n        pet_settings,\n        mask_settings,\n        model_settings,\n        RestoreSettings { enable: true },\n        store,\n    );\n\n    let (state_machine, _request_sender, event_subscriber) = smi.init().await.unwrap();\n\n    assert!(state_machine.is_idle());\n\n    let phase = event_subscriber.phase_listener().get_latest().event;\n    assert!(matches!(phase, PhaseName::Idle));\n\n    let sum_dict = event_subscriber.sum_dict_listener().get_latest().event;\n    assert!(matches!(sum_dict, DictionaryUpdate::Invalidate));\n\n    let seed_dict = event_subscriber.seed_dict_listener().get_latest().event;\n    assert!(matches!(seed_dict, DictionaryUpdate::Invalidate));\n\n    let global_model = event_subscriber.model_listener().get_latest().event;\n    assert!(matches!(global_model, ModelUpdate::Invalidate));\n\n    let round_id = event_subscriber.params_listener().get_latest().round_id;\n    assert_eq!(round_id, new_round_id);\n}\n\n#[cfg(feature = \"model-persistence\")]\n#[tokio::test]\n#[serial]\n#[ignore]\nasync fn integration_state_machine_initializer_with_global_model() {\n    let pet_settings = pet_settings();\n    let mask_settings = mask_settings();\n    let model_settings = model_settings();\n\n    let mut store = init_store().await;\n    let mut state = CoordinatorState::new(pet_settings, mask_settings, model_settings.clone());\n    let new_round_id = 7;\n    state.round_id = new_round_id;\n    store.set_coordinator_state(&state).await.unwrap();\n\n    // upload a global model and set the id\n    let uploaded_global_model = create_global_model(state.round_params.model_length);\n    let global_model_id = store\n        .set_global_model(\n            state.round_id,\n            &state.round_params.seed,\n            &uploaded_global_model,\n        )\n        .await\n        .unwrap();\n    store\n        .set_latest_global_model_id(&global_model_id)\n        .await\n        .unwrap();\n\n    let smi = StateMachineInitializer::new(\n        pet_settings,\n        mask_settings,\n        model_settings,\n        RestoreSettings { enable: true },\n        store,\n    );\n\n    let (state_machine, _request_sender, event_subscriber) = smi.init().await.unwrap();\n\n    assert!(state_machine.is_idle());\n\n    let phase = event_subscriber.phase_listener().get_latest().event;\n    assert!(matches!(phase, PhaseName::Idle));\n\n    let sum_dict = event_subscriber.sum_dict_listener().get_latest().event;\n    assert!(matches!(sum_dict, DictionaryUpdate::Invalidate));\n\n    let seed_dict = event_subscriber.seed_dict_listener().get_latest().event;\n    assert!(matches!(seed_dict, DictionaryUpdate::Invalidate));\n\n    let global_model = event_subscriber.model_listener().get_latest().event;\n    assert!(\n        matches!(global_model, ModelUpdate::New(broadcasted_model) if uploaded_global_model == *broadcasted_model)\n    );\n\n    let round_id = event_subscriber.params_listener().get_latest().round_id;\n    assert_eq!(round_id, new_round_id);\n}\n\n#[cfg(feature = \"model-persistence\")]\n#[tokio::test]\n#[serial]\n#[ignore]\nasync fn integration_state_machine_initializer_failed_because_of_wrong_size() {\n    let pet_settings = pet_settings();\n    let mask_settings = mask_settings();\n    let model_settings = model_settings();\n\n    let mut store = init_store().await;\n    let mut state = CoordinatorState::new(pet_settings, mask_settings, model_settings.clone());\n    let new_round_id = 9;\n    state.round_id = new_round_id;\n    store.set_coordinator_state(&state).await.unwrap();\n\n    // upload a global model with a wrong model length and set the id\n    let uploaded_global_model = create_global_model(state.round_params.model_length + 10);\n    let global_model_id = store\n        .set_global_model(\n            state.round_id,\n            &state.round_params.seed,\n            &uploaded_global_model,\n        )\n        .await\n        .unwrap();\n    store\n        .set_latest_global_model_id(&global_model_id)\n        .await\n        .unwrap();\n\n    let smi = StateMachineInitializer::new(\n        pet_settings,\n        mask_settings,\n        model_settings,\n        RestoreSettings { enable: true },\n        store,\n    );\n\n    let result = smi.init().await;\n\n    assert!(matches!(\n        result,\n        Err(StateMachineInitializationError::GlobalModelInvalid(_))\n    ));\n}\n\n#[cfg(feature = \"model-persistence\")]\n#[tokio::test]\n#[serial]\n#[ignore]\nasync fn integration_state_machine_initializer_failed_to_find_global_model() {\n    let pet_settings = pet_settings();\n    let mask_settings = mask_settings();\n    let model_settings = model_settings();\n\n    let mut store = init_store().await;\n    let mut state = CoordinatorState::new(pet_settings, mask_settings, model_settings.clone());\n    let new_round_id = 11;\n    state.round_id = new_round_id;\n    store.set_coordinator_state(&state).await.unwrap();\n\n    // set a model id but don't store a model\n    let global_model_id = \"1_412957050209fcfa733b1fb4ad51f321\";\n    store\n        .set_latest_global_model_id(global_model_id)\n        .await\n        .unwrap();\n\n    let smi = StateMachineInitializer::new(\n        pet_settings,\n        mask_settings,\n        model_settings,\n        RestoreSettings { enable: true },\n        store,\n    );\n\n    let result = smi.init().await;\n\n    assert!(matches!(\n        result,\n        Err(StateMachineInitializationError::GlobalModelUnavailable(_))\n    ));\n}\n\n#[tokio::test]\n#[serial]\n#[ignore]\nasync fn integration_state_machine_initializer_reset_state() {\n    let pet_settings = pet_settings();\n    let mask_settings = mask_settings();\n    let model_settings = model_settings();\n\n    let mut store = init_store().await;\n    let state = CoordinatorState::new(pet_settings, mask_settings, model_settings.clone());\n    store.set_coordinator_state(&state).await.unwrap();\n\n    let mut smi = StateMachineInitializer::new(\n        pet_settings,\n        mask_settings,\n        model_settings,\n        #[cfg(feature = \"model-persistence\")]\n        RestoreSettings { enable: true },\n        store.clone(),\n    );\n\n    smi.from_settings().await.unwrap();\n\n    assert!(store.coordinator_state().await.unwrap().is_none());\n    assert!(store.sum_dict().await.unwrap().is_none());\n    assert!(store.seed_dict().await.unwrap().is_none());\n    assert!(store.best_masks().await.unwrap().is_none());\n    assert!(store.latest_global_model_id().await.unwrap().is_none());\n    assert_eq!(store.number_of_unique_masks().await.unwrap(), 0);\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/tests/mod.rs",
    "content": "//! State machine test utilities.\n\npub mod coordinator_state;\npub mod event_bus;\npub mod impls;\npub mod initializer;\npub mod utils;\n\npub use coordinator_state::CoordinatorStateBuilder;\npub use event_bus::EventBusBuilder;\n\nconst WARNING: &str = \"All state machine tests were written assuming these initial values.\nFirst, carefully check the correctness of the state machine test before finally\nchanging these values.\";\n"
  },
  {
    "path": "rust/xaynet-server/src/state_machine/tests/utils.rs",
    "content": "//! State machine misc test utilities.\n\nuse std::fmt::Debug;\n\nuse tokio::sync::mpsc;\nuse tracing_subscriber::{EnvFilter, FmtSubscriber};\nuse xaynet_core::{\n    common::RoundParameters,\n    crypto::{ByteObject, EncryptKeyPair, PublicEncryptKey, PublicSigningKey},\n    mask::{BoundType, DataType, GroupType, MaskObject, ModelType},\n    message::{Message, Sum, Sum2, Update},\n    LocalSeedDict,\n    ParticipantTaskSignature,\n    SeedDict,\n    SumDict,\n};\n\nuse crate::{\n    settings::{\n        MaskSettings,\n        ModelSettings,\n        PetSettings,\n        PetSettingsCount,\n        PetSettingsSum,\n        PetSettingsSum2,\n        PetSettingsTime,\n        PetSettingsUpdate,\n    },\n    state_machine::{\n        coordinator::CoordinatorState,\n        events::{DictionaryUpdate, Event, EventPublisher, EventSubscriber, ModelUpdate},\n        phases::{PhaseName, Shared},\n        requests::{RequestReceiver, RequestSender},\n    },\n    storage::tests::utils::create_mask,\n};\n\nuse super::WARNING;\n\npub fn enable_logging() {\n    let _fmt_subscriber = FmtSubscriber::builder()\n        .with_env_filter(EnvFilter::from_default_env())\n        .with_ansi(true)\n        .try_init();\n}\n\npub fn pet_settings() -> PetSettings {\n    PetSettings {\n        sum: PetSettingsSum {\n            prob: 0.4,\n            count: PetSettingsCount { min: 1, max: 100 },\n            time: PetSettingsTime { min: 1, max: 2 },\n        },\n        update: PetSettingsUpdate {\n            prob: 0.5,\n            count: PetSettingsCount { min: 3, max: 1000 },\n            time: PetSettingsTime { min: 1, max: 2 },\n        },\n        sum2: PetSettingsSum2 {\n            count: PetSettingsCount { min: 1, max: 100 },\n            time: PetSettingsTime { min: 1, max: 2 },\n        },\n    }\n}\n\npub fn mask_settings() -> MaskSettings {\n    MaskSettings {\n        group_type: GroupType::Prime,\n        data_type: DataType::F32,\n        bound_type: BoundType::B0,\n        model_type: ModelType::M3,\n    }\n}\n\npub fn model_settings() -> ModelSettings {\n    ModelSettings { length: 1 }\n}\n\npub fn init_shared<T>(\n    coordinator_state: CoordinatorState,\n    store: T,\n    event_publisher: EventPublisher,\n) -> (Shared<T>, RequestSender) {\n    let (request_rx, request_tx) = RequestReceiver::new();\n    (\n        Shared::new(coordinator_state, event_publisher, request_rx, store),\n        request_tx,\n    )\n}\n\n#[derive(Debug, Clone, PartialEq)]\npub struct EventSnapshot {\n    pub keys: Event<EncryptKeyPair>,\n    pub params: Event<RoundParameters>,\n    pub phase: Event<PhaseName>,\n    pub model: Event<ModelUpdate>,\n    pub sum_dict: Event<DictionaryUpdate<SumDict>>,\n    pub seed_dict: Event<DictionaryUpdate<SeedDict>>,\n}\n\nimpl From<&EventSubscriber> for EventSnapshot {\n    fn from(event_subscriber: &EventSubscriber) -> Self {\n        Self {\n            keys: event_subscriber.keys_listener().get_latest(),\n            params: event_subscriber.params_listener().get_latest(),\n            phase: event_subscriber.phase_listener().get_latest(),\n            model: event_subscriber.model_listener().get_latest(),\n            sum_dict: event_subscriber.sum_dict_listener().get_latest(),\n            seed_dict: event_subscriber.seed_dict_listener().get_latest(),\n        }\n    }\n}\n\npub fn assert_event_updated_with_id<T: Debug + PartialEq>(event1: &Event<T>, event2: &Event<T>) {\n    assert_ne!(event1.round_id, event2.round_id);\n    assert_ne!(event1.event, event2.event);\n}\n\npub fn assert_event_updated<T: Debug + PartialEq>(event1: &Event<T>, event2: &Event<T>) {\n    assert_eq!(event1.round_id, event2.round_id);\n    assert_ne!(event1.event, event2.event);\n}\n\npub fn compose_sum_message() -> Message {\n    let payload = Sum {\n        sum_signature: ParticipantTaskSignature::zeroed(),\n        ephm_pk: PublicEncryptKey::zeroed(),\n    };\n    Message::new_sum(\n        PublicSigningKey::zeroed(),\n        PublicEncryptKey::zeroed(),\n        payload,\n    )\n}\n\npub fn compose_update_message(masked_model: MaskObject) -> Message {\n    let payload = Update {\n        sum_signature: ParticipantTaskSignature::zeroed(),\n        update_signature: ParticipantTaskSignature::zeroed(),\n        masked_model,\n        local_seed_dict: LocalSeedDict::new(),\n    };\n    Message::new_update(\n        PublicSigningKey::zeroed(),\n        PublicEncryptKey::zeroed(),\n        payload,\n    )\n}\n\npub fn compose_sum2_message() -> Message {\n    let payload = Sum2 {\n        sum_signature: ParticipantTaskSignature::zeroed(),\n        model_mask: create_mask(1, 1),\n    };\n    Message::new_sum2(\n        PublicSigningKey::zeroed(),\n        PublicEncryptKey::zeroed(),\n        payload,\n    )\n}\n\npub fn send_sum_messages(n: u32, request_tx: RequestSender) {\n    for _ in 0..n {\n        let request = request_tx.clone();\n        tokio::spawn(async move { request.msg(&compose_sum_message()).await });\n    }\n}\n\n#[allow(dead_code)]\npub fn send_sum_messages_with_latch(n: u32, request_tx: RequestSender, latch: Latch) {\n    for _ in 0..n {\n        let request = request_tx.clone();\n        let l = latch.clone();\n        tokio::spawn(async move {\n            let _ = request.msg(&compose_sum_message()).await;\n            l.release();\n        });\n    }\n}\n\npub fn send_sum2_messages(n: u32, request_tx: RequestSender) {\n    for _ in 0..n {\n        let request = request_tx.clone();\n        tokio::spawn(async move { request.msg(&compose_sum2_message()).await });\n    }\n}\n\npub fn send_update_messages(n: u32, request_tx: RequestSender) {\n    let default_model = create_mask(1, 1);\n    for _ in 0..n {\n        let request = request_tx.clone();\n        let masked_model = default_model.clone();\n        tokio::spawn(async move { request.msg(&compose_update_message(masked_model)).await });\n    }\n}\n\npub fn send_update_messages_with_model(\n    n: u32,\n    request_tx: RequestSender,\n    masked_model: MaskObject,\n) {\n    for _ in 0..n {\n        let request = request_tx.clone();\n        let moved_masked_model = masked_model.clone();\n        tokio::spawn(async move {\n            request\n                .msg(&compose_update_message(moved_masked_model))\n                .await\n        });\n    }\n}\n\n#[allow(dead_code)]\npub struct Readiness(mpsc::Receiver<()>);\n\n#[allow(dead_code)]\n#[derive(Clone)]\npub struct Latch(mpsc::Sender<()>);\n\n#[allow(dead_code)]\nimpl Readiness {\n    pub fn new() -> (Readiness, Latch) {\n        let (sender, receiver) = mpsc::channel(1);\n        (Readiness(receiver), Latch(sender))\n    }\n\n    pub async fn is_ready(&mut self) {\n        let _ = self.0.recv().await;\n    }\n}\n\nimpl Latch {\n    /// Releases this readiness latch.\n    pub fn release(self) {\n        drop(self);\n    }\n}\n\n#[test]\nfn test_initial_settings() {\n    let pet = PetSettings {\n        sum: PetSettingsSum {\n            prob: 0.4,\n            count: PetSettingsCount { min: 1, max: 100 },\n            time: PetSettingsTime { min: 1, max: 2 },\n        },\n        update: PetSettingsUpdate {\n            prob: 0.5,\n            count: PetSettingsCount { min: 3, max: 1000 },\n            time: PetSettingsTime { min: 1, max: 2 },\n        },\n        sum2: PetSettingsSum2 {\n            count: PetSettingsCount { min: 1, max: 100 },\n            time: PetSettingsTime { min: 1, max: 2 },\n        },\n    };\n\n    assert_eq!(\n        pet,\n        pet_settings(),\n        \"the initial PetSettings have been changed. {}\",\n        WARNING\n    );\n\n    let mask = MaskSettings {\n        group_type: GroupType::Prime,\n        data_type: DataType::F32,\n        bound_type: BoundType::B0,\n        model_type: ModelType::M3,\n    };\n\n    assert_eq!(\n        mask,\n        mask_settings(),\n        \"the initial MaskSettings have been changed. {}\",\n        WARNING\n    );\n\n    let model = ModelSettings { length: 1 };\n\n    assert_eq!(\n        model,\n        model_settings(),\n        \"the initial ModelSettings have been changed. {}\",\n        WARNING\n    );\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/storage/coordinator_storage/mod.rs",
    "content": "//! Storage backends to manage the coordinator state.\n\npub mod redis;\n"
  },
  {
    "path": "rust/xaynet-server/src/storage/coordinator_storage/redis/impls.rs",
    "content": "use std::convert::TryFrom;\n\nuse derive_more::{From, Into};\nuse paste::paste;\nuse redis::{ErrorKind, FromRedisValue, RedisError, RedisResult, RedisWrite, ToRedisArgs, Value};\nuse serde::{Deserialize, Serialize};\n\nuse crate::{\n    state_machine::coordinator::CoordinatorState,\n    storage::{\n        LocalSeedDictAdd,\n        LocalSeedDictAddError,\n        MaskScoreIncr,\n        MaskScoreIncrError,\n        SumPartAdd,\n        SumPartAddError,\n    },\n};\nuse xaynet_core::{\n    crypto::{ByteObject, PublicEncryptKey, PublicSigningKey},\n    mask::{EncryptedMaskSeed, MaskObject},\n    LocalSeedDict,\n};\n\npub fn redis_type_error(desc: &'static str, details: Option<String>) -> RedisError {\n    if let Some(details) = details {\n        RedisError::from((ErrorKind::TypeError, desc, details))\n    } else {\n        RedisError::from((ErrorKind::TypeError, desc))\n    }\n}\n\nfn error_code_type_error(response: &Value) -> RedisError {\n    redis_type_error(\n        \"Response status not valid integer\",\n        Some(format!(\"Response was {:?}\", response)),\n    )\n}\n\n/// Implements ['FromRedisValue'] and ['ToRedisArgs'] for types that implement ['ByteObject'].\n/// The Redis traits as well as the crypto types are both defined in foreign crates.\n/// To bypass the restrictions of orphan rule, we use `Newtypes` for the crypto types.\n///\n/// Each crypto type has two `Newtypes`, one for reading and one for writing.\n/// The difference between `Read` and `Write` is that the write `Newtype` does not take the\n/// ownership of the value but only a reference. This allows us to use references in the\n/// [`Client`] methods. The `Read` Newtype also implements [`ToRedisArgs`] to reduce the\n/// conversion overhead that you would get if you wanted to reuse a `Read` value for another\n/// Redis query.\n///\n/// Example:\n///\n/// ```compile_fail\n/// let sum_pks: Vec<PublicSigningKeyRead> = self.connection.hkeys(\"sum_dict\").await?;\n/// for sum_pk in sum_pks {\n///    let sum_pk_seed_dict: HashMap<PublicSigningKeyRead, EncryptedMaskSeedRead>\n///       = self.connection.hgetall(&sum_pk).await?; // no need to convert sum_pk from PublicSigningKeyRead to PublicSigningKeyWrite\n/// }\n/// ```\n///\n/// [`Client`]: crate::storage::redis::Client\nmacro_rules! impl_byte_object_redis_traits {\n    ($ty: ty) => {\n        paste! {\n            #[derive(Into, Hash, Eq, PartialEq)]\n            pub(crate) struct [<$ty Read>]($ty);\n\n            impl FromRedisValue for [<$ty Read>] {\n                fn from_redis_value(v: &Value) -> RedisResult<[<$ty Read>]> {\n                    match *v {\n                        Value::Data(ref bytes) => {\n                            let inner = <$ty>::from_slice(bytes).ok_or_else(|| {\n                                redis_type_error(concat!(\"Invalid \", stringify!($ty)), None)\n                            })?;\n                            Ok([<$ty Read>](inner))\n                        }\n                        _ => Err(redis_type_error(\n                            concat!(\"Response not \", stringify!($ty), \" compatible\"),\n                            None,\n                        )),\n                    }\n                }\n            }\n\n            impl ToRedisArgs for [<$ty Read>] {\n                fn write_redis_args<W>(&self, out: &mut W)\n                where\n                    W: ?Sized + RedisWrite,\n                {\n                    self.0.as_slice().write_redis_args(out)\n                }\n            }\n\n            #[derive(From)]\n            pub(crate) struct [<$ty Write>]<'a>(&'a $ty);\n\n            impl ToRedisArgs for [<$ty Write>]<'_> {\n                fn write_redis_args<W>(&self, out: &mut W)\n                where\n                    W: ?Sized + RedisWrite,\n                {\n                    self.0.as_slice().write_redis_args(out)\n                }\n            }\n        }\n    };\n}\n\nimpl_byte_object_redis_traits!(PublicEncryptKey);\nimpl_byte_object_redis_traits!(PublicSigningKey);\nimpl_byte_object_redis_traits!(EncryptedMaskSeed);\n\n/// Implements ['FromRedisValue'] and ['ToRedisArgs'] for types that implement\n/// ['Serialize`] and [`Deserialize']. The data is de/serialized via bincode.\n///\n/// # Panics\n///\n/// `write_redis_args` will panic if the data cannot be serialized with `bincode`\n///\n/// More information about what can cause a panic in bincode:\n/// - https://github.com/servo/bincode/issues/293\n/// - https://github.com/servo/bincode/issues/255\n/// - https://github.com/servo/bincode/issues/130#issuecomment-284641263\nmacro_rules! impl_bincode_redis_traits {\n    ($ty: ty) => {\n        impl FromRedisValue for $ty {\n            fn from_redis_value(v: &Value) -> RedisResult<$ty> {\n                match *v {\n                    Value::Data(ref bytes) => bincode::deserialize(bytes)\n                        .map_err(|e| redis_type_error(\"Invalid data\", Some(e.to_string()))),\n                    _ => Err(redis_type_error(\"Response not bincode compatible\", None)),\n                }\n            }\n        }\n\n        impl ToRedisArgs for $ty {\n            fn write_redis_args<W>(&self, out: &mut W)\n            where\n                W: ?Sized + RedisWrite,\n            {\n                let data = bincode::serialize(self).unwrap();\n                data.write_redis_args(out)\n            }\n        }\n    };\n}\n\n// CoordinatorState is pretty straightforward:\n// - all the sequences have known length (\n// - no untagged enum\n// so bincode will not panic.\nimpl_bincode_redis_traits!(CoordinatorState);\n\n#[derive(From, Into, Serialize, Deserialize)]\npub(crate) struct MaskObjectRead(MaskObject);\n\nimpl_bincode_redis_traits!(MaskObjectRead);\n\n#[derive(From, Serialize)]\npub(crate) struct MaskObjectWrite<'a>(&'a MaskObject);\n\nimpl ToRedisArgs for MaskObjectWrite<'_> {\n    fn write_redis_args<W>(&self, out: &mut W)\n    where\n        W: ?Sized + RedisWrite,\n    {\n        let data = bincode::serialize(self).unwrap();\n        data.write_redis_args(out)\n    }\n}\n\n#[derive(From)]\npub(crate) struct LocalSeedDictWrite<'a>(&'a LocalSeedDict);\n\nimpl ToRedisArgs for LocalSeedDictWrite<'_> {\n    fn write_redis_args<W>(&self, out: &mut W)\n    where\n        W: ?Sized + RedisWrite,\n    {\n        let args: Vec<(PublicSigningKeyWrite, EncryptedMaskSeedWrite)> = self\n            .0\n            .iter()\n            .map(|(pk, seed)| {\n                (\n                    PublicSigningKeyWrite::from(pk),\n                    EncryptedMaskSeedWrite::from(seed),\n                )\n            })\n            .collect();\n\n        args.write_redis_args(out)\n    }\n}\n\nimpl FromRedisValue for LocalSeedDictAdd {\n    fn from_redis_value(v: &Value) -> RedisResult<LocalSeedDictAdd> {\n        match *v {\n            Value::Int(0) => Ok(LocalSeedDictAdd(Ok(()))),\n            Value::Int(error_code) => match LocalSeedDictAddError::try_from(error_code) {\n                Ok(error_variant) => Ok(LocalSeedDictAdd(Err(error_variant))),\n                Err(_) => Err(error_code_type_error(v)),\n            },\n            _ => Err(error_code_type_error(v)),\n        }\n    }\n}\n\nimpl FromRedisValue for SumPartAdd {\n    fn from_redis_value(v: &Value) -> RedisResult<SumPartAdd> {\n        match *v {\n            Value::Int(1) => Ok(SumPartAdd(Ok(()))),\n            Value::Int(error_code) => match SumPartAddError::try_from(error_code) {\n                Ok(error_variant) => Ok(SumPartAdd(Err(error_variant))),\n                Err(_) => Err(error_code_type_error(v)),\n            },\n            _ => Err(error_code_type_error(v)),\n        }\n    }\n}\n\nimpl FromRedisValue for MaskScoreIncr {\n    fn from_redis_value(v: &Value) -> RedisResult<MaskScoreIncr> {\n        match *v {\n            Value::Int(0) => Ok(MaskScoreIncr(Ok(()))),\n            Value::Int(error_code) => match MaskScoreIncrError::try_from(error_code) {\n                Ok(error_variant) => Ok(MaskScoreIncr(Err(error_variant))),\n                Err(_) => Err(error_code_type_error(v)),\n            },\n            _ => Err(error_code_type_error(v)),\n        }\n    }\n}\n\n#[cfg(test)]\n#[derive(derive_more::Deref)]\npub struct SumDictDelete(Result<(), SumDictDeleteError>);\n\n#[cfg(test)]\nimpl SumDictDelete {\n    pub fn into_inner(self) -> Result<(), SumDictDeleteError> {\n        self.0\n    }\n}\n\n#[cfg(test)]\n#[derive(thiserror::Error, Debug, num_enum::TryFromPrimitive)]\n#[repr(i64)]\npub enum SumDictDeleteError {\n    #[error(\"sum participant does not exist\")]\n    DoesNotExist = 0,\n}\n\n#[cfg(test)]\nimpl FromRedisValue for SumDictDelete {\n    fn from_redis_value(v: &Value) -> RedisResult<SumDictDelete> {\n        match *v {\n            Value::Int(1) => Ok(SumDictDelete(Ok(()))),\n            Value::Int(error_code) => match SumDictDeleteError::try_from(error_code) {\n                Ok(error_variant) => Ok(SumDictDelete(Err(error_variant))),\n                Err(_) => Err(error_code_type_error(v)),\n            },\n            _ => Err(error_code_type_error(v)),\n        }\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/storage/coordinator_storage/redis/mod.rs",
    "content": "//! A Redis [`CoordinatorStorage`] backend.\n//!\n//! # Redis Data Model\n//!\n//!```text\n//! {\n//!     // Coordinator state\n//!     \"coordinator_state\": \"...\", // bincode encoded string\n//!     // Sum dict\n//!     \"sum_dict\": { // hash\n//!         \"SumParticipantPublicKey_1\": SumParticipantEphemeralPublicKey_1,\n//!         \"SumParticipantPublicKey_2\": SumParticipantEphemeralPublicKey_2\n//!     },\n//!     // Seed dict\n//!     \"update_participants\": [ // set\n//!         UpdateParticipantPublicKey_1,\n//!         UpdateParticipantPublicKey_2\n//!     ],\n//!     \"SumParticipantPublicKey_1\": { // hash\n//!         \"UpdateParticipantPublicKey_1\": EncryptedMaskSeed,\n//!         \"UpdateParticipantPublicKey_2\": EncryptedMaskSeed\n//!     },\n//!     \"SumParticipantPublicKey_2\": {\n//!         \"UpdateParticipantPublicKey_1\": EncryptedMaskSeed,\n//!         \"UpdateParticipantPublicKey_2\": EncryptedMaskSeed\n//!     },\n//!     // Mask dict\n//!     \"mask_submitted\": [ // set\n//!         SumParticipantPublicKey_1,\n//!         SumParticipantPublicKey_2\n//!     ],\n//!     \"mask_dict\": [ // sorted set\n//!         (mask_object_1, 2), // (mask: bincode encoded string, score/counter: number)\n//!         (mask_object_2, 1)\n//!     ],\n//!     \"latest_global_model_id\": global_model_id\n//! }\n//! ```\n\npub(in crate::storage) mod impls;\n\nuse std::collections::HashMap;\n\nuse async_trait::async_trait;\nuse redis::{aio::ConnectionManager, AsyncCommands, IntoConnectionInfo, Pipeline, Script};\npub use redis::{RedisError, RedisResult};\nuse tracing::debug;\n\nuse self::impls::{\n    EncryptedMaskSeedRead,\n    LocalSeedDictWrite,\n    MaskObjectRead,\n    MaskObjectWrite,\n    PublicEncryptKeyRead,\n    PublicEncryptKeyWrite,\n    PublicSigningKeyRead,\n    PublicSigningKeyWrite,\n};\nuse crate::{\n    state_machine::coordinator::CoordinatorState,\n    storage::{\n        CoordinatorStorage,\n        LocalSeedDictAdd,\n        MaskScoreIncr,\n        StorageError,\n        StorageResult,\n        SumPartAdd,\n    },\n};\nuse xaynet_core::{\n    mask::MaskObject,\n    LocalSeedDict,\n    SeedDict,\n    SumDict,\n    SumParticipantEphemeralPublicKey,\n    SumParticipantPublicKey,\n    UpdateParticipantPublicKey,\n};\n\n/// Redis client.\n#[derive(Clone)]\npub struct Client {\n    connection: ConnectionManager,\n}\n\nfn to_storage_err(e: RedisError) -> StorageError {\n    anyhow::anyhow!(e)\n}\n\nimpl Client {\n    /// Creates a new Redis client.\n    ///\n    /// `url` to which Redis instance the client should connect to.\n    /// The URL format is `redis://[<username>][:<passwd>@]<hostname>[:port][/<db>]`.\n    ///\n    /// The [`Client`] uses a [`ConnectionManager`] that automatically reconnects\n    /// if the connection is dropped.\n    pub async fn new<T: IntoConnectionInfo>(url: T) -> Result<Self, RedisError> {\n        let client = redis::Client::open(url)?;\n        let connection = client.get_tokio_connection_manager().await?;\n        Ok(Self { connection })\n    }\n\n    async fn create_flush_dicts_pipeline(&mut self) -> RedisResult<Pipeline> {\n        // https://redis.io/commands/hkeys\n        // > Return value:\n        //   Array reply: list of fields in the hash, or an empty list when key does not exist.\n        let sum_pks: Vec<PublicSigningKeyRead> = self.connection.hkeys(\"sum_dict\").await?;\n        let mut pipe = redis::pipe();\n\n        // https://redis.io/commands/del\n        // > Return value:\n        //   The number of keys that were removed.\n        //\n        // Returns `0` if the key does not exist.\n        // We ignore the return value because we are not interested in it.\n\n        // delete sum dict\n        pipe.del(\"sum_dict\").ignore();\n\n        // delete seed dict\n        pipe.del(\"update_participants\").ignore();\n        for sum_pk in sum_pks {\n            pipe.del(sum_pk).ignore();\n        }\n\n        // delete mask dict\n        pipe.del(\"mask_submitted\").ignore();\n        pipe.del(\"mask_dict\").ignore();\n        Ok(pipe)\n    }\n}\n\n#[async_trait]\nimpl CoordinatorStorage for Client {\n    async fn set_coordinator_state(&mut self, state: &CoordinatorState) -> StorageResult<()> {\n        debug!(\"set coordinator state\");\n        // https://redis.io/commands/set\n        // > Set key to hold the string value. If key already holds a value,\n        //   it is overwritten, regardless of its type.\n        // Possible return value in our case:\n        // > Simple string reply: OK if SET was executed correctly.\n        self.connection\n            .set(\"coordinator_state\", state)\n            .await\n            .map_err(to_storage_err)\n    }\n\n    async fn coordinator_state(&mut self) -> StorageResult<Option<CoordinatorState>> {\n        // https://redis.io/commands/get\n        // > Get the value of key. If the key does not exist the special value nil is returned.\n        //   An error is returned if the value stored at key is not a string, because GET only\n        //   handles string values.\n        // > Return value\n        //   Bulk string reply: the value of key, or nil when key does not exist.\n        self.connection\n            .get(\"coordinator_state\")\n            .await\n            .map_err(to_storage_err)\n    }\n\n    async fn add_sum_participant(\n        &mut self,\n        pk: &SumParticipantPublicKey,\n        ephm_pk: &SumParticipantEphemeralPublicKey,\n    ) -> StorageResult<SumPartAdd> {\n        debug!(\"add sum participant with pk {:?}\", pk);\n        // https://redis.io/commands/hsetnx\n        // > If field already exists, this operation has no effect.\n        // > Return value\n        //   Integer reply, specifically:\n        //   1 if field is a new field in the hash and value was set.\n        //   0 if field already exists in the hash and no operation was performed.\n        self.connection\n            .hset_nx(\n                \"sum_dict\",\n                PublicSigningKeyWrite::from(pk),\n                PublicEncryptKeyWrite::from(ephm_pk),\n            )\n            .await\n            .map_err(to_storage_err)\n    }\n\n    async fn sum_dict(&mut self) -> StorageResult<Option<SumDict>> {\n        debug!(\"get sum dictionary\");\n        // https://redis.io/commands/hgetall\n        // > Return value\n        //   Array reply: list of fields and their values stored in the hash, or an empty\n        //   list when key does not exist.\n        let reply: Vec<(PublicSigningKeyRead, PublicEncryptKeyRead)> = self\n            .connection\n            .hgetall(\"sum_dict\")\n            .await\n            .map_err(to_storage_err)?;\n\n        if reply.is_empty() {\n            return Ok(None);\n        };\n\n        let sum_dict = reply\n            .into_iter()\n            .map(|(pk, ephm_pk)| (pk.into(), ephm_pk.into()))\n            .collect();\n\n        Ok(Some(sum_dict))\n    }\n\n    async fn add_local_seed_dict(\n        &mut self,\n        update_pk: &UpdateParticipantPublicKey,\n        local_seed_dict: &LocalSeedDict,\n    ) -> StorageResult<LocalSeedDictAdd> {\n        debug!(\n            \"update seed dictionary for update participant with pk {:?}\",\n            update_pk\n        );\n        let script = Script::new(\n            r#\"\n                -- lua lists (tables) start at 1\n                local update_pk = ARGV[1]\n\n                -- check if the local seed dict has the same length as the sum_dict\n\n                -- KEYS is a list (table) of key value pairs ([sum_pk_1, seed_1, sum_pk_2, seed_2, ...])\n                local seed_dict_len = #KEYS / 2\n                local sum_dict_len = redis.call(\"HLEN\", \"sum_dict\")\n                if seed_dict_len ~= sum_dict_len then\n                    return -1\n                end\n\n                -- check if all pks of the local seed dict exists in sum_dict\n                for i = 1, #KEYS, 2 do\n                    local exist_in_sum_dict = redis.call(\"HEXISTS\", \"sum_dict\", KEYS[i])\n                    if exist_in_sum_dict == 0 then\n                        return -2\n                    end\n                end\n\n                -- check if the update pk already exists (i.e. the local seed dict has already been submitted)\n                local exist_in_seed_dict = redis.call(\"SADD\", \"update_participants\", update_pk)\n                -- SADD returns 0 if the key already exists\n                if exist_in_seed_dict == 0 then\n                    return -3\n                end\n\n                -- update the seed dict\n                for i = 1, #KEYS, 2 do\n                    local exist_in_update_seed_dict = redis.call(\"HSETNX\", KEYS[i], update_pk, KEYS[i + 1])\n                    -- HSETNX returns 0 if the update pk already exists\n                    if exist_in_update_seed_dict == 0 then\n                        -- This condition should never apply.\n                        -- If this condition is true, it is an indication that the data in redis is corrupted.\n                        return -4\n                    end\n                end\n\n                return 0\n            \"#,\n        );\n\n        script\n            .key(LocalSeedDictWrite::from(local_seed_dict))\n            .arg(PublicSigningKeyWrite::from(update_pk))\n            .invoke_async(&mut self.connection)\n            .await\n            .map_err(to_storage_err)\n    }\n\n    /// # Note\n    /// This method is **not** an atomic operation.\n    async fn seed_dict(&mut self) -> StorageResult<Option<SeedDict>> {\n        debug!(\"get seed dictionary\");\n        // https://redis.io/commands/hkeys\n        // > Return value:\n        //   Array reply: list of fields in the hash, or an empty list when key does not exist.\n        let sum_pks: Vec<PublicSigningKeyRead> = self.connection.hkeys(\"sum_dict\").await?;\n\n        if sum_pks.is_empty() {\n            return Ok(None);\n        };\n\n        let mut seed_dict: SeedDict = SeedDict::new();\n        for sum_pk in sum_pks {\n            // https://redis.io/commands/hgetall\n            // > Return value\n            //   Array reply: list of fields and their values stored in the hash, or an empty\n            //   list when key does not exist.\n            let sum_pk_seed_dict: HashMap<PublicSigningKeyRead, EncryptedMaskSeedRead> =\n                self.connection.hgetall(&sum_pk).await?;\n            seed_dict.insert(\n                sum_pk.into(),\n                sum_pk_seed_dict\n                    .into_iter()\n                    .map(|(pk, seed)| (pk.into(), seed.into()))\n                    .collect(),\n            );\n        }\n\n        Ok(Some(seed_dict))\n    }\n\n    /// The maximum length of a serialized mask is 512 Megabytes.\n    async fn incr_mask_score(\n        &mut self,\n        sum_pk: &SumParticipantPublicKey,\n        mask: &MaskObject,\n    ) -> StorageResult<MaskScoreIncr> {\n        debug!(\"increment mask count\");\n        let script = Script::new(\n            r#\"\n                -- lua lists (tables) start at 1\n                local sum_pk = ARGV[1]\n\n                -- check if the client participated in sum phase\n                --\n                -- Note: we cannot delete the sum_pk in the sum_dict because we\n                -- need the sum_dict later to delete the seed_dict\n                local sum_pk_exist = redis.call(\"HEXISTS\", \"sum_dict\", sum_pk)\n                if sum_pk_exist == 0 then\n                    return -1\n                end\n\n                -- check if sum participant has not already submitted a mask\n                local mask_already_submitted = redis.call(\"SADD\", \"mask_submitted\", sum_pk)\n                -- SADD returns 0 if the key already exists\n                if mask_already_submitted == 0 then\n                    return -2\n                end\n\n                redis.call(\"ZINCRBY\", \"mask_dict\", 1, KEYS[1])\n\n                return 0\n            \"#,\n        );\n\n        script\n            .key(MaskObjectWrite::from(mask))\n            .arg(PublicSigningKeyWrite::from(sum_pk))\n            .invoke_async(&mut self.connection)\n            .await\n            .map_err(to_storage_err)\n    }\n\n    async fn best_masks(&mut self) -> StorageResult<Option<Vec<(MaskObject, u64)>>> {\n        debug!(\"get best masks\");\n        // https://redis.io/commands/zrevrangebyscore\n        // > Return value:\n        //   Array reply: list of elements in the specified range (optionally with their scores,\n        //   in case the WITHSCORES option is given).\n        let reply: Vec<(MaskObjectRead, u64)> = self\n            .connection\n            .zrevrange_withscores(\"mask_dict\", 0, 1)\n            .await?;\n\n        let result = match reply.is_empty() {\n            true => None,\n            _ => {\n                let masks = reply\n                    .into_iter()\n                    .map(|(mask, count)| (mask.into(), count))\n                    .collect();\n\n                Some(masks)\n            }\n        };\n\n        Ok(result)\n    }\n\n    async fn number_of_unique_masks(&mut self) -> StorageResult<u64> {\n        debug!(\"get number of unique masks\");\n        // https://redis.io/commands/zcount\n        // > Return value:\n        //   Integer reply: the number of elements in the specified score range.\n        self.connection\n            .zcount(\"mask_dict\", \"-inf\", \"+inf\")\n            .await\n            .map_err(to_storage_err)\n    }\n\n    /// # Note\n    /// This method is **not** an atomic operation.\n    async fn delete_coordinator_data(&mut self) -> StorageResult<()> {\n        debug!(\"flush coordinator data\");\n        let mut pipe = self.create_flush_dicts_pipeline().await?;\n        pipe.del(\"coordinator_state\").ignore();\n        pipe.del(\"latest_global_model_id\").ignore();\n        pipe.atomic()\n            .query_async(&mut self.connection)\n            .await\n            .map_err(to_storage_err)\n    }\n\n    /// # Note\n    /// This method is **not** an atomic operation.\n    async fn delete_dicts(&mut self) -> StorageResult<()> {\n        debug!(\"flush all dictionaries\");\n        let mut pipe = self.create_flush_dicts_pipeline().await?;\n        pipe.atomic()\n            .query_async(&mut self.connection)\n            .await\n            .map_err(to_storage_err)\n    }\n\n    async fn set_latest_global_model_id(&mut self, global_model_id: &str) -> StorageResult<()> {\n        debug!(\"set latest global model with id {}\", global_model_id);\n        // https://redis.io/commands/set\n        // > Set key to hold the string value. If key already holds a value,\n        //   it is overwritten, regardless of its type.\n        // Possible return value in our case:\n        // > Simple string reply: OK if SET was executed correctly.\n        self.connection\n            .set(\"latest_global_model_id\", global_model_id)\n            .await\n            .map_err(to_storage_err)\n    }\n\n    async fn latest_global_model_id(&mut self) -> StorageResult<Option<String>> {\n        debug!(\"get latest global model id\");\n        // https://redis.io/commands/get\n        // > Get the value of key. If the key does not exist the special value nil is returned.\n        //   An error is returned if the value stored at key is not a string, because GET only\n        //   handles string values.\n        // > Return value\n        //   Bulk string reply: the value of key, or nil when key does not exist.\n        self.connection\n            .get(\"latest_global_model_id\")\n            .await\n            .map_err(to_storage_err)\n    }\n\n    async fn is_ready(&mut self) -> StorageResult<()> {\n        // https://redis.io/commands/ping\n        redis::cmd(\"PING\")\n            .query_async(&mut self.connection)\n            .await\n            .map_err(to_storage_err)\n    }\n}\n\n#[cfg(test)]\n// Functions that are not needed in the state machine but handy for testing.\nimpl Client {\n    // Removes an entry in the [`SumDict`].\n    //\n    // Returns [`SumDictDelete(Ok(()))`] if field was deleted or\n    // [`SumDictDelete(Err(SumDictDeleteError::DoesNotExist)`] if field does not exist.\n    pub async fn remove_sum_dict_entry(\n        &mut self,\n        pk: &SumParticipantPublicKey,\n    ) -> RedisResult<self::impls::SumDictDelete> {\n        // https://redis.io/commands/hdel\n        // > Return value\n        //   Integer reply: the number of fields that were removed from the hash,\n        //   not including specified but non existing fields.\n        self.connection\n            .hdel(\"sum_dict\", PublicSigningKeyWrite::from(pk))\n            .await\n    }\n\n    // Returns the length of the [`SumDict`].\n    pub async fn sum_dict_len(&mut self) -> RedisResult<u64> {\n        // https://redis.io/commands/hlen\n        // > Return value\n        //   Integer reply: number of fields in the hash, or 0 when key does not exist.\n        self.connection.hlen(\"sum_dict\").await\n    }\n\n    // Returns the [`SumParticipantPublicKey`] of the [`SumDict`] or an empty list when the\n    // [`SumDict`] does not exist.\n    pub async fn sum_pks(\n        &mut self,\n    ) -> RedisResult<std::collections::HashSet<SumParticipantPublicKey>> {\n        // https://redis.io/commands/hkeys\n        // > Return value:\n        //   Array reply: list of fields in the hash, or an empty list when key does not exist.\n        let result: std::collections::HashSet<PublicSigningKeyRead> =\n            self.connection.hkeys(\"sum_dict\").await?;\n        let sum_pks = result.into_iter().map(|pk| pk.into()).collect();\n\n        Ok(sum_pks)\n    }\n\n    // Removes an update pk from the the `update_participants` set.\n    pub async fn remove_update_participant(\n        &mut self,\n        update_pk: &UpdateParticipantPublicKey,\n    ) -> RedisResult<u64> {\n        self.connection\n            .srem(\n                \"update_participants\",\n                PublicSigningKeyWrite::from(update_pk),\n            )\n            .await\n    }\n\n    pub async fn mask_submitted_set(&mut self) -> RedisResult<Vec<SumParticipantPublicKey>> {\n        let result: Vec<PublicSigningKeyRead> =\n            self.connection.smembers(\"update_submitted\").await?;\n        let sum_pks = result.into_iter().map(|pk| pk.into()).collect();\n        Ok(sum_pks)\n    }\n\n    // Returns all keys in the current database\n    pub async fn keys(&mut self) -> RedisResult<Vec<String>> {\n        self.connection.keys(\"*\").await\n    }\n\n    /// Returns the [`SeedDict`] entry for the given ['SumParticipantPublicKey'] or an empty map\n    /// when a [`SeedDict`] entry does not exist.\n    pub async fn seed_dict_for_sum_pk(\n        &mut self,\n        sum_pk: &SumParticipantPublicKey,\n    ) -> RedisResult<HashMap<UpdateParticipantPublicKey, xaynet_core::mask::EncryptedMaskSeed>>\n    {\n        debug!(\n            \"get seed dictionary for sum participant with pk {:?}\",\n            sum_pk\n        );\n        // https://redis.io/commands/hgetall\n        // > Return value\n        //   Array reply: list of fields and their values stored in the hash, or an empty\n        //   list when key does not exist.\n        let result: Vec<(PublicSigningKeyRead, EncryptedMaskSeedRead)> = self\n            .connection\n            .hgetall(PublicSigningKeyWrite::from(sum_pk))\n            .await?;\n        let seed_dict = result\n            .into_iter()\n            .map(|(pk, seed)| (pk.into(), seed.into()))\n            .collect();\n\n        Ok(seed_dict)\n    }\n\n    /// Deletes all data in the current database.\n    pub async fn flush_db(&mut self) -> RedisResult<()> {\n        debug!(\"flush current database\");\n        // https://redis.io/commands/flushdb\n        // > This command never fails.\n        redis::cmd(\"FLUSHDB\")\n            .arg(\"ASYNC\")\n            .query_async(&mut self.connection)\n            .await\n    }\n}\n\n#[cfg(test)]\npub(in crate) mod tests {\n    use self::impls::SumDictDeleteError;\n    use super::*;\n    use crate::{\n        state_machine::tests::utils::{mask_settings, model_settings, pet_settings},\n        storage::{tests::utils::*, LocalSeedDictAddError, MaskScoreIncrError, SumPartAddError},\n    };\n    use serial_test::serial;\n\n    async fn create_redis_client() -> Client {\n        Client::new(\"redis://127.0.0.1/\").await.unwrap()\n    }\n\n    pub async fn init_client() -> Client {\n        let mut client = create_redis_client().await;\n        client.flush_db().await.unwrap();\n        client\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_set_and_get_coordinator_state() {\n        // test the writing and reading of the coordinator state\n        let mut client = init_client().await;\n\n        let set_state = CoordinatorState::new(pet_settings(), mask_settings(), model_settings());\n        client.set_coordinator_state(&set_state).await.unwrap();\n\n        let get_state = client.coordinator_state().await.unwrap().unwrap();\n\n        assert_eq!(set_state, get_state)\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_get_coordinator_empty() {\n        // test the reading of a non existing coordinator state\n        let mut client = init_client().await;\n\n        let get_state = client.coordinator_state().await.unwrap();\n\n        assert_eq!(None, get_state)\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_incr_mask_score() {\n        // test the increment of the mask counter\n        let mut client = init_client().await;\n\n        let should_be_none = client.best_masks().await.unwrap();\n        assert!(should_be_none.is_none());\n\n        let sum_pks = create_and_add_sum_participant_entries(&mut client, 3).await;\n        let mask = create_mask_zeroed(10);\n        for sum_pk in sum_pks {\n            let res = client.incr_mask_score(&sum_pk, &mask).await;\n            assert!(res.is_ok())\n        }\n\n        let best_masks = client.best_masks().await.unwrap().unwrap();\n        assert!(best_masks.len() == 1);\n\n        let (best_mask, count) = best_masks.into_iter().next().unwrap();\n        assert_eq!(best_mask, mask);\n        assert_eq!(count, 3);\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_get_incr_mask_count_unknown_sum_pk() {\n        // test the writing and reading of one mask\n        let mut client = init_client().await;\n\n        let should_be_none = client.best_masks().await.unwrap();\n        assert!(should_be_none.is_none());\n\n        let (sum_pk, _) = create_sum_participant_entry();\n        let mask = create_mask_zeroed(10);\n        let unknown_sum_pk = client.incr_mask_score(&sum_pk, &mask).await.unwrap();\n\n        assert!(matches!(\n            unknown_sum_pk.into_inner().unwrap_err(),\n            MaskScoreIncrError::UnknownSumPk\n        ));\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_get_incr_mask_score_sum_pk_already_submitted() {\n        // test the writing and reading of one mask\n        let mut client = init_client().await;\n\n        let should_be_none = client.best_masks().await.unwrap();\n        assert!(should_be_none.is_none());\n\n        let mut sum_pks = create_and_add_sum_participant_entries(&mut client, 1).await;\n        let sum_pk = sum_pks.pop().unwrap();\n        let mask = create_mask_zeroed(10);\n        let result = client.incr_mask_score(&sum_pk, &mask).await.unwrap();\n        assert!(result.is_ok());\n\n        let already_submitted = client.incr_mask_score(&sum_pk, &mask).await.unwrap();\n\n        assert!(matches!(\n            already_submitted.into_inner().unwrap_err(),\n            MaskScoreIncrError::MaskAlreadySubmitted\n        ));\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_get_best_masks_only_one_mask() {\n        // test the writing and reading of one mask\n        let mut client = init_client().await;\n\n        let should_be_none = client.best_masks().await.unwrap();\n        assert!(should_be_none.is_none());\n\n        let sum_pks = create_and_add_sum_participant_entries(&mut client, 1).await;\n        let mask = create_mask_zeroed(10);\n        let res = client.incr_mask_score(sum_pks.get(0).unwrap(), &mask).await;\n        assert!(res.is_ok());\n\n        let best_masks = client.best_masks().await.unwrap().unwrap();\n        assert!(best_masks.len() == 1);\n\n        let (best_mask, count) = best_masks.into_iter().next().unwrap();\n        assert_eq!(best_mask, mask);\n        assert_eq!(count, 1);\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_get_best_masks_two_masks() {\n        // test the writing and reading of two masks\n        // the first mask is incremented twice\n        let mut client = init_client().await;\n\n        let should_be_none = client.best_masks().await.unwrap();\n        assert!(should_be_none.is_none());\n\n        let sum_pks = create_and_add_sum_participant_entries(&mut client, 2).await;\n        let mask_1 = create_mask_zeroed(10);\n        for sum_pk in sum_pks {\n            let res = client.incr_mask_score(&sum_pk, &mask_1).await;\n            assert!(res.is_ok())\n        }\n\n        let sum_pks = create_and_add_sum_participant_entries(&mut client, 1).await;\n        let mask_2 = create_mask_zeroed(100);\n        for sum_pk in sum_pks {\n            let res = client.incr_mask_score(&sum_pk, &mask_2).await;\n            assert!(res.is_ok())\n        }\n\n        let best_masks = client.best_masks().await.unwrap().unwrap();\n        assert!(best_masks.len() == 2);\n        let mut best_masks_iter = best_masks.into_iter();\n\n        let (first_mask, count) = best_masks_iter.next().unwrap();\n        assert_eq!(first_mask, mask_1);\n        assert_eq!(count, 2);\n        let (second_mask, count) = best_masks_iter.next().unwrap();\n        assert_eq!(second_mask, mask_2);\n        assert_eq!(count, 1);\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_get_best_masks_no_mask() {\n        // ensure that get_best_masks returns an empty vec if no mask exist\n        let mut client = init_client().await;\n\n        let best_masks = client.best_masks().await.unwrap();\n        assert!(best_masks.is_none())\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_get_number_of_unique_masks_empty() {\n        // ensure that get_best_masks returns an empty vec if no mask exist\n        let mut client = init_client().await;\n\n        let number_of_unique_masks = client.number_of_unique_masks().await.unwrap();\n        assert_eq!(number_of_unique_masks, 0)\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_get_number_of_unique_masks() {\n        // ensure that get_best_masks returns an empty vec if no mask exist\n        let mut client = init_client().await;\n\n        let should_be_none = client.best_masks().await.unwrap();\n        assert!(should_be_none.is_none());\n\n        let sum_pks = create_and_add_sum_participant_entries(&mut client, 4).await;\n        for (number, sum_pk) in sum_pks.iter().enumerate() {\n            let mask_1 = create_mask(10, number as u32);\n            let res = client.incr_mask_score(sum_pk, &mask_1).await;\n            assert!(res.is_ok())\n        }\n\n        let number_of_unique_masks = client.number_of_unique_masks().await.unwrap();\n        assert_eq!(number_of_unique_masks, 4)\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_sum_dict() {\n        // test multiple sum dict related methods\n        let mut client = init_client().await;\n\n        // create two entries and write them into redis\n        let mut entries = vec![];\n        for _ in 0..2 {\n            let (pk, epk) = create_sum_participant_entry();\n            let add_new_key = client.add_sum_participant(&pk, &epk).await.unwrap();\n            assert!(add_new_key.is_ok());\n\n            entries.push((pk, epk));\n        }\n\n        // ensure that add_sum_participant returns SumPartAddError::AlreadyExists if the key already exist\n        let (pk, epk) = entries.get(0).unwrap();\n        let key_already_exist = client.add_sum_participant(pk, epk).await.unwrap();\n        assert!(matches!(\n            key_already_exist.into_inner().unwrap_err(),\n            SumPartAddError::AlreadyExists\n        ));\n\n        // ensure that get_sum_dict_len returns 2\n        let len_of_sum_dict = client.sum_dict_len().await.unwrap();\n        assert_eq!(len_of_sum_dict, 2);\n\n        // read the written sum keys\n        // ensure they are equal\n        let sum_pks = client.sum_pks().await.unwrap();\n        for (sum_pk, _) in entries.iter() {\n            assert!(sum_pks.contains(sum_pk));\n        }\n\n        // remove both sum entries\n        for (sum_pk, _) in entries.iter() {\n            let remove_sum_pk = client.remove_sum_dict_entry(sum_pk).await.unwrap();\n\n            assert!(remove_sum_pk.is_ok());\n        }\n\n        // ensure that add_sum_participant returns SumDictDeleteError::DoesNotExist if the key does not exist\n        let (sum_pk, _) = entries.get(0).unwrap();\n        let key_does_not_exist = client.remove_sum_dict_entry(sum_pk).await.unwrap();\n        assert!(matches!(\n            key_does_not_exist.into_inner().unwrap_err(),\n            SumDictDeleteError::DoesNotExist\n        ));\n\n        // ensure that get_sum_dict an empty sum dict\n        let sum_dict = client.sum_dict().await.unwrap();\n        assert!(sum_dict.is_none());\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_seed_dict() {\n        let mut client = init_client().await;\n\n        let sum_pks = create_and_add_sum_participant_entries(&mut client, 2).await;\n        let local_seed_dicts = create_local_seed_entries(&sum_pks);\n\n        let update_result = add_local_seed_entries(&mut client, &local_seed_dicts).await;\n        update_result.iter().for_each(|res| assert!(res.is_ok()));\n\n        let redis_sum_dict = client.sum_dict().await.unwrap().unwrap();\n        let seed_dict = create_seed_dict(redis_sum_dict, &local_seed_dicts);\n\n        let redis_seed_dict = client.seed_dict().await.unwrap().unwrap();\n        assert_eq!(seed_dict, redis_seed_dict)\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_seed_dict_len_mis_match() {\n        let mut client = init_client().await;\n\n        let mut sum_pks = create_and_add_sum_participant_entries(&mut client, 2).await;\n\n        // remove one sum pk to create invalid local seed dicts\n        sum_pks.pop();\n\n        let local_seed_dicts = create_local_seed_entries(&sum_pks);\n        let update_result = add_local_seed_entries(&mut client, &local_seed_dicts).await;\n        update_result.into_iter().for_each(|res| {\n            assert!(matches!(\n                res.into_inner().unwrap_err(),\n                LocalSeedDictAddError::LengthMisMatch\n            ))\n        });\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_seed_dict_unknown_sum_participant() {\n        let mut client = init_client().await;\n\n        let mut sum_pks = create_and_add_sum_participant_entries(&mut client, 2).await;\n\n        // replace a known sum_pk with an unknown one\n        sum_pks.pop();\n        let (pk, _) = create_sum_participant_entry();\n        sum_pks.push(pk);\n\n        let local_seed_dicts = create_local_seed_entries(&sum_pks);\n        let update_result = add_local_seed_entries(&mut client, &local_seed_dicts).await;\n        update_result.into_iter().for_each(|res| {\n            assert!(matches!(\n                res.into_inner().unwrap_err(),\n                LocalSeedDictAddError::UnknownSumParticipant\n            ))\n        });\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_seed_dict_update_pk_already_submitted() {\n        let mut client = init_client().await;\n        let sum_pks = create_and_add_sum_participant_entries(&mut client, 2).await;\n\n        let local_seed_dicts = create_local_seed_entries(&sum_pks);\n        let update_result = add_local_seed_entries(&mut client, &local_seed_dicts).await;\n        update_result.iter().for_each(|res| assert!(res.is_ok()));\n\n        let update_result = add_local_seed_entries(&mut client, &local_seed_dicts).await;\n        update_result.into_iter().for_each(|res| {\n            assert!(matches!(\n                res.into_inner().unwrap_err(),\n                LocalSeedDictAddError::UpdatePkAlreadySubmitted\n            ))\n        });\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_seed_dict_update_pk_already_exists_in_update_seed_dict() {\n        let mut client = init_client().await;\n        let sum_pks = create_and_add_sum_participant_entries(&mut client, 2).await;\n\n        let local_seed_dicts = create_local_seed_entries(&sum_pks);\n        let update_result = add_local_seed_entries(&mut client, &local_seed_dicts).await;\n        update_result.iter().for_each(|res| assert!(res.is_ok()));\n\n        let (update_participant, local_seed_dict) = local_seed_dicts.get(0).unwrap().clone();\n        let remove_result = client\n            .remove_update_participant(&update_participant)\n            .await\n            .unwrap();\n        assert_eq!(remove_result, 1);\n\n        let update_result =\n            add_local_seed_entries(&mut client, &[(update_participant, local_seed_dict)]).await;\n        update_result.into_iter().for_each(|res| {\n            assert!(matches!(\n                res.into_inner().unwrap_err(),\n                LocalSeedDictAddError::UpdatePkAlreadyExistsInUpdateSeedDict\n            ))\n        });\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_seed_dict_get_seed_dict_for_sum_pk() {\n        let mut client = init_client().await;\n        let mut sum_pks = create_and_add_sum_participant_entries(&mut client, 2).await;\n\n        let local_seed_dicts = create_local_seed_entries(&sum_pks);\n        let update_result = add_local_seed_entries(&mut client, &local_seed_dicts).await;\n        update_result.iter().for_each(|res| assert!(res.is_ok()));\n\n        let redis_sum_dict = client.sum_dict().await.unwrap().unwrap();\n        let seed_dict = create_seed_dict(redis_sum_dict, &local_seed_dicts);\n\n        let sum_pk = sum_pks.pop().unwrap();\n\n        let redis_sum_seed_dict = client.seed_dict_for_sum_pk(&sum_pk).await.unwrap();\n\n        assert_eq!(&redis_sum_seed_dict, seed_dict.get(&sum_pk).unwrap())\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_seed_dict_get_seed_dict_for_sum_pk_empty() {\n        let mut client = init_client().await;\n        let (sum_pk, _) = create_sum_participant_entry();\n\n        let result = client.seed_dict_for_sum_pk(&sum_pk).await.unwrap();\n        assert!(result.is_empty())\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_flush_dicts() {\n        let mut client = init_client().await;\n\n        // write some data into redis\n        let set_state = CoordinatorState::new(pet_settings(), mask_settings(), model_settings());\n        let res = client.set_coordinator_state(&set_state).await;\n        assert!(res.is_ok());\n\n        let res = client.set_latest_global_model_id(\"global_model_id\").await;\n        assert!(res.is_ok());\n\n        let sum_pks = create_and_add_sum_participant_entries(&mut client, 2).await;\n\n        let local_seed_dicts = create_local_seed_entries(&sum_pks);\n        let update_result = add_local_seed_entries(&mut client, &local_seed_dicts).await;\n        update_result.iter().for_each(|res| assert!(res.is_ok()));\n\n        let mask = create_mask_zeroed(10);\n        client\n            .incr_mask_score(sum_pks.get(0).unwrap(), &mask)\n            .await\n            .unwrap();\n\n        // remove dicts\n        let res = client.delete_dicts().await;\n        assert!(res.is_ok());\n\n        // ensure that only the coordinator state and latest global model id exists\n        let res = client.coordinator_state().await;\n        assert!(res.unwrap().is_some());\n\n        let res = client.latest_global_model_id().await;\n        assert!(res.unwrap().is_some());\n\n        let res = client.sum_dict().await;\n        assert!(res.unwrap().is_none());\n\n        let res = client.seed_dict().await;\n        assert!(res.unwrap().is_none());\n\n        let res = client.mask_submitted_set().await;\n        assert!(res.unwrap().is_empty());\n\n        let res = client.best_masks().await;\n        assert!(res.unwrap().is_none());\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_flush_coordinator_data() {\n        let mut client = init_client().await;\n\n        // write some data into redis\n        let set_state = CoordinatorState::new(pet_settings(), mask_settings(), model_settings());\n        let res = client.set_coordinator_state(&set_state).await;\n        assert!(res.is_ok());\n\n        let res = client.set_latest_global_model_id(\"global_model_id\").await;\n        assert!(res.is_ok());\n\n        let sum_pks = create_and_add_sum_participant_entries(&mut client, 2).await;\n\n        let local_seed_dicts = create_local_seed_entries(&sum_pks);\n        let update_result = add_local_seed_entries(&mut client, &local_seed_dicts).await;\n        update_result.iter().for_each(|res| assert!(res.is_ok()));\n\n        let mask = create_mask_zeroed(10);\n        client\n            .incr_mask_score(sum_pks.get(0).unwrap(), &mask)\n            .await\n            .unwrap();\n\n        // remove all coordinator data\n        let res = client.delete_coordinator_data().await;\n        assert!(res.is_ok());\n\n        let keys = client.keys().await.unwrap();\n        assert!(keys.is_empty());\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_set_and_get_latest_global_model_id() {\n        // test the writing and reading of the global model id\n        let mut client = init_client().await;\n\n        let set_id = \"global_model_id\";\n        client.set_latest_global_model_id(set_id).await.unwrap();\n\n        let get_id = client.latest_global_model_id().await.unwrap().unwrap();\n\n        assert_eq!(set_id, get_id)\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_is_ready_ok() {\n        // test is_ready command\n        let mut client = init_client().await;\n\n        let res = client.is_ready().await;\n        assert!(res.is_ok())\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_get_latest_global_model_id_empty() {\n        // test the reading of a non existing global model id\n        let mut client = init_client().await;\n\n        let get_id = client.latest_global_model_id().await.unwrap();\n\n        assert_eq!(None, get_id)\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/storage/mod.rs",
    "content": "//! Storage backends for the coordinator.\n\npub mod coordinator_storage;\npub mod model_storage;\npub mod store;\n#[cfg(test)]\npub(crate) mod tests;\npub mod traits;\npub mod trust_anchor;\n\npub use self::{\n    store::Store,\n    traits::{\n        CoordinatorStorage,\n        LocalSeedDictAdd,\n        LocalSeedDictAddError,\n        MaskScoreIncr,\n        MaskScoreIncrError,\n        ModelStorage,\n        Storage,\n        StorageError,\n        StorageResult,\n        SumPartAdd,\n        SumPartAddError,\n        TrustAnchor,\n    },\n};\n"
  },
  {
    "path": "rust/xaynet-server/src/storage/model_storage/mod.rs",
    "content": "//! Storage backends to manage global models.\n\npub mod noop;\n#[cfg(feature = \"model-persistence\")]\n#[cfg_attr(docsrs, doc(cfg(feature = \"model-persistence\")))]\npub mod s3;\n"
  },
  {
    "path": "rust/xaynet-server/src/storage/model_storage/noop.rs",
    "content": "//! A NoOp [`ModelStorage`] backend.\n\nuse crate::storage::{ModelStorage, StorageResult};\nuse async_trait::async_trait;\nuse xaynet_core::{common::RoundSeed, mask::Model};\n\n#[derive(Clone)]\npub struct NoOp;\n\n#[async_trait]\nimpl ModelStorage for NoOp {\n    async fn set_global_model(\n        &mut self,\n        round_id: u64,\n        round_seed: &RoundSeed,\n        _global_model: &Model,\n    ) -> StorageResult<String> {\n        Ok(Self::create_global_model_id(round_id, round_seed))\n    }\n\n    async fn global_model(&mut self, _id: &str) -> StorageResult<Option<Model>> {\n        Err(anyhow::anyhow!(\"No-op model store\"))\n    }\n\n    async fn is_ready(&mut self) -> StorageResult<()> {\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/storage/model_storage/s3.rs",
    "content": "//! A S3 [`ModelStorage`] backend.\n\nuse std::sync::Arc;\n\nuse async_trait::async_trait;\nuse displaydoc::Display;\nuse http::StatusCode;\nuse rusoto_core::{credential::StaticProvider, request::TlsError, HttpClient, RusotoError};\nuse rusoto_s3::{\n    CreateBucketError,\n    CreateBucketOutput,\n    CreateBucketRequest,\n    DeleteObjectsError,\n    GetObjectError,\n    GetObjectOutput,\n    GetObjectRequest,\n    HeadBucketError,\n    HeadBucketRequest,\n    ListObjectsV2Error,\n    PutObjectError,\n    PutObjectOutput,\n    PutObjectRequest,\n    S3Client,\n    StreamingBody,\n    S3,\n};\nuse thiserror::Error;\nuse tokio::io::AsyncReadExt;\nuse tracing::debug;\n\nuse crate::{\n    settings::{S3BucketsSettings, S3Settings},\n    storage::{ModelStorage, StorageResult},\n};\nuse xaynet_core::{common::RoundSeed, mask::Model};\n\ntype ClientResult<T> = Result<T, ClientError>;\n\n#[derive(Debug, Display, Error)]\npub enum ClientError {\n    /// Failed to create bucket: {0}.\n    CreateBucket(#[from] RusotoError<CreateBucketError>),\n    /// Failed to get object: {0}.\n    GetObject(#[from] RusotoError<GetObjectError>),\n    /// Failed to put object: {0}.\n    PutObject(#[from] RusotoError<PutObjectError>),\n    /// Failed to list objects: {0}.\n    ListObjects(#[from] RusotoError<ListObjectsV2Error>),\n    /// Failed to delete objects: {0}.\n    DeleteObjects(#[from] RusotoError<DeleteObjectsError>),\n    /// Failed to dispatch: {0}.\n    Dispatcher(#[from] TlsError),\n    /// Failed to serialize: {0}.\n    Serialization(bincode::Error),\n    /// Failed to deserialize: {0}.\n    Deserialization(bincode::Error),\n    /// Response contains no body.\n    NoBody,\n    /// Failed to download body: {0}.\n    DownloadBody(std::io::Error),\n    /// Object {0} already exists.\n    ObjectAlreadyExists(String),\n    /// Storage not ready: {0}.\n    NotReady(RusotoError<HeadBucketError>),\n}\n\n#[derive(Clone)]\npub struct Client {\n    buckets: Arc<S3BucketsSettings>,\n    client: S3Client,\n}\n\nimpl Client {\n    /// Creates a new S3 client. The client creates and maintains one bucket for storing global models.\n    ///\n    /// To connect to AWS-compatible services such as Minio, you need to specify a custom region.\n    /// ```\n    /// use rusoto_core::Region;\n    /// use xaynet_server::{\n    ///     settings::{S3BucketsSettings, S3Settings},\n    ///     storage::model_storage::s3::Client,\n    /// };\n    ///\n    /// let region = Region::Custom {\n    ///     name: String::from(\"minio\"),\n    ///     endpoint: String::from(\"http://127.0.0.1:9000\"), // URL of minio\n    /// };\n    ///\n    /// let s3_settings = S3Settings {\n    ///     region,\n    ///     access_key: String::from(\"minio\"),\n    ///     secret_access_key: String::from(\"minio123\"),\n    ///     buckets: S3BucketsSettings {\n    ///         global_models: String::from(\"global-models\"),\n    ///     },\n    /// };\n    ///\n    /// let store = Client::new(s3_settings).unwrap();\n    /// ```\n    pub fn new(settings: S3Settings) -> ClientResult<Self> {\n        let credentials_provider =\n            StaticProvider::new_minimal(settings.access_key, settings.secret_access_key);\n\n        let dispatcher = HttpClient::new()?;\n        Ok(Self {\n            buckets: Arc::new(settings.buckets),\n            client: S3Client::new_with(dispatcher, credentials_provider, settings.region),\n        })\n    }\n\n    /// Creates the `global models` bucket.\n    /// This method does not fail if the bucket already exists or is already owned by you.\n    pub async fn create_global_models_bucket(&self) -> ClientResult<()> {\n        debug!(\"create {} bucket\", &self.buckets.global_models);\n        match self.create_bucket(&self.buckets.global_models).await {\n            Ok(_)\n            | Err(RusotoError::Service(CreateBucketError::BucketAlreadyExists(_)))\n            | Err(RusotoError::Service(CreateBucketError::BucketAlreadyOwnedByYou(_))) => Ok(()),\n            Err(err) => Err(ClientError::from(err)),\n        }\n    }\n\n    // Downloads the content of the given object.\n    async fn download_object_body(object: GetObjectOutput) -> ClientResult<Vec<u8>> {\n        let mut body = Vec::new();\n        object\n            .body\n            .ok_or(ClientError::NoBody)?\n            .into_async_read()\n            .read_to_end(&mut body)\n            .await\n            .map_err(ClientError::DownloadBody)?;\n        Ok(body)\n    }\n\n    // Fetches the metadata of the object with the given key from the given bucket.\n    async fn fetch_object_meta(\n        &self,\n        bucket: &str,\n        key: &str,\n    ) -> Result<GetObjectOutput, RusotoError<GetObjectError>> {\n        // If an object does not exist, S3 / Minio will return an error\n        let req = GetObjectRequest {\n            bucket: bucket.to_string(),\n            key: key.to_string(),\n            ..Default::default()\n        };\n        self.client.get_object(req).await\n    }\n\n    // Uploads an object with the given key to the given bucket.\n    async fn upload_object(\n        &self,\n        bucket: &str,\n        key: &str,\n        data: Vec<u8>,\n    ) -> Result<PutObjectOutput, RusotoError<PutObjectError>> {\n        let req = PutObjectRequest {\n            bucket: bucket.to_string(),\n            key: key.to_string(),\n            body: Some(StreamingBody::from(data)),\n            ..Default::default()\n        };\n        self.client.put_object(req).await\n    }\n\n    // Creates a new bucket with the given bucket name.\n    async fn create_bucket(\n        &self,\n        bucket: &str,\n    ) -> Result<CreateBucketOutput, RusotoError<CreateBucketError>> {\n        let req = CreateBucketRequest {\n            bucket: bucket.to_string(),\n            ..Default::default()\n        };\n        self.client.create_bucket(req).await\n    }\n}\n\n#[async_trait]\nimpl ModelStorage for Client {\n    async fn set_global_model(\n        &mut self,\n        round_id: u64,\n        round_seed: &RoundSeed,\n        global_model: &Model,\n    ) -> StorageResult<String> {\n        let id = Self::create_global_model_id(round_id, round_seed);\n\n        debug!(\"upload global model: {}\", id);\n        let output = self\n            .fetch_object_meta(&self.buckets.global_models, &id)\n            .await;\n        if output.is_ok() {\n            return Err(anyhow::anyhow!(ClientError::ObjectAlreadyExists(\n                id.to_string()\n            )));\n        };\n\n        let data = bincode::serialize(global_model).map_err(ClientError::Serialization)?;\n        self.upload_object(&self.buckets.global_models, &id, data)\n            .await\n            .map(|_| Ok(id))?\n    }\n\n    async fn global_model(&mut self, id: &str) -> StorageResult<Option<Model>> {\n        debug!(\"download global model {}\", id);\n        let output = self\n            .fetch_object_meta(&self.buckets.global_models, id)\n            .await;\n        let object_meta = match output {\n            Err(RusotoError::Service(GetObjectError::NoSuchKey(_))) => return Ok(None),\n            Err(err) => return Err(anyhow::anyhow!(err)),\n            Ok(object) => object,\n        };\n\n        let body = Self::download_object_body(object_meta).await?;\n        let model = bincode::deserialize(&body).map_err(ClientError::Deserialization)?;\n        Ok(Some(model))\n    }\n\n    async fn is_ready(&mut self) -> StorageResult<()> {\n        let req = HeadBucketRequest {\n            // we can't use an empty string because S3/Minio would return BAD_REQUEST\n            bucket: self.buckets.global_models.clone(),\n            ..Default::default()\n        };\n        let res = self.client.head_bucket(req).await;\n\n        match res {\n            // rusoto doesn't return NoSuchBucket if the bucket doesn't exist\n            // https://github.com/rusoto/rusoto/issues/1099\n            //\n            // a workaround is to check if the StatusCode is NOT_FOUND\n            Err(RusotoError::Service(HeadBucketError::NoSuchBucket(_))) | Ok(_) => Ok(()),\n            Err(RusotoError::Unknown(resp)) => match resp.status {\n                // https://github.com/timberio/vector/blob/803c68c031e5872876e1167c428cd41358123d64/src/sinks/aws_s3.rs#L229\n                StatusCode::NOT_FOUND => Ok(()),\n                _ => Err(anyhow::anyhow!(ClientError::NotReady(\n                    RusotoError::Unknown(resp)\n                ))),\n            },\n            Err(e) => Err(anyhow::anyhow!(ClientError::NotReady(e))),\n        }\n    }\n}\n\n#[cfg(test)]\npub(in crate) mod tests {\n    use super::*;\n    use crate::storage::tests::utils::create_global_model;\n    use rusoto_core::Region;\n    use rusoto_s3::{\n        Delete,\n        DeleteBucketError,\n        DeleteBucketRequest,\n        DeleteObjectsOutput,\n        DeleteObjectsRequest,\n        ListObjectsV2Output,\n        ListObjectsV2Request,\n        ObjectIdentifier,\n    };\n    use serial_test::serial;\n\n    use xaynet_core::{common::RoundSeed, crypto::ByteObject};\n\n    impl Client {\n        // Deletes all objects in a bucket.\n        pub async fn clear_bucket(&self, bucket: &str) -> ClientResult<()> {\n            let mut continuation_token: Option<String> = None;\n\n            loop {\n                let list_obj_resp = self.list_objects(bucket, continuation_token).await?;\n\n                if let Some(identifiers) = Self::unpack_object_identifier(&list_obj_resp) {\n                    self.delete_objects(bucket, identifiers).await?;\n                } else {\n                    break;\n                }\n\n                // check if more objects exist\n                continuation_token = Self::unpack_next_continuation_token(&list_obj_resp);\n                if continuation_token.is_none() {\n                    break;\n                }\n            }\n            Ok(())\n        }\n\n        // Unpacks the object identifier/keys of a [`ListObjectsV2Output`] response.\n        fn unpack_object_identifier(\n            list_obj_resp: &ListObjectsV2Output,\n        ) -> Option<Vec<ObjectIdentifier>> {\n            if let Some(objects) = &list_obj_resp.contents {\n                let keys = objects\n                    .iter()\n                    .filter_map(|obj| obj.key.clone())\n                    .map(|key| ObjectIdentifier {\n                        key,\n                        ..Default::default()\n                    })\n                    .collect();\n                Some(keys)\n            } else {\n                None\n            }\n        }\n\n        // Deletes the objects of the given bucket.\n        async fn delete_objects(\n            &self,\n            bucket: &str,\n            identifiers: Vec<ObjectIdentifier>,\n        ) -> Result<DeleteObjectsOutput, RusotoError<DeleteObjectsError>> {\n            let req = DeleteObjectsRequest {\n                bucket: bucket.to_string(),\n                delete: Delete {\n                    objects: identifiers,\n                    ..Default::default()\n                },\n                ..Default::default()\n            };\n\n            self.client.delete_objects(req).await.map_err(From::from)\n        }\n\n        // Returns all object keys for the given bucket.\n        async fn list_objects(\n            &self,\n            bucket: &str,\n            continuation_token: Option<String>,\n        ) -> Result<ListObjectsV2Output, RusotoError<ListObjectsV2Error>> {\n            let req = ListObjectsV2Request {\n                bucket: bucket.to_string(),\n                continuation_token,\n                // the S3 response is limited to 1000 keys max.\n                // https://docs.aws.amazon.com/AWSJavaScriptSDK/latest/AWS/S3.html#listObjectsV2-property\n                // However, Minio could return more.\n                max_keys: Some(1000),\n                ..Default::default()\n            };\n\n            self.client.list_objects_v2(req).await.map_err(From::from)\n        }\n\n        // Unpacks the next_continuation_token of the [`ListObjectsV2Output`] response.\n        fn unpack_next_continuation_token(list_obj_resp: &ListObjectsV2Output) -> Option<String> {\n            // https://docs.aws.amazon.com/AmazonS3/latest/dev/ListingObjectKeysUsingJava.html\n            if let Some(is_truncated) = list_obj_resp.is_truncated {\n                if is_truncated {\n                    list_obj_resp.next_continuation_token.clone()\n                } else {\n                    None\n                }\n            } else {\n                None\n            }\n        }\n\n        async fn delete_bucket(&self, bucket: &str) -> Result<(), RusotoError<DeleteBucketError>> {\n            let req = DeleteBucketRequest {\n                bucket: bucket.to_string(),\n                ..Default::default()\n            };\n            self.client.delete_bucket(req).await\n        }\n    }\n\n    fn create_minio_setup(url: &str) -> S3Settings {\n        let region = Region::Custom {\n            name: String::from(\"minio\"),\n            endpoint: String::from(url),\n        };\n\n        S3Settings {\n            region,\n            access_key: String::from(\"minio\"),\n            secret_access_key: String::from(\"minio123\"),\n            buckets: S3BucketsSettings::default(),\n        }\n    }\n\n    pub async fn init_client() -> Client {\n        let settings = create_minio_setup(\"http://localhost:9000\");\n        let client = Client::new(settings).unwrap();\n        client.create_global_models_bucket().await.unwrap();\n        client.clear_bucket(\"global-models\").await.unwrap();\n        client\n    }\n\n    async fn init_disconnected_client() -> Client {\n        let settings = create_minio_setup(\"http://localhost:11000\");\n        Client::new(settings).unwrap()\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_test_set_and_get_global_model() {\n        let mut client = init_client().await;\n\n        let global_model = create_global_model(10);\n        let id = client\n            .set_global_model(1, &RoundSeed::generate(), &global_model)\n            .await\n            .unwrap();\n\n        let downloaded_global_model = client.global_model(&id).await.unwrap().unwrap();\n        assert_eq!(global_model, downloaded_global_model)\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_test_get_global_model_non_existent() {\n        let mut client = init_client().await;\n\n        let id = Client::create_global_model_id(1, &RoundSeed::generate());\n        let res = client.global_model(&id).await.unwrap();\n        assert!(res.is_none())\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_test_global_model_already_exists() {\n        let mut client = init_client().await;\n\n        let global_model = create_global_model(10);\n        let round_seed = RoundSeed::generate();\n        let id = client\n            .set_global_model(1, &round_seed, &global_model)\n            .await\n            .unwrap();\n\n        let global_model_2 = create_global_model(20);\n        let res = client\n            .set_global_model(1, &round_seed, &global_model_2)\n            .await\n            .unwrap_err();\n        assert!(matches!(\n            res.downcast_ref::<ClientError>().unwrap(),\n            ClientError::ObjectAlreadyExists(_)\n        ));\n\n        let downloaded_global_model = client.global_model(&id).await.unwrap().unwrap();\n        assert_eq!(global_model, downloaded_global_model)\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_test_is_ready_ok() {\n        let mut client = init_client().await;\n\n        let res = client.is_ready().await;\n        assert!(res.is_ok())\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_test_is_ready_ok_no_such_bucket() {\n        // test that is_ready returns Ok even if the bucket doesn't exist\n        let mut client = init_client().await;\n        client\n            .delete_bucket(&S3BucketsSettings::default().global_models)\n            .await\n            .unwrap();\n\n        let res = client.is_ready().await;\n        assert!(res.is_ok())\n    }\n\n    #[tokio::test]\n    #[serial]\n    #[ignore]\n    async fn integration_test_is_ready_err() {\n        let mut client = init_disconnected_client().await;\n\n        let res = client.is_ready().await;\n        assert!(res.is_err())\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/storage/store.rs",
    "content": "//! A generic store.\n\nuse async_trait::async_trait;\n\nuse crate::{\n    state_machine::coordinator::CoordinatorState,\n    storage::{\n        trust_anchor::noop::NoOp,\n        CoordinatorStorage,\n        LocalSeedDictAdd,\n        MaskScoreIncr,\n        ModelStorage,\n        Storage,\n        StorageResult,\n        SumPartAdd,\n        TrustAnchor,\n    },\n};\nuse xaynet_core::{\n    common::RoundSeed,\n    mask::{MaskObject, Model},\n    LocalSeedDict,\n    SeedDict,\n    SumDict,\n    SumParticipantEphemeralPublicKey,\n    SumParticipantPublicKey,\n    UpdateParticipantPublicKey,\n};\n\n#[derive(Clone)]\n/// A generic store.\npub struct Store<C, M, T>\nwhere\n    C: CoordinatorStorage,\n    M: ModelStorage,\n    T: TrustAnchor,\n{\n    /// A coordinator store.\n    coordinator: C,\n    /// A model store.\n    model: M,\n    /// A trust anchor.\n    trust_anchor: T,\n}\n\nimpl<C, M, T> Store<C, M, T>\nwhere\n    C: CoordinatorStorage,\n    M: ModelStorage,\n    T: TrustAnchor,\n{\n    pub fn new_with_trust_anchor(coordinator: C, model: M, trust_anchor: T) -> Self {\n        Self {\n            coordinator,\n            model,\n            trust_anchor,\n        }\n    }\n}\n\nimpl<C, M> Store<C, M, NoOp>\nwhere\n    C: CoordinatorStorage,\n    M: ModelStorage,\n{\n    /// Creates a new [`Store`].\n    pub fn new(coordinator: C, model: M) -> Self {\n        Self {\n            coordinator,\n            model,\n            trust_anchor: NoOp,\n        }\n    }\n}\n\n#[async_trait]\nimpl<C, M, T> CoordinatorStorage for Store<C, M, T>\nwhere\n    C: CoordinatorStorage,\n    M: ModelStorage,\n    T: TrustAnchor,\n{\n    async fn set_coordinator_state(&mut self, state: &CoordinatorState) -> StorageResult<()> {\n        self.coordinator.set_coordinator_state(state).await\n    }\n\n    async fn coordinator_state(&mut self) -> StorageResult<Option<CoordinatorState>> {\n        self.coordinator.coordinator_state().await\n    }\n\n    async fn add_sum_participant(\n        &mut self,\n        pk: &SumParticipantPublicKey,\n        ephm_pk: &SumParticipantEphemeralPublicKey,\n    ) -> StorageResult<SumPartAdd> {\n        self.coordinator.add_sum_participant(pk, ephm_pk).await\n    }\n\n    async fn sum_dict(&mut self) -> StorageResult<Option<SumDict>> {\n        self.coordinator.sum_dict().await\n    }\n\n    async fn add_local_seed_dict(\n        &mut self,\n        update_pk: &UpdateParticipantPublicKey,\n        local_seed_dict: &LocalSeedDict,\n    ) -> StorageResult<LocalSeedDictAdd> {\n        self.coordinator\n            .add_local_seed_dict(update_pk, local_seed_dict)\n            .await\n    }\n\n    async fn seed_dict(&mut self) -> StorageResult<Option<SeedDict>> {\n        self.coordinator.seed_dict().await\n    }\n\n    async fn incr_mask_score(\n        &mut self,\n        pk: &SumParticipantPublicKey,\n        mask: &MaskObject,\n    ) -> StorageResult<MaskScoreIncr> {\n        self.coordinator.incr_mask_score(pk, mask).await\n    }\n\n    async fn best_masks(&mut self) -> StorageResult<Option<Vec<(MaskObject, u64)>>> {\n        self.coordinator.best_masks().await\n    }\n\n    async fn number_of_unique_masks(&mut self) -> StorageResult<u64> {\n        self.coordinator.number_of_unique_masks().await\n    }\n\n    async fn delete_coordinator_data(&mut self) -> StorageResult<()> {\n        self.coordinator.delete_coordinator_data().await\n    }\n\n    async fn delete_dicts(&mut self) -> StorageResult<()> {\n        self.coordinator.delete_dicts().await\n    }\n\n    async fn set_latest_global_model_id(&mut self, id: &str) -> StorageResult<()> {\n        self.coordinator.set_latest_global_model_id(id).await\n    }\n\n    async fn latest_global_model_id(&mut self) -> StorageResult<Option<String>> {\n        self.coordinator.latest_global_model_id().await\n    }\n\n    async fn is_ready(&mut self) -> StorageResult<()> {\n        self.coordinator.is_ready().await\n    }\n}\n\n#[async_trait]\nimpl<C, M, T> ModelStorage for Store<C, M, T>\nwhere\n    C: CoordinatorStorage,\n    M: ModelStorage,\n    T: TrustAnchor,\n{\n    async fn set_global_model(\n        &mut self,\n        round_id: u64,\n        round_seed: &RoundSeed,\n        global_model: &Model,\n    ) -> StorageResult<String> {\n        self.model\n            .set_global_model(round_id, round_seed, global_model)\n            .await\n    }\n\n    async fn global_model(&mut self, id: &str) -> StorageResult<Option<Model>> {\n        self.model.global_model(id).await\n    }\n\n    async fn is_ready(&mut self) -> StorageResult<()> {\n        self.model.is_ready().await\n    }\n}\n\n#[async_trait]\nimpl<C, M, T> TrustAnchor for Store<C, M, T>\nwhere\n    C: CoordinatorStorage,\n    M: ModelStorage,\n    T: TrustAnchor,\n{\n    async fn publish_proof(&mut self, global_model: &Model) -> StorageResult<()> {\n        self.trust_anchor.publish_proof(global_model).await\n    }\n\n    async fn is_ready(&mut self) -> StorageResult<()> {\n        self.trust_anchor.is_ready().await\n    }\n}\n\n#[async_trait]\nimpl<C, M, T> Storage for Store<C, M, T>\nwhere\n    C: CoordinatorStorage,\n    M: ModelStorage,\n    T: TrustAnchor,\n{\n    async fn is_ready(&mut self) -> StorageResult<()> {\n        tokio::try_join!(\n            self.coordinator.is_ready(),\n            self.model.is_ready(),\n            self.trust_anchor.is_ready()\n        )\n        .map(|_| ())\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/storage/tests/mod.rs",
    "content": "use crate::{\n    state_machine::coordinator::CoordinatorState,\n    storage::{\n        coordinator_storage::redis,\n        model_storage,\n        CoordinatorStorage,\n        LocalSeedDictAdd,\n        MaskScoreIncr,\n        ModelStorage,\n        Storage,\n        StorageResult,\n        Store,\n        SumPartAdd,\n        TrustAnchor,\n    },\n};\nuse async_trait::async_trait;\nuse mockall::*;\nuse xaynet_core::{\n    common::RoundSeed,\n    mask::{MaskObject, Model},\n    LocalSeedDict,\n    SeedDict,\n    SumDict,\n    SumParticipantEphemeralPublicKey,\n    SumParticipantPublicKey,\n    UpdateParticipantPublicKey,\n};\n\npub mod utils;\n\npub async fn init_store() -> impl Storage {\n    let coordinator_store = redis::tests::init_client().await;\n\n    let model_store = {\n        #[cfg(not(feature = \"model-persistence\"))]\n        {\n            model_storage::noop::NoOp\n        }\n\n        #[cfg(feature = \"model-persistence\")]\n        {\n            model_storage::s3::tests::init_client().await\n        }\n    };\n\n    Store::new(coordinator_store, model_store)\n}\n\nmock! {\n    pub CoordinatorStore {}\n\n    #[async_trait]\n    impl CoordinatorStorage for CoordinatorStore {\n        async fn set_coordinator_state(&mut self, state: &CoordinatorState) -> StorageResult<()>;\n        async fn coordinator_state(&mut self) -> StorageResult<Option<CoordinatorState>>;\n        async fn add_sum_participant(\n            &mut self,\n            pk: &SumParticipantPublicKey,\n            ephm_pk: &SumParticipantEphemeralPublicKey,\n        ) -> StorageResult<SumPartAdd>;\n        async fn sum_dict(&mut self) -> StorageResult<Option<SumDict>>;\n        async fn add_local_seed_dict(\n            &mut self,\n            update_pk: &UpdateParticipantPublicKey,\n            local_seed_dict: &LocalSeedDict,\n        ) -> StorageResult<LocalSeedDictAdd>;\n        async fn seed_dict(&mut self) -> StorageResult<Option<SeedDict>>;\n        async fn incr_mask_score(\n            &mut self,\n            pk: &SumParticipantPublicKey,\n            mask: &MaskObject,\n        ) -> StorageResult<MaskScoreIncr>;\n        async fn best_masks(&mut self) -> StorageResult<Option<Vec<(MaskObject, u64)>>>;\n        async fn number_of_unique_masks(&mut self) -> StorageResult<u64>;\n        async fn delete_coordinator_data(&mut self) -> StorageResult<()>;\n        async fn delete_dicts(&mut self) -> StorageResult<()>;\n        async fn set_latest_global_model_id(&mut self, id: &str) -> StorageResult<()>;\n        async fn latest_global_model_id(&mut self) -> StorageResult<Option<String>>;\n        async fn is_ready(&mut self) -> StorageResult<()>;\n    }\n\n    impl Clone for CoordinatorStore {\n        fn clone(&self) -> Self;\n    }\n}\n\nmock! {\n    pub ModelStore {}\n\n    #[async_trait]\n    impl ModelStorage for ModelStore {\n        async fn set_global_model(\n            &mut self,\n            round_id: u64,\n            round_seed: &RoundSeed,\n            global_model: &Model,\n        ) -> StorageResult<String>;\n        async fn global_model(&mut self, id: &str) -> StorageResult<Option<Model>>;\n        async fn is_ready(&mut self) -> StorageResult<()>;\n    }\n\n    impl Clone for ModelStore {\n        fn clone(&self) -> Self;\n    }\n\n}\n\nmock! {\n    pub TrustAnchor {}\n\n    #[async_trait]\n    impl TrustAnchor for TrustAnchor {\n        async fn publish_proof(&mut self, global_model: &Model) -> StorageResult<()>;\n        async fn is_ready(&mut self) -> StorageResult<()>;\n    }\n\n    impl Clone for TrustAnchor {\n        fn clone(&self) -> Self;\n    }\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/storage/tests/utils.rs",
    "content": "use num::{bigint::BigUint, traits::identities::Zero};\n\nuse crate::{\n    state_machine::tests::utils::mask_settings,\n    storage::{CoordinatorStorage, LocalSeedDictAdd},\n};\nuse xaynet_core::{\n    crypto::{ByteObject, EncryptKeyPair, SigningKeyPair},\n    mask::{EncryptedMaskSeed, MaskConfig, MaskObject},\n    LocalSeedDict,\n    SeedDict,\n    SumDict,\n    SumParticipantEphemeralPublicKey,\n    SumParticipantPublicKey,\n    UpdateParticipantPublicKey,\n};\n\npub fn create_sum_participant_entry() -> (SumParticipantPublicKey, SumParticipantEphemeralPublicKey)\n{\n    let SigningKeyPair { public: pk, .. } = SigningKeyPair::generate();\n    let EncryptKeyPair {\n        public: ephm_pk, ..\n    } = EncryptKeyPair::generate();\n    (pk, ephm_pk)\n}\n\npub fn create_local_seed_entries(\n    sum_pks: &[SumParticipantPublicKey],\n) -> Vec<(UpdateParticipantPublicKey, LocalSeedDict)> {\n    let mut entries = Vec::new();\n\n    for _ in 0..sum_pks.len() {\n        let SigningKeyPair {\n            public: update_pk, ..\n        } = SigningKeyPair::generate();\n\n        let mut local_seed_dict = LocalSeedDict::new();\n        for sum_pk in sum_pks {\n            let seed = EncryptedMaskSeed::zeroed();\n            local_seed_dict.insert(*sum_pk, seed);\n        }\n        entries.push((update_pk, local_seed_dict))\n    }\n\n    entries\n}\n\npub fn create_mask_zeroed(model_length: usize) -> MaskObject {\n    MaskObject::new(\n        MaskConfig::from(mask_settings()).into(),\n        vec![BigUint::zero(); model_length],\n        BigUint::zero(),\n    )\n    .unwrap()\n}\n\npub fn create_mask(model_length: usize, number: u32) -> MaskObject {\n    MaskObject::new(\n        MaskConfig::from(mask_settings()).into(),\n        vec![BigUint::from(number); model_length],\n        BigUint::zero(),\n    )\n    .unwrap()\n}\n\npub fn create_seed_dict(\n    sum_dict: SumDict,\n    seed_updates: &[(UpdateParticipantPublicKey, LocalSeedDict)],\n) -> SeedDict {\n    let mut seed_dict: SeedDict = sum_dict\n        .keys()\n        .map(|pk| (*pk, LocalSeedDict::new()))\n        .collect();\n\n    for (pk, local_seed_dict) in seed_updates {\n        for (sum_pk, seed) in local_seed_dict {\n            seed_dict.get_mut(sum_pk).unwrap().insert(*pk, seed.clone());\n        }\n    }\n\n    seed_dict\n}\n\npub async fn create_and_add_sum_participant_entries(\n    client: &mut impl CoordinatorStorage,\n    n: u32,\n) -> Vec<SumParticipantPublicKey> {\n    let mut sum_pks = Vec::new();\n    for _ in 0..n {\n        let (pk, ephm_pk) = create_sum_participant_entry();\n\n        let _ = client.add_sum_participant(&pk, &ephm_pk).await.unwrap();\n        sum_pks.push(pk);\n    }\n\n    sum_pks\n}\n\npub async fn add_local_seed_entries(\n    client: &mut impl CoordinatorStorage,\n    local_seed_entries: &[(UpdateParticipantPublicKey, LocalSeedDict)],\n) -> Vec<LocalSeedDictAdd> {\n    let mut update_result = Vec::new();\n\n    for (update_pk, local_seed_dict) in local_seed_entries {\n        let res = client.add_local_seed_dict(update_pk, local_seed_dict).await;\n        assert!(res.is_ok());\n        update_result.push(res.unwrap())\n    }\n\n    update_result\n}\n\nuse xaynet_core::mask::{FromPrimitives, Model};\n\npub fn create_global_model(model_length: usize) -> Model {\n    Model::from_primitives(vec![0; model_length].into_iter()).unwrap()\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/storage/traits.rs",
    "content": "//! Storage API.\n\nuse async_trait::async_trait;\nuse derive_more::Deref;\nuse displaydoc::Display;\nuse num_enum::TryFromPrimitive;\nuse thiserror::Error;\n\nuse crate::state_machine::coordinator::CoordinatorState;\nuse xaynet_core::{\n    common::RoundSeed,\n    crypto::ByteObject,\n    mask::{MaskObject, Model},\n    LocalSeedDict,\n    SeedDict,\n    SumDict,\n    SumParticipantEphemeralPublicKey,\n    SumParticipantPublicKey,\n    UpdateParticipantPublicKey,\n};\n\n/// The error type for storage operations that are not directly related to application domain.\n/// These include, for example IO errors like broken pipe, file not found, out-of-memory, etc.\npub type StorageError = anyhow::Error;\n\n/// The result of the storage operation.\npub type StorageResult<T> = Result<T, StorageError>;\n\n#[async_trait]\n/// An abstract coordinator storage.\npub trait CoordinatorStorage\nwhere\n    Self: Clone + Send + Sync + 'static,\n{\n    /// Sets a [`CoordinatorState`].\n    ///\n    /// # Behavior\n    ///\n    /// - If no state has been set yet, set the state and return `StorageResult::Ok(())`.\n    /// - If a state already exists, override the state and return `StorageResult::Ok(())`.\n    async fn set_coordinator_state(&mut self, state: &CoordinatorState) -> StorageResult<()>;\n\n    /// Returns a [`CoordinatorState`].\n    ///\n    /// # Behavior\n    ///\n    /// - If no state has been set yet, return `StorageResult::Ok(Option::None)`.\n    /// - If a state exists, return `StorageResult::Ok(Some(CoordinatorState))`.\n    async fn coordinator_state(&mut self) -> StorageResult<Option<CoordinatorState>>;\n\n    /// Adds a sum participant entry to the [`SumDict`].\n    ///\n    /// # Behavior\n    ///\n    /// - If a sum participant has been successfully added, return `StorageResult::Ok(SumPartAdd)`\n    ///   containing a `Result::Ok(())`.\n    /// - If the participant could not be added due to a PET protocol error, return\n    ///   the corresponding `StorageResult::Ok(SumPartAdd)` containing a\n    ///   `Result::Err(SumPartAddError)`.\n    async fn add_sum_participant(\n        &mut self,\n        pk: &SumParticipantPublicKey,\n        ephm_pk: &SumParticipantEphemeralPublicKey,\n    ) -> StorageResult<SumPartAdd>;\n\n    /// Returns the [`SumDict`].\n    ///\n    /// # Behavior\n    ///\n    /// - If the sum dict does not exist, return `StorageResult::Ok(Option::None)`.\n    /// - If the sum dict exists, return `StorageResult::Ok(Option::Some(SumDict))`.\n    async fn sum_dict(&mut self) -> StorageResult<Option<SumDict>>;\n\n    /// Adds a local [`LocalSeedDict`] of the given [`UpdateParticipantPublicKey`] to the [`SeedDict`].\n    ///\n    /// # Behavior\n    ///\n    /// - If the local seed dict has been successfully added, return\n    ///   `StorageResult::Ok(LocalSeedDictAdd)` containing a `Result::Ok(())`.\n    /// - If the local seed dict could not be added due to a PET protocol error, return\n    ///   the corresponding `StorageResult::Ok(LocalSeedDictAdd)` containing a\n    ///   `Result::Err(LocalSeedDictAddError)`.\n    async fn add_local_seed_dict(\n        &mut self,\n        update_pk: &UpdateParticipantPublicKey,\n        local_seed_dict: &LocalSeedDict,\n    ) -> StorageResult<LocalSeedDictAdd>;\n\n    /// Returns the [`SeedDict`].\n    ///\n    /// # Behavior\n    ///\n    /// - If the seed dict does not exist, return `StorageResult::Ok(Option::None)`.\n    /// - If the seed dict exists, return `StorageResult::Ok(Option::Some(SeedDict))`.\n    async fn seed_dict(&mut self) -> StorageResult<Option<SeedDict>>;\n\n    /// Increments the mask score with the given [`MaskObject`]b by one.\n    ///\n    /// # Behavior\n    ///\n    /// - If the mask score has been successfully incremented, return\n    ///   `StorageResult::Ok(MaskScoreIncr)` containing a `Result::Ok(())`.\n    /// - If the mask score could not be incremented due to a PET protocol error,\n    ///   return the corresponding `Result::Ok(MaskScoreIncr)` containing a\n    ///   `Result::Err(MaskScoreIncrError)`.\n    async fn incr_mask_score(\n        &mut self,\n        pk: &SumParticipantPublicKey,\n        mask: &MaskObject,\n    ) -> StorageResult<MaskScoreIncr>;\n\n    /// Returns the two masks with the highest score.\n    ///\n    /// # Behavior\n    ///\n    /// - If no masks exist, return `Result::Ok(Option::None)`.\n    /// - If only one mask exists, return this mask\n    ///   `StorageResult::Ok(Option::Some(Vec<(MaskObject, u64)>))`.\n    /// - If two masks exist with the same score, return both\n    ///   `StorageResult::Ok(Option::Some(Vec<(MaskObject, u64)>))`.\n    /// - If two masks exist with the different score, return\n    ///   both in descending order `StorageResult::Ok(Option::Some(Vec<(MaskObject, u64)>))`.\n    async fn best_masks(&mut self) -> StorageResult<Option<Vec<(MaskObject, u64)>>>;\n\n    /// Returns the number of unique masks.\n    async fn number_of_unique_masks(&mut self) -> StorageResult<u64>;\n\n    /// Deletes all coordinator data. This includes the coordinator\n    /// state as well as the [`SumDict`], [`SeedDict`] and `mask` dictionary.\n    async fn delete_coordinator_data(&mut self) -> StorageResult<()>;\n\n    /// Deletes the [`SumDict`], [`SeedDict`] and `mask` dictionary.\n    async fn delete_dicts(&mut self) -> StorageResult<()>;\n\n    /// Sets the latest global model id.\n    ///\n    /// # Behavior\n    ///\n    /// - If no global model id has been set yet, set the new id and return `StorageResult::Ok(())`.\n    /// - If the global model id already exists, override with the new id and\n    ///   return `StorageResult::Ok(())`.\n    async fn set_latest_global_model_id(&mut self, id: &str) -> StorageResult<()>;\n\n    /// Returns the latest global model id.\n    ///\n    /// # Behavior\n    ///\n    /// - If the global model id does not exist, return `StorageResult::Ok(None)`.\n    /// - If the global model id exists, return `StorageResult::Ok(Some(String)))`.\n    async fn latest_global_model_id(&mut self) -> StorageResult<Option<String>>;\n\n    /// Checks if the [`CoordinatorStorage`] is ready to process requests.\n    ///\n    /// # Behavior\n    ///\n    /// If the [`CoordinatorStorage`] is ready to process requests, return `StorageResult::Ok(())`.\n    /// If the [`CoordinatorStorage`] cannot process requests because of a connection error,\n    /// for example, return `StorageResult::Err(error)`.\n    async fn is_ready(&mut self) -> StorageResult<()>;\n}\n\n#[async_trait]\n/// An abstract model storage.\npub trait ModelStorage\nwhere\n    Self: Clone + Send + Sync + 'static,\n{\n    /// Sets a global model.\n    ///\n    /// # Behavior\n    ///\n    /// - If the global model already exists (has the same model id), return\n    ///   `StorageResult::Err(StorageError))`.\n    /// - If the global model does not exist, set the model and return `StorageResult::Ok(String)`\n    async fn set_global_model(\n        &mut self,\n        round_id: u64,\n        round_seed: &RoundSeed,\n        global_model: &Model,\n    ) -> StorageResult<String>;\n\n    /// Returns a global model.\n    ///\n    /// # Behavior\n    ///\n    /// - If the global model does not exist, return `StorageResult::Ok(Option::None)`.\n    /// - If the global model exists, return `StorageResult::Ok(Option::Some(Model))`.\n    async fn global_model(&mut self, id: &str) -> StorageResult<Option<Model>>;\n\n    /// Creates a unique global model id by using the round id and the round seed in which\n    /// the global model was created.\n    ///\n    /// The format of the default implementation is `roundid_roundseed`,\n    /// where the [`RoundSeed`] is encoded in hexadecimal.\n    fn create_global_model_id(round_id: u64, round_seed: &RoundSeed) -> String {\n        let round_seed = hex::encode(round_seed.as_slice());\n        format!(\"{}_{}\", round_id, round_seed)\n    }\n\n    /// Checks if the [`ModelStorage`] is ready to process requests.\n    ///\n    /// # Behavior\n    ///\n    /// If the [`ModelStorage`] is ready to process requests, return `StorageResult::Ok(())`.\n    /// If the [`ModelStorage`] cannot process requests because of a connection error,\n    /// for example, return `StorageResult::Err(error)`.\n    async fn is_ready(&mut self) -> StorageResult<()>;\n}\n\n#[async_trait]\n/// An abstract trust anchor provider.\npub trait TrustAnchor\nwhere\n    Self: Clone + Send + Sync + 'static,\n{\n    /// Publishes a proof of the global model.\n    ///\n    /// # Behavior\n    ///\n    /// Return `StorageResult::Ok(())` if the proof was published successfully,\n    /// otherwise return `StorageResult::Err(error)`.\n    async fn publish_proof(&mut self, global_model: &Model) -> StorageResult<()>;\n\n    /// Checks if the [`TrustAnchor`] is ready to process requests.\n    ///\n    /// # Behavior\n    ///\n    /// If the [`TrustAnchor`] is ready to process requests, return `StorageResult::Ok(())`.\n    /// If the [`TrustAnchor`] cannot process requests because of a connection error,\n    /// for example, return `StorageResult::Err(error)`.\n    async fn is_ready(&mut self) -> StorageResult<()>;\n}\n\n#[async_trait]\npub trait Storage: CoordinatorStorage + ModelStorage + TrustAnchor {\n    /// Checks if the [`CoordinatorStorage`], [`ModelStorage`] and  [`TrustAnchor`]\n    /// are ready to process requests.\n    ///\n    /// # Behavior\n    ///\n    /// If all inner services are ready to process requests,\n    /// return `StorageResult::Ok(())`.\n    /// If any inner service cannot process requests because of a connection error,\n    /// for example, return `StorageResult::Err(error)`.\n    async fn is_ready(&mut self) -> StorageResult<()>;\n}\n\n/// A wrapper that contains the result of the \"add sum participant\" operation.\n#[derive(Deref)]\npub struct SumPartAdd(pub(crate) Result<(), SumPartAddError>);\n\nimpl SumPartAdd {\n    /// Unwraps this wrapper, returning the underlying result.\n    pub fn into_inner(self) -> Result<(), SumPartAddError> {\n        self.0\n    }\n}\n\n/// Error that can occur when adding a sum participant to the [`SumDict`].\n#[derive(Display, Error, Debug, TryFromPrimitive)]\n#[repr(i64)]\npub enum SumPartAddError {\n    /// sum participant already exists\n    AlreadyExists = 0,\n}\n\n/// A wrapper that contains the result of the \"add local seed dict\" operation.\n#[derive(Deref)]\npub struct LocalSeedDictAdd(pub(crate) Result<(), LocalSeedDictAddError>);\n\nimpl LocalSeedDictAdd {\n    /// Unwraps this wrapper, returning the underlying result.\n    pub fn into_inner(self) -> Result<(), LocalSeedDictAddError> {\n        self.0\n    }\n}\n\n/// Error that can occur when adding a local seed dict to the [`SeedDict`].\n#[derive(Display, Error, Debug, TryFromPrimitive)]\n#[repr(i64)]\npub enum LocalSeedDictAddError {\n    /// the length of the local seed dict and the length of sum dict are not equal\n    LengthMisMatch = -1,\n    /// local dict contains an unknown sum participant\n    UnknownSumParticipant = -2,\n    /// update participant already submitted an update\n    UpdatePkAlreadySubmitted = -3,\n    /// update participant already exists in the inner update seed dict\n    UpdatePkAlreadyExistsInUpdateSeedDict = -4,\n}\n\n/// A wrapper that contains the result of the \"increment mask score\" operation.\n#[derive(Deref)]\npub struct MaskScoreIncr(pub(crate) Result<(), MaskScoreIncrError>);\n\nimpl MaskScoreIncr {\n    /// Unwraps this wrapper, returning the underlying result.\n    pub fn into_inner(self) -> Result<(), MaskScoreIncrError> {\n        self.0\n    }\n}\n\n/// Error that can occur when incrementing a mask score.\n#[derive(Display, Error, Debug, TryFromPrimitive)]\n#[repr(i64)]\npub enum MaskScoreIncrError {\n    /// unknown sum participant\n    UnknownSumPk = -1,\n    /// sum participant submitted a mask already\n    MaskAlreadySubmitted = -2,\n}\n"
  },
  {
    "path": "rust/xaynet-server/src/storage/trust_anchor/mod.rs",
    "content": "pub mod noop;\n"
  },
  {
    "path": "rust/xaynet-server/src/storage/trust_anchor/noop.rs",
    "content": "use crate::storage::traits::{StorageResult, TrustAnchor};\nuse async_trait::async_trait;\nuse xaynet_core::mask::Model;\n\n#[derive(Clone)]\npub struct NoOp;\n\n#[async_trait]\nimpl TrustAnchor for NoOp {\n    async fn publish_proof(&mut self, _global_model: &Model) -> StorageResult<()> {\n        Ok(())\n    }\n\n    async fn is_ready(&mut self) -> StorageResult<()> {\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "scripts/bump_version.sh",
    "content": "#!/usr/bin/env bash\n\nWORKDIR=\"$(git rev-parse --show-toplevel)\"\n# Save the git HEAD before running the script\nHEAD=\n# Latest tag\nPREV_TAG=\n# Commit that corresponds to the latest tag\nPREV_TAGGED_COMMIT=\n\n# Latest version numbers\nPREV_MAJOR=\nPREV_MINOR=\nPREV_PATCH=\n\n# New version numbers\nMAJOR=\nMINOR=\nPATCH=\n\n# Return the new version number\nversion() {\n    echo \"${MAJOR}.${MINOR}.${PATCH}\"\n}\n\n# Return the previous version number\nprev_version() {\n    echo \"${PREV_MAJOR}.${PREV_MINOR}.${PREV_PATCH}\"\n}\n\n\n# Find and parse the latest tag, and populate the global variables\nfetch_latest_version() {\n    local tag_regex='^v[0-9]\\.[0-9]\\.[0-9]$'\n\n    PREV_TAG=$(git describe --tags --abbrev=0)\n    PREV_TAGGED_COMMIT=$(git rev-list -n 1 \"${PREV_TAG}\")\n    echo \"latest tag found: ${PREV_TAG} (commit ${PREV_TAGGED_COMMIT})\"\n\n    if ! [[ ${PREV_TAG} =~ ${tag_regex} ]] ; then\n        echo \"error: invalid tag ${PREV_TAG}\" >&2\n        exit 1\n    fi\n\n    PREV_MAJOR=${PREV_TAG:1:1}\n    PREV_MINOR=${PREV_TAG:3:1}\n    PREV_PATCH=${PREV_TAG:5:1}\n\n    MAJOR=${PREV_MAJOR}\n    MINOR=${PREV_MINOR}\n    PATCH=${PREV_PATCH}\n}\n\n# Check that the working directory doesn't have un-committed changes. If it\n# does, error out.\ncheck_workdir_is_clean() {\n    if [ -z \"$(git status --untracked-files=no --porcelain)\" ]; then\n        echo \"git working directory is clean, continuing\"\n    else\n        echo \"git working directory is dirty, aborting\" 2>&1\n        exit 1\n    fi\n}\n\n# A helper function for interactively asking the users whether the script\n# should continue or not\nask_yes_or_no() {\n    select yn in \"Yes\" \"No\"; do\n        case $yn in\n            Yes )\n                echo \"continuing\"\n                break\n                ;;\n            No )\n                echo \"aborting\" 2>&1\n                exit 1\n                ;;\n        esac\n    done\n}\n\n# Print a message explaining what the script does, and how to undo the changes\n# if necessary\ndisclaimer() {\n    cat << EOF\n***********************************\n        IMPORTANT\n***********************************\n\nThis script modifies the git commit history. If anything goes wrong, or if you\nhave a doubt, you can always rollback to where this script start by running:\n\n    git reset --hard ${HEAD}\n\nThis script will:\n\n1. Find the latest tag on the current branch\n2. Make sure that the CHANGELOG.md file was updated since this tag was pushed\n3. Update the version number in various files in the repository, and commit\n   these changes\n4. Create a new annotated tag\n\nEOF\n}\n\n# Print a help message\nusage() {\n    cat << EOF\n./bump_version.sh [-h|--help] [-M|--major] [-m|--minor] [-p|--patch]\n\nbump_version.sh is used for bumping the previous version number and creating a\nnew tag.\n\nOPTIONS:\n\n    -h|--help:\n        print this help message\n\n    -M|--major:\n        bump the major version number\n\n    -m|--minor:\n        bump the minor version number\n\n    -p|--patch:\n        bump the patch version number\nEOF\n}\n\n# Make sure the CHANGELOG was updated, and ask the user to double check the\n# changes\ncheck_changelog_was_updated() {\n    diff() {\n        git --no-pager diff \"${PREV_TAGGED_COMMIT}\" HEAD CHANGELOG.md\n    }\n\n    if [ \"$(diff | wc -l)\" -eq 0 ] ; then\n        echo \"error: the CHANGELOG has not been updated since ${PREV_TAG}\" 2>&1\n        echo \"Do you want to continue anyway?\"\n        ask_yes_or_no\n    else\n        echo \"The CHANGELOG has been updated since ${PREV_TAG}\"\n        diff\n        echo \"Does the change above look correct for v$(version)\"\n        ask_yes_or_no\n    fi\n}\n\n# Small helper to update the version number in a file, using sed\nset_version_in_file() {\n    local sed_expr=${1}\n    local file=${2}\n\n    echo \"Setting version to $(version) in ${file}\"\n    sed -i \"${sed_expr}\" \"${file}\"\n}\n\n# Update the version numbers in various files, and ask confirmation from the\n# user before committing these changes.\nupdate_versions() {\n    set_version_in_file 's/^version = \".*\"$/version = \"'\"$(version)\"'\"/g' rust/Cargo.toml\n    (cd rust && cargo update -v)\n\n    if [ \"$(git --no-pager diff | wc -l)\" -eq 0 ] ; then\n        echo \"No changes were made, it seems that the version files were already updated to $(version)\"\n        echo \"Do you want to continue?\"\n        ask_yes_or_no\n    else\n        git --no-pager diff\n        echo \"Do you want to commit the changes above?\"\n        ask_yes_or_no\n        git add rust/Cargo.toml rust/Cargo.lock\n        git commit -m \"bump version $(prev_version) -> $(version)\"\n    fi\n}\n\ncargo_publish_dry_run() {\n    echo \"Checking that the Rust package is ready to be published\"\n    cargo publish --dry-run\n    echo \"The Rust package is ready to be published\"\n}\n\nmain() {\n    local bump_major=false\n    local bump_minor=false\n    local bump_patch=false\n\n    if [ \"$#\" -eq 0 ]; then\n        usage\n        exit 0\n    fi\n\n    cd \"${WORKDIR}\"\n    check_workdir_is_clean\n\n    while (( \"$#\" )); do\n        case \"$1\" in\n            -M|--major)\n                bump_major=true\n                shift\n                ;;\n            -m|--minor)\n                bump_minor=true\n                shift\n                ;;\n            -p|--patch)\n                bump_patch=true\n                shift\n                ;;\n            -h|--help|help)\n                usage\n                exit 0\n                ;;\n            *)\n                echo \"error: unsupported argument \\\"$1\\\"\" 2>&1\n                usage\n                exit 1\n                ;;\n        esac\n    done\n\n\n    HEAD=$(git rev-parse HEAD)\n    disclaimer\n    fetch_latest_version\n\n    if [ \"$bump_major\" = true ] ; then\n        MAJOR=$((PREV_MAJOR + 1))\n    fi\n\n    if [ \"$bump_minor\" = true ] ; then\n        MINOR=$((PREV_MINOR + 1))\n    fi\n\n    if [ \"$bump_patch\" = true ] ; then\n        PATCH=$((PREV_PATCH + 1))\n    fi\n\n    if [ \"$(prev_version)\" = \"$(version)\" ] ; then\n        echo \"error: new version is the same than previous version\" 2>&1\n        exit 1\n    fi\n\n    echo \"Bumping version from $(prev_version) to $(version)\"\n    ask_yes_or_no\n\n    update_versions\n    check_changelog_was_updated\n\n    (cd rust && cargo_publish_dry_run)\n\n    echo \"Tagging ${HEAD} as \\\"v$(version)\\\"\"\n    git tag -a \"v$(version)\" -m \"release v$(version)\"\n\n    echo \"Done!\"\n\n    cat << EOF\nYou can now publish the Rust package:\n\n    (cd rust && cargo publish)\n\nFinally: push the new tag to Github:\n\n    git push <remote> master --tags\n\nEOF\n}\n\nset -e\nmain \"$@\"\n"
  }
]